better transpose pattern and cleanup
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-06-03 12:26:31 +02:00
parent 636310d0cb
commit 0a5e73c3ea
8 changed files with 75 additions and 165 deletions
@@ -10,8 +10,6 @@
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
@@ -68,14 +66,6 @@ FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<Arr
return permutation;
}
Value transposeMaybeInCompute(
Value value, RankedTensorType resultType, ArrayRef<int64_t> permutation, PatternRewriter& rewriter, Location loc) {
auto buildTranspose = [&](Value input) -> Value {
return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult();
};
return materializeOrComputeUnary(value, resultType, rewriter, loc, buildTranspose);
}
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
}
@@ -78,12 +78,6 @@ llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
int64_t rank);
mlir::Value transposeMaybeInCompute(mlir::Value value,
mlir::RankedTensorType resultType,
mlir::ArrayRef<int64_t> permutation,
mlir::PatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
@@ -8,7 +8,6 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
@@ -114,21 +113,6 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
RewritePatternSet matmulPatterns(ctx);
populateMatMulRewritePatterns(matmulPatterns, ctx);
walkAndApplyPatterns(moduleOp, std::move(matmulPatterns));
bool hasUnloweredMatMul = false;
moduleOp.walk([&](ONNXMatMulOp matmulOp) {
hasUnloweredMatMul = true;
matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion");
});
if (hasUnloweredMatMul) {
moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion");
signalPassFailure();
return;
}
ConversionTarget target(*ctx);
target.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
@@ -165,10 +149,6 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
RewritePatternSet transposePatterns(ctx);
populateTransposePatterns(transposePatterns, ctx);
walkAndApplyPatterns(moduleOp, std::move(transposePatterns));
ConversionTarget earlyPostTarget(*ctx);
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
@@ -206,15 +186,14 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
if (failed(verifyONNXToSpatial(*entryFunc))) {
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
signalPassFailure();
return;
}
populateEmptyFunction(*entryFunc);
dumpModule(moduleOp, "spatial0");
if (failed(verifyONNXToSpatial(*entryFunc))) {
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
signalPassFailure();
}
}
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
@@ -10,6 +10,7 @@ void populatePrePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { popula
void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateGeneratedConversionPatterns(patterns, ctx);
populateElementwisePatterns(patterns, ctx);
populateMatMulRewritePatterns(patterns, ctx);
populateGemmPatterns(patterns, ctx);
populateConvPatterns(patterns, ctx);
populatePoolPatterns(patterns, ctx);
@@ -331,14 +331,6 @@ static Value extractDynamicGemmBColumn(
return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult();
}
static Value extractTransposedBRow(
Value transposedB, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
return tensor::ExtractSliceOp::create(rewriter, loc, vectorType, transposedB, offsets, sizes, strides).getResult();
}
static Value extractDynamicGemmRowVector(
Value matrix, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
@@ -424,7 +416,6 @@ static FailureOr<spatial::SpatComputeBatch> createVvdmulBatch(Value a,
RankedTensorType bType,
RankedTensorType scalarPiecesType,
RankedTensorType outType,
bool bAlreadyTransposed,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t numOutRows = outType.getDimSize(0);
@@ -446,8 +437,7 @@ static FailureOr<spatial::SpatComputeBatch> createVvdmulBatch(Value a,
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
Value bVector = bAlreadyTransposed ? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
Value bVector = extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
@@ -739,6 +729,13 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
const int64_t numOutCols = outType.getDimSize(1);
const int64_t reductionSize = aType.getDimSize(1);
if (gemmOpAdaptor.getTransB()) {
auto bShape = bType.getShape();
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType(), bType.getEncoding());
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})).getResult();
bType = transposedType;
}
if (!isCompileTimeComputable(b)) {
bool hasC = hasGemmBias(c);
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
@@ -758,10 +755,8 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
biasType = *verifiedBiasType;
}
const int64_t expectedBRows = gemmOpAdaptor.getTransB() ? numOutCols : reductionSize;
const int64_t expectedBCols = gemmOpAdaptor.getTransB() ? reductionSize : numOutCols;
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != expectedBRows
|| bType.getDimSize(1) != expectedBCols) {
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize
|| bType.getDimSize(1) != numOutCols) {
gemmOp.emitOpError("has inconsistent A, B, and output shapes");
return failure();
}
@@ -773,8 +768,7 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
}
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
auto batchOp =
createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc);
auto batchOp = createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, rewriter, loc);
if (failed(batchOp))
return failure();
auto outputCompute = createDynamicGemmOutputCompute(
@@ -793,13 +787,6 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
b = *scaledB;
bType = cast<RankedTensorType>(b.getType());
if (gemmOpAdaptor.getTransB()) {
auto bShape = bType.getShape();
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
b = transposeMaybeInCompute(b, transposedType, {1, 0}, rewriter, loc);
bType = cast<RankedTensorType>(b.getType());
}
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize || bType.getDimSize(1) != numOutCols) {
gemmOp.emitOpError("has inconsistent A, B, and output shapes after transpose handling");
return failure();
@@ -123,39 +123,16 @@ static Value extractBatchMatrix(Value value,
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
RankedTensorType transposedType;
SmallVector<int64_t> perm;
auto createONNXTranspose = [&](RankedTensorType resultType, ArrayRef<int64_t> permutation) {
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation)).getResult();
};
if (type.getRank() == 2) {
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
perm = {1, 0};
}
else {
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
perm = {0, 2, 1};
auto resultType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType(), type.getEncoding());
return createONNXTranspose(resultType, {1, 0});
}
return transposeMaybeInCompute(value, transposedType, perm, rewriter, loc);
}
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
RankedTensorType transposedType;
SmallVector<int64_t> perm;
if (type.getRank() == 2) {
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
perm = {1, 0};
}
else {
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
perm = {0, 2, 1};
}
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return transposeCompute.getResult(0);
auto resultType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType(), type.getEncoding());
return createONNXTranspose(resultType, {0, 2, 1});
}
static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
@@ -372,32 +349,6 @@ static Value extractDynamicBatchedBColumn(Value matrix,
.getResult();
}
static Value extractDynamicBatchedBRow(Value matrix,
int64_t sourceBatchCount,
Value batch,
Value row,
RankedTensorType vectorType,
PatternRewriter& rewriter,
Location loc) {
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
: OpFoldResult(batch),
row,
rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
auto rowSlice =
tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3));
return tensor::CollapseShapeOp::create(rewriter,
loc,
vectorType,
rowSlice,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
}
static Value extractDynamicBatchedRowVector(Value matrix,
int64_t sourceBatchCount,
Value batch,
@@ -432,7 +383,6 @@ static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
RankedTensorType bType,
RankedTensorType scalarPiecesType,
RankedTensorType outType,
bool bAlreadyTransposed,
PatternRewriter& rewriter,
Location loc) {
const int64_t numBatches = outType.getDimSize(0);
@@ -459,9 +409,7 @@ static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
Value aVector =
extractDynamicBatchedRowVector(args.inputs[0], aBatchCount, batch, row, vectorType, rewriter, loc);
Value bVector =
bAlreadyTransposed
? extractDynamicBatchedBRow(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc)
: extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc);
extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc);
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
@@ -537,15 +485,6 @@ static FailureOr<Value> createBatchedDynamicOutputCompute(Value scalarPieces,
return computeOp->getResult(0);
}
static Value transposeBatchedOutput(Value value, RankedTensorType outputType, PatternRewriter& rewriter, Location loc) {
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, outputType, input, rewriter.getI64ArrayAttr({0, 2, 1}));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return transposeCompute.getResult(0);
}
static Value extractBatchedReductionPiece(Value partialPiecesArg,
Value batch,
Value hSlice,
@@ -764,7 +703,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
int64_t gemmK = shapeInfo->k;
int64_t gemmN = shapeInfo->n;
if (useTransposedForm) {
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc);
lhsBatchForGemm = shapeInfo->rhsBatch;
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
rhsBatchForGemm = shapeInfo->lhsBatch;
@@ -787,15 +726,10 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
if (useTransposedForm) {
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {shapeInfo->outType}, {}, gemmResult, [&](Value input) {
Value transposed =
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, input, rewriter.getI64ArrayAttr({1, 0}));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
gemmResult = transposeCompute.getResult(0);
}
if (useTransposedForm)
gemmResult =
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}))
.getResult();
rewriter.replaceOp(matmulOp, gemmResult);
return success();
}
@@ -822,7 +756,7 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
int64_t gemmK = shapeInfo->k;
int64_t gemmN = shapeInfo->n;
if (useTransposedForm) {
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc);
lhsBatchForGemm = shapeInfo->rhsBatch;
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
rhsBatchForGemm = shapeInfo->lhsBatch;
@@ -880,12 +814,14 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
if (failed(result))
return failure();
Value finalResult = *result;
if (useTransposedForm)
finalResult = transposeBatchedOutput(
finalResult,
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
rewriter,
loc);
if (useTransposedForm) {
auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n},
shapeInfo->outType.getElementType(),
shapeInfo->outType.getEncoding());
finalResult =
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
.getResult();
}
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
rewriter.replaceOp(matmulOp, finalResult);
return success();
@@ -901,7 +837,6 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
rhsBatchedType,
scalarPiecesType,
directOutType,
false,
rewriter,
loc);
if (failed(batchOp))
@@ -911,12 +846,14 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
if (failed(result))
return failure();
Value finalResult = *result;
if (useTransposedForm)
finalResult = transposeBatchedOutput(
finalResult,
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
rewriter,
loc);
if (useTransposedForm) {
auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n},
shapeInfo->outType.getElementType(),
shapeInfo->outType.getEncoding());
finalResult =
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
.getResult();
}
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
rewriter.replaceOp(matmulOp, finalResult);
return success();
@@ -137,11 +137,17 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
auto transposedType = RankedTensorType::get(
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
Value transposedInput = transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc());
Value transposedInput =
ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), transposedType, input, rewriter.getI64ArrayAttr(permutation))
.getResult();
auto transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
if (failed(transposedResult))
return failure();
result = transposeMaybeInCompute(*transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc());
result =
ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), inputType, *transposedResult, rewriter.getI64ArrayAttr(inversePermutation))
.getResult();
}
rewriter.replaceOp(softmaxOp, result);
@@ -15,6 +15,10 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isInsideSpatialComputeRegion(Operation* op) {
return op->getParentOfType<spatial::SpatCompute>() || op->getParentOfType<spatial::SpatComputeBatch>();
}
static Value createTransposeInit(Value input,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
@@ -102,10 +106,22 @@ struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
return success();
}
}
Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
Value transposed =
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation).getResult()[0];
rewriter.replaceOp(transposeOp, transposed);
auto buildTranspose = [&](Value input) -> Value {
Value init = createTransposeInit(input, resultType, *permutation, rewriter, transposeOp.getLoc());
return linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), input, init, *permutation).getResult()[0];
};
if (isInsideSpatialComputeRegion(transposeOp.getOperation())) {
rewriter.replaceOp(transposeOp, buildTranspose(adaptor.getData()));
return success();
}
auto computeOp = createSpatCompute<1>(
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {adaptor.getData()}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), buildTranspose(input));
});
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
return success();
}
};