fix output paths

add gemm test
This commit is contained in:
NiccoloN
2026-02-25 17:24:31 +01:00
parent d036c02160
commit ae6e815c7b
14 changed files with 86 additions and 106 deletions

View File

@@ -1,17 +1,40 @@
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
#include <fstream>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
llvm::FailureOr<Operation *> getOtherEndOfChannel( std::string getOutputDir() { return outputBaseName.substr(0, outputBaseName.find_last_of('/')); }
Operation *op, bool opIsReceive, RewriterBase &rewriter) {
void createDirectory(const std::string& directory) {
std::error_code errorCode;
std::filesystem::create_directories(directory, errorCode);
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
}
void dumpModule(ModuleOp moduleOp, const std::string& name) {
std::string dialectsDir = getOutputDir() + "/dialects";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *moduleOp;
os.flush();
file.close();
}
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>(); auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) { if (!channelNewOp) {
op->emitError( op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
"User of Channel must have the first operand created by ChannelNewOp.");
return failure(); return failure();
} }
// channelNewOp should have two users: `op` and a // channelNewOp should have two users: `op` and a
@@ -38,9 +61,11 @@ llvm::FailureOr<Operation *> getOtherEndOfChannel(
Operation* notOpUser; Operation* notOpUser;
if (firstUser == op) { if (firstUser == op) {
notOpUser = secondUser; notOpUser = secondUser;
} else if (secondUser == op) { }
else if (secondUser == op) {
notOpUser = firstUser; notOpUser = firstUser;
} else { }
else {
op->emitError("Operand generated by ChannelNewOp must have two users, " op->emitError("Operand generated by ChannelNewOp must have two users, "
"and one of them must be me, but" "and one of them must be me, but"
"none of them is actually me."); "none of them is actually me.");
@@ -54,7 +79,8 @@ llvm::FailureOr<Operation *> getOtherEndOfChannel(
return failure(); return failure();
} }
return notOpUser; return notOpUser;
} else { }
else {
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) { if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is " op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelReceiveOp."); "me, the other is not a ChannelReceiveOp.");

View File

@@ -3,14 +3,22 @@
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = #include "src/Compiler/CompilerOptions.hpp"
"pim.constant.should_allocate";
const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = "pim.constant.should_allocate";
namespace onnx_mlir { namespace onnx_mlir {
llvm::FailureOr<mlir::Operation *> getOtherEndOfChannel( std::string getOutputDir();
mlir::Operation *op, bool opIsReceive, mlir::RewriterBase &rewriter);
void createDirectory(const std::string &directory);
void dumpModule(mlir::ModuleOp moduleOp, const std::string &name);
llvm::FailureOr<mlir::Operation*>
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -433,14 +433,7 @@ std::string getMemorySizeAsString(size_t size) {
return std::to_string(size) + " Bytes"; return std::to_string(size) + " Bytes";
} }
OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(const OwningOpRef<ModuleOp>& moduleOpRef, std::string& outputDirPath) { OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(ModuleOp& moduleOp, std::string& outputDirPath) {
ModuleOp moduleOp = moduleOpRef.get();
if (pimEmissionTarget != EmitPimCodegen) {
moduleOp.dump();
return CompilerSuccess;
}
if (!outputDirPath.empty()) { if (!outputDirPath.empty()) {
if (auto error = llvm::sys::fs::create_directory(outputDirPath)) { if (auto error = llvm::sys::fs::create_directory(outputDirPath)) {
llvm::errs() << "Error creating output directory: " << outputDirPath << ": " << error.message() << '\n'; llvm::errs() << "Error creating output directory: " << outputDirPath << ": " << error.message() << '\n';
@@ -696,8 +689,6 @@ OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(const OwningOpRef<ModuleOp>& m
jsonOS << llvm::json::Value(std::move(configJson)) << '\n'; jsonOS << llvm::json::Value(std::move(configJson)) << '\n';
jsonOS.close(); jsonOS.close();
showCompilePhase("Code generated into " + configPath);
return CompilerSuccess; return CompilerSuccess;
} }

View File

@@ -17,11 +17,6 @@
namespace onnx_mlir { namespace onnx_mlir {
llvm::cl::opt<std::string> pimOutputDir("pim-output-dir",
llvm::cl::desc("Directory where pim json code will be emitted"),
llvm::cl::init("pim"),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget( llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
llvm::cl::desc("[Optional] Choose PIM-related target to emit (once selected it will cancel the other targets):"), llvm::cl::desc("[Optional] Choose PIM-related target to emit (once selected it will cancel the other targets):"),
llvm::cl::values(clEnumVal(EmitSpatial, "Lower model to spatial IR")), llvm::cl::values(clEnumVal(EmitSpatial, "Lower model to spatial IR")),

View File

@@ -21,7 +21,6 @@ typedef enum {
} PimEmissionTargetType; } PimEmissionTargetType;
extern llvm::cl::OptionCategory OnnxMlirOptions; extern llvm::cl::OptionCategory OnnxMlirOptions;
extern llvm::cl::opt<std::string> pimOutputDir;
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget; extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
extern llvm::cl::opt<bool> pimOnlyCodegen; extern llvm::cl::opt<bool> pimOnlyCodegen;

View File

@@ -1,18 +1,11 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerPasses.hpp" #include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp"
#include "llvm/Support/JSON.h"
#include <cassert>
#include <cstddef>
#define DEBUG_TYPE "PimCompilerUtils" #define DEBUG_TYPE "PimCompilerUtils"
@@ -37,19 +30,25 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitSpatial) { if (pimEmissionTarget >= EmitSpatial) {
pm.addPass(createONNXToSpatialPass()); pm.addPass(createONNXToSpatialPass());
// pm.addPass(createCountInstructionPass()); // pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("ONNX lowered to SPATIAL")); pm.addPass(createMessagePass("Onnx lowered to Spatial"));
} }
if (pimEmissionTarget >= EmitPim) { if (pimEmissionTarget >= EmitPim) {
pm.addPass(createSpatialToPIMPass()); pm.addPass(createSpatialToPIMPass());
// pm.addPass(createCountInstructionPass()); // pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("SPATIAL lowered to PIM")); pm.addPass(createMessagePass("Spatial lowered to Pim"));
} }
if (pimEmissionTarget >= EmitPimBufferized) { if (pimEmissionTarget >= EmitPimBufferized) {
pm.addPass(createBufferizePimPass()); pm.addPass(createBufferizePimPass());
// pm.addPass(createCountInstructionPass()); // pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("PIM bufferized")); pm.addPass(createMessagePass("Pim bufferized"));
}
if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createEmitPimJsonPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim json code emitted"));
} }
} }

View File

@@ -13,7 +13,6 @@ void addPassesPim(mlir::OwningOpRef<mlir::ModuleOp>& module,
EmissionTargetType& emissionTarget, EmissionTargetType& emissionTarget,
std::string outputNameNoExt); std::string outputNameNoExt);
OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(const mlir::OwningOpRef<mlir::ModuleOp>& moduleOpRef, OnnxMlirCompilerErrorCodes compileModuleToPIMJSON(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
std::string& outputDirName);
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -9,6 +9,7 @@
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include "Common/PIMCommon.hpp"
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp" #include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
#include "ONNXToSpatialPass.hpp" #include "ONNXToSpatialPass.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
@@ -27,7 +28,7 @@ namespace spatial {
void ONNXToSpatialPass::runOnOperation() { void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Running ONNXToSpatialLoweringPass\n"; llvm::dbgs() << "Running ONNXToSpatialLoweringPass\n";
ModuleOp module = getOperation(); ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext(); MLIRContext* ctx = &getContext();
RewritePatternSet mergeActivationPatterns(ctx); RewritePatternSet mergeActivationPatterns(ctx);
@@ -38,11 +39,11 @@ void ONNXToSpatialPass::runOnOperation() {
mergeActivationPatterns.add<matMulToGemmPattern>(ctx); mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx); mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
if (failed(applyPatternsAndFoldGreedily(module, std::move(mergeActivationPatterns)))) if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(mergeActivationPatterns))))
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n"; llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
IRRewriter rewriter(module); IRRewriter rewriter(moduleOp);
func::FuncOp funcOp = *module.getOps<func::FuncOp>().begin(); func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin();
if (annotateReplication(funcOp, rewriter).failed()) { if (annotateReplication(funcOp, rewriter).failed()) {
llvm::dbgs() << "Failed during annotation for replication analysis\n"; llvm::dbgs() << "Failed during annotation for replication analysis\n";
signalPassFailure(); signalPassFailure();
@@ -78,7 +79,7 @@ void ONNXToSpatialPass::runOnOperation() {
populateONNXConcatToTensorConcatPattern(patterns, ctx); populateONNXConcatToTensorConcatPattern(patterns, ctx);
populateReduceMeanConversionPattern(patterns, ctx); populateReduceMeanConversionPattern(patterns, ctx);
if (failed(applyPartialConversion(module, target, std::move(patterns)))) { if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure(); signalPassFailure();
return; return;
} }
@@ -101,19 +102,13 @@ void ONNXToSpatialPass::runOnOperation() {
RewritePatternSet removeUnusedHelperOpsPatterns(ctx); RewritePatternSet removeUnusedHelperOpsPatterns(ctx);
populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx); populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx);
if (failed(applyPatternsAndFoldGreedily(module, std::move(removeUnusedHelperOpsPatterns)))) if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns))))
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n"; llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
annotateWeightsConstants(funcOp); annotateWeightsConstants(funcOp);
// Dump to file for debug // Dump to file for debug
std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects"); dumpModule(moduleOp, "spatial");
std::filesystem::create_directory(outputDir);
std::fstream file(outputDir + "/spatial.mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *module;
os.flush();
file.close();
} }
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {

View File

@@ -13,7 +13,6 @@
#include <cassert> #include <cassert>
#include <filesystem> #include <filesystem>
#include <fstream>
#include <string> #include <string>
#include <utility> #include <utility>
@@ -75,13 +74,7 @@ void SpatialToPIMPass::runOnOperation() {
} }
// Dump to file for debug // Dump to file for debug
std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects"); dumpModule(moduleOp, "pim");
std::filesystem::create_directory(outputDir);
std::fstream file(outputDir + "/pim.mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *moduleOp;
os.flush();
file.close();
} }
void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) { void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {

View File

@@ -1,8 +1,6 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include <filesystem> #include "Common/PIMCommon.hpp"
#include "Compiler/PimCompilerOptions.hpp"
#include "Compiler/PimCompilerUtils.hpp" #include "Compiler/PimCompilerUtils.hpp"
using namespace mlir; using namespace mlir;
@@ -19,20 +17,13 @@ struct EmitPimJsonPass : PassWrapper<EmitPimJsonPass, OperationPass<ModuleOp>> {
EmitPimJsonPass() {} EmitPimJsonPass() {}
EmitPimJsonPass(const EmitPimJsonPass& pass) {} EmitPimJsonPass(const EmitPimJsonPass& pass) {}
void runOnOperation() final { void runOnOperation() override {
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
std::filesystem::path pimDir(pimOutputDir.data()); std::string pimDir = getOutputDir() + "/pim";
createDirectory(pimDir);
std::error_code error_code; int compiler_error_code = compileModuleToPIMJSON(moduleOp, pimDir);
std::filesystem::create_directories(pimDir, error_code);
if (error_code) {
moduleOp.emitError("Failed to create PIM output directory: " + error_code.message());
signalPassFailure();
return;
}
int compiler_error_code = compileModuleToPIMJSON(moduleOp, pimOutputDir);
if (compiler_error_code != CompilerSuccess) if (compiler_error_code != CompilerSuccess)
signalPassFailure(); signalPassFailure();
} }

View File

@@ -1,4 +1,5 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "src/Compiler/CompilerUtils.hpp" #include "src/Compiler/CompilerUtils.hpp"
using namespace mlir; using namespace mlir;
@@ -7,21 +8,15 @@ namespace onnx_mlir {
namespace { namespace {
struct MessagePass : public PassWrapper<MessagePass, OperationPass<ModuleOp>> { struct MessagePass : PassWrapper<MessagePass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MessagePass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MessagePass)
StringRef getArgument() const override { return "message-pass"; } StringRef getArgument() const override { return "message-pass"; }
StringRef getDescription() const override { return "Print compilation status"; }
StringRef getDescription() const override { MessagePass(std::string message)
return "Lower ONNX ops to Spatial ops."; : message(message) {}
} MessagePass(const MessagePass& pass) {}
// Make sure that we have a valid default constructor and copy
// constructor to make sure that the options are initialized properly.
MessagePass(std::string message) : message(message) {}
MessagePass(const MessagePass &pass)
: PassWrapper<MessagePass, OperationPass<ModuleOp>>() {}
void runOnOperation() final { showCompilePhase(message); } void runOnOperation() final { showCompilePhase(message); }
private: private:
@@ -30,8 +25,6 @@ private:
} // namespace } // namespace
std::unique_ptr<Pass> createMessagePass(std::string message) { std::unique_ptr<Pass> createMessagePass(std::string message) { return std::make_unique<MessagePass>(message); }
return std::make_unique<MessagePass>(message);
}
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -7,9 +7,7 @@
#include "llvm/Support/raw_os_ostream.h" #include "llvm/Support/raw_os_ostream.h"
#include <filesystem> #include "Common/PIMCommon.hpp"
#include <fstream>
#include "Compiler/PimCodeGen.hpp" #include "Compiler/PimCodeGen.hpp"
#include "PimBufferizationPass.hpp" #include "PimBufferizationPass.hpp"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
@@ -59,14 +57,7 @@ void PimBufferizationPass::runOnOperation() {
annotateWeightsMemrefs(moduleOp, funcOp); annotateWeightsMemrefs(moduleOp, funcOp);
// Dump to file for debug // Dump to file for debug
ModuleOp module = getOperation(); dumpModule(moduleOp, "pim_buf");
std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects");
std::filesystem::create_directory(outputDir);
std::fstream file(outputDir + "/pim_buf.mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *module;
os.flush();
file.close();
} }
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const { void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {