#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" namespace onnx_mlir { bool hasWeightAlways(mlir::Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; } void markWeightAlways(mlir::Operation* op) { assert(op && "expected valid op"); op->setAttr(PimWeightAlwaysAttrName, mlir::UnitAttr::get(op->getContext())); } namespace { template bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { mlir::Value weightArg = parentOp.getWeightArgument(weightIndex); bool found = false; parentOp.walk([&](mlir::Operation* op) { if (auto mvmOp = mlir::dyn_cast(op)) found |= mvmOp.getWeight() == weightArg; else if (auto vmmOp = mlir::dyn_cast(op)) found |= vmmOp.getWeight() == weightArg; }); return found; } template void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref callback) { auto weights = parentOp.getWeights(); llvm::SmallSet visited; auto walkWeight = [&](mlir::Value weight) { for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) { if (parentOp.getWeightArgument(weightIndex) != weight) continue; if (visited.insert(weightIndex).second) callback(parentOp->getOpOperand(weightIndex)); break; } }; parentOp.walk([&](MVMOpTy op) { walkWeight(op.getWeight()); }); parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); }); } } // namespace bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) { mlir::Operation* user = use.getOwner(); unsigned operandIndex = use.getOperandNumber(); auto computeOp = mlir::dyn_cast(user); if (!computeOp || operandIndex >= computeOp.getWeights().size()) return false; return hasMvmVmmWeightUse(computeOp, operandIndex); } bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) { llvm::SmallPtrSet visited; auto walkUses = [&](mlir::Value currentValue, auto& self) -> bool { if (!visited.insert(currentValue).second) return true; if (currentValue.use_empty()) return false; return llvm::all_of(currentValue.getUses(), [&](mlir::OpOperand& use) { if (isSpatialMvmVmmWeightUse(use)) return true; mlir::Operation* user = use.getOwner(); if (auto extractSliceOp = mlir::dyn_cast(user)) return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self); if (auto expandShapeOp = mlir::dyn_cast(user)) return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self); if (auto collapseShapeOp = mlir::dyn_cast(user)) return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self); if (auto transposeOp = mlir::dyn_cast(user)) return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self); return false; }); }; return walkUses(value, walkUses); } void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref callback) { assert(root && "expected valid root op"); root->walk([&](pim::PimCoreOp coreOp) { coreOp.walk([&](pim::PimVMMOp vmmOp) { for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) { callback(coreOp->getOpOperand(weightIndex)); break; } }); }); root->walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOp.walk([&](pim::PimVMMOp vmmOp) { for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex) if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) { callback(coreBatchOp->getOpOperand(weightIndex)); break; } }); }); } } // namespace onnx_mlir