fix MatMul pattern non-contiguous extract_slices
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m31s
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m31s
This commit is contained in:
@@ -1,10 +1,12 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallSet.h"
|
#include "llvm/ADT/SmallSet.h"
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
@@ -137,6 +139,35 @@ bool isSpatialMvmVmmWeightUse(OpOperand& use) {
|
|||||||
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
|
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool hasOnlySpatialMvmVmmWeightUses(Value value) {
|
||||||
|
SmallPtrSet<Value, 8> visited;
|
||||||
|
auto walkUses = [&](Value currentValue, auto& self) -> bool {
|
||||||
|
if (!visited.insert(currentValue).second)
|
||||||
|
return true;
|
||||||
|
if (currentValue.use_empty())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return llvm::all_of(currentValue.getUses(), [&](OpOperand& use) {
|
||||||
|
if (isSpatialMvmVmmWeightUse(use))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
Operation* user = use.getOwner();
|
||||||
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user))
|
||||||
|
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
|
||||||
|
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(user))
|
||||||
|
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
|
||||||
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(user))
|
||||||
|
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
|
||||||
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(user))
|
||||||
|
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
return walkUses(value, walkUses);
|
||||||
|
}
|
||||||
|
|
||||||
void walkPimMvmVmmWeightUses(Operation* root, function_ref<void(OpOperand&)> callback) {
|
void walkPimMvmVmmWeightUses(Operation* root, function_ref<void(OpOperand&)> callback) {
|
||||||
assert(root && "expected valid root op");
|
assert(root && "expected valid root op");
|
||||||
root->walk([&](pim::PimCoreOp coreOp) {
|
root->walk([&](pim::PimCoreOp coreOp) {
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ bool hasWeightAlways(mlir::Operation* op);
|
|||||||
void markWeightAlways(mlir::Operation* op);
|
void markWeightAlways(mlir::Operation* op);
|
||||||
|
|
||||||
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
|
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
|
||||||
|
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
|
||||||
|
|
||||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
||||||
|
|
||||||
|
|||||||
@@ -87,8 +87,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
tensor::TensorDialect,
|
tensor::TensorDialect,
|
||||||
arith::ArithDialect,
|
arith::ArithDialect,
|
||||||
scf::SCFDialect>();
|
scf::SCFDialect>();
|
||||||
target.addDynamicallyLegalOp<ONNXMatMulOp>(
|
target.addIllegalOp<ONNXMatMulOp>();
|
||||||
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
|
|
||||||
target.addIllegalOp<ONNXAddOp>();
|
target.addIllegalOp<ONNXAddOp>();
|
||||||
target.addIllegalOp<ONNXDivOp>();
|
target.addIllegalOp<ONNXDivOp>();
|
||||||
target.addIllegalOp<ONNXMulOp>();
|
target.addIllegalOp<ONNXMulOp>();
|
||||||
@@ -391,11 +390,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||||
bool isAlwaysWeight =
|
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
|
||||||
!constantOp->use_empty() && llvm::all_of(constantOp->getUses(), [](OpOperand& use) -> bool {
|
|
||||||
return isSpatialMvmVmmWeightUse(use);
|
|
||||||
});
|
|
||||||
if (isAlwaysWeight)
|
|
||||||
markWeightAlways(constantOp);
|
markWeightAlways(constantOp);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
@@ -14,7 +15,108 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||||
|
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value extractBatchMatrix(Value value,
|
||||||
|
int64_t batchIndex,
|
||||||
|
int64_t batchSize,
|
||||||
|
int64_t rows,
|
||||||
|
int64_t cols,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
|
if (type.getRank() == 2)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType());
|
||||||
|
SmallVector<OpFoldResult> offsets = {
|
||||||
|
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {
|
||||||
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, value, offsets, sizes, strides);
|
||||||
|
|
||||||
|
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
|
||||||
|
return tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
matrixType,
|
||||||
|
slice,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isConstantLikeOperand(Value value) {
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
|
|
||||||
|
while (auto* definingOp = value.getDefiningOp()) {
|
||||||
|
if (!visited.insert(definingOp).second)
|
||||||
|
return false;
|
||||||
|
if (definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||||
|
return true;
|
||||||
|
|
||||||
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||||
|
value = extractSliceOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||||
|
value = expandShapeOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||||
|
value = collapseShapeOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||||
|
value = transposeOp.getData();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||||
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
|
auto shape = type.getShape();
|
||||||
|
if (type.getRank() == 2) {
|
||||||
|
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||||
|
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||||
|
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||||
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
|
auto shape = type.getShape();
|
||||||
|
RankedTensorType transposedType;
|
||||||
|
SmallVector<int64_t> perm;
|
||||||
|
if (type.getRank() == 2) {
|
||||||
|
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||||
|
perm = {1, 0};
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||||
|
perm = {0, 2, 1};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto transposeCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||||
|
Value transposed =
|
||||||
|
ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||||
|
});
|
||||||
|
return transposeCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||||
@@ -24,80 +126,115 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||||
|| !outType.hasStaticShape())
|
|| !outType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
|
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|
||||||
|
|| (outType.getRank() != 2 && outType.getRank() != 3))
|
||||||
|
return failure();
|
||||||
|
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
||||||
|
|| !haveStaticPositiveShape(outType.getShape()))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
const int64_t batch = rhsType.getDimSize(0);
|
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
|
||||||
const int64_t k = rhsType.getDimSize(1);
|
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
|
||||||
const int64_t n = rhsType.getDimSize(2);
|
const int64_t batch = std::max(lhsBatch, rhsBatch);
|
||||||
const int64_t m = lhsType.getDimSize(0);
|
|
||||||
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
|
||||||
|| outType.getDimSize(2) != n)
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Location loc = matmulOp.getLoc();
|
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
|
||||||
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
|
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
|
||||||
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
|
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
|
||||||
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
|
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
|
||||||
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
|
if (k != rhsK)
|
||||||
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
|
return failure();
|
||||||
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
|
|
||||||
|
|
||||||
Value lhsTransposed =
|
if (outType.getRank() == 2) {
|
||||||
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
|
if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
|
||||||
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
return failure();
|
||||||
|
}
|
||||||
SmallVector<Value> gemmRows;
|
else {
|
||||||
gemmRows.reserve(batch * n);
|
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
|
||||||
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
return failure();
|
||||||
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
|
|
||||||
SmallVector<OpFoldResult> offsets = {
|
|
||||||
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
|
|
||||||
SmallVector<OpFoldResult> sizes = {
|
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
|
|
||||||
SmallVector<OpFoldResult> strides = {
|
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
Value rhsSlice =
|
|
||||||
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
|
|
||||||
Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
rhsRowType,
|
|
||||||
rhsSlice,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0},
|
|
||||||
{1, 2}
|
|
||||||
});
|
|
||||||
|
|
||||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
gemmRowType,
|
|
||||||
rhsRow,
|
|
||||||
lhsTransposed,
|
|
||||||
none,
|
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
|
||||||
rewriter.getBoolAttr(false),
|
|
||||||
rewriter.getBoolAttr(false));
|
|
||||||
gemmRows.push_back(gemmOp.getY());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOutType, {}, gemmRows, [&](ValueRange gemmRowsArgs) {
|
Location loc = matmulOp.getLoc();
|
||||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowsArgs);
|
bool useTransposedForm = isConstantLikeOperand(matmulOp.getA()) && !isConstantLikeOperand(matmulOp.getB());
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
|
||||||
});
|
|
||||||
|
|
||||||
Value gemmOut = concatComputeOp.getResult(0);
|
Value lhs = matmulOp.getA();
|
||||||
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
|
Value rhs = matmulOp.getB();
|
||||||
loc,
|
int64_t lhsBatchForGemm = lhsBatch;
|
||||||
gemmExpandedType,
|
int64_t rhsBatchForGemm = rhsBatch;
|
||||||
gemmOut,
|
int64_t gemmM = m;
|
||||||
SmallVector<ReassociationIndices> {
|
int64_t gemmK = k;
|
||||||
{0, 1},
|
int64_t gemmN = n;
|
||||||
{2}
|
if (useTransposedForm) {
|
||||||
});
|
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
||||||
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
|
lhsBatchForGemm = rhsBatch;
|
||||||
|
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
||||||
|
rhsBatchForGemm = lhsBatch;
|
||||||
|
gemmM = n;
|
||||||
|
gemmN = m;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
|
||||||
|
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
|
||||||
|
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
|
|
||||||
|
if (outType.getRank() == 2) {
|
||||||
|
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||||
|
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||||
|
Value gemmResult = ONNXGemmOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
gemmType,
|
||||||
|
lhsMatrix,
|
||||||
|
rhsMatrix,
|
||||||
|
none,
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getBoolAttr(false),
|
||||||
|
rewriter.getBoolAttr(false))
|
||||||
|
.getY();
|
||||||
|
if (useTransposedForm)
|
||||||
|
gemmResult = ONNXTransposeOp::create(rewriter, loc, outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}));
|
||||||
|
rewriter.replaceOp(matmulOp, gemmResult);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> batchResults;
|
||||||
|
batchResults.reserve(batch);
|
||||||
|
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
||||||
|
Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||||
|
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||||
|
Value gemmResult = ONNXGemmOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
gemmType,
|
||||||
|
lhsMatrix,
|
||||||
|
rhsMatrix,
|
||||||
|
none,
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getBoolAttr(false),
|
||||||
|
rewriter.getBoolAttr(false))
|
||||||
|
.getY();
|
||||||
|
if (useTransposedForm)
|
||||||
|
gemmResult = ONNXTransposeOp::create(
|
||||||
|
rewriter,
|
||||||
|
loc,
|
||||||
|
RankedTensorType::get({m, n}, outType.getElementType()),
|
||||||
|
gemmResult,
|
||||||
|
rewriter.getI64ArrayAttr({1, 0}));
|
||||||
|
batchResults.push_back(tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
batchedOutType,
|
||||||
|
gemmResult,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value result = batchResults.size() == 1
|
||||||
|
? batchResults.front()
|
||||||
|
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, batchResults).getResult();
|
||||||
rewriter.replaceOp(matmulOp, result);
|
rewriter.replaceOp(matmulOp, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -106,7 +243,7 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
patterns.insert<MatMulRank3ToGemm>(ctx);
|
patterns.insert<MatMulToGemm>(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -419,8 +419,16 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
|||||||
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
|
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
|
||||||
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
|
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
|
||||||
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
|
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
|
||||||
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
|
if (outTensorOperand == vmmOp.getInput()) {
|
||||||
outTensorOperand.setType(newType);
|
rewriter.setInsertionPoint(vmmOp);
|
||||||
|
auto newOutputBuffer =
|
||||||
|
tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType());
|
||||||
|
vmmOp.getOutputBufferMutable().assign(newOutputBuffer);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
|
||||||
|
outTensorOperand.setType(newType);
|
||||||
|
}
|
||||||
resultTensor.setType(newType);
|
resultTensor.setType(newType);
|
||||||
|
|
||||||
IntegerAttr zeroAttr = rewriter.getIndexAttr(0);
|
IntegerAttr zeroAttr = rewriter.getIndexAttr(0);
|
||||||
|
|||||||
@@ -178,8 +178,10 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
||||||
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
|
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -203,8 +205,10 @@ struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface,
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimMVMOp>(
|
replaceOpWithNewBufferizedOp<PimMVMOp>(
|
||||||
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
|
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user