replace greedy pattern rewrites with partial conversions
Validate Operations / validate-operations (push) Has been cancelled

better failure messages
This commit is contained in:
NiccoloN
2026-05-14 11:48:16 +02:00
parent 43f3ccdd21
commit c5e608fa5b
10 changed files with 146 additions and 78 deletions
@@ -5,15 +5,12 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "Common/Common.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
@@ -91,13 +88,25 @@ void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext();
ConversionTarget preTarget(*ctx);
preTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
preTarget.addIllegalOp<ONNXConstantOp, ONNXMatMulOp, ONNXFlattenOp>();
RewritePatternSet prePatterns(ctx);
populatePrePatterns(prePatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns))))
moduleOp.emitWarning("failed to apply ONNX-to-Spatial pre-patterns; continuing");
if (failed(applyPartialConversion(moduleOp, preTarget, std::move(prePatterns)))) {
moduleOp.emitError("failed to apply ONNX-to-Spatial pre-rewrites");
signalPassFailure();
return;
}
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
moduleOp.emitError("failed to locate the PIM entry function during ONNX-to-Spatial lowering");
signalPassFailure();
return;
}
@@ -130,13 +139,24 @@ void ONNXToSpatialPass::runOnOperation() {
RewritePatternSet conversionPatterns(ctx);
populateConversionPatterns(conversionPatterns, ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
moduleOp.emitError("failed to convert required ONNX ops to Spatial ops");
signalPassFailure();
return;
}
ConversionTarget earlyPostTarget(*ctx);
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
earlyPostTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
[](spatial::SpatComputeBatch batchOp) { return !requiresEarlyPostRewrite(batchOp); });
RewritePatternSet earlyPostPatterns(ctx);
populateEarlyPostPatterns(earlyPostPatterns, ctx);
if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) {
if (failed(applyPartialConversion(*entryFunc, earlyPostTarget, std::move(earlyPostPatterns)))) {
moduleOp.emitError("failed to normalize single-lane spat.compute_batch ops before core assignment checks");
signalPassFailure();
return;
}
@@ -162,14 +182,27 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc);
ConversionTarget postTarget(*ctx);
postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
RewritePatternSet postPatterns(ctx);
populatePostPatterns(postPatterns, ctx);
if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) {
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
moduleOp.emitError("failed to normalize weight-like Spatial compute operands before Spatial-to-PIM lowering");
signalPassFailure();
return;
}
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
signalPassFailure();
return;
}
@@ -31,6 +31,21 @@ static bool isDirectConstantValue(Value value) {
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
}
template <typename ComputeOpTy>
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
Block& block = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= block.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(block.getArgument(inputIdx)))
continue;
return true;
}
return false;
}
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
@@ -262,4 +277,10 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
});
}
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp) { return batchOp.getLaneCount() == 1; }
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
} // namespace onnx_mlir
@@ -3,8 +3,16 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp);
bool requiresPostRewrite(spatial::SpatCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
@@ -202,6 +202,7 @@ void SpatialToGraphvizPass::runOnOperation() {
auto entryFunc = getPimEntryFunc(module);
if (failed(entryFunc)) {
module.emitError("failed to locate the PIM entry function for Spatial graph visualization");
signalPassFailure();
return;
}
@@ -144,6 +144,7 @@ void SpatialToPimPass::runOnOperation() {
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
moduleOp.emitError("failed to locate the PIM entry function during Spatial-to-PIM lowering");
signalPassFailure();
return;
}
@@ -169,26 +170,22 @@ void SpatialToPimPass::runOnOperation() {
spatial::SpatChannelSendTensorBatchOp,
spatial::SpatExtractRowsOp>();
{
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
return;
}
RewritePatternSet initialPatterns(ctx);
populateWithGenerated(initialPatterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) {
moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form");
signalPassFailure();
return;
}
{
RewritePatternSet patterns(ctx);
populateGlobalTensorMaterializationPatterns(patterns);
walkAndApplyPatterns(moduleOp, std::move(patterns));
}
RewritePatternSet globalTensorPatterns(ctx);
populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
signalPassFailure();
return;
}
@@ -197,6 +194,7 @@ void SpatialToPimPass::runOnOperation() {
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
markOpToRemove(computeOp);
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
computeOp.emitOpError("failed to lower spat.compute to pim.core");
signalPassFailure();
return;
}
@@ -205,17 +203,16 @@ void SpatialToPimPass::runOnOperation() {
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
markOpToRemove(computeBatchOp);
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
signalPassFailure();
return;
}
}
{
RewritePatternSet patterns(ctx);
populateTensorPackingPatterns(patterns);
walkAndApplyPatterns(funcOp, std::move(patterns));
eraseUnusedTensorPackingOps(funcOp, rewriter);
}
RewritePatternSet initialTensorPackingPatterns(ctx);
populateTensorPackingPatterns(initialTensorPackingPatterns);
walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns));
eraseUnusedTensorPackingOps(funcOp, rewriter);
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
@@ -229,27 +226,27 @@ void SpatialToPimPass::runOnOperation() {
}
}
{
RewritePatternSet coreBodyPatterns(ctx);
populateWithGenerated(coreBodyPatterns);
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
RewritePatternSet coreBodyPatterns(ctx);
populateWithGenerated(coreBodyPatterns);
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
SmallVector<pim::PimCoreOp> coreOps;
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
for (auto coreOp : coreOps) {
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
signalPassFailure();
return;
}
SmallVector<pim::PimCoreOp> coreOps;
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
for (auto coreOp : coreOps) {
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
coreOp.emitOpError("failed to convert nested Spatial ops inside pim.core");
signalPassFailure();
return;
}
}
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
for (auto coreBatchOp : coreBatchOps) {
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
signalPassFailure();
return;
}
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
for (auto coreBatchOp : coreBatchOps) {
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
coreBatchOp.emitOpError("failed to convert nested Spatial ops inside pim.core_batch");
signalPassFailure();
return;
}
}
@@ -259,44 +256,43 @@ void SpatialToPimPass::runOnOperation() {
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
funcOp.emitOpError("failed to erase obsolete Spatial ops after lowering to PIM");
signalPassFailure();
return;
}
{
RewritePatternSet patterns(ctx);
populateTensorPackingPatterns(patterns);
walkAndApplyPatterns(funcOp, std::move(patterns));
eraseUnusedTensorPackingOps(funcOp, rewriter);
}
RewritePatternSet finalTensorPackingPatterns(ctx);
populateTensorPackingPatterns(finalTensorPackingPatterns);
walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns));
eraseUnusedTensorPackingOps(funcOp, rewriter);
{
ConversionTarget communicationTarget(*ctx);
communicationTarget.addLegalDialect<PimDialect,
tensor::TensorDialect,
arith::ArithDialect,
bufferization::BufferizationDialect,
func::FuncDialect,
memref::MemRefDialect,
scf::SCFDialect,
BuiltinDialect>();
communicationTarget.addLegalOp<ModuleOp>();
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
spatial::SpatChannelReceiveOp,
spatial::SpatChannelReceiveTensorOp,
spatial::SpatChannelSendOp,
spatial::SpatChannelSendTensorOp,
spatial::SpatExtractRowsOp>();
ConversionTarget communicationTarget(*ctx);
communicationTarget.addLegalDialect<PimDialect,
tensor::TensorDialect,
arith::ArithDialect,
bufferization::BufferizationDialect,
func::FuncDialect,
memref::MemRefDialect,
scf::SCFDialect,
BuiltinDialect>();
communicationTarget.addLegalOp<ModuleOp>();
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
spatial::SpatChannelReceiveOp,
spatial::SpatChannelReceiveTensorOp,
spatial::SpatChannelSendOp,
spatial::SpatChannelSendTensorOp,
spatial::SpatExtractRowsOp>();
RewritePatternSet communicationPatterns(ctx);
populateChannelLoweringPatterns(communicationPatterns);
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
signalPassFailure();
return;
}
RewritePatternSet communicationPatterns(ctx);
populateChannelLoweringPatterns(communicationPatterns);
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops");
signalPassFailure();
return;
}
if (failed(verifySpatialToPimBoundary(moduleOp))) {
moduleOp.emitError("Spatial-to-PIM boundary verification failed");
signalPassFailure();
return;
}
@@ -79,6 +79,7 @@ void PimBufferizationPass::runOnOperation() {
return WalkResult::skip();
});
if (hasFailed) {
moduleOp.emitError("failed to lower memref.copy-like ops inside PIM core bodies during bufferization");
signalPassFailure();
return;
}
+4 -1
View File
@@ -24,8 +24,11 @@ struct EmitPimCodePass : PassWrapper<EmitPimCodePass, OperationPass<ModuleOp>> {
createDirectory(pimDir);
int compiler_error_code = compileToPimCode(moduleOp, pimDir);
if (compiler_error_code != CompilerSuccess)
if (compiler_error_code != CompilerSuccess) {
moduleOp.emitError() << "failed to emit PIM simulator code artifacts; compiler error code "
<< compiler_error_code;
signalPassFailure();
}
}
};
@@ -32,14 +32,16 @@ struct HostConstantFoldingPass : PassWrapper<HostConstantFoldingPass, OperationP
}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
GreedyRewriteConfig config;
config.enableFolding();
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
if (failed(applyPatternsGreedily(moduleOp, *patterns, config))) {
moduleOp.emitError("PIM host constant folding failed in the greedy rewrite driver");
signalPassFailure();
return;
}
dumpModule(getOperation(), "pim3_folded");
dumpModule(moduleOp, "pim3_folded");
}
std::shared_ptr<const FrozenRewritePatternSet> patterns;
@@ -160,6 +160,7 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
}
if (hasFailure) {
moduleOp.emitError("PIM host-constant materialization failed; see diagnostics above");
signalPassFailure();
return;
}
+3 -1
View File
@@ -197,8 +197,10 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
}
}
if (hasFailure)
if (hasFailure) {
moduleOp.emitError("PIM codegen verification failed; see diagnostics above");
signalPassFailure();
}
}
private: