#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_os_ostream.h" #include #include "Common.hpp" #include "Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { bool haveSameStaticShape(Value lhs, Value rhs); namespace { #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc" struct ONNXToSpatialPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass) StringRef getArgument() const override { return "convert-onnx-to-spatial"; } StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; } ONNXToSpatialPass() = default; ONNXToSpatialPass(const ONNXToSpatialPass& pass) {} void runOnOperation() override; private: void annotateWeightsConstants(func::FuncOp funcOp) const; void encapsulateGlobalInstruction(func::FuncOp funcOp); }; } // namespace void ONNXToSpatialPass::runOnOperation() { ModuleOp moduleOp = getOperation(); MLIRContext* ctx = &getContext(); RewritePatternSet mergeActivationPatterns(ctx); mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); populateMatMulRewritePatterns(mergeActivationPatterns, ctx); if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns)))) llvm::dbgs() << "Failed to merge activation patterns, continuing...\n"; IRRewriter rewriter(moduleOp); auto entryFunc = getPimEntryFunc(moduleOp); if (failed(entryFunc)) { signalPassFailure(); return; } ConversionTarget target(*ctx); target.addLegalDialect(); target.addDynamicallyLegalOp( [](ONNXMatMulOp op) { return cast(op.getY().getType()).getRank() != 2; }); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); RewritePatternSet patterns(ctx); patterns.add(ctx); populateElementwisePatterns(patterns, ctx); populateGemmPatterns(patterns, ctx); populateConvPatterns(patterns, ctx); populatePoolPatterns(patterns, ctx); populateReduceMeanPatterns(patterns, ctx); populateReluPatterns(patterns, ctx); populateSigmoidPatterns(patterns, ctx); populateSoftmaxPatterns(patterns, ctx); populateConcatPatterns(patterns, ctx); populateGatherPatterns(patterns, ctx); populateResizePatterns(patterns, ctx); populateReshapePatterns(patterns, ctx); populateSplitPatterns(patterns, ctx); if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); return; } // Count the number of compute ops and check they do not exceed the core count if (coresCount != -1) { int computeOpsCount = 0; for (auto& op : entryFunc->getFunctionBody().front().getOperations()) if (isa(op)) computeOpsCount++; if (computeOpsCount > coresCount) { llvm::dbgs() << "Number of compute ops exceeds the core count\n"; signalPassFailure(); return; } } PassManager cleanupPM(ctx); cleanupPM.addPass(createCanonicalizerPass()); if (failed(cleanupPM.run(moduleOp))) llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n"; annotateWeightsConstants(*entryFunc); encapsulateGlobalInstruction(*entryFunc); // Dump to file for debug dumpModule(moduleOp, "spatial"); } template bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function funcSource) { if (T toRemoveOp = llvm::dyn_cast_if_present(inst)) { Value source = funcSource(toRemoveOp); rewriter.setInsertionPointAfter(toRemoveOp); if (isa_and_present(source.getDefiningOp())) { auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), source); auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); rewriter.setInsertionPointToEnd(BB); IRMapping mapper; mapper.map(source, BB->getArgument(0)); auto newInst = rewriter.clone(*inst, mapper); spatial::SpatYieldOp::create(rewriter, loc, newInst->getResult(0)); inst->replaceAllUsesWith(newCompute); inst->erase(); return true; } } return false; } bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) { if (auto toRemoveOp = llvm::dyn_cast_if_present(inst)) { auto sources = toRemoveOp.getInputs(); rewriter.setInsertionPointAfter(toRemoveOp); if (llvm::any_of( sources, [](auto source) { return isa_and_present(source.getDefiningOp()); })) { auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources); llvm::SmallVector sourceTypes; llvm::SmallVector sourceLoc; for (auto source : sources) { sourceTypes.push_back(source.getType()); sourceLoc.push_back(loc); } auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc); newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()}); rewriter.setInsertionPointToEnd(BB); IRMapping mapper; for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments())) mapper.map(source, bbArg); auto newConcat = rewriter.clone(*inst, mapper); spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0)); inst->replaceAllUsesWith(newCompute); inst->erase(); return true; } } return false; } // TODO what we want to keep in global? void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) { Location loc = funcOp.getLoc(); IRRewriter rewriter(&getContext()); bool keep = true; while (keep) { keep = false; for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) { keep |= encapsulator( rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); }); keep |= encapsulator( rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); }); keep |= encapsulator( rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); }); keep |= encapsulator( rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); }); keep |= encapsulateConcat(rewriter, loc, &instruction); } } } void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { funcOp.walk([&](arith::ConstantOp constantOp) { bool isAlwaysWeight = llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa(user); }); if (isAlwaysWeight) markWeightAlways(constantOp); }); } std::unique_ptr createONNXToSpatialPass() { return std::make_unique(); } } // namespace onnx_mlir