This commit is contained in:
@@ -21,12 +21,13 @@ namespace {
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
mlir::Value weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
bool found = false;
|
||||
parentOp.walk([&](mlir::Operation* op) {
|
||||
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
|
||||
found |= mvmOp.getWeightIndex() == weightIndex;
|
||||
found |= mvmOp.getWeight() == weightArg;
|
||||
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
||||
found |= vmmOp.getWeightIndex() == weightIndex;
|
||||
found |= vmmOp.getWeight() == weightArg;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
@@ -35,13 +36,18 @@ template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||
auto weights = parentOp.getWeights();
|
||||
llvm::SmallSet<unsigned, 8> visited;
|
||||
auto walkWeightIndex = [&](unsigned weightIndex) {
|
||||
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
|
||||
callback(parentOp->getOpOperand(weightIndex));
|
||||
auto walkWeight = [&](mlir::Value weight) {
|
||||
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
|
||||
if (parentOp.getWeightArgument(weightIndex) != weight)
|
||||
continue;
|
||||
if (visited.insert(weightIndex).second)
|
||||
callback(parentOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
parentOp.walk([&](MVMOpTy op) { walkWeight(op.getWeight()); });
|
||||
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -90,18 +96,21 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
|
||||
assert(root && "expected valid root op");
|
||||
root->walk([&](pim::PimCoreOp coreOp) {
|
||||
coreOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
auto weights = coreOp.getWeights();
|
||||
unsigned weightIndex = vmmOp.getWeightIndex();
|
||||
if (weightIndex < weights.size())
|
||||
callback(coreOp->getOpOperand(weightIndex));
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
|
||||
callback(coreOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
});
|
||||
});
|
||||
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto weights = coreBatchOp.getWeights();
|
||||
for (auto weight : weights)
|
||||
for (mlir::OpOperand& use : weight.getUses())
|
||||
if (use.getOwner() == coreBatchOp.getOperation())
|
||||
callback(use);
|
||||
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
|
||||
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
|
||||
callback(coreBatchOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user