extend operation support for conv and gemm
add more tests in validation
This commit is contained in:
@@ -5,9 +5,11 @@ add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||
add_onnx_mlir_library(OMONNXToSpatial
|
||||
Math/Gemm.cpp
|
||||
Math/Conv.cpp
|
||||
Math/MatMul.cpp
|
||||
NN/Pooling.cpp
|
||||
NN/ReduceMean.cpp
|
||||
Tensor/ONNXConcatToTensorConcat.cpp
|
||||
Tensor/ONNXReshapeToTensorReshape.cpp
|
||||
Tensor/RemoveUnusedHelperOps.cpp
|
||||
Utils/SpatialReducer.cpp
|
||||
Utils/WeightSubdivider.cpp
|
||||
|
||||
@@ -130,19 +130,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
});
|
||||
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
||||
|
||||
// Reshape bias [numChannelsOut] -> [1, numChannelsOut] for Gemm C row-broadcasting, or use none
|
||||
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
Value gemmC;
|
||||
if (hasB) {
|
||||
auto biasType = RankedTensorType::get({1, numChannelsOut}, cast<RankedTensorType>(b.getType()).getElementType());
|
||||
gemmC = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
biasType,
|
||||
b,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
}
|
||||
if (hasB)
|
||||
gemmC = b;
|
||||
else
|
||||
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
|
||||
|
||||
@@ -23,6 +23,38 @@ namespace {
|
||||
|
||||
constexpr StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
|
||||
|
||||
static FailureOr<Value> materializeScaledConstantTensor(Value value,
|
||||
float factor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (factor == 1.0f)
|
||||
return value;
|
||||
|
||||
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
|
||||
if (!constantOp)
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseFPElementsAttr>(constantOp.getValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
SmallVector<APFloat> scaledValues;
|
||||
scaledValues.reserve(denseAttr.getNumElements());
|
||||
APFloat scale(factor);
|
||||
bool hadFailure = false;
|
||||
for (const APFloat& originalValue : denseAttr.getValues<APFloat>()) {
|
||||
APFloat scaledValue(originalValue);
|
||||
if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven))
|
||||
hadFailure = true;
|
||||
scaledValues.push_back(std::move(scaledValue));
|
||||
}
|
||||
if (hadFailure)
|
||||
return failure();
|
||||
|
||||
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
||||
}
|
||||
|
||||
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -74,10 +106,25 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
if (numOutRows <= 1)
|
||||
return failure();
|
||||
|
||||
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
||||
if (failed(scaledB))
|
||||
return failure();
|
||||
b = *scaledB;
|
||||
|
||||
RankedTensorType cType = nullptr;
|
||||
bool cHasNumOutRows = false;
|
||||
if (hasC) {
|
||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||
if (failed(scaledC))
|
||||
return failure();
|
||||
c = *scaledC;
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, c, SmallVector<ReassociationIndices>{{0, 1}});
|
||||
cType = expandedType;
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||
}
|
||||
@@ -112,8 +159,8 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
aSlice,
|
||||
b,
|
||||
cSlice,
|
||||
gemmOp.getAlphaAttr(),
|
||||
gemmOp.getBetaAttr(),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
gemmOp.getTransAAttr(),
|
||||
gemmOp.getTransBAttr());
|
||||
gemvOps.push_back(gemvOp.getY());
|
||||
@@ -158,6 +205,12 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
if (hasC) {
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter, gemmLoc, expandedType, c, SmallVector<ReassociationIndices>{{0, 1}});
|
||||
cType = expandedType;
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
}
|
||||
|
||||
@@ -177,19 +230,24 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
|
||||
if (alpha != 1.0f) {
|
||||
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
|
||||
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
|
||||
auto alphaTensor = arith::ConstantOp::create(rewriter, gemmLoc, alphaTensorType, alphaTensorValue);
|
||||
a = spatial::SpatVMulOp::create(rewriter, gemmLoc, a.getType(), a, alphaTensor);
|
||||
auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc);
|
||||
if (failed(scaledB))
|
||||
return failure();
|
||||
b = *scaledB;
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
alpha = 1.0f;
|
||||
}
|
||||
if (hasC && beta != 1.0f) {
|
||||
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
|
||||
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
|
||||
auto betaTensor = arith::ConstantOp::create(rewriter, gemmLoc, betaTensorType, betaTensorValue);
|
||||
c = spatial::SpatVMulOp::create(rewriter, gemmLoc, c.getType(), c, betaTensor);
|
||||
auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc);
|
||||
if (failed(scaledC))
|
||||
return failure();
|
||||
c = *scaledC;
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
beta = 1.0f;
|
||||
}
|
||||
|
||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||
|
||||
108
src/PIM/Conversion/ONNXToSpatial/Math/MatMul.cpp
Normal file
108
src/PIM/Conversion/ONNXToSpatial/Math/MatMul.cpp
Normal file
@@ -0,0 +1,108 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
||||
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
||||
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||
|| !outType.hasStaticShape())
|
||||
return failure();
|
||||
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
|
||||
return failure();
|
||||
|
||||
const int64_t batch = rhsType.getDimSize(0);
|
||||
const int64_t k = rhsType.getDimSize(1);
|
||||
const int64_t n = rhsType.getDimSize(2);
|
||||
const int64_t m = lhsType.getDimSize(0);
|
||||
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|
||||
|| outType.getDimSize(2) != n)
|
||||
return failure();
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
|
||||
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
|
||||
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
|
||||
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
|
||||
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
|
||||
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
|
||||
|
||||
Value lhsTransposed =
|
||||
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
|
||||
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
|
||||
SmallVector<Value> gemmRows;
|
||||
gemmRows.reserve(batch * n);
|
||||
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
||||
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {
|
||||
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
|
||||
SmallVector<OpFoldResult> sizes = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value rhsSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
|
||||
Value rhsRow = tensor::CollapseShapeOp::create(
|
||||
rewriter, loc, rhsRowType, rhsSlice, SmallVector<ReassociationIndices>{{0}, {1, 2}});
|
||||
|
||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmRowType,
|
||||
rhsRow,
|
||||
lhsTransposed,
|
||||
none,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
gemmRows.push_back(gemmOp.getY());
|
||||
}
|
||||
}
|
||||
|
||||
auto concatComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOutType, SmallVector<Value>(), gemmRows);
|
||||
|
||||
auto* concatBlock = new Block();
|
||||
for (Value gemmRow : gemmRows)
|
||||
concatBlock->addArgument(gemmRow.getType(), loc);
|
||||
concatComputeOp.getBody().push_back(concatBlock);
|
||||
rewriter.setInsertionPointToStart(concatBlock);
|
||||
|
||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, concatBlock->getArguments());
|
||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||
|
||||
rewriter.setInsertionPointAfter(concatComputeOp);
|
||||
Value gemmOut = concatComputeOp.getResult(0);
|
||||
Value gemmExpanded = tensor::ExpandShapeOp::create(
|
||||
rewriter, loc, gemmExpandedType, gemmOut, SmallVector<ReassociationIndices>{{0, 1}, {2}});
|
||||
Value result = ONNXTransposeOp::create(
|
||||
rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<MatMulRank3ToGemm>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -15,6 +15,10 @@ def onnxToArithConstantOp : Pat<
|
||||
|
||||
// ONNXMatMulOp to ONNXGemmOp patterns
|
||||
|
||||
def IsRank2Result: Constraint<
|
||||
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
|
||||
"Result is rank 2">;
|
||||
|
||||
def matMulAddToGemmPattern : Pat<
|
||||
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
|
||||
(ONNXGemmOp $A, $B, $C,
|
||||
@@ -22,19 +26,21 @@ def matMulAddToGemmPattern : Pat<
|
||||
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
||||
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
||||
)
|
||||
),
|
||||
[(IsRank2Result $matmulres)]
|
||||
>;
|
||||
|
||||
def matMulToGemmPattern : Pat<
|
||||
(ONNXMatMulOp:$matmulres $A, $B),
|
||||
(
|
||||
ONNXGemmOp $A, $B,
|
||||
ONNXGemmOp $A, $B,
|
||||
/* C = */ (NativeCodeCall<"tensor::EmptyOp::create($_builder, $_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
|
||||
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
|
||||
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
||||
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
||||
)
|
||||
),
|
||||
[(IsRank2Result $matmulres)]
|
||||
>;
|
||||
|
||||
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
@@ -56,6 +55,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx);
|
||||
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
|
||||
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
|
||||
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
|
||||
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
||||
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
|
||||
@@ -74,7 +74,9 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
|
||||
target.addIllegalOp<ONNXMatMulOp>();
|
||||
target.addDynamicallyLegalOp<ONNXMatMulOp>([](ONNXMatMulOp op) {
|
||||
return cast<ShapedType>(op.getY().getType()).getRank() != 2;
|
||||
});
|
||||
target.addIllegalOp<ONNXGemmOp>();
|
||||
target.addIllegalOp<ONNXConvOp>();
|
||||
target.addIllegalOp<ONNXLRNOp>();
|
||||
@@ -83,6 +85,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
target.addIllegalOp<ONNXConcatOp>();
|
||||
target.addIllegalOp<ONNXSoftmaxOp>();
|
||||
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
||||
target.addIllegalOp<ONNXReshapeOp>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<removeLRNPattern>(ctx);
|
||||
@@ -90,6 +93,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
populateConvOpPatterns(patterns, ctx);
|
||||
populatePoolingTilingPattern(patterns, ctx);
|
||||
populateOnnxGemmOpPatterns(patterns, ctx);
|
||||
populateReshapeConversionPattern(patterns, ctx);
|
||||
|
||||
populateONNXConcatToTensorConcatPattern(patterns, ctx);
|
||||
populateReduceMeanConversionPattern(patterns, ctx);
|
||||
|
||||
@@ -7,12 +7,16 @@ namespace onnx_mlir {
|
||||
|
||||
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape,
|
||||
ArrayRef<int64_t> resultShape,
|
||||
SmallVector<ReassociationIndices>& reassociation) {
|
||||
reassociation.clear();
|
||||
|
||||
size_t sourceIdx = 0;
|
||||
size_t resultIdx = 0;
|
||||
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
|
||||
int64_t sourceProduct = sourceShape[sourceIdx];
|
||||
int64_t resultProduct = resultShape[resultIdx];
|
||||
|
||||
ReassociationIndices group;
|
||||
group.push_back(sourceIdx);
|
||||
while (sourceProduct != resultProduct) {
|
||||
if (sourceProduct > resultProduct)
|
||||
return false;
|
||||
sourceIdx++;
|
||||
if (sourceIdx >= sourceShape.size())
|
||||
return false;
|
||||
group.push_back(sourceIdx);
|
||||
sourceProduct *= sourceShape[sourceIdx];
|
||||
}
|
||||
|
||||
reassociation.push_back(group);
|
||||
sourceIdx++;
|
||||
resultIdx++;
|
||||
}
|
||||
|
||||
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||
}
|
||||
|
||||
static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
|
||||
ArrayRef<int64_t> resultShape,
|
||||
SmallVector<ReassociationIndices>& reassociation) {
|
||||
reassociation.clear();
|
||||
|
||||
size_t sourceIdx = 0;
|
||||
size_t resultIdx = 0;
|
||||
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
|
||||
int64_t sourceProduct = sourceShape[sourceIdx];
|
||||
int64_t resultProduct = resultShape[resultIdx];
|
||||
|
||||
ReassociationIndices group;
|
||||
group.push_back(resultIdx);
|
||||
while (resultProduct != sourceProduct) {
|
||||
if (resultProduct > sourceProduct)
|
||||
return false;
|
||||
resultIdx++;
|
||||
if (resultIdx >= resultShape.size())
|
||||
return false;
|
||||
group.push_back(resultIdx);
|
||||
resultProduct *= resultShape[resultIdx];
|
||||
}
|
||||
|
||||
reassociation.push_back(group);
|
||||
sourceIdx++;
|
||||
resultIdx++;
|
||||
}
|
||||
|
||||
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||
}
|
||||
|
||||
struct ONNXReshapeToTensorReshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp,
|
||||
ONNXReshapeOpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape()))
|
||||
return failure();
|
||||
|
||||
if (sourceType == resultType) {
|
||||
rewriter.replaceOp(reshapeOp, adaptor.getData());
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
if (sourceType.getRank() > resultType.getRank()
|
||||
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
||||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
||||
return success();
|
||||
}
|
||||
|
||||
if (sourceType.getRank() < resultType.getRank()
|
||||
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
||||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ONNXReshapeToTensorReshape>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
@@ -79,8 +80,31 @@ private:
|
||||
} // namespace
|
||||
|
||||
static bool isChannelUseChainOp(Operation* op) {
|
||||
return isa<tensor::ExtractSliceOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp>(
|
||||
op);
|
||||
return isa<tensor::ExtractSliceOp,
|
||||
tensor::CollapseShapeOp,
|
||||
tensor::ExpandShapeOp,
|
||||
tensor::CastOp,
|
||||
tosa::ReshapeOp,
|
||||
pim::PimTransposeOp>(op);
|
||||
}
|
||||
|
||||
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (mapping.lookupOrNull(operand))
|
||||
continue;
|
||||
|
||||
Operation* definingOp = operand.getDefiningOp();
|
||||
if (!definingOp)
|
||||
continue;
|
||||
|
||||
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
||||
continue;
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
rewriter.setInsertionPointAfter(clonedOp);
|
||||
}
|
||||
}
|
||||
|
||||
static size_t countComputeLeafUsers(Value value) {
|
||||
@@ -204,6 +228,56 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
||||
OpOperand& resultUse = *resultUses.begin();
|
||||
Operation* resultUser = resultUse.getOwner();
|
||||
|
||||
if (isChannelUseChainOp(resultUser)) {
|
||||
SmallVector<Operation*> returnChain;
|
||||
Value chainedValue = result;
|
||||
Operation* chainUser = resultUser;
|
||||
|
||||
while (isChannelUseChainOp(chainUser)) {
|
||||
returnChain.push_back(chainUser);
|
||||
auto chainUses = chainUser->getResult(0).getUses();
|
||||
if (rangeLength(chainUses) != 1)
|
||||
break;
|
||||
chainedValue = chainUser->getResult(0);
|
||||
chainUser = chainUses.begin()->getOwner();
|
||||
}
|
||||
|
||||
if (isa<func::ReturnOp>(chainUser)) {
|
||||
size_t resultIndexInReturn = chainedValue.getUses().begin()->getOperandNumber();
|
||||
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
IRMapping mapping;
|
||||
mapping.map(result, yieldValue);
|
||||
|
||||
Value storedValue = yieldValue;
|
||||
for (Operation* op : returnChain) {
|
||||
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
storedValue = clonedOp->getResult(0);
|
||||
rewriter.setInsertionPointAfter(clonedOp);
|
||||
markOpToRemove(op);
|
||||
}
|
||||
|
||||
auto storedType = cast<ShapedType>(storedValue.getType());
|
||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||
|
||||
Value outputTensor = outputTensors[resultIndexInReturn];
|
||||
if (auto storedOp = storedValue.getDefiningOp())
|
||||
rewriter.setInsertionPointAfter(storedOp);
|
||||
PimMemCopyDevToHostOp::create(rewriter,
|
||||
loc,
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
storedValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(storedType.getNumElements() * elementSize));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (isa<func::ReturnOp>(resultUser)) {
|
||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||
size_t offset = 0;
|
||||
@@ -493,6 +567,7 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
|
||||
IRMapping mapping;
|
||||
mapping.map(channelSourceOp, receivedValue);
|
||||
for (Operation* op : llvm::reverse(clonedChain)) {
|
||||
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
|
||||
@@ -30,6 +30,24 @@ static Value stripMemRefCasts(Value value) {
|
||||
return value;
|
||||
}
|
||||
|
||||
static Value stripMemRefViewOps(Value value) {
|
||||
while (true) {
|
||||
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||
Location loc,
|
||||
MemRefType globalType,
|
||||
@@ -204,6 +222,7 @@ struct StaticSubviewInfo {
|
||||
};
|
||||
|
||||
static FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
value = stripMemRefViewOps(value);
|
||||
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||
if (!subviewOp)
|
||||
return failure();
|
||||
@@ -321,6 +340,77 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
|
||||
}
|
||||
};
|
||||
|
||||
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
||||
auto srcSubview = getStaticSubviewInfo(copyOp.getHostSrc());
|
||||
auto dstSubview = getStaticSubviewInfo(copyOp.getDeviceDst());
|
||||
const bool splitSrc = succeeded(srcSubview)
|
||||
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
||||
const bool splitDst = succeeded(dstSubview)
|
||||
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
||||
if (!splitSrc && !splitDst)
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(copyOp.getHostSrc().getType());
|
||||
auto dstType = dyn_cast<MemRefType>(copyOp.getDeviceDst().getType());
|
||||
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getElementType() != dstType.getElementType())
|
||||
return failure();
|
||||
|
||||
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||
return failure();
|
||||
|
||||
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
||||
if (elementByteWidth <= 0)
|
||||
return failure();
|
||||
|
||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||
if (copyOp.getSize() != totalBytes)
|
||||
return failure();
|
||||
|
||||
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
||||
if (sliceBytes <= 0)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||
|
||||
rewriter.setInsertionPoint(copyOp);
|
||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||
SmallVector<int64_t> outerIndices =
|
||||
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||
const int64_t srcByteOffset = copyOp.getHostSrcOffset()
|
||||
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
const int64_t dstByteOffset = copyOp.getDeviceDstOffset()
|
||||
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter,
|
||||
copyOp.getLoc(),
|
||||
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
||||
splitDst ? dstSubview->source : copyOp.getDeviceDst(),
|
||||
splitSrc ? srcSubview->source : copyOp.getHostSrc(),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getDeviceDst());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
|
||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||
if (!allocType || !allocType.hasStaticShape())
|
||||
@@ -578,6 +668,170 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
// Only match top-level memcp (not inside pim.core)
|
||||
if (copyOp->getParentOfType<pim::PimCoreOp>())
|
||||
return failure();
|
||||
|
||||
// dst must be an alloc with static shape
|
||||
auto allocOp = copyOp.getDst().getDefiningOp<memref::AllocOp>();
|
||||
if (!allocOp)
|
||||
return failure();
|
||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||
if (!allocType || !allocType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// The copy must cover the full destination (offsets both zero)
|
||||
if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0)
|
||||
return failure();
|
||||
|
||||
// Resolve the source through an optional subview to a get_global
|
||||
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
|
||||
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc());
|
||||
|
||||
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
|
||||
if (failed(denseAttr))
|
||||
return failure();
|
||||
|
||||
// Build the folded dense attribute
|
||||
DenseElementsAttr foldedAttr;
|
||||
if (succeeded(srcSubview)) {
|
||||
// Extract the sub-tensor from the source constant
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return failure();
|
||||
if (llvm::any_of(srcSubview->strides, [](int64_t s) { return s != 1; }))
|
||||
return failure();
|
||||
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(numResultElements);
|
||||
|
||||
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices))
|
||||
sourceIndices.push_back(off + idx);
|
||||
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
|
||||
resultValues[i] = sourceValues[srcLinear];
|
||||
}
|
||||
foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
}
|
||||
else {
|
||||
// Direct copy from a global — just reuse its dense attribute
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
if (resultTensorType != denseAttr->getType())
|
||||
return failure();
|
||||
foldedAttr = *denseAttr;
|
||||
}
|
||||
|
||||
// Verify that the alloc's remaining users are supported ops.
|
||||
bool allLiveUsersAreCores = true;
|
||||
for (Operation* user : allocOp->getUsers()) {
|
||||
if (user == copyOp)
|
||||
continue;
|
||||
if (isa<memref::DeallocOp>(user))
|
||||
continue;
|
||||
if (isa<pim::PimCoreOp>(user))
|
||||
continue;
|
||||
if (isa<memref::SubViewOp>(user)) {
|
||||
allLiveUsersAreCores = false;
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp");
|
||||
if (allLiveUsersAreCores)
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
rewriter.setInsertionPoint(allocOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
|
||||
if (allLiveUsersAreCores)
|
||||
markWeightAlways(newGetGlobal);
|
||||
|
||||
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
|
||||
rewriter.eraseOp(copyOp);
|
||||
if (allocOp.use_empty())
|
||||
rewriter.eraseOp(allocOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
|
||||
// Only handle subviews whose users are all pim.core ops.
|
||||
if (subviewOp.use_empty())
|
||||
return failure();
|
||||
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
|
||||
return failure();
|
||||
|
||||
// Source must resolve to a constant get_global.
|
||||
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
|
||||
if (failed(denseAttr))
|
||||
return failure();
|
||||
|
||||
// Static subview info.
|
||||
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
|
||||
if (failed(subviewInfo))
|
||||
return failure();
|
||||
if (llvm::any_of(subviewInfo->strides, [](int64_t s) { return s != 1; }))
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// Build the contiguous result type.
|
||||
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
|
||||
auto resultMemRefType = MemRefType::get(
|
||||
SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
|
||||
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
|
||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||
|
||||
// Extract the sub-tensor.
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(numResultElements);
|
||||
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
|
||||
sourceIndices.push_back(off + idx);
|
||||
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
||||
}
|
||||
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
rewriter.setInsertionPoint(subviewOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
|
||||
markWeightAlways(newGetGlobal);
|
||||
|
||||
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
|
||||
|
||||
@@ -591,7 +845,13 @@ struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPas
|
||||
for (RegisteredOperationName op : context->getRegisteredOperations())
|
||||
op.getCanonicalizationPatterns(owningPatterns, context);
|
||||
owningPatterns
|
||||
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, RewriteCoreSubviewCopyPattern>(
|
||||
.add<FoldConstantTransposePattern,
|
||||
FoldConstantAllocPattern,
|
||||
FoldConstantCoreMapPattern,
|
||||
RewriteCoreSubviewCopyPattern,
|
||||
RewriteHostSubviewLoadPattern,
|
||||
FoldConstantMemCpPattern,
|
||||
FoldConstantCoreSubviewPattern>(
|
||||
context);
|
||||
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
||||
return success();
|
||||
|
||||
Reference in New Issue
Block a user