248 lines
10 KiB
C++
248 lines
10 KiB
C++
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
#include "llvm/ADT/SmallPtrSet.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
|
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
|
}
|
|
|
|
static Value extractBatchMatrix(Value value,
|
|
int64_t batchIndex,
|
|
int64_t batchSize,
|
|
int64_t rows,
|
|
int64_t cols,
|
|
PatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto type = cast<RankedTensorType>(value.getType());
|
|
if (type.getRank() == 2)
|
|
return value;
|
|
|
|
auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType());
|
|
SmallVector<OpFoldResult> offsets = {
|
|
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> sizes = {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
|
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, value, offsets, sizes, strides);
|
|
|
|
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
|
|
return tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
matrixType,
|
|
slice,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2}
|
|
});
|
|
}
|
|
|
|
static bool isConstantLikeOperand(Value value) {
|
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
|
|
|
while (auto* definingOp = value.getDefiningOp()) {
|
|
if (!visited.insert(definingOp).second)
|
|
return false;
|
|
if (definingOp->hasTrait<OpTrait::ConstantLike>())
|
|
return true;
|
|
|
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
|
value = extractSliceOp.getSource();
|
|
continue;
|
|
}
|
|
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
|
value = expandShapeOp.getSrc();
|
|
continue;
|
|
}
|
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
|
value = collapseShapeOp.getSrc();
|
|
continue;
|
|
}
|
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
|
value = transposeOp.getData();
|
|
continue;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
|
auto type = cast<RankedTensorType>(value.getType());
|
|
auto shape = type.getShape();
|
|
if (type.getRank() == 2) {
|
|
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
|
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
|
|
}
|
|
|
|
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
|
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
|
|
}
|
|
|
|
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
|
auto type = cast<RankedTensorType>(value.getType());
|
|
auto shape = type.getShape();
|
|
RankedTensorType transposedType;
|
|
SmallVector<int64_t> perm;
|
|
if (type.getRank() == 2) {
|
|
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
|
perm = {1, 0};
|
|
}
|
|
else {
|
|
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
|
perm = {0, 2, 1};
|
|
}
|
|
|
|
auto transposeCompute =
|
|
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
|
Value transposed =
|
|
ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
|
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
|
});
|
|
return transposeCompute.getResult(0);
|
|
}
|
|
|
|
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
|
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
|
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
|
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
|
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
|
|| !outType.hasStaticShape())
|
|
return failure();
|
|
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|
|
|| (outType.getRank() != 2 && outType.getRank() != 3))
|
|
return failure();
|
|
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
|
|| !haveStaticPositiveShape(outType.getShape()))
|
|
return failure();
|
|
|
|
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
|
|
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
|
|
const int64_t batch = std::max(lhsBatch, rhsBatch);
|
|
|
|
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
|
|
return failure();
|
|
|
|
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
|
|
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
|
|
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
|
|
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
|
|
if (k != rhsK)
|
|
return failure();
|
|
|
|
if (outType.getRank() == 2) {
|
|
if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
|
|
return failure();
|
|
}
|
|
else {
|
|
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
|
|
return failure();
|
|
}
|
|
|
|
Location loc = matmulOp.getLoc();
|
|
bool useTransposedForm = isConstantLikeOperand(matmulOp.getA()) && !isConstantLikeOperand(matmulOp.getB());
|
|
|
|
Value lhs = matmulOp.getA();
|
|
Value rhs = matmulOp.getB();
|
|
int64_t lhsBatchForGemm = lhsBatch;
|
|
int64_t rhsBatchForGemm = rhsBatch;
|
|
int64_t gemmM = m;
|
|
int64_t gemmK = k;
|
|
int64_t gemmN = n;
|
|
if (useTransposedForm) {
|
|
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
|
lhsBatchForGemm = rhsBatch;
|
|
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
|
rhsBatchForGemm = lhsBatch;
|
|
gemmM = n;
|
|
gemmN = m;
|
|
}
|
|
|
|
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
|
|
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
|
|
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
|
|
|
if (outType.getRank() == 2) {
|
|
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
|
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
|
Value gemmResult = ONNXGemmOp::create(rewriter,
|
|
loc,
|
|
gemmType,
|
|
lhsMatrix,
|
|
rhsMatrix,
|
|
none,
|
|
rewriter.getF32FloatAttr(1.0f),
|
|
rewriter.getF32FloatAttr(1.0f),
|
|
rewriter.getBoolAttr(false),
|
|
rewriter.getBoolAttr(false))
|
|
.getY();
|
|
if (useTransposedForm)
|
|
gemmResult = ONNXTransposeOp::create(rewriter, loc, outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}));
|
|
rewriter.replaceOp(matmulOp, gemmResult);
|
|
return success();
|
|
}
|
|
|
|
SmallVector<Value> batchResults;
|
|
batchResults.reserve(batch);
|
|
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
|
Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
|
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
|
Value gemmResult = ONNXGemmOp::create(rewriter,
|
|
loc,
|
|
gemmType,
|
|
lhsMatrix,
|
|
rhsMatrix,
|
|
none,
|
|
rewriter.getF32FloatAttr(1.0f),
|
|
rewriter.getF32FloatAttr(1.0f),
|
|
rewriter.getBoolAttr(false),
|
|
rewriter.getBoolAttr(false))
|
|
.getY();
|
|
if (useTransposedForm)
|
|
gemmResult = ONNXTransposeOp::create(
|
|
rewriter,
|
|
loc,
|
|
RankedTensorType::get({m, n}, outType.getElementType()),
|
|
gemmResult,
|
|
rewriter.getI64ArrayAttr({1, 0}));
|
|
batchResults.push_back(tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
batchedOutType,
|
|
gemmResult,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2}
|
|
}));
|
|
}
|
|
|
|
Value result = createSpatConcat(rewriter, loc, /*axis=*/0, batchResults);
|
|
rewriter.replaceOp(matmulOp, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
patterns.insert<MatMulToGemm>(ctx);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|