This commit is contained in:
@@ -28,8 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
||||
|
||||
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto biasType = cast<RankedTensorType>(bias.getType());
|
||||
if (biasType.getRank() != 1)
|
||||
@@ -615,10 +613,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||
const int64_t strideHeight = getOptionalI64Attr(stridesAttr, 0, 1);
|
||||
const int64_t strideWidth = getOptionalI64Attr(stridesAttr, 1, 1);
|
||||
const int64_t dilationHeight = getOptionalI64Attr(dilationsAttr, 0, 1);
|
||||
const int64_t dilationWidth = getOptionalI64Attr(dilationsAttr, 1, 1);
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
@@ -626,10 +624,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||
padHeightBegin = getI64Attr(*padsAttr, 0);
|
||||
padWidthBegin = getI64Attr(*padsAttr, 1);
|
||||
padHeightEnd = getI64Attr(*padsAttr, 2);
|
||||
padWidthEnd = getI64Attr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
@@ -58,47 +58,16 @@ static Value transposeForSpatial(Value value,
|
||||
ArrayRef<int64_t> permutation,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (isCompileTimeComputable(value))
|
||||
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value createIndexConstant(ConversionPatternRewriter& rewriter, int64_t value) {
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
return getOrCreateHostIndexConstant(anchorOp, value, rewriter);
|
||||
}
|
||||
|
||||
static Value
|
||||
createAffineApply(ConversionPatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
|
||||
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
|
||||
return transposeMaybeInCompute(value, resultType, permutation, rewriter, loc);
|
||||
}
|
||||
|
||||
static Value
|
||||
multiplyIndexByConstant(Value value, int64_t multiplier, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (multiplier == 0)
|
||||
return createIndexConstant(rewriter, 0);
|
||||
if (multiplier == 1)
|
||||
return value;
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value});
|
||||
return onnx_mlir::multiplyIndexByConstant(rewriter, value.getDefiningOp(), value, multiplier);
|
||||
}
|
||||
|
||||
static Value modIndexByConstant(Value value, int64_t divisor, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (divisor == 1)
|
||||
return createIndexConstant(rewriter, 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApply(rewriter, loc, d0 % divisor, ValueRange {value});
|
||||
return onnx_mlir::modIndexByConstant(rewriter, loc, value, divisor);
|
||||
}
|
||||
|
||||
static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
@@ -108,11 +77,11 @@ static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatter
|
||||
static Value createGemmBatchKOffset(
|
||||
Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (numKSlices == 1)
|
||||
return createIndexConstant(rewriter, 0);
|
||||
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApply(
|
||||
return createAffineApplyOrConstant(
|
||||
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
|
||||
}
|
||||
|
||||
@@ -123,11 +92,11 @@ static Value createGemmBatchHOffset(Value lane,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (numOutHSlices == 1)
|
||||
return createIndexConstant(rewriter, 0);
|
||||
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApply(
|
||||
return createAffineApplyOrConstant(
|
||||
rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane});
|
||||
}
|
||||
|
||||
@@ -303,53 +272,37 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
||||
loc,
|
||||
TypeRange {partialPiecesType},
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
|
||||
ValueRange {b},
|
||||
ValueRange {a});
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) {
|
||||
Value row = createGemmBatchRow(args.lane, numOutRows, rewriter, loc);
|
||||
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
|
||||
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
||||
|
||||
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), paddedBType, aType, partialPiecesType};
|
||||
SmallVector<Location> blockArgLocs(blockArgTypes.size(), loc);
|
||||
Block* body =
|
||||
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
auto aTileType =
|
||||
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
|
||||
auto bTileType = RankedTensorType::get(
|
||||
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
|
||||
paddedBType.getElementType());
|
||||
auto pieceType =
|
||||
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
|
||||
Value aTile = extractATile(args.inputs.front(), row, kOffset, aTileType, rewriter, loc);
|
||||
|
||||
auto lane = batchOp.getLaneArgument();
|
||||
auto weight = batchOp.getWeightArgument(0);
|
||||
auto input = batchOp.getInputArgument(0);
|
||||
auto output = batchOp.getOutputArgument(0);
|
||||
assert(lane && weight && input && output && "malformed Gemm compute_batch body");
|
||||
SmallVector<OpFoldResult> bOffsets {kOffset, hOffset};
|
||||
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
|
||||
rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
|
||||
Value bTile =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, args.weights.front(), bOffsets, bSizes, unitStrides)
|
||||
.getResult();
|
||||
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
||||
|
||||
Value row = createGemmBatchRow(*lane, numOutRows, rewriter, loc);
|
||||
Value kOffset = createGemmBatchKOffset(*lane, numOutRows, numKSlices, rewriter, loc);
|
||||
Value hOffset = createGemmBatchHOffset(*lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
||||
|
||||
auto aTileType = RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
|
||||
auto bTileType = RankedTensorType::get(
|
||||
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
|
||||
paddedBType.getElementType());
|
||||
auto pieceType =
|
||||
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
|
||||
Value aTile = extractATile(*input, row, kOffset, aTileType, rewriter, loc);
|
||||
|
||||
SmallVector<OpFoldResult> bOffsets {kOffset, hOffset};
|
||||
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
|
||||
rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value bTile =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, *weight, bOffsets, bSizes, unitStrides).getResult();
|
||||
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
||||
|
||||
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
SmallVector<OpFoldResult> pieceOffsets {*lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
tensor::ParallelInsertSliceOp::create(rewriter, loc, piece, *output, pieceOffsets, pieceSizes, unitStrides);
|
||||
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
return batchOp;
|
||||
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
createParallelInsertSliceIntoBatchOutput(
|
||||
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, unitStrides);
|
||||
});
|
||||
assert(succeeded(batchOp) && "expected Gemm VMM batch construction to succeed");
|
||||
return *batchOp;
|
||||
}
|
||||
|
||||
static Value createDynamicGemmBatchRow(
|
||||
@@ -359,7 +312,7 @@ static Value createDynamicGemmBatchRow(
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApply(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
|
||||
return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
|
||||
}
|
||||
|
||||
static Value createDynamicGemmBatchColumn(
|
||||
@@ -479,45 +432,27 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
|
||||
const int64_t numOutCols = outType.getDimSize(1);
|
||||
const int64_t reductionSize = aType.getDimSize(1);
|
||||
const int64_t laneCount = numOutRows * numOutCols;
|
||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
||||
loc,
|
||||
TypeRange {scalarPiecesType},
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
|
||||
ValueRange {},
|
||||
ValueRange {a, b});
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) {
|
||||
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
|
||||
Value column = createDynamicGemmBatchColumn(args.lane, numOutCols, rewriter, loc);
|
||||
|
||||
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), aType, bType, scalarPiecesType};
|
||||
SmallVector<Location> blockArgLocs(blockArgTypes.size(), loc);
|
||||
Block* body =
|
||||
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
|
||||
Value bVector = bAlreadyTransposed
|
||||
? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
|
||||
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
|
||||
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||
|
||||
auto lane = batchOp.getLaneArgument();
|
||||
auto inputA = batchOp.getInputArgument(0);
|
||||
auto inputB = batchOp.getInputArgument(1);
|
||||
auto output = batchOp.getOutputArgument(0);
|
||||
assert(lane && inputA && inputB && output && "malformed dynamic Gemm compute_batch body");
|
||||
|
||||
Value row = createDynamicGemmBatchRow(*lane, numOutCols, rewriter, loc);
|
||||
Value column = createDynamicGemmBatchColumn(*lane, numOutCols, rewriter, loc);
|
||||
|
||||
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
Value aVector = extractDynamicGemmRowVector(*inputA, row, vectorType, rewriter, loc);
|
||||
Value bVector = bAlreadyTransposed
|
||||
? extractTransposedBRow(*inputB, column, vectorType, rewriter, loc)
|
||||
: extractDynamicGemmBColumn(*inputB, column, vectorType, rewriter, loc);
|
||||
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||
|
||||
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
SmallVector<OpFoldResult> outputOffsets {*lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
tensor::ParallelInsertSliceOp::create(rewriter, loc, scalar, *output, outputOffsets, scalarSizes, unitStrides);
|
||||
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
return batchOp;
|
||||
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
|
||||
createParallelInsertSliceIntoBatchOutput(
|
||||
rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, unitStrides);
|
||||
});
|
||||
assert(succeeded(batchOp) && "expected Gemm VVDMul batch construction to succeed");
|
||||
return *batchOp;
|
||||
}
|
||||
|
||||
static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
|
||||
@@ -540,9 +475,9 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
|
||||
Value biasArg = bias ? blockArgs[1] : Value();
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
|
||||
Value c0 = createIndexConstant(rewriter, 0);
|
||||
Value c1 = createIndexConstant(rewriter, 1);
|
||||
Value cLaneCount = createIndexConstant(rewriter, laneCount);
|
||||
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cLaneCount = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
@@ -587,7 +522,8 @@ static Value createPartialGroupOffset(Value hSlice,
|
||||
Location loc) {
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApply(rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
|
||||
return createAffineApplyOrConstant(
|
||||
rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
|
||||
}
|
||||
|
||||
static Value extractReductionPiece(Value partialPiecesArg,
|
||||
@@ -684,13 +620,13 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
|
||||
|
||||
Value paddedOutput = outputInit;
|
||||
if (numOutHSlices == 1) {
|
||||
Value hSlice = createIndexConstant(rewriter, 0);
|
||||
Value hSlice = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
paddedOutput = buildOutputSlice(outputInit, hSlice);
|
||||
}
|
||||
else {
|
||||
Value c0 = createIndexConstant(rewriter, 0);
|
||||
Value c1 = createIndexConstant(rewriter, 1);
|
||||
Value cOutHSlices = createIndexConstant(rewriter, numOutHSlices);
|
||||
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cOutHSlices = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
|
||||
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(hLoop.getBody());
|
||||
|
||||
|
||||
@@ -19,14 +19,6 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
||||
ArrayRef<int64_t> rhsBatchShape) {
|
||||
if (lhsBatchShape.empty())
|
||||
@@ -54,15 +46,7 @@ collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, Pa
|
||||
auto buildCollapsed = [&](Value input) -> Value {
|
||||
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
|
||||
};
|
||||
|
||||
if (isCompileTimeComputable(value))
|
||||
return buildCollapsed(value);
|
||||
|
||||
auto collapseCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
|
||||
});
|
||||
return collapseCompute.getResult(0);
|
||||
return materializeOrComputeUnary(value, collapsedType, rewriter, loc, buildCollapsed);
|
||||
}
|
||||
|
||||
static Value
|
||||
@@ -76,12 +60,10 @@ expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, Patt
|
||||
for (size_t dim = 0; dim < batchRank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
|
||||
auto expandCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
});
|
||||
return expandCompute.getResult(0);
|
||||
auto buildExpanded = [&](Value input) -> Value {
|
||||
return tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation).getResult();
|
||||
};
|
||||
return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded);
|
||||
}
|
||||
|
||||
static Value extractBatchMatrix(Value value,
|
||||
@@ -100,7 +82,7 @@ static Value extractBatchMatrix(Value value,
|
||||
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, 3);
|
||||
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
|
||||
auto buildMatrix = [&](Value input) -> Value {
|
||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
||||
@@ -114,14 +96,7 @@ static Value extractBatchMatrix(Value value,
|
||||
});
|
||||
};
|
||||
|
||||
if (isCompileTimeComputable(value))
|
||||
return buildMatrix(value);
|
||||
|
||||
auto batchMatrixCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
|
||||
});
|
||||
return batchMatrixCompute.getResult(0);
|
||||
return materializeOrComputeUnary(value, matrixType, rewriter, loc, buildMatrix);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -138,18 +113,7 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto buildTranspose = [&](Value input) -> Value {
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
};
|
||||
|
||||
if (isCompileTimeComputable(value))
|
||||
return buildTranspose(value);
|
||||
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
return transposeMaybeInCompute(value, transposedType, perm, rewriter, loc);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -166,10 +130,11 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
@@ -203,8 +168,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
return failure();
|
||||
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
||||
return failure();
|
||||
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
||||
|| !haveStaticPositiveShape(outType.getShape()))
|
||||
if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType))
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
@@ -16,26 +18,6 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
|
||||
SmallVector<int64_t> normalizedAxes;
|
||||
if (!axesAttr) {
|
||||
normalizedAxes.reserve(rank);
|
||||
for (int64_t axis = 0; axis < rank; axis++)
|
||||
normalizedAxes.push_back(axis);
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
normalizedAxes.reserve(axesAttr.size());
|
||||
for (Attribute attr : axesAttr) {
|
||||
int64_t axis = cast<IntegerAttr>(attr).getInt();
|
||||
normalizedAxes.push_back(axis >= 0 ? axis : rank + axis);
|
||||
}
|
||||
|
||||
llvm::sort(normalizedAxes);
|
||||
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
static SmallVector<bool> buildReducedAxesMask(ArrayRef<int64_t> axes, int64_t rank) {
|
||||
SmallVector<bool> reducedAxes(rank, false);
|
||||
for (int64_t axis : axes) {
|
||||
@@ -50,6 +32,181 @@ static RankedTensorType getAllOnesType(RankedTensorType inputType, Type elementT
|
||||
return RankedTensorType::get(SmallVector<int64_t>(inputType.getRank(), 1), elementType);
|
||||
}
|
||||
|
||||
static RankedTensorType getKeepdimsType(RankedTensorType inputType, Type elementType, ArrayRef<bool> reducedAxes) {
|
||||
SmallVector<int64_t> shape;
|
||||
shape.reserve(inputType.getRank());
|
||||
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
|
||||
shape.push_back(isReduced ? 1 : dim);
|
||||
return RankedTensorType::get(shape, elementType, inputType.getEncoding());
|
||||
}
|
||||
|
||||
static RankedTensorType getCompactKeptType(RankedTensorType inputType, Type elementType, ArrayRef<bool> reducedAxes) {
|
||||
SmallVector<int64_t> shape;
|
||||
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
|
||||
if (!isReduced)
|
||||
shape.push_back(dim);
|
||||
return RankedTensorType::get(shape, elementType, inputType.getEncoding());
|
||||
}
|
||||
|
||||
static RankedTensorType getReducedSliceType(RankedTensorType inputType, ArrayRef<bool> reducedAxes) {
|
||||
SmallVector<int64_t> shape;
|
||||
shape.reserve(inputType.getRank());
|
||||
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
|
||||
shape.push_back(isReduced ? dim : 1);
|
||||
return RankedTensorType::get(shape, inputType.getElementType(), inputType.getEncoding());
|
||||
}
|
||||
|
||||
static RankedTensorType getLanePackedKeepdimsType(int64_t laneCount, RankedTensorType leafType) {
|
||||
SmallVector<int64_t> shape(leafType.getShape().begin(), leafType.getShape().end());
|
||||
shape.front() = laneCount;
|
||||
return RankedTensorType::get(shape, leafType.getElementType(), leafType.getEncoding());
|
||||
}
|
||||
|
||||
static SmallVector<int64_t> getKeptAxes(ArrayRef<bool> reducedAxes) {
|
||||
SmallVector<int64_t> keptAxes;
|
||||
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes))
|
||||
if (!isReduced)
|
||||
keptAxes.push_back(static_cast<int64_t>(axis));
|
||||
return keptAxes;
|
||||
}
|
||||
|
||||
static Value computeLaneIndex(Value lane,
|
||||
int64_t stride,
|
||||
int64_t dimSize,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (dimSize == 1)
|
||||
return arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
AffineExpr expr = d0;
|
||||
if (stride != 1)
|
||||
expr = expr.floorDiv(stride);
|
||||
if (dimSize != 1)
|
||||
expr = expr % dimSize;
|
||||
return createAffineApplyOrConstant(rewriter, loc, expr, ValueRange {lane});
|
||||
}
|
||||
|
||||
static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,
|
||||
ArrayRef<bool> reducedAxes,
|
||||
RankedTensorType batchType,
|
||||
RankedTensorType leafType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto sliceType = getReducedSliceType(inputType, reducedAxes);
|
||||
SmallVector<int64_t> keptAxes = getKeptAxes(reducedAxes);
|
||||
|
||||
int64_t laneCount = 1;
|
||||
SmallVector<int64_t> keptAxisStrides(keptAxes.size(), 1);
|
||||
for (int64_t index = static_cast<int64_t>(keptAxes.size()) - 1; index >= 0; --index) {
|
||||
keptAxisStrides[index] = laneCount;
|
||||
int64_t dimSize = inputType.getDimSize(keptAxes[index]);
|
||||
if (dimSize <= 0)
|
||||
return failure();
|
||||
if (laneCount > std::numeric_limits<int32_t>::max() / dimSize)
|
||||
return failure();
|
||||
laneCount *= dimSize;
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> sliceOffsets;
|
||||
SmallVector<OpFoldResult> sliceSizes;
|
||||
SmallVector<OpFoldResult> insertOffsets;
|
||||
SmallVector<OpFoldResult> insertSizes(inputType.getRank(), rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, inputType.getRank());
|
||||
sliceOffsets.reserve(inputType.getRank());
|
||||
sliceSizes.reserve(inputType.getRank());
|
||||
insertOffsets.reserve(inputType.getRank());
|
||||
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter, loc, TypeRange {batchType}, laneCount, {}, ValueRange {input}, [&](detail::SpatComputeBatchBodyArgs args) {
|
||||
size_t keptAxisIndex = 0;
|
||||
sliceOffsets.clear();
|
||||
sliceSizes.clear();
|
||||
insertOffsets.clear();
|
||||
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
|
||||
if (isReduced) {
|
||||
sliceOffsets.push_back(rewriter.getIndexAttr(0));
|
||||
sliceSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(axis)));
|
||||
continue;
|
||||
}
|
||||
|
||||
Value axisIndex =
|
||||
computeLaneIndex(args.lane, keptAxisStrides[keptAxisIndex], inputType.getDimSize(axis), rewriter, loc);
|
||||
++keptAxisIndex;
|
||||
sliceOffsets.push_back(axisIndex);
|
||||
sliceSizes.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
insertOffsets.push_back(args.lane);
|
||||
insertOffsets.append(inputType.getRank() - 1, rewriter.getIndexAttr(0));
|
||||
|
||||
Value slice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, sliceType, args.inputs.front(), sliceOffsets, sliceSizes, unitStrides);
|
||||
Value reduced = spatial::SpatVAvgOp::create(rewriter, loc, leafType, slice).getResult();
|
||||
createParallelInsertSliceIntoBatchOutput(
|
||||
rewriter, loc, reduced, args.outputs.front(), insertOffsets, insertSizes, unitStrides);
|
||||
});
|
||||
if (failed(batchOp))
|
||||
return failure();
|
||||
return (*batchOp).getResult(0);
|
||||
}
|
||||
|
||||
static Value buildKeepdimsFromLanePackedBatch(Value batchValue,
|
||||
RankedTensorType keepdimsType,
|
||||
RankedTensorType compactKeptType,
|
||||
ArrayRef<bool> reducedAxes,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto batchType = cast<RankedTensorType>(batchValue.getType());
|
||||
if (batchType == keepdimsType)
|
||||
return batchValue;
|
||||
|
||||
SmallVector<ReassociationIndices> collapseToFlat {{}};
|
||||
for (int64_t axis = 0; axis < batchType.getRank(); ++axis)
|
||||
collapseToFlat.front().push_back(axis);
|
||||
|
||||
SmallVector<ReassociationIndices> expandFlatToCompact(1);
|
||||
for (int64_t axis = 0; axis < compactKeptType.getRank(); ++axis)
|
||||
expandFlatToCompact.front().push_back(axis);
|
||||
|
||||
SmallVector<ReassociationIndices> expandCompactToKeepdims;
|
||||
ReassociationIndices pendingLeadingReducedAxes;
|
||||
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
|
||||
if (isReduced) {
|
||||
if (expandCompactToKeepdims.empty())
|
||||
pendingLeadingReducedAxes.push_back(axis);
|
||||
else
|
||||
expandCompactToKeepdims.back().push_back(axis);
|
||||
continue;
|
||||
}
|
||||
|
||||
expandCompactToKeepdims.emplace_back();
|
||||
auto& group = expandCompactToKeepdims.back();
|
||||
group.append(pendingLeadingReducedAxes.begin(), pendingLeadingReducedAxes.end());
|
||||
pendingLeadingReducedAxes.clear();
|
||||
group.push_back(axis);
|
||||
}
|
||||
if (!pendingLeadingReducedAxes.empty())
|
||||
expandCompactToKeepdims.back().append(pendingLeadingReducedAxes.begin(), pendingLeadingReducedAxes.end());
|
||||
|
||||
auto reshapeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {keepdimsType}, {}, ValueRange {batchValue}, [&](Value input) {
|
||||
auto flatType = RankedTensorType::get({batchType.getDimSize(0)}, batchType.getElementType(), batchType.getEncoding());
|
||||
Value flat = tensor::CollapseShapeOp::create(rewriter, loc, flatType, input, collapseToFlat);
|
||||
Value compact = flat;
|
||||
if (compactKeptType != flatType)
|
||||
compact = tensor::ExpandShapeOp::create(rewriter, loc, compactKeptType, flat, expandFlatToCompact);
|
||||
Value keepdims = compact;
|
||||
if (keepdimsType != compactKeptType)
|
||||
keepdims =
|
||||
tensor::ExpandShapeOp::create(rewriter, loc, keepdimsType, compact, expandCompactToKeepdims);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, keepdims);
|
||||
});
|
||||
return reshapeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> reducedAxes) {
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
ReassociationIndices currentGroup;
|
||||
@@ -72,56 +229,6 @@ static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<boo
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
static Value
|
||||
createAverageCompute(Value input, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) {
|
||||
auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, avgOp.getResult());
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isCompileTimeComputable))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value buildReduceMeanKeepdims(Value input,
|
||||
ArrayRef<bool> reducedAxes,
|
||||
int64_t axis,
|
||||
RankedTensorType leafType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
int64_t rank = cast<RankedTensorType>(input.getType()).getRank();
|
||||
if (axis == rank)
|
||||
return createAverageCompute(input, leafType, rewriter, loc);
|
||||
|
||||
if (reducedAxes[axis])
|
||||
return buildReduceMeanKeepdims(input, reducedAxes, axis + 1, leafType, rewriter, loc);
|
||||
|
||||
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
|
||||
SmallVector<Value> reducedSlices;
|
||||
reducedSlices.reserve(slices.size());
|
||||
for (Value slice : slices)
|
||||
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
|
||||
|
||||
return concatValues(reducedSlices, axis, rewriter, loc);
|
||||
}
|
||||
|
||||
static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<bool> reducedAxes,
|
||||
@@ -156,16 +263,33 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||
auto resultType = dyn_cast<RankedTensorType>(reduceMeanOp.getReduced().getType());
|
||||
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
if (inputType.getRank() == 0) {
|
||||
rewriter.replaceOp(reduceMeanOp, adaptor.getData());
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<int64_t> axes = normalizeAxes(reduceMeanOp.getAxesAttr(), inputType.getRank());
|
||||
SmallVector<bool> reducedAxes = buildReducedAxesMask(axes, inputType.getRank());
|
||||
auto axes = normalizeAxesChecked(reduceMeanOp.getAxesAttr(), inputType.getRank());
|
||||
if (failed(axes))
|
||||
return failure();
|
||||
SmallVector<bool> reducedAxes = buildReducedAxesMask(*axes, inputType.getRank());
|
||||
if (reducedAxes.empty() && inputType.getRank() != 0)
|
||||
return failure();
|
||||
|
||||
Location loc = reduceMeanOp.getLoc();
|
||||
RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType());
|
||||
RankedTensorType compactKeptType = getCompactKeptType(inputType, resultType.getElementType(), reducedAxes);
|
||||
RankedTensorType keepdimsType = getKeepdimsType(inputType, resultType.getElementType(), reducedAxes);
|
||||
int64_t laneCount = 1;
|
||||
for (int64_t dim : compactKeptType.getShape())
|
||||
laneCount *= dim;
|
||||
RankedTensorType batchType = getLanePackedKeepdimsType(laneCount, leafType);
|
||||
|
||||
auto lanePackedKeepdims =
|
||||
buildReduceMeanKeepdimsBatch(adaptor.getData(), reducedAxes, batchType, leafType, rewriter, loc);
|
||||
if (failed(lanePackedKeepdims))
|
||||
return failure();
|
||||
Value reducedKeepdims =
|
||||
buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc);
|
||||
buildKeepdimsFromLanePackedBatch(*lanePackedKeepdims, keepdimsType, compactKeptType, reducedAxes, rewriter, loc);
|
||||
|
||||
if (reduceMeanOp.getKeepdims() != 0) {
|
||||
rewriter.replaceOp(reduceMeanOp, reducedKeepdims);
|
||||
|
||||
Reference in New Issue
Block a user