Transpose and Refactor of Patterns
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-29 13:23:31 +02:00
parent 82b44a6387
commit 832bd7f1f7
37 changed files with 285 additions and 153 deletions
@@ -1,6 +1,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
@@ -14,10 +15,8 @@
#include "Common/Common.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -86,30 +85,6 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
returnOp.setOperand(index, computeResult);
}
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
Block& entryBlock = funcOp.getFunctionBody().front();
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
if (!transposeOp || isCompileTimeOp(transposeOp))
continue;
// Transpose stays globally legal because constant/view-only cases are
// allowed on the host. Any residual runtime transpose must be sunk into
// spat.compute before the host legality check.
auto resultType = transposeOp.getResult().getType();
rewriter.setInsertionPoint(transposeOp);
auto computeOp = createSpatCompute<1>(
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
Value transposed =
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
});
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
}
}
void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext();
@@ -117,6 +92,7 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget preTarget(*ctx);
preTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
@@ -156,11 +132,13 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget target(*ctx);
target.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
target.addIllegalOp<ONNXMatMulOp>();
target.addIllegalOp<ONNXTransposeOp>();
target.addIllegalOp<ONNXAddOp>();
target.addIllegalOp<ONNXDivOp>();
target.addIllegalOp<ONNXMulOp>();
@@ -187,9 +165,14 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
RewritePatternSet transposePatterns(ctx);
populateTransposePatterns(transposePatterns, ctx);
walkAndApplyPatterns(moduleOp, std::move(transposePatterns));
ConversionTarget earlyPostTarget(*ctx);
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
@@ -205,6 +188,7 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget postTarget(*ctx);
postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
@@ -222,8 +206,6 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
wrapTopLevelRuntimeTransposes(*entryFunc);
if (failed(verifyONNXToSpatial(*entryFunc))) {
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
signalPassFailure();