#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_os_ostream.h" #include #include #include #include "Common.hpp" #include "Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { bool haveSameStaticShape(Value lhs, Value rhs); namespace { #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc" struct ONNXToSpatialPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass) StringRef getArgument() const override { return "convert-onnx-to-spatial"; } StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; } ONNXToSpatialPass() = default; ONNXToSpatialPass(const ONNXToSpatialPass& pass) {} void runOnOperation() override; private: void annotateWeightsConstants(func::FuncOp funcOp) const; void encapsulateGlobalInstruction(func::FuncOp funcOp); LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp); }; } // namespace static void foldSingleLaneComputeBatches(func::FuncOp funcOp) { IRRewriter rewriter(funcOp.getContext()); SmallVector batchOps; funcOp.walk([&](spatial::SpatComputeBatch batchOp) { batchOps.push_back(batchOp); }); for (auto batchOp : batchOps) { if (batchOp.getLaneCount() != 1) continue; auto loc = batchOp.getLoc(); rewriter.setInsertionPoint(batchOp); auto computeOp = spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs()); computeOp.getProperties().setOperandSegmentSizes( {static_cast(batchOp.getWeights().size()), static_cast(batchOp.getInputs().size())}); Block& templateBlock = batchOp.getBody().front(); SmallVector blockArgTypes; SmallVector blockArgLocs; for (BlockArgument arg : templateBlock.getArguments()) { blockArgTypes.push_back(arg.getType()); blockArgLocs.push_back(loc); } auto* newBlock = rewriter.createBlock( &computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); IRMapping mapper; for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments())) mapper.map(oldArg, newArg); rewriter.setInsertionPointToEnd(newBlock); for (Operation& op : templateBlock) rewriter.clone(op, mapper); batchOp.replaceAllUsesWith(computeOp.getResults()); rewriter.eraseOp(batchOp); } } void ONNXToSpatialPass::runOnOperation() { ModuleOp moduleOp = getOperation(); MLIRContext* ctx = &getContext(); RewritePatternSet mergeActivationPatterns(ctx); mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); populateMatMulRewritePatterns(mergeActivationPatterns, ctx); if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns)))) llvm::dbgs() << "Failed to merge activation patterns, continuing...\n"; IRRewriter rewriter(moduleOp); auto entryFunc = getPimEntryFunc(moduleOp); if (failed(entryFunc)) { signalPassFailure(); return; } ConversionTarget target(*ctx); target.addLegalDialect(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); RewritePatternSet patterns(ctx); patterns.add(ctx); populateElementwisePatterns(patterns, ctx); populateGemmPatterns(patterns, ctx); populateConvPatterns(patterns, ctx); populatePoolPatterns(patterns, ctx); populateReduceMeanPatterns(patterns, ctx); populateReluPatterns(patterns, ctx); populateSigmoidPatterns(patterns, ctx); populateSoftmaxPatterns(patterns, ctx); populateConcatPatterns(patterns, ctx); populateGatherPatterns(patterns, ctx); populateResizePatterns(patterns, ctx); populateReshapePatterns(patterns, ctx); populateSplitPatterns(patterns, ctx); if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); return; } foldSingleLaneComputeBatches(*entryFunc); // Count the number of compute ops and check they do not exceed the core count if (coresCount != -1) { int computeOpsCount = 0; for (auto& op : entryFunc->getFunctionBody().front().getOperations()) if (isa(op)) computeOpsCount++; if (computeOpsCount > coresCount) { llvm::dbgs() << "Number of compute ops exceeds the core count\n"; signalPassFailure(); return; } } PassManager cleanupPM(ctx); cleanupPM.addPass(createCanonicalizerPass()); if (failed(cleanupPM.run(moduleOp))) llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n"; annotateWeightsConstants(*entryFunc); encapsulateGlobalInstruction(*entryFunc); if (failed(promoteConstantInputsToWeights(*entryFunc))) { signalPassFailure(); return; } // Dump to file for debug dumpModule(moduleOp, "spatial0"); } template bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function funcSource) { if (T toRemoveOp = llvm::dyn_cast_if_present(inst)) { Value source = funcSource(toRemoveOp); rewriter.setInsertionPointAfter(toRemoveOp); if (isa_and_present(source.getDefiningOp())) { auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source); auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); rewriter.setInsertionPointToEnd(BB); IRMapping mapper; mapper.map(source, BB->getArgument(0)); auto newInst = rewriter.clone(*inst, mapper); spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults()); inst->replaceAllUsesWith(newCompute->getResults()); inst->erase(); return true; } } return false; } bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) { if (auto toRemoveOp = llvm::dyn_cast_if_present(inst)) { auto sources = toRemoveOp.getInputs(); rewriter.setInsertionPointAfter(toRemoveOp); if (llvm::any_of(sources, [](auto source) { return isa_and_present(source.getDefiningOp()); })) { auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources); SmallVector sourceTypes; SmallVector sourceLoc; for (auto source : sources) { sourceTypes.push_back(source.getType()); sourceLoc.push_back(loc); } auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc); newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()}); rewriter.setInsertionPointToEnd(BB); IRMapping mapper; for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments())) mapper.map(source, bbArg); auto newConcat = spatial::SpatConcatOp::create(rewriter, loc, toRemoveOp.getType(), rewriter.getI64IntegerAttr(toRemoveOp.getDim()), ValueRange(BB->getArguments())); spatial::SpatYieldOp::create(rewriter, loc, newConcat.getOutput()); inst->replaceAllUsesWith(newCompute->getResults()); inst->erase(); return true; } } return false; } static FailureOr materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) { if (auto mapped = mapper.lookupOrNull(value)) return cast(mapped); Operation* definingOp = value.getDefiningOp(); if (!definingOp) return failure(); if (isa(definingOp)) { auto tensorType = dyn_cast(value.getType()); if (!tensorType || !tensorType.hasStaticShape()) return failure(); SmallVector offsets(tensorType.getRank(), rewriter.getIndexAttr(0)); SmallVector sizes; SmallVector strides(tensorType.getRank(), rewriter.getIndexAttr(1)); sizes.reserve(tensorType.getRank()); for (int64_t dim : tensorType.getShape()) sizes.push_back(rewriter.getIndexAttr(dim)); auto referencedValue = tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides); mapper.map(value, referencedValue.getResult()); return referencedValue.getResult(); } if (!isa(definingOp)) return failure(); IRMapping localMapper; for (Value operand : definingOp->getOperands()) { if (auto mapped = mapper.lookupOrNull(operand)) { localMapper.map(operand, cast(mapped)); continue; } if (isWeightLikeComputeOperand(operand)) { auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper); if (failed(clonedOperand)) return failure(); localMapper.map(operand, *clonedOperand); continue; } localMapper.map(operand, operand); } Operation* clonedOp = rewriter.clone(*definingOp, localMapper); for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults())) mapper.map(oldResult, newResult); auto mapped = mapper.lookupOrNull(value); if (!mapped) return failure(); return cast(mapped); } // TODO what we want to keep in global? void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) { Location loc = funcOp.getLoc(); IRRewriter rewriter(&getContext()); bool keep = true; while (keep) { keep = false; for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) { keep |= encapsulator( rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); }); keep |= encapsulator( rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); }); keep |= encapsulator( rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); }); keep |= encapsulator( rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); }); keep |= encapsulateConcat(rewriter, loc, &instruction); } } } void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { funcOp.walk([&](arith::ConstantOp constantOp) { if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult())) markWeightAlways(constantOp); }); } LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) { IRRewriter rewriter(&getContext()); SmallVector computes(funcOp.getOps()); for (auto compute : computes) { SmallVector promoteInput(compute.getInputs().size(), false); bool needsRewrite = false; for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { if (!isWeightLikeComputeOperand(input)) continue; promoteInput[inputIdx] = true; needsRewrite = true; } if (!needsRewrite) continue; rewriter.setInsertionPointAfter(compute); SmallVector newWeights(compute.getWeights().begin(), compute.getWeights().end()); SmallVector newInputs; SmallVector newInputTypes; SmallVector newInputLocs; newWeights.reserve(compute.getWeights().size() + compute.getInputs().size()); newInputs.reserve(compute.getInputs().size()); newInputTypes.reserve(compute.getInputs().size()); newInputLocs.reserve(compute.getInputs().size()); for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { if (promoteInput[inputIdx]) { newWeights.push_back(input); continue; } newInputs.push_back(input); newInputTypes.push_back(input.getType()); newInputLocs.push_back(input.getLoc()); } auto newCompute = spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs); auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs); newCompute.getProperties().setOperandSegmentSizes( {static_cast(newWeights.size()), static_cast(newInputs.size())}); rewriter.setInsertionPointToStart(newBlock); IRMapping mapper; auto& oldBlock = compute.getBody().front(); size_t newInputIdx = 0; for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) { if (!promoteInput[oldInputIdx]) { mapper.map(oldArg, newBlock->getArgument(newInputIdx++)); continue; } auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], rewriter, mapper); if (failed(clonedValue)) return compute.emitError("failed to materialize promoted weight-like operand inside compute body"); mapper.map(oldArg, *clonedValue); } for (auto& op : oldBlock.without_terminator()) rewriter.clone(op, mapper); auto oldYield = cast(oldBlock.getTerminator()); SmallVector newYieldOperands; newYieldOperands.reserve(oldYield.getOutputs().size()); for (Value operand : oldYield.getOutputs()) { auto mapped = mapper.lookupOrNull(operand); newYieldOperands.push_back(mapped ? cast(mapped) : operand); } spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands); compute.replaceAllUsesWith(newCompute); compute.erase(); } return success(); } std::unique_ptr createONNXToSpatialPass() { return std::make_unique(); } } // namespace onnx_mlir