From c5e608fa5b13e10ddfc5ff0f9eca68d7e706f9d6 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Thu, 14 May 2026 11:48:16 +0200 Subject: [PATCH] replace greedy pattern rewrites with partial conversions better failure messages --- .../ONNXToSpatial/ONNXToSpatialPass.cpp | 47 ++++++- .../Conversion/ONNXToSpatial/PostPatterns.cpp | 21 +++ .../Conversion/ONNXToSpatial/PostPatterns.hpp | 8 ++ .../SpatialToGraphviz/SpatialToGraphviz.cpp | 1 + .../SpatialToPim/SpatialToPimPass.cpp | 130 +++++++++--------- .../Bufferization/PimBufferizationPass.cpp | 1 + src/PIM/Pass/PimCodegen/EmitPimCodePass.cpp | 5 +- .../HostConstantFoldingPass.cpp | 6 +- .../MaterializeHostConstantsPass.cpp | 1 + src/PIM/Pass/PimCodegen/VerificationPass.cpp | 4 +- 10 files changed, 146 insertions(+), 78 deletions(-) diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index efa8ec6..651a4a7 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -5,15 +5,12 @@ #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" #include "Common/Common.hpp" #include "Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp" @@ -91,13 +88,25 @@ void ONNXToSpatialPass::runOnOperation() { ModuleOp moduleOp = getOperation(); MLIRContext* ctx = &getContext(); + ConversionTarget preTarget(*ctx); + preTarget.addLegalDialect(); + preTarget.addIllegalOp(); + RewritePatternSet prePatterns(ctx); populatePrePatterns(prePatterns, ctx); - if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns)))) - moduleOp.emitWarning("failed to apply ONNX-to-Spatial pre-patterns; continuing"); + if (failed(applyPartialConversion(moduleOp, preTarget, std::move(prePatterns)))) { + moduleOp.emitError("failed to apply ONNX-to-Spatial pre-rewrites"); + signalPassFailure(); + return; + } auto entryFunc = getPimEntryFunc(moduleOp); if (failed(entryFunc)) { + moduleOp.emitError("failed to locate the PIM entry function during ONNX-to-Spatial lowering"); signalPassFailure(); return; } @@ -130,13 +139,24 @@ void ONNXToSpatialPass::runOnOperation() { RewritePatternSet conversionPatterns(ctx); populateConversionPatterns(conversionPatterns, ctx); if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) { + moduleOp.emitError("failed to convert required ONNX ops to Spatial ops"); signalPassFailure(); return; } + ConversionTarget earlyPostTarget(*ctx); + earlyPostTarget.addLegalDialect(); + earlyPostTarget.addDynamicallyLegalOp( + [](spatial::SpatComputeBatch batchOp) { return !requiresEarlyPostRewrite(batchOp); }); + RewritePatternSet earlyPostPatterns(ctx); populateEarlyPostPatterns(earlyPostPatterns, ctx); - if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) { + if (failed(applyPartialConversion(*entryFunc, earlyPostTarget, std::move(earlyPostPatterns)))) { + moduleOp.emitError("failed to normalize single-lane spat.compute_batch ops before core assignment checks"); signalPassFailure(); return; } @@ -162,14 +182,27 @@ void ONNXToSpatialPass::runOnOperation() { annotateWeightsConstants(*entryFunc); + ConversionTarget postTarget(*ctx); + postTarget.addLegalDialect(); + postTarget.addDynamicallyLegalOp( + [](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); }); + postTarget.addDynamicallyLegalOp( + [](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); }); + RewritePatternSet postPatterns(ctx); populatePostPatterns(postPatterns, ctx); - if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) { + if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) { + moduleOp.emitError("failed to normalize weight-like Spatial compute operands before Spatial-to-PIM lowering"); signalPassFailure(); return; } if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) { + moduleOp.emitError("ONNX-to-Spatial host legality verification failed"); signalPassFailure(); return; } diff --git a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp index 4a3861f..09b6459 100644 --- a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp @@ -31,6 +31,21 @@ static bool isDirectConstantValue(Value value) { return isa_and_nonnull(value.getDefiningOp()); } +template +static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) { + Block& block = compute.getBody().front(); + for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { + if (inputIdx >= block.getNumArguments()) + continue; + if (!isWeightLikeComputeOperand(input)) + continue; + if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(block.getArgument(inputIdx))) + continue; + return true; + } + return false; +} + // Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily. struct FoldSingleLaneComputeBatchPattern : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -262,4 +277,10 @@ void annotateWeightsConstants(func::FuncOp funcOp) { }); } +bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp) { return batchOp.getLaneCount() == 1; } + +bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); } + +bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); } + } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp index b094373..6a1b4bd 100644 --- a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp @@ -3,8 +3,16 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/MLIRContext.h" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + namespace onnx_mlir { +bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp); + +bool requiresPostRewrite(spatial::SpatCompute computeOp); + +bool requiresPostRewrite(spatial::SpatComputeBatch computeOp); + void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); diff --git a/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp b/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp index 822fe63..61f90c1 100644 --- a/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp +++ b/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp @@ -202,6 +202,7 @@ void SpatialToGraphvizPass::runOnOperation() { auto entryFunc = getPimEntryFunc(module); if (failed(entryFunc)) { + module.emitError("failed to locate the PIM entry function for Spatial graph visualization"); signalPassFailure(); return; } diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 9885733..5d79acb 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -144,6 +144,7 @@ void SpatialToPimPass::runOnOperation() { auto entryFunc = getPimEntryFunc(moduleOp); if (failed(entryFunc)) { + moduleOp.emitError("failed to locate the PIM entry function during Spatial-to-PIM lowering"); signalPassFailure(); return; } @@ -169,26 +170,22 @@ void SpatialToPimPass::runOnOperation() { spatial::SpatChannelSendTensorBatchOp, spatial::SpatExtractRowsOp>(); - { - RewritePatternSet patterns(ctx); - populateWithGenerated(patterns); - - if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { - signalPassFailure(); - return; - } + RewritePatternSet initialPatterns(ctx); + populateWithGenerated(initialPatterns); + if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) { + moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form"); + signalPassFailure(); + return; } - { - RewritePatternSet patterns(ctx); - populateGlobalTensorMaterializationPatterns(patterns); - - walkAndApplyPatterns(moduleOp, std::move(patterns)); - } + RewritePatternSet globalTensorPatterns(ctx); + populateGlobalTensorMaterializationPatterns(globalTensorPatterns); + walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns)); auto returnOp = cast(funcOp.front().getTerminator()); addReturnOutputBuffers(returnOp, rewriter, outputTensors); if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) { + funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering"); signalPassFailure(); return; } @@ -197,6 +194,7 @@ void SpatialToPimPass::runOnOperation() { for (auto computeOp : funcOp.getOps()) { markOpToRemove(computeOp); if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) { + computeOp.emitOpError("failed to lower spat.compute to pim.core"); signalPassFailure(); return; } @@ -205,17 +203,16 @@ void SpatialToPimPass::runOnOperation() { for (auto computeBatchOp : funcOp.getOps()) { markOpToRemove(computeBatchOp); if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) { + computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch"); signalPassFailure(); return; } } - { - RewritePatternSet patterns(ctx); - populateTensorPackingPatterns(patterns); - walkAndApplyPatterns(funcOp, std::move(patterns)); - eraseUnusedTensorPackingOps(funcOp, rewriter); - } + RewritePatternSet initialTensorPackingPatterns(ctx); + populateTensorPackingPatterns(initialTensorPackingPatterns); + walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns)); + eraseUnusedTensorPackingOps(funcOp, rewriter); SmallVector receiveOps; for (auto op : funcOp.getOps()) @@ -229,27 +226,27 @@ void SpatialToPimPass::runOnOperation() { } } - { - RewritePatternSet coreBodyPatterns(ctx); - populateWithGenerated(coreBodyPatterns); - FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns)); + RewritePatternSet coreBodyPatterns(ctx); + populateWithGenerated(coreBodyPatterns); + FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns)); - SmallVector coreOps; - funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); }); - for (auto coreOp : coreOps) { - if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) { - signalPassFailure(); - return; - } + SmallVector coreOps; + funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); }); + for (auto coreOp : coreOps) { + if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) { + coreOp.emitOpError("failed to convert nested Spatial ops inside pim.core"); + signalPassFailure(); + return; } + } - SmallVector coreBatchOps; - funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); }); - for (auto coreBatchOp : coreBatchOps) { - if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) { - signalPassFailure(); - return; - } + SmallVector coreBatchOps; + funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); }); + for (auto coreBatchOp : coreBatchOps) { + if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) { + coreBatchOp.emitOpError("failed to convert nested Spatial ops inside pim.core_batch"); + signalPassFailure(); + return; } } @@ -259,44 +256,43 @@ void SpatialToPimPass::runOnOperation() { SmallVector pendingRemovals(operationsToRemove.begin(), operationsToRemove.end()); if (failed(erasePendingOps(pendingRemovals, rewriter))) { + funcOp.emitOpError("failed to erase obsolete Spatial ops after lowering to PIM"); signalPassFailure(); return; } - { - RewritePatternSet patterns(ctx); - populateTensorPackingPatterns(patterns); - walkAndApplyPatterns(funcOp, std::move(patterns)); - eraseUnusedTensorPackingOps(funcOp, rewriter); - } + RewritePatternSet finalTensorPackingPatterns(ctx); + populateTensorPackingPatterns(finalTensorPackingPatterns); + walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns)); + eraseUnusedTensorPackingOps(funcOp, rewriter); - { - ConversionTarget communicationTarget(*ctx); - communicationTarget.addLegalDialect(); - communicationTarget.addLegalOp(); - communicationTarget.addIllegalOp(); + ConversionTarget communicationTarget(*ctx); + communicationTarget.addLegalDialect(); + communicationTarget.addLegalOp(); + communicationTarget.addIllegalOp(); - RewritePatternSet communicationPatterns(ctx); - populateChannelLoweringPatterns(communicationPatterns); - if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) { - signalPassFailure(); - return; - } + RewritePatternSet communicationPatterns(ctx); + populateChannelLoweringPatterns(communicationPatterns); + if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) { + funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops"); + signalPassFailure(); + return; } if (failed(verifySpatialToPimBoundary(moduleOp))) { + moduleOp.emitError("Spatial-to-PIM boundary verification failed"); signalPassFailure(); return; } diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index 50d3b6e..409b92c 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -79,6 +79,7 @@ void PimBufferizationPass::runOnOperation() { return WalkResult::skip(); }); if (hasFailed) { + moduleOp.emitError("failed to lower memref.copy-like ops inside PIM core bodies during bufferization"); signalPassFailure(); return; } diff --git a/src/PIM/Pass/PimCodegen/EmitPimCodePass.cpp b/src/PIM/Pass/PimCodegen/EmitPimCodePass.cpp index ea7cdf8..a150f3c 100644 --- a/src/PIM/Pass/PimCodegen/EmitPimCodePass.cpp +++ b/src/PIM/Pass/PimCodegen/EmitPimCodePass.cpp @@ -24,8 +24,11 @@ struct EmitPimCodePass : PassWrapper> { createDirectory(pimDir); int compiler_error_code = compileToPimCode(moduleOp, pimDir); - if (compiler_error_code != CompilerSuccess) + if (compiler_error_code != CompilerSuccess) { + moduleOp.emitError() << "failed to emit PIM simulator code artifacts; compiler error code " + << compiler_error_code; signalPassFailure(); + } } }; diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp index 04038b0..7fc2e0e 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/HostConstantFoldingPass.cpp @@ -32,14 +32,16 @@ struct HostConstantFoldingPass : PassWrapper patterns; diff --git a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp index 3c43ff1..8c705ef 100644 --- a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp +++ b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp @@ -160,6 +160,7 @@ struct MaterializeHostConstantsPass : PassWrapper> } } - if (hasFailure) + if (hasFailure) { + moduleOp.emitError("PIM codegen verification failed; see diagnostics above"); signalPassFailure(); + } } private: