Refactor + ReduceMean batched
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-29 15:57:13 +02:00
parent 832bd7f1f7
commit 819d8af0f7
27 changed files with 929 additions and 568 deletions
@@ -28,8 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
ConversionPatternRewriter& rewriter) const override;
};
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
auto biasType = cast<RankedTensorType>(bias.getType());
if (biasType.getRank() != 1)
@@ -615,10 +613,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
return failure();
}
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
const int64_t strideHeight = getOptionalI64Attr(stridesAttr, 0, 1);
const int64_t strideWidth = getOptionalI64Attr(stridesAttr, 1, 1);
const int64_t dilationHeight = getOptionalI64Attr(dilationsAttr, 0, 1);
const int64_t dilationWidth = getOptionalI64Attr(dilationsAttr, 1, 1);
int64_t padHeightBegin = 0;
int64_t padHeightEnd = 0;
@@ -626,10 +624,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
int64_t padWidthEnd = 0;
if (padsAttr) {
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
padHeightBegin = getI64Attr(*padsAttr, 0);
padWidthBegin = getI64Attr(*padsAttr, 1);
padHeightEnd = getI64Attr(*padsAttr, 2);
padWidthEnd = getI64Attr(*padsAttr, 3);
}
else {
// Compute padding from auto_pad attribute
@@ -13,7 +13,7 @@
#include <limits>
#include <utility>
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
@@ -58,47 +58,16 @@ static Value transposeForSpatial(Value value,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
if (isCompileTimeComputable(value))
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return computeOp.getResult(0);
}
static Value createIndexConstant(ConversionPatternRewriter& rewriter, int64_t value) {
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
return getOrCreateHostIndexConstant(anchorOp, value, rewriter);
}
static Value
createAffineApply(ConversionPatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
return transposeMaybeInCompute(value, resultType, permutation, rewriter, loc);
}
static Value
multiplyIndexByConstant(Value value, int64_t multiplier, ConversionPatternRewriter& rewriter, Location loc) {
if (multiplier == 0)
return createIndexConstant(rewriter, 0);
if (multiplier == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value});
return onnx_mlir::multiplyIndexByConstant(rewriter, value.getDefiningOp(), value, multiplier);
}
static Value modIndexByConstant(Value value, int64_t divisor, ConversionPatternRewriter& rewriter, Location loc) {
if (divisor == 1)
return createIndexConstant(rewriter, 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(rewriter, loc, d0 % divisor, ValueRange {value});
return onnx_mlir::modIndexByConstant(rewriter, loc, value, divisor);
}
static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) {
@@ -108,11 +77,11 @@ static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatter
static Value createGemmBatchKOffset(
Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) {
if (numKSlices == 1)
return createIndexConstant(rewriter, 0);
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(
return createAffineApplyOrConstant(
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
}
@@ -123,11 +92,11 @@ static Value createGemmBatchHOffset(Value lane,
ConversionPatternRewriter& rewriter,
Location loc) {
if (numOutHSlices == 1)
return createIndexConstant(rewriter, 0);
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(
return createAffineApplyOrConstant(
rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane});
}
@@ -303,53 +272,37 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t laneCount = partialPiecesType.getDimSize(0);
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
loc,
TypeRange {partialPiecesType},
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
ValueRange {b},
ValueRange {a});
auto batchOp = createSpatComputeBatch(
rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) {
Value row = createGemmBatchRow(args.lane, numOutRows, rewriter, loc);
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), paddedBType, aType, partialPiecesType};
SmallVector<Location> blockArgLocs(blockArgTypes.size(), loc);
Block* body =
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToEnd(body);
auto aTileType =
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
auto bTileType = RankedTensorType::get(
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
paddedBType.getElementType());
auto pieceType =
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
Value aTile = extractATile(args.inputs.front(), row, kOffset, aTileType, rewriter, loc);
auto lane = batchOp.getLaneArgument();
auto weight = batchOp.getWeightArgument(0);
auto input = batchOp.getInputArgument(0);
auto output = batchOp.getOutputArgument(0);
assert(lane && weight && input && output && "malformed Gemm compute_batch body");
SmallVector<OpFoldResult> bOffsets {kOffset, hOffset};
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
rewriter.getIndexAttr(crossbarSize.getValue())};
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
Value bTile =
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, args.weights.front(), bOffsets, bSizes, unitStrides)
.getResult();
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
Value row = createGemmBatchRow(*lane, numOutRows, rewriter, loc);
Value kOffset = createGemmBatchKOffset(*lane, numOutRows, numKSlices, rewriter, loc);
Value hOffset = createGemmBatchHOffset(*lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
auto aTileType = RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
auto bTileType = RankedTensorType::get(
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
paddedBType.getElementType());
auto pieceType =
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
Value aTile = extractATile(*input, row, kOffset, aTileType, rewriter, loc);
SmallVector<OpFoldResult> bOffsets {kOffset, hOffset};
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
rewriter.getIndexAttr(crossbarSize.getValue())};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value bTile =
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, *weight, bOffsets, bSizes, unitStrides).getResult();
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult> pieceOffsets {*lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
tensor::ParallelInsertSliceOp::create(rewriter, loc, piece, *output, pieceOffsets, pieceSizes, unitStrides);
rewriter.setInsertionPointAfter(batchOp);
return batchOp;
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, unitStrides);
});
assert(succeeded(batchOp) && "expected Gemm VMM batch construction to succeed");
return *batchOp;
}
static Value createDynamicGemmBatchRow(
@@ -359,7 +312,7 @@ static Value createDynamicGemmBatchRow(
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
}
static Value createDynamicGemmBatchColumn(
@@ -479,45 +432,27 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
const int64_t numOutCols = outType.getDimSize(1);
const int64_t reductionSize = aType.getDimSize(1);
const int64_t laneCount = numOutRows * numOutCols;
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
loc,
TypeRange {scalarPiecesType},
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
ValueRange {},
ValueRange {a, b});
auto batchOp = createSpatComputeBatch(
rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) {
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
Value column = createDynamicGemmBatchColumn(args.lane, numOutCols, rewriter, loc);
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), aType, bType, scalarPiecesType};
SmallVector<Location> blockArgLocs(blockArgTypes.size(), loc);
Block* body =
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToEnd(body);
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
Value bVector = bAlreadyTransposed
? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
auto lane = batchOp.getLaneArgument();
auto inputA = batchOp.getInputArgument(0);
auto inputB = batchOp.getInputArgument(1);
auto output = batchOp.getOutputArgument(0);
assert(lane && inputA && inputB && output && "malformed dynamic Gemm compute_batch body");
Value row = createDynamicGemmBatchRow(*lane, numOutCols, rewriter, loc);
Value column = createDynamicGemmBatchColumn(*lane, numOutCols, rewriter, loc);
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value aVector = extractDynamicGemmRowVector(*inputA, row, vectorType, rewriter, loc);
Value bVector = bAlreadyTransposed
? extractTransposedBRow(*inputB, column, vectorType, rewriter, loc)
: extractDynamicGemmBColumn(*inputB, column, vectorType, rewriter, loc);
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult> outputOffsets {*lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
tensor::ParallelInsertSliceOp::create(rewriter, loc, scalar, *output, outputOffsets, scalarSizes, unitStrides);
rewriter.setInsertionPointAfter(batchOp);
return batchOp;
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, unitStrides);
});
assert(succeeded(batchOp) && "expected Gemm VVDMul batch construction to succeed");
return *batchOp;
}
static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
@@ -540,9 +475,9 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
Value biasArg = bias ? blockArgs[1] : Value();
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
Value c0 = createIndexConstant(rewriter, 0);
Value c1 = createIndexConstant(rewriter, 1);
Value cLaneCount = createIndexConstant(rewriter, laneCount);
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cLaneCount = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
rewriter.setInsertionPointToStart(loop.getBody());
@@ -587,7 +522,8 @@ static Value createPartialGroupOffset(Value hSlice,
Location loc) {
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
return createAffineApplyOrConstant(
rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
}
static Value extractReductionPiece(Value partialPiecesArg,
@@ -684,13 +620,13 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
Value paddedOutput = outputInit;
if (numOutHSlices == 1) {
Value hSlice = createIndexConstant(rewriter, 0);
Value hSlice = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
paddedOutput = buildOutputSlice(outputInit, hSlice);
}
else {
Value c0 = createIndexConstant(rewriter, 0);
Value c1 = createIndexConstant(rewriter, 1);
Value cOutHSlices = createIndexConstant(rewriter, numOutHSlices);
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cOutHSlices = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit});
rewriter.setInsertionPointToStart(hLoop.getBody());
@@ -19,14 +19,6 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
ArrayRef<int64_t> rhsBatchShape) {
if (lhsBatchShape.empty())
@@ -54,15 +46,7 @@ collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, Pa
auto buildCollapsed = [&](Value input) -> Value {
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
};
if (isCompileTimeComputable(value))
return buildCollapsed(value);
auto collapseCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
});
return collapseCompute.getResult(0);
return materializeOrComputeUnary(value, collapsedType, rewriter, loc, buildCollapsed);
}
static Value
@@ -76,12 +60,10 @@ expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, Patt
for (size_t dim = 0; dim < batchRank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
auto expandCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
spatial::SpatYieldOp::create(rewriter, loc, expanded);
});
return expandCompute.getResult(0);
auto buildExpanded = [&](Value input) -> Value {
return tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation).getResult();
};
return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded);
}
static Value extractBatchMatrix(Value value,
@@ -100,7 +82,7 @@ static Value extractBatchMatrix(Value value,
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, 3);
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
auto buildMatrix = [&](Value input) -> Value {
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
@@ -114,14 +96,7 @@ static Value extractBatchMatrix(Value value,
});
};
if (isCompileTimeComputable(value))
return buildMatrix(value);
auto batchMatrixCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
});
return batchMatrixCompute.getResult(0);
return materializeOrComputeUnary(value, matrixType, rewriter, loc, buildMatrix);
}
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
@@ -138,18 +113,7 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
perm = {0, 2, 1};
}
auto buildTranspose = [&](Value input) -> Value {
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
};
if (isCompileTimeComputable(value))
return buildTranspose(value);
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
});
return transposeCompute.getResult(0);
return transposeMaybeInCompute(value, transposedType, perm, rewriter, loc);
}
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
@@ -166,10 +130,11 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite
perm = {0, 2, 1};
}
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return transposeCompute.getResult(0);
}
@@ -203,8 +168,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
return failure();
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
return failure();
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|| !haveStaticPositiveShape(outType.getShape()))
if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType))
return failure();
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
@@ -1,9 +1,11 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <numeric>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
@@ -16,26 +18,6 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes;
if (!axesAttr) {
normalizedAxes.reserve(rank);
for (int64_t axis = 0; axis < rank; axis++)
normalizedAxes.push_back(axis);
return normalizedAxes;
}
normalizedAxes.reserve(axesAttr.size());
for (Attribute attr : axesAttr) {
int64_t axis = cast<IntegerAttr>(attr).getInt();
normalizedAxes.push_back(axis >= 0 ? axis : rank + axis);
}
llvm::sort(normalizedAxes);
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
return normalizedAxes;
}
static SmallVector<bool> buildReducedAxesMask(ArrayRef<int64_t> axes, int64_t rank) {
SmallVector<bool> reducedAxes(rank, false);
for (int64_t axis : axes) {
@@ -50,6 +32,181 @@ static RankedTensorType getAllOnesType(RankedTensorType inputType, Type elementT
return RankedTensorType::get(SmallVector<int64_t>(inputType.getRank(), 1), elementType);
}
static RankedTensorType getKeepdimsType(RankedTensorType inputType, Type elementType, ArrayRef<bool> reducedAxes) {
SmallVector<int64_t> shape;
shape.reserve(inputType.getRank());
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
shape.push_back(isReduced ? 1 : dim);
return RankedTensorType::get(shape, elementType, inputType.getEncoding());
}
static RankedTensorType getCompactKeptType(RankedTensorType inputType, Type elementType, ArrayRef<bool> reducedAxes) {
SmallVector<int64_t> shape;
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
if (!isReduced)
shape.push_back(dim);
return RankedTensorType::get(shape, elementType, inputType.getEncoding());
}
static RankedTensorType getReducedSliceType(RankedTensorType inputType, ArrayRef<bool> reducedAxes) {
SmallVector<int64_t> shape;
shape.reserve(inputType.getRank());
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
shape.push_back(isReduced ? dim : 1);
return RankedTensorType::get(shape, inputType.getElementType(), inputType.getEncoding());
}
static RankedTensorType getLanePackedKeepdimsType(int64_t laneCount, RankedTensorType leafType) {
SmallVector<int64_t> shape(leafType.getShape().begin(), leafType.getShape().end());
shape.front() = laneCount;
return RankedTensorType::get(shape, leafType.getElementType(), leafType.getEncoding());
}
static SmallVector<int64_t> getKeptAxes(ArrayRef<bool> reducedAxes) {
SmallVector<int64_t> keptAxes;
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes))
if (!isReduced)
keptAxes.push_back(static_cast<int64_t>(axis));
return keptAxes;
}
static Value computeLaneIndex(Value lane,
int64_t stride,
int64_t dimSize,
ConversionPatternRewriter& rewriter,
Location loc) {
if (dimSize == 1)
return arith::ConstantIndexOp::create(rewriter, loc, 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
AffineExpr expr = d0;
if (stride != 1)
expr = expr.floorDiv(stride);
if (dimSize != 1)
expr = expr % dimSize;
return createAffineApplyOrConstant(rewriter, loc, expr, ValueRange {lane});
}
static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,
ArrayRef<bool> reducedAxes,
RankedTensorType batchType,
RankedTensorType leafType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
auto sliceType = getReducedSliceType(inputType, reducedAxes);
SmallVector<int64_t> keptAxes = getKeptAxes(reducedAxes);
int64_t laneCount = 1;
SmallVector<int64_t> keptAxisStrides(keptAxes.size(), 1);
for (int64_t index = static_cast<int64_t>(keptAxes.size()) - 1; index >= 0; --index) {
keptAxisStrides[index] = laneCount;
int64_t dimSize = inputType.getDimSize(keptAxes[index]);
if (dimSize <= 0)
return failure();
if (laneCount > std::numeric_limits<int32_t>::max() / dimSize)
return failure();
laneCount *= dimSize;
}
SmallVector<OpFoldResult> sliceOffsets;
SmallVector<OpFoldResult> sliceSizes;
SmallVector<OpFoldResult> insertOffsets;
SmallVector<OpFoldResult> insertSizes(inputType.getRank(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, inputType.getRank());
sliceOffsets.reserve(inputType.getRank());
sliceSizes.reserve(inputType.getRank());
insertOffsets.reserve(inputType.getRank());
auto batchOp = createSpatComputeBatch(
rewriter, loc, TypeRange {batchType}, laneCount, {}, ValueRange {input}, [&](detail::SpatComputeBatchBodyArgs args) {
size_t keptAxisIndex = 0;
sliceOffsets.clear();
sliceSizes.clear();
insertOffsets.clear();
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
if (isReduced) {
sliceOffsets.push_back(rewriter.getIndexAttr(0));
sliceSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(axis)));
continue;
}
Value axisIndex =
computeLaneIndex(args.lane, keptAxisStrides[keptAxisIndex], inputType.getDimSize(axis), rewriter, loc);
++keptAxisIndex;
sliceOffsets.push_back(axisIndex);
sliceSizes.push_back(rewriter.getIndexAttr(1));
}
insertOffsets.push_back(args.lane);
insertOffsets.append(inputType.getRank() - 1, rewriter.getIndexAttr(0));
Value slice =
tensor::ExtractSliceOp::create(rewriter, loc, sliceType, args.inputs.front(), sliceOffsets, sliceSizes, unitStrides);
Value reduced = spatial::SpatVAvgOp::create(rewriter, loc, leafType, slice).getResult();
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, reduced, args.outputs.front(), insertOffsets, insertSizes, unitStrides);
});
if (failed(batchOp))
return failure();
return (*batchOp).getResult(0);
}
static Value buildKeepdimsFromLanePackedBatch(Value batchValue,
RankedTensorType keepdimsType,
RankedTensorType compactKeptType,
ArrayRef<bool> reducedAxes,
ConversionPatternRewriter& rewriter,
Location loc) {
auto batchType = cast<RankedTensorType>(batchValue.getType());
if (batchType == keepdimsType)
return batchValue;
SmallVector<ReassociationIndices> collapseToFlat {{}};
for (int64_t axis = 0; axis < batchType.getRank(); ++axis)
collapseToFlat.front().push_back(axis);
SmallVector<ReassociationIndices> expandFlatToCompact(1);
for (int64_t axis = 0; axis < compactKeptType.getRank(); ++axis)
expandFlatToCompact.front().push_back(axis);
SmallVector<ReassociationIndices> expandCompactToKeepdims;
ReassociationIndices pendingLeadingReducedAxes;
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
if (isReduced) {
if (expandCompactToKeepdims.empty())
pendingLeadingReducedAxes.push_back(axis);
else
expandCompactToKeepdims.back().push_back(axis);
continue;
}
expandCompactToKeepdims.emplace_back();
auto& group = expandCompactToKeepdims.back();
group.append(pendingLeadingReducedAxes.begin(), pendingLeadingReducedAxes.end());
pendingLeadingReducedAxes.clear();
group.push_back(axis);
}
if (!pendingLeadingReducedAxes.empty())
expandCompactToKeepdims.back().append(pendingLeadingReducedAxes.begin(), pendingLeadingReducedAxes.end());
auto reshapeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {keepdimsType}, {}, ValueRange {batchValue}, [&](Value input) {
auto flatType = RankedTensorType::get({batchType.getDimSize(0)}, batchType.getElementType(), batchType.getEncoding());
Value flat = tensor::CollapseShapeOp::create(rewriter, loc, flatType, input, collapseToFlat);
Value compact = flat;
if (compactKeptType != flatType)
compact = tensor::ExpandShapeOp::create(rewriter, loc, compactKeptType, flat, expandFlatToCompact);
Value keepdims = compact;
if (keepdimsType != compactKeptType)
keepdims =
tensor::ExpandShapeOp::create(rewriter, loc, keepdimsType, compact, expandCompactToKeepdims);
spatial::SpatYieldOp::create(rewriter, loc, keepdims);
});
return reshapeCompute.getResult(0);
}
static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> reducedAxes) {
SmallVector<ReassociationIndices> reassociation;
ReassociationIndices currentGroup;
@@ -72,56 +229,6 @@ static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<boo
return reassociation;
}
static Value
createAverageCompute(Value input, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) {
auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x);
spatial::SpatYieldOp::create(rewriter, loc, avgOp.getResult());
});
return computeOp.getResult(0);
}
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isCompileTimeComputable))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
static Value buildReduceMeanKeepdims(Value input,
ArrayRef<bool> reducedAxes,
int64_t axis,
RankedTensorType leafType,
ConversionPatternRewriter& rewriter,
Location loc) {
int64_t rank = cast<RankedTensorType>(input.getType()).getRank();
if (axis == rank)
return createAverageCompute(input, leafType, rewriter, loc);
if (reducedAxes[axis])
return buildReduceMeanKeepdims(input, reducedAxes, axis + 1, leafType, rewriter, loc);
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
SmallVector<Value> reducedSlices;
reducedSlices.reserve(slices.size());
for (Value slice : slices)
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
return concatValues(reducedSlices, axis, rewriter, loc);
}
static Value squeezeReducedAxes(Value keepdimsValue,
RankedTensorType resultType,
ArrayRef<bool> reducedAxes,
@@ -156,16 +263,33 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
auto resultType = dyn_cast<RankedTensorType>(reduceMeanOp.getReduced().getType());
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
if (inputType.getRank() == 0) {
rewriter.replaceOp(reduceMeanOp, adaptor.getData());
return success();
}
SmallVector<int64_t> axes = normalizeAxes(reduceMeanOp.getAxesAttr(), inputType.getRank());
SmallVector<bool> reducedAxes = buildReducedAxesMask(axes, inputType.getRank());
auto axes = normalizeAxesChecked(reduceMeanOp.getAxesAttr(), inputType.getRank());
if (failed(axes))
return failure();
SmallVector<bool> reducedAxes = buildReducedAxesMask(*axes, inputType.getRank());
if (reducedAxes.empty() && inputType.getRank() != 0)
return failure();
Location loc = reduceMeanOp.getLoc();
RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType());
RankedTensorType compactKeptType = getCompactKeptType(inputType, resultType.getElementType(), reducedAxes);
RankedTensorType keepdimsType = getKeepdimsType(inputType, resultType.getElementType(), reducedAxes);
int64_t laneCount = 1;
for (int64_t dim : compactKeptType.getShape())
laneCount *= dim;
RankedTensorType batchType = getLanePackedKeepdimsType(laneCount, leafType);
auto lanePackedKeepdims =
buildReduceMeanKeepdimsBatch(adaptor.getData(), reducedAxes, batchType, leafType, rewriter, loc);
if (failed(lanePackedKeepdims))
return failure();
Value reducedKeepdims =
buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc);
buildKeepdimsFromLanePackedBatch(*lanePackedKeepdims, keepdimsType, compactKeptType, reducedAxes, rewriter, loc);
if (reduceMeanOp.getKeepdims() != 0) {
rewriter.replaceOp(reduceMeanOp, reducedKeepdims);