|
|
|
|
@@ -25,9 +25,96 @@ namespace onnx_mlir {
|
|
|
|
|
|
|
|
|
|
const StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
|
|
|
|
|
|
|
|
|
|
struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
|
|
|
|
|
ONNXGemmOpTile(MLIRContext* ctx)
|
|
|
|
|
: OpConversionPattern(ctx) {}
|
|
|
|
|
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
|
|
|
|
GemmToManyGemv(MLIRContext* ctx)
|
|
|
|
|
: OpConversionPattern(ctx, 2) {}
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
|
|
|
|
Location loc = gemmOp.getLoc();
|
|
|
|
|
Value a = adaptor.getA();
|
|
|
|
|
Value b = adaptor.getB();
|
|
|
|
|
Value c = adaptor.getC();
|
|
|
|
|
|
|
|
|
|
assert("A should have been transposed already" && !adaptor.getTransA());
|
|
|
|
|
|
|
|
|
|
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
|
|
|
|
|
|
|
|
|
auto aType = cast<RankedTensorType>(a.getType());
|
|
|
|
|
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
|
|
|
|
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
|
|
|
|
|
|
|
|
|
|
const int64_t numOutRows = aType.getDimSize(0);
|
|
|
|
|
|
|
|
|
|
// Only decompose when there are multiple rows to split
|
|
|
|
|
if (numOutRows <= 1)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
RankedTensorType cType = nullptr;
|
|
|
|
|
bool cHasNumOutRows = false;
|
|
|
|
|
if (hasC) {
|
|
|
|
|
cType = cast<RankedTensorType>(c.getType());
|
|
|
|
|
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
|
|
|
|
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
|
|
|
|
|
|
|
|
|
SmallVector<Value> gemvOps;
|
|
|
|
|
gemvOps.reserve(numOutRows);
|
|
|
|
|
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
|
|
|
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
|
|
|
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
|
|
|
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
|
|
|
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
|
|
|
|
auto aSlice = rewriter.create<tensor::ExtractSliceOp>(loc, aSliceType, a, offsets, sizes, strides).getResult();
|
|
|
|
|
|
|
|
|
|
Value cSlice = c;
|
|
|
|
|
if (hasC) {
|
|
|
|
|
if (cHasNumOutRows) {
|
|
|
|
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
|
|
|
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
|
|
|
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
|
|
|
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
|
|
|
|
cSlice = rewriter.create<tensor::ExtractSliceOp>(loc, cSliceType, c, offsets, sizes, strides).getResult();
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto gemvOp = rewriter.create<ONNXGemmOp>(loc,
|
|
|
|
|
outRowType,
|
|
|
|
|
aSlice,
|
|
|
|
|
b,
|
|
|
|
|
cSlice,
|
|
|
|
|
gemmOp.getAlphaAttr(),
|
|
|
|
|
gemmOp.getBetaAttr(),
|
|
|
|
|
gemmOp.getTransAAttr(),
|
|
|
|
|
gemmOp.getTransBAttr());
|
|
|
|
|
gemvOps.push_back(gemvOp.getY());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto concatComputeOp =
|
|
|
|
|
rewriter.create<spatial::SpatWeightedCompute>(loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
|
|
|
|
|
|
|
|
|
|
auto* concatBlock = new Block();
|
|
|
|
|
for (auto gemvOp : gemvOps)
|
|
|
|
|
concatBlock->addArgument(gemvOp.getType(), loc);
|
|
|
|
|
concatComputeOp.getBody().push_back(concatBlock);
|
|
|
|
|
rewriter.setInsertionPointToStart(concatBlock);
|
|
|
|
|
|
|
|
|
|
auto blockArgs = concatBlock->getArguments();
|
|
|
|
|
auto concatOp = rewriter.create<tensor::ConcatOp>(loc, /*axis=*/0, blockArgs);
|
|
|
|
|
rewriter.create<spatial::SpatYieldOp>(loc, concatOp.getResult());
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
|
|
|
|
GemvToSpatialCompute(MLIRContext* ctx)
|
|
|
|
|
: OpConversionPattern(ctx, 1) {}
|
|
|
|
|
|
|
|
|
|
LogicalResult
|
|
|
|
|
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
|
|
|
|
@@ -50,12 +137,16 @@ struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
|
|
|
|
|
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
|
|
|
|
if (hasC) {
|
|
|
|
|
cType = cast<RankedTensorType>(c.getType());
|
|
|
|
|
assert("Only support 2 tensor for C" && cType.getRank() == 2);
|
|
|
|
|
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
|
|
|
|
|
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
|
|
|
|
|
|
|
|
|
|
if (!isVectorShape(aType.getShape()) || !isVectorShape(aType.getShape()))
|
|
|
|
|
// Not a gemv
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
if (transA) {
|
|
|
|
|
auto aShape = aType.getShape();
|
|
|
|
|
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
|
|
|
|
@@ -169,9 +260,20 @@ struct ONNXGemmOpTile : public OpConversionPattern<ONNXGemmOp> {
|
|
|
|
|
outHSlices.push_back(reduceComputeOp.getResult(0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rewriter.setInsertionPoint(gemmOp);
|
|
|
|
|
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, outHSlices);
|
|
|
|
|
rewriter.replaceOp(gemmOp, concatOp);
|
|
|
|
|
auto concatComputeOp =
|
|
|
|
|
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
|
|
|
|
|
|
|
|
|
|
auto* concatBlock = new Block();
|
|
|
|
|
for (auto outHSlice : outHSlices)
|
|
|
|
|
concatBlock->addArgument(outHSlice.getType(), gemmLoc);
|
|
|
|
|
concatComputeOp.getBody().push_back(concatBlock);
|
|
|
|
|
rewriter.setInsertionPointToStart(concatBlock);
|
|
|
|
|
|
|
|
|
|
auto blockArgs = concatBlock->getArguments();
|
|
|
|
|
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, blockArgs);
|
|
|
|
|
rewriter.create<spatial::SpatYieldOp>(gemmLoc, concatOp.getResult());
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -310,8 +412,9 @@ private:
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void populateTilingGemmOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
|
|
|
patterns.insert<ONNXGemmOpTile>(ctx);
|
|
|
|
|
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
|
|
|
patterns.insert<GemmToManyGemv>(ctx);
|
|
|
|
|
patterns.insert<GemvToSpatialCompute>(ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace onnx_mlir
|
|
|
|
|
|