refactorone
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-20 19:06:41 +02:00
parent f56c4159b5
commit a50e77ff38
50 changed files with 3420 additions and 1187 deletions
+25 -16
View File
@@ -21,12 +21,13 @@ namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
mlir::Value weightArg = parentOp.getWeightArgument(weightIndex);
bool found = false;
parentOp.walk([&](mlir::Operation* op) {
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
found |= mvmOp.getWeightIndex() == weightIndex;
found |= mvmOp.getWeight() == weightArg;
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeightIndex() == weightIndex;
found |= vmmOp.getWeight() == weightArg;
});
return found;
}
@@ -35,13 +36,18 @@ template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited;
auto walkWeightIndex = [&](unsigned weightIndex) {
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
auto walkWeight = [&](mlir::Value weight) {
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
if (parentOp.getWeightArgument(weightIndex) != weight)
continue;
if (visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
break;
}
};
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](MVMOpTy op) { walkWeight(op.getWeight()); });
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
}
} // namespace
@@ -90,18 +96,21 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) {
coreOp.walk([&](pim::PimVMMOp vmmOp) {
auto weights = coreOp.getWeights();
unsigned weightIndex = vmmOp.getWeightIndex();
if (weightIndex < weights.size())
callback(coreOp->getOpOperand(weightIndex));
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
callback(coreOp->getOpOperand(weightIndex));
break;
}
});
});
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
auto weights = coreBatchOp.getWeights();
for (auto weight : weights)
for (mlir::OpOperand& use : weight.getUses())
if (use.getOwner() == coreBatchOp.getOperation())
callback(use);
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
callback(coreBatchOp->getOpOperand(weightIndex));
break;
}
});
});
}