less affine code and better affine helpers
Validate Operations / validate-operations (push) Waiting to run
Validate Operations / validate-operations (push) Waiting to run
This commit is contained in:
@@ -1184,48 +1184,6 @@ static Value createZeroPaddedTensor(Value value,
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
static Value affineAddConst(
|
||||
PatternRewriter& rewriter, Location loc, Value value, int64_t offset, Operation* constantAnchor) {
|
||||
if (offset == 0)
|
||||
return value;
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
static Value affineMulConst(
|
||||
PatternRewriter& rewriter, Location loc, Value value, int64_t factor, Operation* constantAnchor) {
|
||||
if (factor == 1)
|
||||
return value;
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createOrFoldAffineApply(rewriter, loc, d0 * factor, ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
static Value affineFloorDivConst(
|
||||
PatternRewriter& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
|
||||
assert(divisor > 0 && "expected positive affine floordiv divisor");
|
||||
if (divisor == 1)
|
||||
return value;
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
static Value affineModConst(
|
||||
PatternRewriter& rewriter, Location loc, Value value, int64_t modulus, Operation* constantAnchor) {
|
||||
assert(modulus > 0 && "expected positive affine mod divisor");
|
||||
if (modulus == 1)
|
||||
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createOrFoldAffineApply(rewriter, loc, d0 % modulus, ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
static Value createConvInputPatch(Value input,
|
||||
RankedTensorType patchType,
|
||||
Value batchIndex,
|
||||
@@ -2316,11 +2274,10 @@ static Value createIm2colRows(const ConvLoweringState& state,
|
||||
ValueRange {im2colInit},
|
||||
[&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
||||
Value im2colAcc = iterArgs.front();
|
||||
Value globalPatchIndex = affineAddConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, anchorOp);
|
||||
Value batchIndex =
|
||||
affineFloorDivConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp);
|
||||
affineAddFloorDivConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, plan.numPatchesPerBatch, anchorOp);
|
||||
Value batchPatchIndex =
|
||||
affineModConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp);
|
||||
affineAddModConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, plan.numPatchesPerBatch, anchorOp);
|
||||
Value outHeightIndex = affineFloorDivConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
|
||||
Value outWidthIndex = affineModConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp);
|
||||
Value inputHeightOffset =
|
||||
|
||||
Reference in New Issue
Block a user