fix bufferization and weight emission after new gemm patterns
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
@@ -10,12 +11,24 @@
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ResolvedWeightView {
|
||||
mlir::memref::GlobalOp globalOp;
|
||||
llvm::SmallVector<int64_t> shape;
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
|
||||
bool operator==(const ResolvedWeightView& other) const {
|
||||
return globalOp == other.globalOp && shape == other.shape && strides == other.strides && offset == other.offset;
|
||||
}
|
||||
};
|
||||
|
||||
bool hasWeightAlways(mlir::Operation* op);
|
||||
|
||||
/// Tags an op as producing a value that should stay materialized as a reusable
|
||||
@@ -32,24 +45,21 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
|
||||
/// passes can identify globals that must remain weight-backed.
|
||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight);
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp);
|
||||
llvm::FailureOr<ResolvedWeightView>
|
||||
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge = {});
|
||||
|
||||
template <typename CoreLikeOpTy>
|
||||
llvm::SmallVector<unsigned, 8> getUsedWeightIndices(CoreLikeOpTy coreLikeOp) {
|
||||
llvm::SmallVector<unsigned, 8> indices;
|
||||
auto addWeight = [&](mlir::Value weight) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreLikeOp.getWeights().size(); ++weightIndex) {
|
||||
if (coreLikeOp.getWeightArgument(weightIndex) != weight)
|
||||
continue;
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
coreLikeOp.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
|
||||
coreLikeOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
auto weightIndex = resolveWeightIndex(coreLikeOp.getOperation(), vmmOp.getWeight());
|
||||
if (weightIndex && !llvm::is_contained(indices, *weightIndex))
|
||||
indices.push_back(*weightIndex);
|
||||
});
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user