diff --git a/src/PIM/Conversion/ONNXToSpatial/Common.hpp b/src/PIM/Conversion/ONNXToSpatial/Common.hpp index 351a54b..da5626e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common.hpp @@ -1,16 +1,19 @@ #pragma once +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ValueRange.h" -#include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include #include #include +#include "llvm/ADT/SmallPtrSet.h" + +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -102,6 +105,42 @@ inline auto getTensorShape(mlir::Value tensor) { return mlir::cast(tensor.getType()).getShape(); } +inline bool isWeightLikeComputeOperand(mlir::Value value) { + auto rankedType = mlir::dyn_cast(value.getType()); + if (!rankedType || !isMatrixShape(rankedType.getShape())) + return false; + + llvm::SmallPtrSet visited; + + while (auto* definingOp = value.getDefiningOp()) { + if (!visited.insert(definingOp).second) + return false; + if (hasWeightAlways(definingOp)) + return true; + + if (auto extractSliceOp = mlir::dyn_cast(definingOp)) { + value = extractSliceOp.getSource(); + continue; + } + if (auto expandShapeOp = mlir::dyn_cast(definingOp)) { + value = expandShapeOp.getSrc(); + continue; + } + if (auto collapseShapeOp = mlir::dyn_cast(definingOp)) { + value = collapseShapeOp.getSrc(); + continue; + } + if (auto transposeOp = mlir::dyn_cast(definingOp)) { + value = transposeOp.getData(); + continue; + } + + return false; + } + + return false; +} + namespace detail { inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); } @@ -111,6 +150,11 @@ decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_seque return std::forward(fn)(block->getArgument(Is)...); } +template +decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef values, std::index_sequence) { + return std::forward(fn)(values[Is]...); +} + template using ValueArg = mlir::Value; diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index f6e56c6..908cae6 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -53,6 +53,7 @@ private: void annotateWeightsConstants(func::FuncOp funcOp) const; void encapsulateGlobalInstruction(func::FuncOp funcOp); void mergeTriviallyConnectedComputes(func::FuncOp funcOp); + LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp); }; } // namespace @@ -149,6 +150,12 @@ void ONNXToSpatialPass::runOnOperation() { annotateWeightsConstants(*entryFunc); encapsulateGlobalInstruction(*entryFunc); + + if (failed(promoteConstantInputsToWeights(*entryFunc))) { + signalPassFailure(); + return; + } + mergeTriviallyConnectedComputes(*entryFunc); // Dump to file for debug @@ -184,8 +191,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) { if (llvm::any_of( sources, [](auto source) { return isa_and_present(source.getDefiningOp()); })) { auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources); - llvm::SmallVector sourceTypes; - llvm::SmallVector sourceLoc; + SmallVector sourceTypes; + SmallVector sourceLoc; for (auto source : sources) { sourceTypes.push_back(source.getType()); sourceLoc.push_back(loc); @@ -206,6 +213,63 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) { return false; } +static FailureOr materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) { + if (auto mapped = mapper.lookupOrNull(value)) + return cast(mapped); + + Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return failure(); + + if (isa(definingOp)) { + auto tensorType = dyn_cast(value.getType()); + if (!tensorType || !tensorType.hasStaticShape()) + return failure(); + + SmallVector offsets(tensorType.getRank(), rewriter.getIndexAttr(0)); + SmallVector sizes; + SmallVector strides(tensorType.getRank(), rewriter.getIndexAttr(1)); + sizes.reserve(tensorType.getRank()); + for (int64_t dim : tensorType.getShape()) + sizes.push_back(rewriter.getIndexAttr(dim)); + + auto referencedValue = + tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides); + mapper.map(value, referencedValue.getResult()); + return referencedValue.getResult(); + } + + if (!isa(definingOp)) + return failure(); + + IRMapping localMapper; + for (Value operand : definingOp->getOperands()) { + if (auto mapped = mapper.lookupOrNull(operand)) { + localMapper.map(operand, cast(mapped)); + continue; + } + + if (isWeightLikeComputeOperand(operand)) { + auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper); + if (failed(clonedOperand)) + return failure(); + localMapper.map(operand, *clonedOperand); + continue; + } + + localMapper.map(operand, operand); + } + + Operation* clonedOp = rewriter.clone(*definingOp, localMapper); + for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults())) + mapper.map(oldResult, newResult); + + auto mapped = mapper.lookupOrNull(value); + if (!mapped) + return failure(); + return cast(mapped); +} + // TODO what we want to keep in global? void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) { Location loc = funcOp.getLoc(); @@ -328,6 +392,85 @@ void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { }); } +LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) { + IRRewriter rewriter(&getContext()); + SmallVector computes(funcOp.getOps()); + + for (auto compute : computes) { + SmallVector promoteInput(compute.getInputs().size(), false); + bool needsRewrite = false; + for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { + if (!isWeightLikeComputeOperand(input)) + continue; + promoteInput[inputIdx] = true; + needsRewrite = true; + } + if (!needsRewrite) + continue; + + rewriter.setInsertionPointAfter(compute); + + SmallVector newWeights(compute.getWeights().begin(), compute.getWeights().end()); + SmallVector newInputs; + SmallVector newInputTypes; + SmallVector newInputLocs; + newWeights.reserve(compute.getWeights().size() + compute.getInputs().size()); + newInputs.reserve(compute.getInputs().size()); + newInputTypes.reserve(compute.getInputs().size()); + newInputLocs.reserve(compute.getInputs().size()); + + for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { + if (promoteInput[inputIdx]) { + newWeights.push_back(input); + continue; + } + newInputs.push_back(input); + newInputTypes.push_back(input.getType()); + newInputLocs.push_back(input.getLoc()); + } + + auto newCompute = + spatial::SpatWeightedCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs); + auto* newBlock = + rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs); + newCompute.getProperties().setOperandSegmentSizes( + {static_cast(newWeights.size()), static_cast(newInputs.size())}); + rewriter.setInsertionPointToStart(newBlock); + + IRMapping mapper; + auto& oldBlock = compute.getBody().front(); + size_t newInputIdx = 0; + for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) { + if (!promoteInput[oldInputIdx]) { + mapper.map(oldArg, newBlock->getArgument(newInputIdx++)); + continue; + } + + auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], rewriter, mapper); + if (failed(clonedValue)) + return compute.emitError("failed to materialize promoted weight-like operand inside compute body"); + mapper.map(oldArg, *clonedValue); + } + + for (auto& op : oldBlock.without_terminator()) + rewriter.clone(op, mapper); + + auto oldYield = cast(oldBlock.getTerminator()); + SmallVector newYieldOperands; + newYieldOperands.reserve(oldYield.getOutputs().size()); + for (Value operand : oldYield.getOutputs()) { + auto mapped = mapper.lookupOrNull(operand); + newYieldOperands.push_back(mapped ? cast(mapped) : operand); + } + spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands); + + compute.replaceAllUsesWith(newCompute); + compute.erase(); + } + + return success(); +} + std::unique_ptr createONNXToSpatialPass() { return std::make_unique(); } } // namespace onnx_mlir