diff --git a/src/PIM/Common/PIMCommon.cpp b/src/PIM/Common/PIMCommon.cpp index f858003..ada0745 100644 --- a/src/PIM/Common/PIMCommon.cpp +++ b/src/PIM/Common/PIMCommon.cpp @@ -1,17 +1,40 @@ +#include "llvm/Support/raw_os_ostream.h" + +#include +#include + #include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Compiler/CompilerOptions.hpp" using namespace mlir; namespace onnx_mlir { -llvm::FailureOr getOtherEndOfChannel( - Operation *op, bool opIsReceive, RewriterBase &rewriter) { +std::string getOutputDir() { return outputBaseName.substr(0, outputBaseName.find_last_of('/')); } + +void createDirectory(const std::string& directory) { + std::error_code errorCode; + std::filesystem::create_directories(directory, errorCode); + assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data()); +} + +void dumpModule(ModuleOp moduleOp, const std::string& name) { + std::string dialectsDir = getOutputDir() + "/dialects"; + createDirectory(dialectsDir); + + std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out); + llvm::raw_os_ostream os(file); + os << *moduleOp; + os.flush(); + file.close(); +} + +FailureOr getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) { auto channelNewOp = op->getOperand(0).getDefiningOp(); if (!channelNewOp) { - op->emitError( - "User of Channel must have the first operand created by ChannelNewOp."); + op->emitError("User of Channel must have the first operand created by ChannelNewOp."); return failure(); } // channelNewOp should have two users: `op` and a @@ -35,12 +58,14 @@ llvm::FailureOr getOtherEndOfChannel( "more than two found."); return failure(); } - Operation *notOpUser; + Operation* notOpUser; if (firstUser == op) { notOpUser = secondUser; - } else if (secondUser == op) { + } + else if (secondUser == op) { notOpUser = firstUser; - } else { + } + else { op->emitError("Operand generated by ChannelNewOp must have two users, " "and one of them must be me, but" "none of them is actually me."); @@ -54,7 +79,8 @@ llvm::FailureOr getOtherEndOfChannel( return failure(); } return notOpUser; - } else { + } + else { if (!isa(notOpUser)) { op->emitError("Operand generated by ChannelNewOp has two user, one is " "me, the other is not a ChannelReceiveOp."); @@ -64,4 +90,4 @@ llvm::FailureOr getOtherEndOfChannel( } } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Common/PIMCommon.hpp b/src/PIM/Common/PIMCommon.hpp index 2ae0e30..fac20a6 100644 --- a/src/PIM/Common/PIMCommon.hpp +++ b/src/PIM/Common/PIMCommon.hpp @@ -3,14 +3,22 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" + #include "llvm/ADT/StringRef.h" -const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = - "pim.constant.should_allocate"; +#include "src/Compiler/CompilerOptions.hpp" + +const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = "pim.constant.should_allocate"; namespace onnx_mlir { -llvm::FailureOr getOtherEndOfChannel( - mlir::Operation *op, bool opIsReceive, mlir::RewriterBase &rewriter); +std::string getOutputDir(); -} // namespace onnx_mlir \ No newline at end of file +void createDirectory(const std::string &directory); + +void dumpModule(mlir::ModuleOp moduleOp, const std::string &name); + +llvm::FailureOr +getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter); + +} // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 605028e..cecccd9 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -433,14 +433,7 @@ std::string getMemorySizeAsString(size_t size) { return std::to_string(size) + " Bytes"; } -OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(const OwningOpRef& moduleOpRef, std::string& outputDirPath) { - ModuleOp moduleOp = moduleOpRef.get(); - - if (pimEmissionTarget != EmitPimCodegen) { - moduleOp.dump(); - return CompilerSuccess; - } - +OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(ModuleOp& moduleOp, std::string& outputDirPath) { if (!outputDirPath.empty()) { if (auto error = llvm::sys::fs::create_directory(outputDirPath)) { llvm::errs() << "Error creating output directory: " << outputDirPath << ": " << error.message() << '\n'; @@ -696,8 +689,6 @@ OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(const OwningOpRef& m jsonOS << llvm::json::Value(std::move(configJson)) << '\n'; jsonOS.close(); - showCompilePhase("Code generated into " + configPath); - return CompilerSuccess; } diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index b6cc382..70de763 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -36,7 +36,7 @@ public: void allocateCore(Operation* op); size_t getFirstAvailableAddress() const { return firstAvailableAddress; } - MemEntry getMemEntry(Value value) const ; + MemEntry getMemEntry(Value value) const; }; class PimAcceleratorMemory { diff --git a/src/PIM/Compiler/PimCompilerOptions.cpp b/src/PIM/Compiler/PimCompilerOptions.cpp index 695b0a2..e13af09 100644 --- a/src/PIM/Compiler/PimCompilerOptions.cpp +++ b/src/PIM/Compiler/PimCompilerOptions.cpp @@ -17,11 +17,6 @@ namespace onnx_mlir { -llvm::cl::opt pimOutputDir("pim-output-dir", - llvm::cl::desc("Directory where pim json code will be emitted"), - llvm::cl::init("pim"), - llvm::cl::cat(OnnxMlirOptions)); - llvm::cl::opt pimEmissionTarget( llvm::cl::desc("[Optional] Choose PIM-related target to emit (once selected it will cancel the other targets):"), llvm::cl::values(clEnumVal(EmitSpatial, "Lower model to spatial IR")), diff --git a/src/PIM/Compiler/PimCompilerOptions.hpp b/src/PIM/Compiler/PimCompilerOptions.hpp index 053b625..cd95654 100644 --- a/src/PIM/Compiler/PimCompilerOptions.hpp +++ b/src/PIM/Compiler/PimCompilerOptions.hpp @@ -21,7 +21,6 @@ typedef enum { } PimEmissionTargetType; extern llvm::cl::OptionCategory OnnxMlirOptions; -extern llvm::cl::opt pimOutputDir; extern llvm::cl::opt pimEmissionTarget; extern llvm::cl::opt pimOnlyCodegen; diff --git a/src/PIM/Compiler/PimCompilerUtils.cpp b/src/PIM/Compiler/PimCompilerUtils.cpp index 6186bd4..c41b6ad 100644 --- a/src/PIM/Compiler/PimCompilerUtils.cpp +++ b/src/PIM/Compiler/PimCompilerUtils.cpp @@ -1,18 +1,11 @@ -#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/Passes.h" + #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" #include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Compiler/CompilerPasses.hpp" - -#include "llvm/Support/JSON.h" - -#include -#include +#include "src/Compiler/CompilerUtils.hpp" #define DEBUG_TYPE "PimCompilerUtils" @@ -37,19 +30,25 @@ void addPassesPim(OwningOpRef& module, if (pimEmissionTarget >= EmitSpatial) { pm.addPass(createONNXToSpatialPass()); // pm.addPass(createCountInstructionPass()); - pm.addPass(createMessagePass("ONNX lowered to SPATIAL")); + pm.addPass(createMessagePass("Onnx lowered to Spatial")); } if (pimEmissionTarget >= EmitPim) { pm.addPass(createSpatialToPIMPass()); // pm.addPass(createCountInstructionPass()); - pm.addPass(createMessagePass("SPATIAL lowered to PIM")); + pm.addPass(createMessagePass("Spatial lowered to Pim")); } if (pimEmissionTarget >= EmitPimBufferized) { pm.addPass(createBufferizePimPass()); // pm.addPass(createCountInstructionPass()); - pm.addPass(createMessagePass("PIM bufferized")); + pm.addPass(createMessagePass("Pim bufferized")); + } + + if (pimEmissionTarget >= EmitPimCodegen) { + pm.addPass(createEmitPimJsonPass()); + // pm.addPass(createCountInstructionPass()); + pm.addPass(createMessagePass("Pim json code emitted")); } } diff --git a/src/PIM/Compiler/PimCompilerUtils.hpp b/src/PIM/Compiler/PimCompilerUtils.hpp index 077d566..55a5369 100644 --- a/src/PIM/Compiler/PimCompilerUtils.hpp +++ b/src/PIM/Compiler/PimCompilerUtils.hpp @@ -13,7 +13,6 @@ void addPassesPim(mlir::OwningOpRef& module, EmissionTargetType& emissionTarget, std::string outputNameNoExt); -OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(const mlir::OwningOpRef& moduleOpRef, - std::string& outputDirName); +OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(mlir::ModuleOp& moduleOpRef, std::string& outputDirName); } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index ed22fbe..0981380 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -9,6 +9,7 @@ #include #include +#include "Common/PIMCommon.hpp" #include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp" #include "ONNXToSpatialPass.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" @@ -27,7 +28,7 @@ namespace spatial { void ONNXToSpatialPass::runOnOperation() { llvm::dbgs() << "Running ONNXToSpatialLoweringPass\n"; - ModuleOp module = getOperation(); + ModuleOp moduleOp = getOperation(); MLIRContext* ctx = &getContext(); RewritePatternSet mergeActivationPatterns(ctx); @@ -38,11 +39,11 @@ void ONNXToSpatialPass::runOnOperation() { mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); - if (failed(applyPatternsAndFoldGreedily(module, std::move(mergeActivationPatterns)))) + if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(mergeActivationPatterns)))) llvm::dbgs() << "Failed to merge activation patterns, continuing...\n"; - IRRewriter rewriter(module); - func::FuncOp funcOp = *module.getOps().begin(); + IRRewriter rewriter(moduleOp); + func::FuncOp funcOp = *moduleOp.getOps().begin(); if (annotateReplication(funcOp, rewriter).failed()) { llvm::dbgs() << "Failed during annotation for replication analysis\n"; signalPassFailure(); @@ -78,7 +79,7 @@ void ONNXToSpatialPass::runOnOperation() { populateONNXConcatToTensorConcatPattern(patterns, ctx); populateReduceMeanConversionPattern(patterns, ctx); - if (failed(applyPartialConversion(module, target, std::move(patterns)))) { + if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); return; } @@ -101,19 +102,13 @@ void ONNXToSpatialPass::runOnOperation() { RewritePatternSet removeUnusedHelperOpsPatterns(ctx); populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx); - if (failed(applyPatternsAndFoldGreedily(module, std::move(removeUnusedHelperOpsPatterns)))) + if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns)))) llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n"; annotateWeightsConstants(funcOp); // Dump to file for debug - std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects"); - std::filesystem::create_directory(outputDir); - std::fstream file(outputDir + "/spatial.mlir", std::ios::out); - llvm::raw_os_ostream os(file); - os << *module; - os.flush(); - file.close(); + dumpModule(moduleOp, "spatial"); } void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp index 0d1b8fe..11d9a8a 100644 --- a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp @@ -13,7 +13,6 @@ #include #include -#include #include #include @@ -75,13 +74,7 @@ void SpatialToPIMPass::runOnOperation() { } // Dump to file for debug - std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects"); - std::filesystem::create_directory(outputDir); - std::fstream file(outputDir + "/pim.mlir", std::ios::out); - llvm::raw_os_ostream os(file); - os << *moduleOp; - os.flush(); - file.close(); + dumpModule(moduleOp, "pim"); } void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) { diff --git a/src/PIM/Pass/EmitPimJsonPass.cpp b/src/PIM/Pass/EmitPimJsonPass.cpp index 1b857c2..052a584 100644 --- a/src/PIM/Pass/EmitPimJsonPass.cpp +++ b/src/PIM/Pass/EmitPimJsonPass.cpp @@ -1,8 +1,6 @@ #include "mlir/Pass/Pass.h" -#include - -#include "Compiler/PimCompilerOptions.hpp" +#include "Common/PIMCommon.hpp" #include "Compiler/PimCompilerUtils.hpp" using namespace mlir; @@ -19,20 +17,13 @@ struct EmitPimJsonPass : PassWrapper> { EmitPimJsonPass() {} EmitPimJsonPass(const EmitPimJsonPass& pass) {} - void runOnOperation() final { + void runOnOperation() override { ModuleOp moduleOp = getOperation(); - std::filesystem::path pimDir(pimOutputDir.data()); + std::string pimDir = getOutputDir() + "/pim"; + createDirectory(pimDir); - std::error_code error_code; - std::filesystem::create_directories(pimDir, error_code); - if (error_code) { - moduleOp.emitError("Failed to create PIM output directory: " + error_code.message()); - signalPassFailure(); - return; - } - - int compiler_error_code = compileModuleToPIMJSON(moduleOp, pimOutputDir); + int compiler_error_code = compileModuleToPIMJSON(moduleOp, pimDir); if (compiler_error_code != CompilerSuccess) signalPassFailure(); } diff --git a/src/PIM/Pass/MessagePass.cpp b/src/PIM/Pass/MessagePass.cpp index 1fd61ce..0ac5c29 100644 --- a/src/PIM/Pass/MessagePass.cpp +++ b/src/PIM/Pass/MessagePass.cpp @@ -1,4 +1,5 @@ #include "mlir/Pass/Pass.h" + #include "src/Compiler/CompilerUtils.hpp" using namespace mlir; @@ -7,21 +8,15 @@ namespace onnx_mlir { namespace { -struct MessagePass : public PassWrapper> { - +struct MessagePass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MessagePass) - StringRef getArgument() const override { return "message-pass"; } + StringRef getDescription() const override { return "Print compilation status"; } - StringRef getDescription() const override { - return "Lower ONNX ops to Spatial ops."; - } + MessagePass(std::string message) + : message(message) {} + MessagePass(const MessagePass& pass) {} - // Make sure that we have a valid default constructor and copy - // constructor to make sure that the options are initialized properly. - MessagePass(std::string message) : message(message) {} - MessagePass(const MessagePass &pass) - : PassWrapper>() {} void runOnOperation() final { showCompilePhase(message); } private: @@ -30,8 +25,6 @@ private: } // namespace -std::unique_ptr createMessagePass(std::string message) { - return std::make_unique(message); -} +std::unique_ptr createMessagePass(std::string message) { return std::make_unique(message); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Transforms/PimBufferizationPass.cpp b/src/PIM/Transforms/PimBufferizationPass.cpp index 054d312..18d3522 100644 --- a/src/PIM/Transforms/PimBufferizationPass.cpp +++ b/src/PIM/Transforms/PimBufferizationPass.cpp @@ -7,9 +7,7 @@ #include "llvm/Support/raw_os_ostream.h" -#include -#include - +#include "Common/PIMCommon.hpp" #include "Compiler/PimCodeGen.hpp" #include "PimBufferizationPass.hpp" #include "src/Compiler/CompilerOptions.hpp" @@ -59,14 +57,7 @@ void PimBufferizationPass::runOnOperation() { annotateWeightsMemrefs(moduleOp, funcOp); // Dump to file for debug - ModuleOp module = getOperation(); - std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects"); - std::filesystem::create_directory(outputDir); - std::fstream file(outputDir + "/pim_buf.mlir", std::ios::out); - llvm::raw_os_ostream os(file); - os << *module; - os.flush(); - file.close(); + dumpModule(moduleOp, "pim_buf"); } void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const { diff --git a/validation/operations/gemm/with_scalar_constant/gemm_with_scalar_constant.onnx b/validation/operations/gemm/with_scalar_constant/gemm_with_scalar_constant.onnx new file mode 100644 index 0000000..04197fd Binary files /dev/null and b/validation/operations/gemm/with_scalar_constant/gemm_with_scalar_constant.onnx differ