cf93caecd5
Validate Operations / validate-operations (push) Has been cancelled
fix codegen symlinks overwrite remove deprecated pim memcp_hd_batch op
273 lines
11 KiB
C++
273 lines
11 KiB
C++
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
|
|
#include "llvm/ADT/SmallPtrSet.h"
|
|
#include "llvm/ADT/SmallSet.h"
|
|
|
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
namespace onnx_mlir {
|
|
|
|
bool hasWeightAlways(mlir::Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
|
|
|
void markWeightAlways(mlir::Operation* op) {
|
|
assert(op && "expected valid op");
|
|
op->setAttr(PimWeightAlwaysAttrName, mlir::UnitAttr::get(op->getContext()));
|
|
}
|
|
|
|
namespace {
|
|
|
|
CompiledIndexExpr makeConstantExpr(int64_t constant) {
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
|
expr.constant = constant;
|
|
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::move(expr)));
|
|
}
|
|
|
|
CompiledIndexExpr makeBinaryExpr(CompiledIndexExprNode::Kind kind, CompiledIndexExpr lhs, CompiledIndexExpr rhs) {
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = kind;
|
|
expr.operands = {std::move(lhs), std::move(rhs)};
|
|
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::move(expr)));
|
|
}
|
|
|
|
CompiledIndexExpr addExpr(CompiledIndexExpr lhs, CompiledIndexExpr rhs) {
|
|
return makeBinaryExpr(CompiledIndexExprNode::Kind::Add, std::move(lhs), std::move(rhs));
|
|
}
|
|
|
|
CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
|
|
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
|
|
}
|
|
|
|
template <typename VMMOpTy, typename ParentOpTy>
|
|
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
|
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
|
if (!weightArg)
|
|
return false;
|
|
bool found = false;
|
|
parentOp.walk([&](mlir::Operation* op) {
|
|
if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
|
found |= vmmOp.getWeight() == *weightArg;
|
|
});
|
|
return found;
|
|
}
|
|
|
|
template <typename VMMOpTy, typename ParentOpTy>
|
|
void walkVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
|
auto weights = parentOp.getWeights();
|
|
llvm::SmallSet<unsigned, 8> visited;
|
|
auto walkWeight = [&](mlir::Value weight) {
|
|
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
|
|
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
|
if (!weightArg || *weightArg != weight)
|
|
continue;
|
|
if (visited.insert(weightIndex).second)
|
|
callback(parentOp->getOpOperand(weightIndex));
|
|
break;
|
|
}
|
|
};
|
|
|
|
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
|
|
mlir::Operation* user = use.getOwner();
|
|
unsigned operandIndex = use.getOperandNumber();
|
|
|
|
auto computeOp = mlir::dyn_cast<spatial::SpatCompute>(user);
|
|
if (!computeOp || operandIndex >= computeOp.getWeights().size())
|
|
return false;
|
|
|
|
return hasVmmWeightUse<spatial::SpatVMMOp>(computeOp, operandIndex);
|
|
}
|
|
|
|
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
|
|
llvm::SmallPtrSet<mlir::Value, 8> visited;
|
|
auto walkUses = [&](mlir::Value currentValue, auto& self) -> bool {
|
|
if (!visited.insert(currentValue).second)
|
|
return true;
|
|
if (currentValue.use_empty())
|
|
return false;
|
|
|
|
return llvm::all_of(currentValue.getUses(), [&](mlir::OpOperand& use) {
|
|
if (isSpatialMvmVmmWeightUse(use))
|
|
return true;
|
|
|
|
mlir::Operation* user = use.getOwner();
|
|
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(user))
|
|
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
|
|
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(user))
|
|
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
|
|
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
|
|
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
|
|
if (auto transposeOp = mlir::dyn_cast<mlir::linalg::TransposeOp>(user))
|
|
return transposeOp.getInput() == currentValue && self(transposeOp.getResult()[0], self);
|
|
|
|
return false;
|
|
});
|
|
};
|
|
|
|
return walkUses(value, walkUses);
|
|
}
|
|
|
|
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
|
assert(root && "expected valid root op");
|
|
root->walk([&](pim::PimCoreOp coreOp) {
|
|
coreOp.walk([&](pim::PimVMMOp vmmOp) {
|
|
if (auto weightIndex = resolveWeightIndex(coreOp.getOperation(), vmmOp.getWeight()))
|
|
callback(coreOp->getOpOperand(*weightIndex));
|
|
});
|
|
});
|
|
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
|
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
|
|
if (auto weightIndex = resolveWeightIndex(coreBatchOp.getOperation(), vmmOp.getWeight()))
|
|
callback(coreBatchOp->getOpOperand(*weightIndex));
|
|
});
|
|
});
|
|
}
|
|
|
|
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) {
|
|
weight = stripMemRefAddressingOps(weight);
|
|
|
|
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
|
|
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
|
if (coreOp.getWeightArgument(weightIndex) == weight)
|
|
return weightIndex;
|
|
return std::nullopt;
|
|
}
|
|
|
|
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
|
|
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
|
|
if (coreBatchOp.getWeightArgument(weightIndex) == weight)
|
|
return weightIndex;
|
|
return std::nullopt;
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
llvm::FailureOr<ResolvedWeightView>
|
|
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge) {
|
|
llvm::SmallVector<mlir::Operation*> viewOps;
|
|
mlir::Value current = weight;
|
|
|
|
while (true) {
|
|
if (auto defOp = current.getDefiningOp()) {
|
|
if (auto getGlobalOp = mlir::dyn_cast<mlir::memref::GetGlobalOp>(defOp)) {
|
|
auto moduleOp = weightOwner ? weightOwner->getParentOfType<mlir::ModuleOp>() : mlir::ModuleOp {};
|
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
|
if (!globalOp || !globalOp.getInitialValue())
|
|
return mlir::failure();
|
|
|
|
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
|
if (!denseAttr)
|
|
return mlir::failure();
|
|
|
|
ResolvedWeightView view;
|
|
view.globalOp = globalOp;
|
|
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
|
view.strides = computeRowMajorStrides(view.shape);
|
|
|
|
CompiledIndexExpr offsetExpr = makeConstantExpr(0);
|
|
for (mlir::Operation* viewOp : llvm::reverse(viewOps)) {
|
|
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(viewOp)) {
|
|
llvm::SmallVector<int64_t> nextStrides;
|
|
nextStrides.reserve(subview.getMixedOffsets().size());
|
|
for (auto [offset, stride, sourceStride] :
|
|
llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) {
|
|
CompiledIndexExpr offsetValue = makeConstantExpr(0);
|
|
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
|
|
auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
|
|
if (!intAttr)
|
|
return mlir::failure();
|
|
offsetValue = makeConstantExpr(intAttr.getInt());
|
|
}
|
|
else if (auto value = mlir::dyn_cast<mlir::Value>(offset)) {
|
|
auto compiledOffset = compileIndexExpr(value);
|
|
if (failed(compiledOffset))
|
|
return mlir::failure();
|
|
offsetValue = *compiledOffset;
|
|
}
|
|
else {
|
|
return mlir::failure();
|
|
}
|
|
offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride));
|
|
nextStrides.push_back(stride * sourceStride);
|
|
}
|
|
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
|
view.strides = std::move(nextStrides);
|
|
continue;
|
|
}
|
|
|
|
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(viewOp)) {
|
|
if (view.strides != computeRowMajorStrides(view.shape))
|
|
return mlir::failure();
|
|
auto resultType = mlir::cast<mlir::MemRefType>(collapse.getResult().getType());
|
|
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
|
view.strides = computeRowMajorStrides(view.shape);
|
|
continue;
|
|
}
|
|
|
|
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(viewOp)) {
|
|
if (view.strides != computeRowMajorStrides(view.shape))
|
|
return mlir::failure();
|
|
auto resultType = mlir::cast<mlir::MemRefType>(expand.getResult().getType());
|
|
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
|
view.strides = computeRowMajorStrides(view.shape);
|
|
}
|
|
}
|
|
|
|
auto resolvedOffset = offsetExpr.evaluate(knowledge);
|
|
if (failed(resolvedOffset))
|
|
return mlir::failure();
|
|
view.offset = *resolvedOffset;
|
|
return view;
|
|
}
|
|
|
|
if (mlir::isa<mlir::memref::SubViewOp, mlir::memref::CollapseShapeOp, mlir::memref::ExpandShapeOp>(defOp)) {
|
|
viewOps.push_back(defOp);
|
|
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp))
|
|
current = subview.getSource();
|
|
else if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp))
|
|
current = collapse.getSrc();
|
|
else
|
|
current = mlir::cast<mlir::memref::ExpandShapeOp>(defOp).getSrc();
|
|
continue;
|
|
}
|
|
|
|
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(defOp)) {
|
|
current = castOp.getSource();
|
|
continue;
|
|
}
|
|
|
|
return mlir::failure();
|
|
}
|
|
|
|
auto weightIndex = resolveWeightIndex(weightOwner, current);
|
|
if (!weightIndex)
|
|
return mlir::failure();
|
|
|
|
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
|
|
current = coreOp.getWeights()[*weightIndex];
|
|
continue;
|
|
}
|
|
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
|
|
current = coreBatchOp.getWeights()[*weightIndex];
|
|
continue;
|
|
}
|
|
return mlir::failure();
|
|
}
|
|
}
|
|
|
|
} // namespace onnx_mlir
|