better transpose pattern and cleanup
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -8,7 +8,6 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
@@ -114,21 +113,6 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
RewritePatternSet matmulPatterns(ctx);
|
||||
populateMatMulRewritePatterns(matmulPatterns, ctx);
|
||||
walkAndApplyPatterns(moduleOp, std::move(matmulPatterns));
|
||||
|
||||
bool hasUnloweredMatMul = false;
|
||||
moduleOp.walk([&](ONNXMatMulOp matmulOp) {
|
||||
hasUnloweredMatMul = true;
|
||||
matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion");
|
||||
});
|
||||
if (hasUnloweredMatMul) {
|
||||
moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
@@ -165,10 +149,6 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
RewritePatternSet transposePatterns(ctx);
|
||||
populateTransposePatterns(transposePatterns, ctx);
|
||||
walkAndApplyPatterns(moduleOp, std::move(transposePatterns));
|
||||
|
||||
ConversionTarget earlyPostTarget(*ctx);
|
||||
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
@@ -206,15 +186,14 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
if (failed(verifyONNXToSpatial(*entryFunc))) {
|
||||
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
populateEmptyFunction(*entryFunc);
|
||||
|
||||
dumpModule(moduleOp, "spatial0");
|
||||
|
||||
if (failed(verifyONNXToSpatial(*entryFunc))) {
|
||||
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
||||
|
||||
Reference in New Issue
Block a user