robba
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
NiccoloN
2026-06-29 12:22:33 +02:00
parent 78e97f9fd8
commit e8f09fd67f
8 changed files with 1376 additions and 492 deletions
@@ -1334,6 +1334,38 @@ static Value affineAddConst(
return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor);
}
static Value affineMulConst(
PatternRewriter& rewriter, Location loc, Value value, int64_t factor, Operation* constantAnchor) {
if (factor == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createOrFoldAffineApply(rewriter, loc, d0 * factor, ValueRange {value}, constantAnchor);
}
static Value affineFloorDivConst(
PatternRewriter& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
assert(divisor > 0 && "expected positive affine floordiv divisor");
if (divisor == 1)
return value;
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
}
static Value affineModConst(
PatternRewriter& rewriter, Location loc, Value value, int64_t modulus, Operation* constantAnchor) {
assert(modulus > 0 && "expected positive affine mod divisor");
if (modulus == 1)
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
return createOrFoldAffineApply(rewriter, loc, d0 % modulus, ValueRange {value}, constantAnchor);
}
static Value createConvInputPatch(Value input,
RankedTensorType patchType,
Value batchIndex,
@@ -2414,11 +2446,6 @@ static Value createIm2colRows(const ConvLoweringState& state,
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkNumPatches);
Value cChunkStart = getOrCreateIndexConstant(rewriter, anchorOp, plan.chunkStart);
Value cNumPatchesPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, plan.numPatchesPerBatch);
Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth);
Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, state.strideHeight);
Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.strideWidth);
auto im2colLoop = buildNormalizedScfFor(
rewriter,
@@ -2429,13 +2456,17 @@ static Value createIm2colRows(const ConvLoweringState& state,
ValueRange {im2colInit},
[&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
Value im2colAcc = iterArgs.front();
Value globalPatchIndex = arith::AddIOp::create(rewriter, nestedLoc, patchIndex, cChunkStart);
Value batchIndex = arith::DivUIOp::create(rewriter, nestedLoc, globalPatchIndex, cNumPatchesPerBatch);
Value batchPatchIndex = arith::RemUIOp::create(rewriter, nestedLoc, globalPatchIndex, cNumPatchesPerBatch);
Value outHeightIndex = arith::DivUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth);
Value outWidthIndex = arith::RemUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutWidth);
Value inputHeightOffset = arith::MulIOp::create(rewriter, nestedLoc, outHeightIndex, cStrideHeight);
Value inputWidthOffset = arith::MulIOp::create(rewriter, nestedLoc, outWidthIndex, cStrideWidth);
Value globalPatchIndex = affineAddConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, anchorOp);
Value batchIndex =
affineFloorDivConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp);
Value batchPatchIndex =
affineModConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp);
Value outHeightIndex = affineFloorDivConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
Value outWidthIndex = affineModConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
Value inputHeightOffset =
affineMulConst(rewriter, nestedLoc, outHeightIndex, state.strideHeight, anchorOp);
Value inputWidthOffset =
affineMulConst(rewriter, nestedLoc, outWidthIndex, state.strideWidth, anchorOp);
auto patchType =
RankedTensorType::get({1, state.numChannelsIn, state.wHeight, state.wWidth}, elemType);
@@ -2844,11 +2875,9 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cNumKSlices = getOrCreateIndexConstant(rewriter, anchorOp, numKSlices);
Value cXbar = getOrCreateIndexConstant(rewriter, anchorOp, xbarDim);
Value cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.outWidth);
Value cNumChannels = getOrCreateIndexConstant(rewriter, anchorOp, state.numChannelsIn);
Value cPatchWidth = getOrCreateIndexConstant(rewriter, anchorOp, state.wHeight * state.wWidth);
Value localHeightOffset = arith::MulIOp::create(rewriter, loc, args.lane, c1);
Value localHeightOffset = args.lane;
Value packedRowInit =
tensor::EmptyOp::create(rewriter, loc, ArrayRef<int64_t> {1, state.outWidth, state.numChannelsOut}, elementType);
auto widthLoop = buildNormalizedScfFor(
@@ -2859,7 +2888,7 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
c1,
ValueRange {packedRowInit},
[&](OpBuilder&, Location widthLoc, Value widthIndex, ValueRange widthIterArgs, SmallVectorImpl<Value>& widthYielded) {
Value localWidthOffset = arith::MulIOp::create(rewriter, widthLoc, widthIndex, c1);
Value localWidthOffset = widthIndex;
Value rowInit = tensor::EmptyOp::create(rewriter, widthLoc, ArrayRef<int64_t> {1, patchSize}, elementType);
auto rowLoop = buildNormalizedScfFor(
rewriter,
@@ -2878,7 +2907,8 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
rewriter, rowLoc, flatPatchType, channelPatch, SmallVector<ReassociationIndices> {{0, 1, 2}});
Value rowChunk = tensor::ExpandShapeOp::create(
rewriter, rowLoc, rowChunkType, flatPatch, SmallVector<ReassociationIndices> {{0, 1}});
Value flatOffset = arith::MulIOp::create(rewriter, rowLoc, channel, cPatchWidth);
Value flatOffset = affineMulConst(
rewriter, rowLoc, channel, state.wHeight * state.wWidth, anchorOp);
SmallVector<OpFoldResult> rowOffsets {rewriter.getIndexAttr(0), flatOffset};
SmallVector<OpFoldResult> rowSizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(state.wHeight * state.wWidth)};
@@ -2905,7 +2935,7 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
c1,
ValueRange {zeroRow},
[&](OpBuilder&, Location reduceLoc, Value kSlice, ValueRange reduceIterArgs, SmallVectorImpl<Value>& reduceYielded) {
Value kOffset = arith::MulIOp::create(rewriter, reduceLoc, kSlice, cXbar);
Value kOffset = affineMulConst(rewriter, reduceLoc, kSlice, xbarDim, anchorOp);
SmallVector<OpFoldResult> aOffsets {rewriter.getIndexAttr(0), kOffset};
SmallVector<OpFoldResult> aSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(xbarDim)};
Value aTile = tensor::ExtractSliceOp::create(