implement mem copy codgen (lmv)
add more gemv/gemm tests refactor
This commit is contained in:
@@ -25,9 +25,96 @@ namespace onnx_mlir {
|
||||
|
||||
const StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
|
||||
|
||||
struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
|
||||
ONNXGemmOpTile(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||
GemmToManyGemv(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx, 2) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
Location loc = gemmOp.getLoc();
|
||||
Value a = adaptor.getA();
|
||||
Value b = adaptor.getB();
|
||||
Value c = adaptor.getC();
|
||||
|
||||
assert("A should have been transposed already" && !adaptor.getTransA());
|
||||
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
|
||||
auto aType = cast<RankedTensorType>(a.getType());
|
||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
|
||||
|
||||
const int64_t numOutRows = aType.getDimSize(0);
|
||||
|
||||
// Only decompose when there are multiple rows to split
|
||||
if (numOutRows <= 1)
|
||||
return failure();
|
||||
|
||||
RankedTensorType cType = nullptr;
|
||||
bool cHasNumOutRows = false;
|
||||
if (hasC) {
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||
}
|
||||
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
|
||||
SmallVector<Value> gemvOps;
|
||||
gemvOps.reserve(numOutRows);
|
||||
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||
auto aSlice = rewriter.create<tensor::ExtractSliceOp>(loc, aSliceType, a, offsets, sizes, strides).getResult();
|
||||
|
||||
Value cSlice = c;
|
||||
if (hasC) {
|
||||
if (cHasNumOutRows) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
||||
cSlice = rewriter.create<tensor::ExtractSliceOp>(loc, cSliceType, c, offsets, sizes, strides).getResult();
|
||||
}
|
||||
else
|
||||
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
|
||||
}
|
||||
|
||||
auto gemvOp = rewriter.create<ONNXGemmOp>(loc,
|
||||
outRowType,
|
||||
aSlice,
|
||||
b,
|
||||
cSlice,
|
||||
gemmOp.getAlphaAttr(),
|
||||
gemmOp.getBetaAttr(),
|
||||
gemmOp.getTransAAttr(),
|
||||
gemmOp.getTransBAttr());
|
||||
gemvOps.push_back(gemvOp.getY());
|
||||
}
|
||||
|
||||
auto concatComputeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
|
||||
|
||||
auto* concatBlock = new Block();
|
||||
for (auto gemvOp : gemvOps)
|
||||
concatBlock->addArgument(gemvOp.getType(), loc);
|
||||
concatComputeOp.getBody().push_back(concatBlock);
|
||||
rewriter.setInsertionPointToStart(concatBlock);
|
||||
|
||||
auto blockArgs = concatBlock->getArguments();
|
||||
auto concatOp = rewriter.create<tensor::ConcatOp>(loc, /*axis=*/0, blockArgs);
|
||||
rewriter.create<spatial::SpatYieldOp>(loc, concatOp.getResult());
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
||||
GemvToSpatialCompute(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx, 1) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
@@ -50,12 +137,16 @@ struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
if (hasC) {
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
assert("Only support 2 tensor for C" && cType.getRank() == 2);
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
}
|
||||
|
||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
||||
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
||||
|
||||
if (!isVectorShape(aType.getShape()) || !isVectorShape(aType.getShape()))
|
||||
// Not a gemv
|
||||
return failure();
|
||||
|
||||
if (transA) {
|
||||
auto aShape = aType.getShape();
|
||||
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
||||
@@ -169,9 +260,20 @@ struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
|
||||
outHSlices.push_back(reduceComputeOp.getResult(0));
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(gemmOp);
|
||||
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, outHSlices);
|
||||
rewriter.replaceOp(gemmOp, concatOp);
|
||||
auto concatComputeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
|
||||
|
||||
auto* concatBlock = new Block();
|
||||
for (auto outHSlice : outHSlices)
|
||||
concatBlock->addArgument(outHSlice.getType(), gemmLoc);
|
||||
concatComputeOp.getBody().push_back(concatBlock);
|
||||
rewriter.setInsertionPointToStart(concatBlock);
|
||||
|
||||
auto blockArgs = concatBlock->getArguments();
|
||||
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, blockArgs);
|
||||
rewriter.create<spatial::SpatYieldOp>(gemmLoc, concatOp.getResult());
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -310,8 +412,9 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
void populateTilingGemmOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ONNXGemmOpTile>(ctx);
|
||||
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<GemmToManyGemv>(ctx);
|
||||
patterns.insert<GemvToSpatialCompute>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -71,7 +71,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
else {
|
||||
populateTilingConvOpPattern(patterns, ctx);
|
||||
populatePoolingTilingPattern(patterns, ctx);
|
||||
populateTilingGemmOpPattern(patterns, ctx);
|
||||
populateOnnxGemmOpPatterns(patterns, ctx);
|
||||
}
|
||||
|
||||
populateONNXConcatToTensorConcatPattern(patterns, ctx);
|
||||
|
||||
@@ -5,7 +5,7 @@ namespace onnx_mlir {
|
||||
|
||||
void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateTilingGemmOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
Reference in New Issue
Block a user