fix much stuff
This commit is contained in:
@@ -21,13 +21,15 @@ namespace {
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
mlir::Value weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
if (!weightArg)
|
||||
return false;
|
||||
bool found = false;
|
||||
parentOp.walk([&](mlir::Operation* op) {
|
||||
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
|
||||
found |= mvmOp.getWeight() == weightArg;
|
||||
found |= mvmOp.getWeight() == *weightArg;
|
||||
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
||||
found |= vmmOp.getWeight() == weightArg;
|
||||
found |= vmmOp.getWeight() == *weightArg;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
@@ -38,7 +40,8 @@ void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpO
|
||||
llvm::SmallSet<unsigned, 8> visited;
|
||||
auto walkWeight = [&](mlir::Value weight) {
|
||||
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
|
||||
if (parentOp.getWeightArgument(weightIndex) != weight)
|
||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
if (!weightArg || *weightArg != weight)
|
||||
continue;
|
||||
if (visited.insert(weightIndex).second)
|
||||
callback(parentOp->getOpOperand(weightIndex));
|
||||
|
||||
Reference in New Issue
Block a user