120 lines
5.4 KiB
C++
120 lines
5.4 KiB
C++
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
|
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
|
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
|
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
|
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
|
|| !outType.hasStaticShape())
|
|
return failure();
|
|
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
|
|
return failure();
|
|
|
|
const int64_t batch = rhsType.getDimSize(0);
|
|
const int64_t k = rhsType.getDimSize(1);
|
|
const int64_t n = rhsType.getDimSize(2);
|
|
const int64_t m = lhsType.getDimSize(0);
|
|
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|
|
|| outType.getDimSize(2) != n)
|
|
return failure();
|
|
|
|
Location loc = matmulOp.getLoc();
|
|
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
|
|
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
|
|
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
|
|
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
|
|
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
|
|
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
|
|
|
|
Value lhsTransposed =
|
|
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
|
|
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
|
|
|
SmallVector<Value> gemmRows;
|
|
gemmRows.reserve(batch * n);
|
|
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
|
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
|
|
SmallVector<OpFoldResult> offsets = {
|
|
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
|
|
SmallVector<OpFoldResult> sizes = {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
|
|
SmallVector<OpFoldResult> strides = {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
Value rhsSlice =
|
|
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
|
|
Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
rhsRowType,
|
|
rhsSlice,
|
|
SmallVector<ReassociationIndices> {
|
|
{0},
|
|
{1, 2}
|
|
});
|
|
|
|
auto gemmOp = ONNXGemmOp::create(rewriter,
|
|
loc,
|
|
gemmRowType,
|
|
rhsRow,
|
|
lhsTransposed,
|
|
none,
|
|
rewriter.getF32FloatAttr(1.0f),
|
|
rewriter.getF32FloatAttr(1.0f),
|
|
rewriter.getBoolAttr(false),
|
|
rewriter.getBoolAttr(false));
|
|
gemmRows.push_back(gemmOp.getY());
|
|
}
|
|
}
|
|
|
|
auto concatComputeOp =
|
|
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOutType, SmallVector<Value>(), gemmRows);
|
|
|
|
auto* concatBlock = new Block();
|
|
for (Value gemmRow : gemmRows)
|
|
concatBlock->addArgument(gemmRow.getType(), loc);
|
|
concatComputeOp.getBody().push_back(concatBlock);
|
|
rewriter.setInsertionPointToStart(concatBlock);
|
|
|
|
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, concatBlock->getArguments());
|
|
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
|
|
|
rewriter.setInsertionPointAfter(concatComputeOp);
|
|
Value gemmOut = concatComputeOp.getResult(0);
|
|
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
gemmExpandedType,
|
|
gemmOut,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2}
|
|
});
|
|
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
|
|
|
|
rewriter.replaceOp(matmulOp, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
patterns.insert<MatMulRank3ToGemm>(ctx);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|