use uniqued constant helpers everywhere materialize transposed constants directly
This commit is contained in:
@@ -50,38 +50,17 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr
|
||||
return failure();
|
||||
|
||||
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
||||
}
|
||||
|
||||
static Value transposeForSpatial(Value value,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<int64_t> permutation,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
return transposeMaybeInCompute(value, resultType, permutation, rewriter, loc);
|
||||
}
|
||||
|
||||
static Value
|
||||
multiplyIndexByConstant(Value value, int64_t multiplier, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
return onnx_mlir::multiplyIndexByConstant(rewriter, value.getDefiningOp(), value, multiplier);
|
||||
}
|
||||
|
||||
static Value modIndexByConstant(Value value, int64_t divisor, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
return onnx_mlir::modIndexByConstant(rewriter, loc, value, divisor);
|
||||
}
|
||||
|
||||
static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
return modIndexByConstant(lane, numOutRows, rewriter, loc);
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scaledAttr, denseAttr.getType());
|
||||
}
|
||||
|
||||
static Value createGemmBatchKOffset(
|
||||
Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (numKSlices == 1)
|
||||
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrConstant(
|
||||
return createAffineApplyOrFoldedConstant(
|
||||
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
|
||||
}
|
||||
|
||||
@@ -92,11 +71,11 @@ static Value createGemmBatchHOffset(Value lane,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (numOutHSlices == 1)
|
||||
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrConstant(
|
||||
return createAffineApplyOrFoldedConstant(
|
||||
rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane});
|
||||
}
|
||||
|
||||
@@ -115,9 +94,9 @@ createZeroPaddedTensor(Value value, RankedTensorType resultType, ConversionPatte
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(
|
||||
rewriter, loc, sourceType.getElementType(), rewriter.getZeroAttr(sourceType.getElementType()));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
@@ -149,7 +128,7 @@ static FailureOr<Value> materializePaddedConstantMatrix(Value value,
|
||||
resultValues[row * resultShape[1] + col] = sourceValues[row * sourceShape[1] + col];
|
||||
|
||||
auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, resultType, resultAttr).getResult();
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializePaddedBroadcastedConstantTensor(Value value,
|
||||
@@ -215,7 +194,7 @@ static FailureOr<Value> materializePaddedBroadcastedConstantTensor(Value value,
|
||||
}
|
||||
|
||||
auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, resultType, resultAttr).getResult();
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
|
||||
}
|
||||
|
||||
static FailureOr<Value> prepareBias(Value c,
|
||||
@@ -274,7 +253,7 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
|
||||
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) {
|
||||
Value row = createGemmBatchRow(args.lane, numOutRows, rewriter, loc);
|
||||
Value row = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutRows);
|
||||
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
|
||||
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
||||
|
||||
@@ -312,12 +291,7 @@ static Value createDynamicGemmBatchRow(
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
|
||||
}
|
||||
|
||||
static Value createDynamicGemmBatchColumn(
|
||||
Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
return modIndexByConstant(lane, numOutCols, rewriter, loc);
|
||||
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
|
||||
}
|
||||
|
||||
static Value
|
||||
@@ -385,7 +359,7 @@ static Value createScalarTensorConstant(RankedTensorType scalarType,
|
||||
auto elementType = scalarType.getElementType();
|
||||
auto scalarAttr = rewriter.getFloatAttr(elementType, value);
|
||||
auto denseAttr = DenseElementsAttr::get(scalarType, scalarAttr);
|
||||
return arith::ConstantOp::create(rewriter, loc, scalarType, denseAttr).getResult();
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), denseAttr, scalarType);
|
||||
}
|
||||
|
||||
static Value createBroadcastedBiasScalar(Value bias,
|
||||
@@ -435,7 +409,7 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) {
|
||||
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
|
||||
Value column = createDynamicGemmBatchColumn(args.lane, numOutCols, rewriter, loc);
|
||||
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutCols);
|
||||
|
||||
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
@@ -475,16 +449,16 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
|
||||
Value biasArg = bias ? blockArgs[1] : Value();
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
|
||||
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cLaneCount = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
Value lane = loop.getInductionVar();
|
||||
Value outputAcc = loop.getRegionIterArgs().front();
|
||||
Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc);
|
||||
Value column = createDynamicGemmBatchColumn(lane, numOutCols, rewriter, loc);
|
||||
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, lane, numOutCols);
|
||||
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
@@ -522,7 +496,7 @@ static Value createPartialGroupOffset(Value hSlice,
|
||||
Location loc) {
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrConstant(
|
||||
return createAffineApplyOrFoldedConstant(
|
||||
rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
|
||||
}
|
||||
|
||||
@@ -604,7 +578,9 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
|
||||
auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value {
|
||||
Value reduced =
|
||||
reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc);
|
||||
Value hOffset = multiplyIndexByConstant(hSlice, crossbarSize.getValue(), rewriter, loc);
|
||||
Value hOffset =
|
||||
onnx_mlir::multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice,
|
||||
crossbarSize.getValue());
|
||||
if (biasArg) {
|
||||
SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset};
|
||||
Value biasSlice =
|
||||
@@ -620,13 +596,14 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
|
||||
|
||||
Value paddedOutput = outputInit;
|
||||
if (numOutHSlices == 1) {
|
||||
Value hSlice = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value hSlice = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
paddedOutput = buildOutputSlice(outputInit, hSlice);
|
||||
}
|
||||
else {
|
||||
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cOutHSlices = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cOutHSlices =
|
||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
|
||||
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(hLoop.getBody());
|
||||
|
||||
@@ -763,7 +740,7 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
if (gemmOpAdaptor.getTransB()) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
||||
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc);
|
||||
b = transposeMaybeInCompute(b, transposedType, {1, 0}, rewriter, loc);
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user