implement mem copy codgen (lmv)

add more gemv/gemm tests
refactor
This commit is contained in:
NiccoloN
2026-03-04 18:04:48 +01:00
parent 143c8f960a
commit 8ee1e5ece8
27 changed files with 265 additions and 70 deletions

View File

@@ -25,9 +25,96 @@ namespace onnx_mlir {
const StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
ONNXGemmOpTile(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
GemmToManyGemv(MLIRContext* ctx)
: OpConversionPattern(ctx, 2) {}
LogicalResult
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Location loc = gemmOp.getLoc();
Value a = adaptor.getA();
Value b = adaptor.getB();
Value c = adaptor.getC();
assert("A should have been transposed already" && !adaptor.getTransA());
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
auto aType = cast<RankedTensorType>(a.getType());
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
const int64_t numOutRows = aType.getDimSize(0);
// Only decompose when there are multiple rows to split
if (numOutRows <= 1)
return failure();
RankedTensorType cType = nullptr;
bool cHasNumOutRows = false;
if (hasC) {
cType = cast<RankedTensorType>(c.getType());
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
}
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
SmallVector<Value> gemvOps;
gemvOps.reserve(numOutRows);
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
auto aSlice = rewriter.create<tensor::ExtractSliceOp>(loc, aSliceType, a, offsets, sizes, strides).getResult();
Value cSlice = c;
if (hasC) {
if (cHasNumOutRows) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
cSlice = rewriter.create<tensor::ExtractSliceOp>(loc, cSliceType, c, offsets, sizes, strides).getResult();
}
else
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
}
auto gemvOp = rewriter.create<ONNXGemmOp>(loc,
outRowType,
aSlice,
b,
cSlice,
gemmOp.getAlphaAttr(),
gemmOp.getBetaAttr(),
gemmOp.getTransAAttr(),
gemmOp.getTransBAttr());
gemvOps.push_back(gemvOp.getY());
}
auto concatComputeOp =
rewriter.create<spatial::SpatWeightedCompute>(loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
auto* concatBlock = new Block();
for (auto gemvOp : gemvOps)
concatBlock->addArgument(gemvOp.getType(), loc);
concatComputeOp.getBody().push_back(concatBlock);
rewriter.setInsertionPointToStart(concatBlock);
auto blockArgs = concatBlock->getArguments();
auto concatOp = rewriter.create<tensor::ConcatOp>(loc, /*axis=*/0, blockArgs);
rewriter.create<spatial::SpatYieldOp>(loc, concatOp.getResult());
rewriter.replaceOp(gemmOp, concatComputeOp);
return success();
}
};
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
GemvToSpatialCompute(MLIRContext* ctx)
: OpConversionPattern(ctx, 1) {}
LogicalResult
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
@@ -50,12 +137,16 @@ struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
if (hasC) {
cType = cast<RankedTensorType>(c.getType());
assert("Only support 2 tensor for C" && cType.getRank() == 2);
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
}
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
if (!isVectorShape(aType.getShape()) || !isVectorShape(aType.getShape()))
// Not a gemv
return failure();
if (transA) {
auto aShape = aType.getShape();
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
@@ -169,9 +260,20 @@ struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
outHSlices.push_back(reduceComputeOp.getResult(0));
}
rewriter.setInsertionPoint(gemmOp);
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, outHSlices);
rewriter.replaceOp(gemmOp, concatOp);
auto concatComputeOp =
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
auto* concatBlock = new Block();
for (auto outHSlice : outHSlices)
concatBlock->addArgument(outHSlice.getType(), gemmLoc);
concatComputeOp.getBody().push_back(concatBlock);
rewriter.setInsertionPointToStart(concatBlock);
auto blockArgs = concatBlock->getArguments();
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, blockArgs);
rewriter.create<spatial::SpatYieldOp>(gemmLoc, concatOp.getResult());
rewriter.replaceOp(gemmOp, concatComputeOp);
return success();
}
@@ -310,8 +412,9 @@ private:
}
};
void populateTilingGemmOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ONNXGemmOpTile>(ctx);
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<GemmToManyGemv>(ctx);
patterns.insert<GemvToSpatialCompute>(ctx);
}
} // namespace onnx_mlir

View File

@@ -71,7 +71,7 @@ void ONNXToSpatialPass::runOnOperation() {
else {
populateTilingConvOpPattern(patterns, ctx);
populatePoolingTilingPattern(patterns, ctx);
populateTilingGemmOpPattern(patterns, ctx);
populateOnnxGemmOpPatterns(patterns, ctx);
}
populateONNXConcatToTensorConcatPattern(patterns, ctx);

View File

@@ -5,7 +5,7 @@ namespace onnx_mlir {
void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTilingGemmOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);