fix MatMul pattern non-contiguous extract_slices
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m31s
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m31s
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <filesystem>
|
||||
@@ -137,6 +139,35 @@ bool isSpatialMvmVmmWeightUse(OpOperand& use) {
|
||||
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
|
||||
}
|
||||
|
||||
bool hasOnlySpatialMvmVmmWeightUses(Value value) {
|
||||
SmallPtrSet<Value, 8> visited;
|
||||
auto walkUses = [&](Value currentValue, auto& self) -> bool {
|
||||
if (!visited.insert(currentValue).second)
|
||||
return true;
|
||||
if (currentValue.use_empty())
|
||||
return false;
|
||||
|
||||
return llvm::all_of(currentValue.getUses(), [&](OpOperand& use) {
|
||||
if (isSpatialMvmVmmWeightUse(use))
|
||||
return true;
|
||||
|
||||
Operation* user = use.getOwner();
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user))
|
||||
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(user))
|
||||
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(user))
|
||||
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(user))
|
||||
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
|
||||
|
||||
return false;
|
||||
});
|
||||
};
|
||||
|
||||
return walkUses(value, walkUses);
|
||||
}
|
||||
|
||||
void walkPimMvmVmmWeightUses(Operation* root, function_ref<void(OpOperand&)> callback) {
|
||||
assert(root && "expected valid root op");
|
||||
root->walk([&](pim::PimCoreOp coreOp) {
|
||||
|
||||
@@ -42,6 +42,7 @@ bool hasWeightAlways(mlir::Operation* op);
|
||||
void markWeightAlways(mlir::Operation* op);
|
||||
|
||||
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
|
||||
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
|
||||
|
||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
||||
|
||||
|
||||
@@ -87,8 +87,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
target.addDynamicallyLegalOp<ONNXMatMulOp>(
|
||||
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
|
||||
target.addIllegalOp<ONNXMatMulOp>();
|
||||
target.addIllegalOp<ONNXAddOp>();
|
||||
target.addIllegalOp<ONNXDivOp>();
|
||||
target.addIllegalOp<ONNXMulOp>();
|
||||
@@ -391,11 +390,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
|
||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
bool isAlwaysWeight =
|
||||
!constantOp->use_empty() && llvm::all_of(constantOp->getUses(), [](OpOperand& use) -> bool {
|
||||
return isSpatialMvmVmmWeightUse(use);
|
||||
});
|
||||
if (isAlwaysWeight)
|
||||
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
|
||||
markWeightAlways(constantOp);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
@@ -14,7 +15,108 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
static Value extractBatchMatrix(Value value,
|
||||
int64_t batchIndex,
|
||||
int64_t batchSize,
|
||||
int64_t rows,
|
||||
int64_t cols,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
if (type.getRank() == 2)
|
||||
return value;
|
||||
|
||||
auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType());
|
||||
SmallVector<OpFoldResult> offsets = {
|
||||
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, value, offsets, sizes, strides);
|
||||
|
||||
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
matrixType,
|
||||
slice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
}
|
||||
|
||||
static bool isConstantLikeOperand(Value value) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
|
||||
while (auto* definingOp = value.getDefiningOp()) {
|
||||
if (!visited.insert(definingOp).second)
|
||||
return false;
|
||||
if (definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
return true;
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
value = extractSliceOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
value = expandShapeOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseShapeOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||
value = transposeOp.getData();
|
||||
continue;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
auto shape = type.getShape();
|
||||
if (type.getRank() == 2) {
|
||||
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
|
||||
}
|
||||
|
||||
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
auto shape = type.getShape();
|
||||
RankedTensorType transposedType;
|
||||
SmallVector<int64_t> perm;
|
||||
if (type.getRank() == 2) {
|
||||
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
perm = {1, 0};
|
||||
}
|
||||
else {
|
||||
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||
@@ -24,80 +126,115 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||
|| !outType.hasStaticShape())
|
||||
return failure();
|
||||
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
|
||||
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|
||||
|| (outType.getRank() != 2 && outType.getRank() != 3))
|
||||
return failure();
|
||||
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
||||
|| !haveStaticPositiveShape(outType.getShape()))
|
||||
return failure();
|
||||
|
||||
const int64_t batch = rhsType.getDimSize(0);
|
||||
const int64_t k = rhsType.getDimSize(1);
|
||||
const int64_t n = rhsType.getDimSize(2);
|
||||
const int64_t m = lhsType.getDimSize(0);
|
||||
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|
||||
|| outType.getDimSize(2) != n)
|
||||
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
|
||||
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
|
||||
const int64_t batch = std::max(lhsBatch, rhsBatch);
|
||||
|
||||
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
|
||||
return failure();
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
|
||||
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
|
||||
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
|
||||
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
|
||||
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
|
||||
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
|
||||
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
|
||||
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
|
||||
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
|
||||
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
|
||||
if (k != rhsK)
|
||||
return failure();
|
||||
|
||||
Value lhsTransposed =
|
||||
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
|
||||
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
|
||||
SmallVector<Value> gemmRows;
|
||||
gemmRows.reserve(batch * n);
|
||||
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
||||
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {
|
||||
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
|
||||
SmallVector<OpFoldResult> sizes = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value rhsSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
|
||||
Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
rhsRowType,
|
||||
rhsSlice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
|
||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmRowType,
|
||||
rhsRow,
|
||||
lhsTransposed,
|
||||
none,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
gemmRows.push_back(gemmOp.getY());
|
||||
}
|
||||
if (outType.getRank() == 2) {
|
||||
if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
|
||||
return failure();
|
||||
}
|
||||
else {
|
||||
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOutType, {}, gemmRows, [&](ValueRange gemmRowsArgs) {
|
||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowsArgs);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||
});
|
||||
Location loc = matmulOp.getLoc();
|
||||
bool useTransposedForm = isConstantLikeOperand(matmulOp.getA()) && !isConstantLikeOperand(matmulOp.getB());
|
||||
|
||||
Value gemmOut = concatComputeOp.getResult(0);
|
||||
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
gemmExpandedType,
|
||||
gemmOut,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
Value lhs = matmulOp.getA();
|
||||
Value rhs = matmulOp.getB();
|
||||
int64_t lhsBatchForGemm = lhsBatch;
|
||||
int64_t rhsBatchForGemm = rhsBatch;
|
||||
int64_t gemmM = m;
|
||||
int64_t gemmK = k;
|
||||
int64_t gemmN = n;
|
||||
if (useTransposedForm) {
|
||||
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
||||
lhsBatchForGemm = rhsBatch;
|
||||
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
||||
rhsBatchForGemm = lhsBatch;
|
||||
gemmM = n;
|
||||
gemmN = m;
|
||||
}
|
||||
|
||||
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
|
||||
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
|
||||
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
|
||||
if (outType.getRank() == 2) {
|
||||
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||
Value gemmResult = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmType,
|
||||
lhsMatrix,
|
||||
rhsMatrix,
|
||||
none,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
if (useTransposedForm)
|
||||
gemmResult = ONNXTransposeOp::create(rewriter, loc, outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}));
|
||||
rewriter.replaceOp(matmulOp, gemmResult);
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value> batchResults;
|
||||
batchResults.reserve(batch);
|
||||
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
||||
Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||
Value gemmResult = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmType,
|
||||
lhsMatrix,
|
||||
rhsMatrix,
|
||||
none,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
if (useTransposedForm)
|
||||
gemmResult = ONNXTransposeOp::create(
|
||||
rewriter,
|
||||
loc,
|
||||
RankedTensorType::get({m, n}, outType.getElementType()),
|
||||
gemmResult,
|
||||
rewriter.getI64ArrayAttr({1, 0}));
|
||||
batchResults.push_back(tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
batchedOutType,
|
||||
gemmResult,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
}));
|
||||
}
|
||||
|
||||
Value result = batchResults.size() == 1
|
||||
? batchResults.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, batchResults).getResult();
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
@@ -106,7 +243,7 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
} // namespace
|
||||
|
||||
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<MatMulRank3ToGemm>(ctx);
|
||||
patterns.insert<MatMulToGemm>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -419,8 +419,16 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
||||
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
|
||||
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
|
||||
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
|
||||
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
|
||||
outTensorOperand.setType(newType);
|
||||
if (outTensorOperand == vmmOp.getInput()) {
|
||||
rewriter.setInsertionPoint(vmmOp);
|
||||
auto newOutputBuffer =
|
||||
tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType());
|
||||
vmmOp.getOutputBufferMutable().assign(newOutputBuffer);
|
||||
}
|
||||
else {
|
||||
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
|
||||
outTensorOperand.setType(newType);
|
||||
}
|
||||
resultTensor.setType(newType);
|
||||
|
||||
IntegerAttr zeroAttr = rewriter.getIndexAttr(0);
|
||||
|
||||
@@ -178,8 +178,10 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
|
||||
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -203,8 +205,10 @@ struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface,
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMVMOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
|
||||
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user