#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" namespace onnx_mlir { bool hasWeightAlways(mlir::Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; } void markWeightAlways(mlir::Operation* op) { assert(op && "expected valid op"); op->setAttr(PimWeightAlwaysAttrName, mlir::UnitAttr::get(op->getContext())); } namespace { CompiledIndexExpr makeConstantExpr(int64_t constant) { CompiledIndexExprNode expr; expr.kind = CompiledIndexExprNode::Kind::Constant; expr.constant = constant; return CompiledIndexExpr(std::make_shared(std::move(expr))); } CompiledIndexExpr makeBinaryExpr(CompiledIndexExprNode::Kind kind, CompiledIndexExpr lhs, CompiledIndexExpr rhs) { CompiledIndexExprNode expr; expr.kind = kind; expr.operands = {std::move(lhs), std::move(rhs)}; return CompiledIndexExpr(std::make_shared(std::move(expr))); } CompiledIndexExpr addExpr(CompiledIndexExpr lhs, CompiledIndexExpr rhs) { return makeBinaryExpr(CompiledIndexExprNode::Kind::Add, std::move(lhs), std::move(rhs)); } CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) { return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs)); } template bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { auto weightArg = parentOp.getWeightArgument(weightIndex); if (!weightArg) return false; bool found = false; parentOp.walk([&](mlir::Operation* op) { if (auto vmmOp = mlir::dyn_cast(op)) found |= vmmOp.getWeight() == *weightArg; }); return found; } template void walkVmmWeightUses(ParentOpTy parentOp, llvm::function_ref callback) { auto weights = parentOp.getWeights(); llvm::SmallSet visited; auto walkWeight = [&](mlir::Value weight) { for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) { auto weightArg = parentOp.getWeightArgument(weightIndex); if (!weightArg || *weightArg != weight) continue; if (visited.insert(weightIndex).second) callback(parentOp->getOpOperand(weightIndex)); break; } }; parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); }); } } // namespace bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) { mlir::Operation* user = use.getOwner(); unsigned operandIndex = use.getOperandNumber(); auto computeOp = mlir::dyn_cast(user); if (!computeOp || operandIndex >= computeOp.getWeights().size()) return false; return hasVmmWeightUse(computeOp, operandIndex); } bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) { llvm::SmallPtrSet visited; auto walkUses = [&](mlir::Value currentValue, auto& self) -> bool { if (!visited.insert(currentValue).second) return true; if (currentValue.use_empty()) return false; return llvm::all_of(currentValue.getUses(), [&](mlir::OpOperand& use) { if (isSpatialMvmVmmWeightUse(use)) return true; mlir::Operation* user = use.getOwner(); if (auto extractSliceOp = mlir::dyn_cast(user)) return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self); if (auto expandShapeOp = mlir::dyn_cast(user)) return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self); if (auto collapseShapeOp = mlir::dyn_cast(user)) return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self); if (auto transposeOp = mlir::dyn_cast(user)) return transposeOp.getInput() == currentValue && self(transposeOp.getResult()[0], self); return false; }); }; return walkUses(value, walkUses); } void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref callback) { assert(root && "expected valid root op"); root->walk([&](pim::PimCoreOp coreOp) { coreOp.walk([&](pim::PimVMMOp vmmOp) { if (auto weightIndex = resolveWeightIndex(coreOp.getOperation(), vmmOp.getWeight())) callback(coreOp->getOpOperand(*weightIndex)); }); }); root->walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOp.walk([&](pim::PimVMMOp vmmOp) { if (auto weightIndex = resolveWeightIndex(coreBatchOp.getOperation(), vmmOp.getWeight())) callback(coreBatchOp->getOpOperand(*weightIndex)); }); }); } std::optional resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) { weight = stripMemRefAddressingOps(weight); if (auto coreOp = mlir::dyn_cast_or_null(weightOwner)) { for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) if (coreOp.getWeightArgument(weightIndex) == weight) return weightIndex; return std::nullopt; } if (auto coreBatchOp = mlir::dyn_cast_or_null(weightOwner)) { for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex) if (coreBatchOp.getWeightArgument(weightIndex) == weight) return weightIndex; return std::nullopt; } return std::nullopt; } llvm::FailureOr resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge) { llvm::SmallVector viewOps; mlir::Value current = weight; while (true) { if (auto defOp = current.getDefiningOp()) { if (auto getGlobalOp = mlir::dyn_cast(defOp)) { auto moduleOp = weightOwner ? weightOwner->getParentOfType() : mlir::ModuleOp {}; auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); if (!globalOp || !globalOp.getInitialValue()) return mlir::failure(); auto denseAttr = mlir::dyn_cast(*globalOp.getInitialValue()); if (!denseAttr) return mlir::failure(); ResolvedWeightView view; view.globalOp = globalOp; view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end()); view.strides = computeRowMajorStrides(view.shape); CompiledIndexExpr offsetExpr = makeConstantExpr(0); for (mlir::Operation* viewOp : llvm::reverse(viewOps)) { if (auto subview = mlir::dyn_cast(viewOp)) { llvm::SmallVector nextStrides; nextStrides.reserve(subview.getMixedOffsets().size()); for (auto [offset, stride, sourceStride] : llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) { CompiledIndexExpr offsetValue = makeConstantExpr(0); if (auto attr = mlir::dyn_cast(offset)) { auto intAttr = mlir::dyn_cast(attr); if (!intAttr) return mlir::failure(); offsetValue = makeConstantExpr(intAttr.getInt()); } else if (auto value = mlir::dyn_cast(offset)) { auto compiledOffset = compileIndexExpr(value); if (failed(compiledOffset)) return mlir::failure(); offsetValue = *compiledOffset; } else { return mlir::failure(); } offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride)); nextStrides.push_back(stride * sourceStride); } view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end()); view.strides = std::move(nextStrides); continue; } if (auto collapse = mlir::dyn_cast(viewOp)) { if (view.strides != computeRowMajorStrides(view.shape)) return mlir::failure(); auto resultType = mlir::cast(collapse.getResult().getType()); view.shape.assign(resultType.getShape().begin(), resultType.getShape().end()); view.strides = computeRowMajorStrides(view.shape); continue; } if (auto expand = mlir::dyn_cast(viewOp)) { if (view.strides != computeRowMajorStrides(view.shape)) return mlir::failure(); auto resultType = mlir::cast(expand.getResult().getType()); view.shape.assign(resultType.getShape().begin(), resultType.getShape().end()); view.strides = computeRowMajorStrides(view.shape); } } auto resolvedOffset = offsetExpr.evaluate(knowledge); if (failed(resolvedOffset)) return mlir::failure(); view.offset = *resolvedOffset; return view; } if (mlir::isa(defOp)) { viewOps.push_back(defOp); if (auto subview = mlir::dyn_cast(defOp)) current = subview.getSource(); else if (auto collapse = mlir::dyn_cast(defOp)) current = collapse.getSrc(); else current = mlir::cast(defOp).getSrc(); continue; } if (auto castOp = mlir::dyn_cast(defOp)) { current = castOp.getSource(); continue; } return mlir::failure(); } auto weightIndex = resolveWeightIndex(weightOwner, current); if (!weightIndex) return mlir::failure(); if (auto coreOp = mlir::dyn_cast_or_null(weightOwner)) { current = coreOp.getWeights()[*weightIndex]; continue; } if (auto coreBatchOp = mlir::dyn_cast_or_null(weightOwner)) { current = coreBatchOp.getWeights()[*weightIndex]; continue; } return mlir::failure(); } } } // namespace onnx_mlir