#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/Support/Debug.h" #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" #include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/PimAccelerator.hpp" #define DEBUG_TYPE "PimAccelerator" namespace onnx_mlir { namespace accel { Accelerator* createPIM() { return PimAccelerator::getInstance(); } PimAccelerator* PimAccelerator::instance = nullptr; PimAccelerator* PimAccelerator::getInstance() { if (instance == nullptr) instance = new PimAccelerator(); return instance; } PimAccelerator::PimAccelerator() : Accelerator(Kind::PIM) { LLVM_DEBUG(llvm::dbgs() << "Creating a PIM accelerator\n"); acceleratorTargets.push_back(this); } PimAccelerator::~PimAccelerator() { delete instance; } uint64_t PimAccelerator::getVersionNumber() const { return 0x000001; } void PimAccelerator::addPasses(OwningOpRef& module, PassManager& pm, EmissionTargetType& emissionTarget, std::string outputNameNoExt) const { LLVM_DEBUG(llvm::dbgs() << "Adding passes for PIM accelerator\n"); addPassesPim(module, pm, emissionTarget, outputNameNoExt); } void PimAccelerator::registerDialects(DialectRegistry& registry) const { LLVM_DEBUG(llvm::dbgs() << "Registering dialects for PIM accelerator\n"); registry.insert(); registry.insert(); registry.insert(); registry.insert(); registry.insert(); tensor::registerBufferizableOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry); bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(registry); spatial::registerBufferizableOpInterfaceExternalModels(registry); spatial::registerONNXBufferizableOpInterfaceExternalModels(registry); pim::registerBufferizableOpInterfaceExternalModels(registry); } void PimAccelerator::registerPasses(int optLevel) const { LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n"); registerPass(createONNXToSpatialPass); registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToPIMPass); registerPass(createBufferizePimPass); registerPass(createEmitPimJsonPass); } void PimAccelerator::configurePasses() const { LLVM_DEBUG(llvm::dbgs() << "Configuring passes for PIM accelerator\n"); // TODO: This does nothing for now. } MemRefType PimAccelerator::convertTensorTypeToMemRefType(const TensorType tensorType) const { // Do not convert tensor types to memref types. return nullptr; } void PimAccelerator::conversionTargetONNXToKrnl(ConversionTarget& target) const { target.addLegalDialect(); } void PimAccelerator::rewritePatternONNXToKrnl(RewritePatternSet& patterns, TypeConverter& typeConverter, MLIRContext* ctx) const { // TODO: Add patterns for conversion } void PimAccelerator::conversionTargetKrnlToLLVM(ConversionTarget& target) const {} void PimAccelerator::rewritePatternKrnlToLLVM(RewritePatternSet& patterns, LLVMTypeConverter& typeConverter, MLIRContext* ctx) const { // We should not need this, since we offload it all to PIM. } } // namespace accel } // namespace onnx_mlir