Refactor + ReduceMean batched
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-29 15:57:13 +02:00
parent 832bd7f1f7
commit 819d8af0f7
27 changed files with 929 additions and 568 deletions
@@ -13,7 +13,7 @@
#include <limits>
#include <utility>
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
@@ -58,47 +58,16 @@ static Value transposeForSpatial(Value value,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
if (isCompileTimeComputable(value))
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return computeOp.getResult(0);
}
static Value createIndexConstant(ConversionPatternRewriter& rewriter, int64_t value) {
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
return getOrCreateHostIndexConstant(anchorOp, value, rewriter);
}
static Value
createAffineApply(ConversionPatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
return transposeMaybeInCompute(value, resultType, permutation, rewriter, loc);
}
static Value
multiplyIndexByConstant(Value value, int64_t multiplier, ConversionPatternRewriter& rewriter, Location loc) {
if (multiplier == 0)
return createIndexConstant(rewriter, 0);
if (multiplier == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value});
return onnx_mlir::multiplyIndexByConstant(rewriter, value.getDefiningOp(), value, multiplier);
}
static Value modIndexByConstant(Value value, int64_t divisor, ConversionPatternRewriter& rewriter, Location loc) {
if (divisor == 1)
return createIndexConstant(rewriter, 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(rewriter, loc, d0 % divisor, ValueRange {value});
return onnx_mlir::modIndexByConstant(rewriter, loc, value, divisor);
}
static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) {
@@ -108,11 +77,11 @@ static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatter
static Value createGemmBatchKOffset(
Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) {
if (numKSlices == 1)
return createIndexConstant(rewriter, 0);
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(
return createAffineApplyOrConstant(
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
}
@@ -123,11 +92,11 @@ static Value createGemmBatchHOffset(Value lane,
ConversionPatternRewriter& rewriter,
Location loc) {
if (numOutHSlices == 1)
return createIndexConstant(rewriter, 0);
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(
return createAffineApplyOrConstant(
rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane});
}
@@ -303,53 +272,37 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t laneCount = partialPiecesType.getDimSize(0);
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
loc,
TypeRange {partialPiecesType},
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
ValueRange {b},
ValueRange {a});
auto batchOp = createSpatComputeBatch(
rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) {
Value row = createGemmBatchRow(args.lane, numOutRows, rewriter, loc);
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), paddedBType, aType, partialPiecesType};
SmallVector<Location> blockArgLocs(blockArgTypes.size(), loc);
Block* body =
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToEnd(body);
auto aTileType =
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
auto bTileType = RankedTensorType::get(
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
paddedBType.getElementType());
auto pieceType =
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
Value aTile = extractATile(args.inputs.front(), row, kOffset, aTileType, rewriter, loc);
auto lane = batchOp.getLaneArgument();
auto weight = batchOp.getWeightArgument(0);
auto input = batchOp.getInputArgument(0);
auto output = batchOp.getOutputArgument(0);
assert(lane && weight && input && output && "malformed Gemm compute_batch body");
SmallVector<OpFoldResult> bOffsets {kOffset, hOffset};
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
rewriter.getIndexAttr(crossbarSize.getValue())};
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
Value bTile =
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, args.weights.front(), bOffsets, bSizes, unitStrides)
.getResult();
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
Value row = createGemmBatchRow(*lane, numOutRows, rewriter, loc);
Value kOffset = createGemmBatchKOffset(*lane, numOutRows, numKSlices, rewriter, loc);
Value hOffset = createGemmBatchHOffset(*lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
auto aTileType = RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
auto bTileType = RankedTensorType::get(
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
paddedBType.getElementType());
auto pieceType =
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
Value aTile = extractATile(*input, row, kOffset, aTileType, rewriter, loc);
SmallVector<OpFoldResult> bOffsets {kOffset, hOffset};
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
rewriter.getIndexAttr(crossbarSize.getValue())};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value bTile =
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, *weight, bOffsets, bSizes, unitStrides).getResult();
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult> pieceOffsets {*lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
tensor::ParallelInsertSliceOp::create(rewriter, loc, piece, *output, pieceOffsets, pieceSizes, unitStrides);
rewriter.setInsertionPointAfter(batchOp);
return batchOp;
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, unitStrides);
});
assert(succeeded(batchOp) && "expected Gemm VMM batch construction to succeed");
return *batchOp;
}
static Value createDynamicGemmBatchRow(
@@ -359,7 +312,7 @@ static Value createDynamicGemmBatchRow(
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
}
static Value createDynamicGemmBatchColumn(
@@ -479,45 +432,27 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
const int64_t numOutCols = outType.getDimSize(1);
const int64_t reductionSize = aType.getDimSize(1);
const int64_t laneCount = numOutRows * numOutCols;
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
loc,
TypeRange {scalarPiecesType},
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
ValueRange {},
ValueRange {a, b});
auto batchOp = createSpatComputeBatch(
rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) {
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
Value column = createDynamicGemmBatchColumn(args.lane, numOutCols, rewriter, loc);
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), aType, bType, scalarPiecesType};
SmallVector<Location> blockArgLocs(blockArgTypes.size(), loc);
Block* body =
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToEnd(body);
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
Value bVector = bAlreadyTransposed
? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
auto lane = batchOp.getLaneArgument();
auto inputA = batchOp.getInputArgument(0);
auto inputB = batchOp.getInputArgument(1);
auto output = batchOp.getOutputArgument(0);
assert(lane && inputA && inputB && output && "malformed dynamic Gemm compute_batch body");
Value row = createDynamicGemmBatchRow(*lane, numOutCols, rewriter, loc);
Value column = createDynamicGemmBatchColumn(*lane, numOutCols, rewriter, loc);
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value aVector = extractDynamicGemmRowVector(*inputA, row, vectorType, rewriter, loc);
Value bVector = bAlreadyTransposed
? extractTransposedBRow(*inputB, column, vectorType, rewriter, loc)
: extractDynamicGemmBColumn(*inputB, column, vectorType, rewriter, loc);
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult> outputOffsets {*lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
tensor::ParallelInsertSliceOp::create(rewriter, loc, scalar, *output, outputOffsets, scalarSizes, unitStrides);
rewriter.setInsertionPointAfter(batchOp);
return batchOp;
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, unitStrides);
});
assert(succeeded(batchOp) && "expected Gemm VVDMul batch construction to succeed");
return *batchOp;
}
static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
@@ -540,9 +475,9 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
Value biasArg = bias ? blockArgs[1] : Value();
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
Value c0 = createIndexConstant(rewriter, 0);
Value c1 = createIndexConstant(rewriter, 1);
Value cLaneCount = createIndexConstant(rewriter, laneCount);
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cLaneCount = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
rewriter.setInsertionPointToStart(loop.getBody());
@@ -587,7 +522,8 @@ static Value createPartialGroupOffset(Value hSlice,
Location loc) {
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApply(rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
return createAffineApplyOrConstant(
rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
}
static Value extractReductionPiece(Value partialPiecesArg,
@@ -684,13 +620,13 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
Value paddedOutput = outputInit;
if (numOutHSlices == 1) {
Value hSlice = createIndexConstant(rewriter, 0);
Value hSlice = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
paddedOutput = buildOutputSlice(outputInit, hSlice);
}
else {
Value c0 = createIndexConstant(rewriter, 0);
Value c1 = createIndexConstant(rewriter, 1);
Value cOutHSlices = createIndexConstant(rewriter, numOutHSlices);
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cOutHSlices = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit});
rewriter.setInsertionPointToStart(hLoop.getBody());