Files
Raptor/src/PIM/Common/IR/WeightUtils.cpp
T
NiccoloN a50e77ff38
Validate Operations / validate-operations (push) Has been cancelled
refactorone
2026-05-20 19:06:41 +02:00

118 lines
4.3 KiB
C++

#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallSet.h"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
bool hasWeightAlways(mlir::Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
void markWeightAlways(mlir::Operation* op) {
assert(op && "expected valid op");
op->setAttr(PimWeightAlwaysAttrName, mlir::UnitAttr::get(op->getContext()));
}
namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
mlir::Value weightArg = parentOp.getWeightArgument(weightIndex);
bool found = false;
parentOp.walk([&](mlir::Operation* op) {
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
found |= mvmOp.getWeight() == weightArg;
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeight() == weightArg;
});
return found;
}
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited;
auto walkWeight = [&](mlir::Value weight) {
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
if (parentOp.getWeightArgument(weightIndex) != weight)
continue;
if (visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
break;
}
};
parentOp.walk([&](MVMOpTy op) { walkWeight(op.getWeight()); });
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
}
} // namespace
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
mlir::Operation* user = use.getOwner();
unsigned operandIndex = use.getOperandNumber();
auto computeOp = mlir::dyn_cast<spatial::SpatCompute>(user);
if (!computeOp || operandIndex >= computeOp.getWeights().size())
return false;
return hasMvmVmmWeightUse<spatial::SpatMVMOp, spatial::SpatVMMOp>(computeOp, operandIndex);
}
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
llvm::SmallPtrSet<mlir::Value, 8> visited;
auto walkUses = [&](mlir::Value currentValue, auto& self) -> bool {
if (!visited.insert(currentValue).second)
return true;
if (currentValue.use_empty())
return false;
return llvm::all_of(currentValue.getUses(), [&](mlir::OpOperand& use) {
if (isSpatialMvmVmmWeightUse(use))
return true;
mlir::Operation* user = use.getOwner();
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(user))
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(user))
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(user))
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
return false;
});
};
return walkUses(value, walkUses);
}
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) {
coreOp.walk([&](pim::PimVMMOp vmmOp) {
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
callback(coreOp->getOpOperand(weightIndex));
break;
}
});
});
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
callback(coreBatchOp->getOpOperand(weightIndex));
break;
}
});
});
}
} // namespace onnx_mlir