finish helper refactoring
Validate Operations / validate-operations (push) Has been cancelled

use uniqued constant helpers everywhere
materialize transposed constants directly
This commit is contained in:
NiccoloN
2026-05-29 17:05:45 +02:00
parent 819d8af0f7
commit 8bb0babf1b
32 changed files with 300 additions and 467 deletions
@@ -50,38 +50,17 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr
return failure();
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
}
static Value transposeForSpatial(Value value,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
return transposeMaybeInCompute(value, resultType, permutation, rewriter, loc);
}
static Value
multiplyIndexByConstant(Value value, int64_t multiplier, ConversionPatternRewriter& rewriter, Location loc) {
return onnx_mlir::multiplyIndexByConstant(rewriter, value.getDefiningOp(), value, multiplier);
}
static Value modIndexByConstant(Value value, int64_t divisor, ConversionPatternRewriter& rewriter, Location loc) {
return onnx_mlir::modIndexByConstant(rewriter, loc, value, divisor);
}
static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) {
return modIndexByConstant(lane, numOutRows, rewriter, loc);
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scaledAttr, denseAttr.getType());
}
static Value createGemmBatchKOffset(
Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) {
if (numKSlices == 1)
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant(
return createAffineApplyOrFoldedConstant(
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
}
@@ -92,11 +71,11 @@ static Value createGemmBatchHOffset(Value lane,
ConversionPatternRewriter& rewriter,
Location loc) {
if (numOutHSlices == 1)
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant(
return createAffineApplyOrFoldedConstant(
rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane});
}
@@ -115,9 +94,9 @@ createZeroPaddedTensor(Value value, RankedTensorType resultType, ConversionPatte
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = arith::ConstantOp::create(
rewriter, loc, sourceType.getElementType(), rewriter.getZeroAttr(sourceType.getElementType()));
tensor::YieldOp::create(rewriter, loc, zero.getResult());
auto zero = getOrCreateConstant(
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
@@ -149,7 +128,7 @@ static FailureOr<Value> materializePaddedConstantMatrix(Value value,
resultValues[row * resultShape[1] + col] = sourceValues[row * sourceShape[1] + col];
auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
return arith::ConstantOp::create(rewriter, loc, resultType, resultAttr).getResult();
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
}
static FailureOr<Value> materializePaddedBroadcastedConstantTensor(Value value,
@@ -215,7 +194,7 @@ static FailureOr<Value> materializePaddedBroadcastedConstantTensor(Value value,
}
auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
return arith::ConstantOp::create(rewriter, loc, resultType, resultAttr).getResult();
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
}
static FailureOr<Value> prepareBias(Value c,
@@ -274,7 +253,7 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
const int64_t laneCount = partialPiecesType.getDimSize(0);
auto batchOp = createSpatComputeBatch(
rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) {
Value row = createGemmBatchRow(args.lane, numOutRows, rewriter, loc);
Value row = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutRows);
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
@@ -312,12 +291,7 @@ static Value createDynamicGemmBatchRow(
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
}
static Value createDynamicGemmBatchColumn(
Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
return modIndexByConstant(lane, numOutCols, rewriter, loc);
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
}
static Value
@@ -385,7 +359,7 @@ static Value createScalarTensorConstant(RankedTensorType scalarType,
auto elementType = scalarType.getElementType();
auto scalarAttr = rewriter.getFloatAttr(elementType, value);
auto denseAttr = DenseElementsAttr::get(scalarType, scalarAttr);
return arith::ConstantOp::create(rewriter, loc, scalarType, denseAttr).getResult();
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), denseAttr, scalarType);
}
static Value createBroadcastedBiasScalar(Value bias,
@@ -435,7 +409,7 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
auto batchOp = createSpatComputeBatch(
rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) {
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
Value column = createDynamicGemmBatchColumn(args.lane, numOutCols, rewriter, loc);
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutCols);
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
@@ -475,16 +449,16 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
Value biasArg = bias ? blockArgs[1] : Value();
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cLaneCount = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
rewriter.setInsertionPointToStart(loop.getBody());
Value lane = loop.getInductionVar();
Value outputAcc = loop.getRegionIterArgs().front();
Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc);
Value column = createDynamicGemmBatchColumn(lane, numOutCols, rewriter, loc);
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, lane, numOutCols);
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
@@ -522,7 +496,7 @@ static Value createPartialGroupOffset(Value hSlice,
Location loc) {
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createAffineApplyOrConstant(
return createAffineApplyOrFoldedConstant(
rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
}
@@ -604,7 +578,9 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value {
Value reduced =
reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc);
Value hOffset = multiplyIndexByConstant(hSlice, crossbarSize.getValue(), rewriter, loc);
Value hOffset =
onnx_mlir::multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice,
crossbarSize.getValue());
if (biasArg) {
SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset};
Value biasSlice =
@@ -620,13 +596,14 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
Value paddedOutput = outputInit;
if (numOutHSlices == 1) {
Value hSlice = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value hSlice = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
paddedOutput = buildOutputSlice(outputInit, hSlice);
}
else {
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cOutHSlices = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cOutHSlices =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit});
rewriter.setInsertionPointToStart(hLoop.getBody());
@@ -763,7 +740,7 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
if (gemmOpAdaptor.getTransB()) {
auto bShape = bType.getShape();
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc);
b = transposeMaybeInCompute(b, transposedType, {1, 0}, rewriter, loc);
bType = cast<RankedTensorType>(b.getType());
}