#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Value.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/STLExtras.h" #include "WeightMaterialization.hpp" #include "ShapeTilingUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { bool isWeightLikeComputeOperand(Value value) { auto rankedType = dyn_cast(value.getType()); if (!rankedType || !isMatrixShape(rankedType.getShape())) return false; llvm::SmallPtrSet visited; while (auto* definingOp = value.getDefiningOp()) { if (!visited.insert(definingOp).second) return false; if (hasWeightAlways(definingOp)) return true; if (auto extractSliceOp = dyn_cast(definingOp)) { value = extractSliceOp.getSource(); continue; } if (auto expandShapeOp = dyn_cast(definingOp)) { value = expandShapeOp.getSrc(); continue; } if (auto collapseShapeOp = dyn_cast(definingOp)) { value = collapseShapeOp.getSrc(); continue; } if (auto transposeOp = dyn_cast(definingOp)) { value = transposeOp.getData(); continue; } return false; } return false; } FailureOr materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) { if (auto mapped = mapper.lookupOrNull(value)) return cast(mapped); Operation* definingOp = value.getDefiningOp(); if (!definingOp) return failure(); if (isa(definingOp)) { auto tensorType = dyn_cast(value.getType()); if (!tensorType || !tensorType.hasStaticShape()) return failure(); SmallVector offsets(tensorType.getRank(), rewriter.getIndexAttr(0)); SmallVector sizes; SmallVector strides(tensorType.getRank(), rewriter.getIndexAttr(1)); sizes.reserve(tensorType.getRank()); for (int64_t dim : tensorType.getShape()) sizes.push_back(rewriter.getIndexAttr(dim)); auto referencedValue = tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides); mapper.map(value, referencedValue.getResult()); return referencedValue.getResult(); } if (!isa(definingOp)) return failure(); IRMapping localMapper; for (Value operand : definingOp->getOperands()) { if (auto mapped = mapper.lookupOrNull(operand)) { localMapper.map(operand, cast(mapped)); continue; } if (isWeightLikeComputeOperand(operand)) { auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper); if (failed(clonedOperand)) return failure(); localMapper.map(operand, *clonedOperand); continue; } localMapper.map(operand, operand); } Operation* clonedOp = rewriter.clone(*definingOp, localMapper); for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults())) mapper.map(oldResult, newResult); auto mapped = mapper.lookupOrNull(value); if (!mapped) return failure(); return cast(mapped); } } // namespace onnx_mlir