use uniqued constant helpers everywhere materialize transposed constants directly
This commit is contained in:
@@ -61,9 +61,9 @@ static Value createPaddedRows(Value tensorValue,
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(
|
||||
rewriter, loc, tensorType.getElementType(), rewriter.getZeroAttr(tensorType.getElementType()));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()),
|
||||
tensorType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
@@ -106,7 +106,7 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
|
||||
}
|
||||
|
||||
auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr);
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType);
|
||||
}
|
||||
|
||||
static Value createConvWeightMatrix(Value w,
|
||||
@@ -158,7 +158,7 @@ static Value buildPackedBias(bool hasBias,
|
||||
|
||||
auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType());
|
||||
auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedBiasAttr, packedBiasType);
|
||||
}
|
||||
|
||||
static Value createIm2colRowComputes(Value x,
|
||||
@@ -214,8 +214,8 @@ static Value createIm2colRowComputes(Value x,
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getFloatAttr(elemType, 0.0), elemType);
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
paddedInput = padOp.getResult();
|
||||
}
|
||||
@@ -223,13 +223,14 @@ static Value createIm2colRowComputes(Value x,
|
||||
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
|
||||
// until the late PIM unrolling step.
|
||||
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
|
||||
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
|
||||
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
|
||||
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
|
||||
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
||||
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
auto c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||
auto c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
||||
auto cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, numPatches);
|
||||
auto cNumPatchesPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, numPatchesPerBatch);
|
||||
auto cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, outWidth);
|
||||
auto cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight);
|
||||
auto cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth);
|
||||
|
||||
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
|
||||
rewriter.setInsertionPointToStart(im2colLoop.getBody());
|
||||
|
||||
Reference in New Issue
Block a user