use uniqued constant helpers everywhere materialize transposed constants directly
This commit is contained in:
@@ -40,14 +40,6 @@ static SmallVector<int64_t> normalizeAxesImpl(std::optional<ArrayAttr> axesAttr,
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
|
||||
return normalizeAxesImpl(std::optional<ArrayAttr>(axesAttr), rank);
|
||||
}
|
||||
|
||||
SmallVector<int64_t> normalizeAxes(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
||||
return normalizeAxesImpl(axesAttr, rank);
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
||||
SmallVector<int64_t> normalizedAxes = normalizeAxesImpl(axesAttr, rank);
|
||||
for (int64_t axis : normalizedAxes)
|
||||
@@ -56,11 +48,7 @@ FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> ax
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(ArrayAttr axesAttr, int64_t rank) {
|
||||
return normalizeAxesChecked(std::optional<ArrayAttr>(axesAttr), rank);
|
||||
}
|
||||
|
||||
Value createAffineApplyOrConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
|
||||
Value createAffineApplyOrFoldedConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
|
||||
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
|
||||
@@ -68,22 +56,22 @@ Value createAffineApplyOrConstant(PatternRewriter& rewriter, Location loc, Affin
|
||||
|
||||
Value multiplyIndexByConstant(PatternRewriter& rewriter, Operation* anchorOp, Value value, int64_t multiplier) {
|
||||
if (multiplier == 0)
|
||||
return getOrCreateHostIndexConstant(rewriter, anchorOp, 0);
|
||||
return getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||
if (multiplier == 1)
|
||||
return value;
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value});
|
||||
return createAffineApplyOrFoldedConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value});
|
||||
}
|
||||
|
||||
Value modIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
|
||||
if (divisor == 1)
|
||||
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrConstant(rewriter, loc, d0 % divisor, ValueRange {value});
|
||||
return createAffineApplyOrFoldedConstant(rewriter, loc, d0 % divisor, ValueRange {value});
|
||||
}
|
||||
|
||||
Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
|
||||
@@ -92,12 +80,12 @@ Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value val
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value});
|
||||
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value});
|
||||
}
|
||||
|
||||
Value getOrMaterializeIndexValue(PatternRewriter& rewriter, Location loc, OpFoldResult value) {
|
||||
Value getOrMaterializeIndexValue(PatternRewriter& rewriter, OpFoldResult value) {
|
||||
if (auto attr = dyn_cast<Attribute>(value))
|
||||
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
|
||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
|
||||
return cast<Value>(value);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user