use uniqued constant helpers everywhere materialize transposed constants directly
This commit is contained in:
@@ -31,17 +31,18 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
|
||||
|
||||
static Value
|
||||
createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
if (!useMinimumValue)
|
||||
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
|
||||
return getOrCreateConstant(rewriter, anchorOp, rewriter.getZeroAttr(elementType), elementType);
|
||||
|
||||
if (auto floatType = dyn_cast<FloatType>(elementType)) {
|
||||
auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true);
|
||||
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue));
|
||||
return getOrCreateConstant(rewriter, anchorOp, rewriter.getFloatAttr(floatType, minValue), elementType);
|
||||
}
|
||||
|
||||
if (auto integerType = dyn_cast<IntegerType>(elementType)) {
|
||||
auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth());
|
||||
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue));
|
||||
return getOrCreateConstant(rewriter, anchorOp, rewriter.getIntegerAttr(integerType, minValue), elementType);
|
||||
}
|
||||
|
||||
llvm_unreachable("unsupported pool element type");
|
||||
@@ -148,7 +149,7 @@ static FailureOr<Value> createAverageScaleTensor(ConversionPatternRewriter& rewr
|
||||
}
|
||||
|
||||
auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult();
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scaleAttr, scaleType);
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
@@ -265,13 +266,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
|
||||
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
|
||||
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount);
|
||||
Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth);
|
||||
Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth);
|
||||
Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
||||
Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
||||
Value cOutputPatchCount = getOrCreateIndexConstant(rewriter, anchorOp, outputPatchCount);
|
||||
Value cOutputPixelsPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, outputHeight * outputWidth);
|
||||
Value cOutputWidth = getOrCreateIndexConstant(rewriter, anchorOp, outputWidth);
|
||||
Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight);
|
||||
Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth);
|
||||
|
||||
auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit});
|
||||
rewriter.setInsertionPointToStart(outputLoop.getBody());
|
||||
@@ -296,14 +298,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||
Value paddedInH = windowBaseH;
|
||||
if (kernelH * dilationHeight != 0) {
|
||||
Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight);
|
||||
Value kernelHOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelH * dilationHeight);
|
||||
paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset);
|
||||
}
|
||||
|
||||
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||
Value paddedInW = windowBaseW;
|
||||
if (kernelW * dilationWidth != 0) {
|
||||
Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth);
|
||||
Value kernelWOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelW * dilationWidth);
|
||||
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user