#pragma once #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/FoldUtils.h" #include "llvm/ADT/StringRef.h" #include #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" namespace onnx_mlir { namespace raptor { struct SpatialToPimPass : mlir::PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass) llvm::StringRef getArgument() const override { return "convert-spatial-to-pim"; } llvm::StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; } SpatialToPimPass() = default; SpatialToPimPass(const SpatialToPimPass& pass) {} void runOnOperation() final; private: using OutputTensorFactory = std::function; llvm::SmallVector outputTensors; llvm::SmallVector operationsToRemove; mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter); mlir::LogicalResult lowerComputeOp(spatial::SpatScheduledCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder); mlir::LogicalResult lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp, mlir::IRRewriter& rewriter); enum class ReturnPathLoweringResult { Handled, NotReturnPath, Failure }; void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter); ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatScheduledCompute computeOp, mlir::OpResult result, mlir::Value yieldValue, mlir::IRRewriter& rewriter); ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp, mlir::Value producedValue, mlir::Value storedValue, mlir::IRRewriter& rewriter); void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter); void markOpToRemove(mlir::Operation* op); void eraseOpsToRemove(); mlir::LogicalResult enlargeVMMOutTensorsToCrossbarSize(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter); }; } // namespace raptor } // namespace onnx_mlir