Refactor PIM/Common (splitting in files, adding helpers, adding brief
Some checks failed
Validate Operations / validate-operations (push) Failing after 18m36s
Some checks failed
Validate Operations / validate-operations (push) Failing after 18m36s
docs)
This commit is contained in:
101
src/PIM/Common/IR/WeightUtils.cpp
Normal file
101
src/PIM/Common/IR/WeightUtils.cpp
Normal file
@@ -0,0 +1,101 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool hasWeightAlways(mlir::Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
||||
|
||||
void markWeightAlways(mlir::Operation* op) {
|
||||
assert(op && "expected valid op");
|
||||
op->setAttr(PimWeightAlwaysAttrName, mlir::UnitAttr::get(op->getContext()));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
bool found = false;
|
||||
parentOp.walk([&](mlir::Operation* op) {
|
||||
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
|
||||
found |= mvmOp.getWeightIndex() == weightIndex;
|
||||
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
||||
found |= vmmOp.getWeightIndex() == weightIndex;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||
auto weights = parentOp.getWeights();
|
||||
llvm::SmallSet<unsigned, 8> visited;
|
||||
auto walkWeightIndex = [&](unsigned weightIndex) {
|
||||
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
|
||||
callback(parentOp->getOpOperand(weightIndex));
|
||||
};
|
||||
|
||||
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
|
||||
mlir::Operation* user = use.getOwner();
|
||||
unsigned operandIndex = use.getOperandNumber();
|
||||
|
||||
auto computeOp = mlir::dyn_cast<spatial::SpatCompute>(user);
|
||||
if (!computeOp || operandIndex >= computeOp.getWeights().size())
|
||||
return false;
|
||||
|
||||
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
|
||||
}
|
||||
|
||||
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
|
||||
llvm::SmallPtrSet<mlir::Value, 8> visited;
|
||||
auto walkUses = [&](mlir::Value currentValue, auto& self) -> bool {
|
||||
if (!visited.insert(currentValue).second)
|
||||
return true;
|
||||
if (currentValue.use_empty())
|
||||
return false;
|
||||
|
||||
return llvm::all_of(currentValue.getUses(), [&](mlir::OpOperand& use) {
|
||||
if (isSpatialMvmVmmWeightUse(use))
|
||||
return true;
|
||||
|
||||
mlir::Operation* user = use.getOwner();
|
||||
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(user))
|
||||
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
|
||||
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(user))
|
||||
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
|
||||
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
|
||||
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
|
||||
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(user))
|
||||
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
|
||||
|
||||
return false;
|
||||
});
|
||||
};
|
||||
|
||||
return walkUses(value, walkUses);
|
||||
}
|
||||
|
||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||
assert(root && "expected valid root op");
|
||||
root->walk([&](pim::PimCoreOp coreOp) { walkMvmVmmWeightUses<pim::PimMVMOp, pim::PimVMMOp>(coreOp, callback); });
|
||||
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto weights = coreBatchOp.getWeights();
|
||||
for (auto weight : weights)
|
||||
for (mlir::OpOperand& use : weight.getUses())
|
||||
if (use.getOwner() == coreBatchOp.getOperation())
|
||||
callback(use);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user