fix bufferization and weight emission after new gemm patterns
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -1,8 +1,13 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -19,6 +24,50 @@ void markWeightAlways(mlir::Operation* op) {
|
||||
|
||||
namespace {
|
||||
|
||||
CompiledIndexExpr makeConstantExpr(int64_t constant) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = constant;
|
||||
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::move(expr)));
|
||||
}
|
||||
|
||||
CompiledIndexExpr makeBinaryExpr(CompiledIndexExprNode::Kind kind, CompiledIndexExpr lhs, CompiledIndexExpr rhs) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = kind;
|
||||
expr.operands = {std::move(lhs), std::move(rhs)};
|
||||
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::move(expr)));
|
||||
}
|
||||
|
||||
CompiledIndexExpr addExpr(CompiledIndexExpr lhs, CompiledIndexExpr rhs) {
|
||||
return makeBinaryExpr(CompiledIndexExprNode::Kind::Add, std::move(lhs), std::move(rhs));
|
||||
}
|
||||
|
||||
CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
|
||||
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
|
||||
}
|
||||
|
||||
mlir::Value stripWeightViewOps(mlir::Value value) {
|
||||
while (true) {
|
||||
if (auto subviewOp = value.getDefiningOp<mlir::memref::SubViewOp>()) {
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto castOp = value.getDefiningOp<mlir::memref::CastOp>()) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = value.getDefiningOp<mlir::memref::CollapseShapeOp>()) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = value.getDefiningOp<mlir::memref::ExpandShapeOp>()) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
@@ -96,35 +145,31 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
|
||||
assert(root && "expected valid root op");
|
||||
root->walk([&](pim::PimCoreOp coreOp) {
|
||||
coreOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
|
||||
callback(coreOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
if (auto weightIndex = resolveWeightIndex(coreOp.getOperation(), vmmOp.getWeight()))
|
||||
callback(coreOp->getOpOperand(*weightIndex));
|
||||
});
|
||||
});
|
||||
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
|
||||
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
|
||||
callback(coreBatchOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
if (auto weightIndex = resolveWeightIndex(coreBatchOp.getOperation(), vmmOp.getWeight()))
|
||||
callback(coreBatchOp->getOpOperand(*weightIndex));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp) {
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) {
|
||||
weight = stripWeightViewOps(weight);
|
||||
|
||||
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
|
||||
if (coreOp.getWeightArgument(weightIndex) == weight)
|
||||
return weightIndex;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
|
||||
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
|
||||
if (coreBatchOp.getWeightArgument(weightIndex) == weight)
|
||||
return weightIndex;
|
||||
return std::nullopt;
|
||||
}
|
||||
@@ -132,4 +177,121 @@ std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::Pi
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp) {
|
||||
return resolveWeightIndex(weightOwner, vmmOp.getWeight());
|
||||
}
|
||||
|
||||
llvm::FailureOr<ResolvedWeightView>
|
||||
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge) {
|
||||
llvm::SmallVector<mlir::Operation*> viewOps;
|
||||
mlir::Value current = weight;
|
||||
|
||||
while (true) {
|
||||
if (auto defOp = current.getDefiningOp()) {
|
||||
if (auto getGlobalOp = mlir::dyn_cast<mlir::memref::GetGlobalOp>(defOp)) {
|
||||
auto moduleOp = weightOwner ? weightOwner->getParentOfType<mlir::ModuleOp>() : mlir::ModuleOp {};
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getInitialValue())
|
||||
return mlir::failure();
|
||||
|
||||
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return mlir::failure();
|
||||
|
||||
ResolvedWeightView view;
|
||||
view.globalOp = globalOp;
|
||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
|
||||
CompiledIndexExpr offsetExpr = makeConstantExpr(0);
|
||||
for (mlir::Operation* viewOp : llvm::reverse(viewOps)) {
|
||||
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(viewOp)) {
|
||||
llvm::SmallVector<int64_t> nextStrides;
|
||||
nextStrides.reserve(subview.getMixedOffsets().size());
|
||||
for (auto [offset, stride, sourceStride] :
|
||||
llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||
CompiledIndexExpr offsetValue = makeConstantExpr(0);
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
|
||||
auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
|
||||
if (!intAttr)
|
||||
return mlir::failure();
|
||||
offsetValue = makeConstantExpr(intAttr.getInt());
|
||||
}
|
||||
else if (auto value = mlir::dyn_cast<mlir::Value>(offset)) {
|
||||
auto compiledOffset = compileIndexExpr(value);
|
||||
if (failed(compiledOffset))
|
||||
return mlir::failure();
|
||||
offsetValue = *compiledOffset;
|
||||
}
|
||||
else {
|
||||
return mlir::failure();
|
||||
}
|
||||
offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride));
|
||||
nextStrides.push_back(stride * sourceStride);
|
||||
}
|
||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||
view.strides = std::move(nextStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return mlir::failure();
|
||||
auto resultType = mlir::cast<mlir::MemRefType>(collapse.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return mlir::failure();
|
||||
auto resultType = mlir::cast<mlir::MemRefType>(expand.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
}
|
||||
}
|
||||
|
||||
auto resolvedOffset = offsetExpr.evaluate(knowledge);
|
||||
if (failed(resolvedOffset))
|
||||
return mlir::failure();
|
||||
view.offset = *resolvedOffset;
|
||||
return view;
|
||||
}
|
||||
|
||||
if (mlir::isa<mlir::memref::SubViewOp, mlir::memref::CollapseShapeOp, mlir::memref::ExpandShapeOp>(defOp)) {
|
||||
viewOps.push_back(defOp);
|
||||
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp))
|
||||
current = subview.getSource();
|
||||
else if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp))
|
||||
current = collapse.getSrc();
|
||||
else
|
||||
current = mlir::cast<mlir::memref::ExpandShapeOp>(defOp).getSrc();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(defOp)) {
|
||||
current = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
auto weightIndex = resolveWeightIndex(weightOwner, current);
|
||||
if (!weightIndex)
|
||||
return mlir::failure();
|
||||
|
||||
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
|
||||
current = coreOp.getWeights()[*weightIndex];
|
||||
continue;
|
||||
}
|
||||
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
|
||||
current = coreBatchOp.getWeights()[*weightIndex];
|
||||
continue;
|
||||
}
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
@@ -10,12 +11,24 @@
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ResolvedWeightView {
|
||||
mlir::memref::GlobalOp globalOp;
|
||||
llvm::SmallVector<int64_t> shape;
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
|
||||
bool operator==(const ResolvedWeightView& other) const {
|
||||
return globalOp == other.globalOp && shape == other.shape && strides == other.strides && offset == other.offset;
|
||||
}
|
||||
};
|
||||
|
||||
bool hasWeightAlways(mlir::Operation* op);
|
||||
|
||||
/// Tags an op as producing a value that should stay materialized as a reusable
|
||||
@@ -32,24 +45,21 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
|
||||
/// passes can identify globals that must remain weight-backed.
|
||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight);
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp);
|
||||
llvm::FailureOr<ResolvedWeightView>
|
||||
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge = {});
|
||||
|
||||
template <typename CoreLikeOpTy>
|
||||
llvm::SmallVector<unsigned, 8> getUsedWeightIndices(CoreLikeOpTy coreLikeOp) {
|
||||
llvm::SmallVector<unsigned, 8> indices;
|
||||
auto addWeight = [&](mlir::Value weight) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreLikeOp.getWeights().size(); ++weightIndex) {
|
||||
if (coreLikeOp.getWeightArgument(weightIndex) != weight)
|
||||
continue;
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
coreLikeOp.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
|
||||
coreLikeOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
auto weightIndex = resolveWeightIndex(coreLikeOp.getOperation(), vmmOp.getWeight());
|
||||
if (weightIndex && !llvm::is_contained(indices, *weightIndex))
|
||||
indices.push_back(*weightIndex);
|
||||
});
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user