fix output paths

add gemm test
This commit is contained in:
NiccoloN
2026-02-25 17:24:31 +01:00
parent d036c02160
commit ae6e815c7b
14 changed files with 86 additions and 106 deletions

View File

@@ -9,6 +9,7 @@
#include <filesystem>
#include <fstream>
#include "Common/PIMCommon.hpp"
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
#include "ONNXToSpatialPass.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
@@ -27,7 +28,7 @@ namespace spatial {
void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Running ONNXToSpatialLoweringPass\n";
ModuleOp module = getOperation();
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext();
RewritePatternSet mergeActivationPatterns(ctx);
@@ -38,11 +39,11 @@ void ONNXToSpatialPass::runOnOperation() {
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
if (failed(applyPatternsAndFoldGreedily(module, std::move(mergeActivationPatterns))))
if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(mergeActivationPatterns))))
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
IRRewriter rewriter(module);
func::FuncOp funcOp = *module.getOps<func::FuncOp>().begin();
IRRewriter rewriter(moduleOp);
func::FuncOp funcOp = *moduleOp.getOps<func::FuncOp>().begin();
if (annotateReplication(funcOp, rewriter).failed()) {
llvm::dbgs() << "Failed during annotation for replication analysis\n";
signalPassFailure();
@@ -78,7 +79,7 @@ void ONNXToSpatialPass::runOnOperation() {
populateONNXConcatToTensorConcatPattern(patterns, ctx);
populateReduceMeanConversionPattern(patterns, ctx);
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
return;
}
@@ -101,19 +102,13 @@ void ONNXToSpatialPass::runOnOperation() {
RewritePatternSet removeUnusedHelperOpsPatterns(ctx);
populateRemoveUnusedHelperOpsPatterns(removeUnusedHelperOpsPatterns, ctx);
if (failed(applyPatternsAndFoldGreedily(module, std::move(removeUnusedHelperOpsPatterns))))
if (failed(applyPatternsGreedily(moduleOp, std::move(removeUnusedHelperOpsPatterns))))
llvm::dbgs() << "Failed to remove unused helper ops, continuing...\n";
annotateWeightsConstants(funcOp);
// Dump to file for debug
std::string outputDir = outputBaseName.substr(0, outputBaseName.find_last_of('/')).append("/dialects");
std::filesystem::create_directory(outputDir);
std::fstream file(outputDir + "/spatial.mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *module;
os.flush();
file.close();
dumpModule(moduleOp, "spatial");
}
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {