multiple-output spat computes
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m38s

This commit is contained in:
NiccoloN
2026-04-23 09:28:57 +02:00
parent 0f13269040
commit 412ca957f6
16 changed files with 415 additions and 420 deletions

View File

@@ -147,33 +147,37 @@ static Value buildPackedBias(bool hasBias,
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
}
static Value createIm2colCompute(Value x,
RankedTensorType xType,
RankedTensorType im2colType,
RankedTensorType rowType,
int64_t batchSize,
int64_t numChannelsIn,
int64_t xHeight,
int64_t xWidth,
int64_t wHeight,
int64_t wWidth,
int64_t padHeightBegin,
int64_t padHeightEnd,
int64_t padWidthBegin,
int64_t padWidthEnd,
int64_t strideHeight,
int64_t strideWidth,
int64_t dilationHeight,
int64_t dilationWidth,
int64_t outWidth,
int64_t patchSize,
int64_t numPatches,
int64_t numPatchesPerBatch,
ConversionPatternRewriter& rewriter,
Location loc) {
static SmallVector<Value> createIm2colRowComputes(Value x,
RankedTensorType xType,
RankedTensorType im2colType,
RankedTensorType im2colRowType,
RankedTensorType gemmInputRowType,
int64_t batchSize,
int64_t numChannelsIn,
int64_t xHeight,
int64_t xWidth,
int64_t wHeight,
int64_t wWidth,
int64_t padHeightBegin,
int64_t padHeightEnd,
int64_t padWidthBegin,
int64_t padWidthEnd,
int64_t strideHeight,
int64_t strideWidth,
int64_t dilationHeight,
int64_t dilationWidth,
int64_t outWidth,
int64_t patchSize,
int64_t numPatches,
int64_t numPatchesPerBatch,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
auto elemType = xType.getElementType();
constexpr size_t numInputs = 1;
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) {
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
SmallVector<Type> resultTypes(packedNumRows, gemmInputRowType);
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, resultTypes, {}, x, [&](Value xArg) {
Value paddedInput = xArg;
// Pad input with zeros if needed:
@@ -240,7 +244,7 @@ static Value createIm2colCompute(Value x,
Value row = tensor::CollapseShapeOp::create(rewriter,
loc,
rowType,
im2colRowType,
patch,
SmallVector<ReassociationIndices> {
{0},
@@ -256,121 +260,115 @@ static Value createIm2colCompute(Value x,
rewriter.setInsertionPointAfter(im2colLoop);
Value im2col = im2colLoop.getResult(0);
spatial::SpatYieldOp::create(rewriter, loc, im2col);
});
return im2colComputeOp.getResult(0);
}
static Value createPackedIm2colRows(Value im2col,
RankedTensorType im2colType,
Type elemType,
int64_t numPatches,
int64_t patchSize,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (packFactor == 1)
return im2col;
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
auto packedComputeOp = createSpatCompute<1>(rewriter, loc, packedType, {}, im2col, [&](Value im2colArg) {
Value paddedIm2col = createPaddedRows(im2colArg, im2colType, paddedNumPatches, rewriter, loc);
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
loc,
groupedType,
paddedIm2col,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
Value packedIm2col = tensor::CollapseShapeOp::create(rewriter,
loc,
packedType,
groupedIm2col,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
spatial::SpatYieldOp::create(rewriter, loc, packedIm2col);
});
return packedComputeOp.getResult(0);
}
static Value createUnpackedOutput(Value packedOutput,
RankedTensorType gemmOutType,
RankedTensorType outType,
int64_t numPatches,
int64_t numChannelsOut,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (packFactor == 1)
return packedOutput;
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
auto unpackComputeOp = createSpatCompute<1>(rewriter, loc, gemmOutType, {}, packedOutput, [&](Value packedOutputArg) {
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
packedOutputArg,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
loc,
paddedType,
expandedOutput,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
Value unpackedOutput = paddedOutput;
if (paddedNumPatches != numPatches) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
unpackedOutput =
tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
Value gemmInputRows = im2col;
if (packFactor != 1) {
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
loc,
groupedType,
paddedIm2col,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
loc,
packedType,
groupedIm2col,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
}
spatial::SpatYieldOp::create(rewriter, loc, unpackedOutput);
SmallVector<Value> rowResults;
rowResults.reserve(packedNumRows);
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(packFactor * patchSize)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
rowResults.push_back(
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
}
spatial::SpatYieldOp::create(rewriter, loc, rowResults);
});
return unpackComputeOp.getResult(0);
SmallVector<Value> rows;
rows.reserve(im2colComputeOp.getNumResults());
for (Value result : im2colComputeOp.getResults())
rows.push_back(result);
return rows;
}
static Value createCollectedConvOutput(Value gemmOut,
static Value createCollectedConvOutput(ValueRange gemmRows,
Type convType,
RankedTensorType gemmOutType,
RankedTensorType nhwcType,
RankedTensorType outType,
int64_t numPatches,
int64_t numChannelsOut,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
auto collectComputeOp =
createSpatCompute(rewriter, loc, convType, {}, ValueRange {gemmOut}, [&](ValueRange gemmOutArgs) {
Value gemmOutArg = gemmOutArgs.front();
// Restore to NCHW layout:
// [numPatches, numChannelsOut]
// -> [1, outHeight, outWidth, numChannelsOut]
// -> [1, numChannelsOut, outHeight, outWidth]
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
loc,
nhwcType,
gemmOutArg,
SmallVector<ReassociationIndices> {
{0, 1, 2},
{3}
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
Value gemmOut;
if (packFactor == 1) {
gemmOut = gemmRowArgs.size() == 1 ? gemmRowArgs.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
}
else {
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
Value packedOutput = gemmRowArgs.size() == 1
? gemmRowArgs.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
packedOutput,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
loc,
paddedType,
expandedOutput,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
gemmOut = paddedOutput;
if (paddedNumPatches != numPatches) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
}
}
// Restore to NCHW layout:
// [numPatches, numChannelsOut]
// -> [1, outHeight, outWidth, numChannelsOut]
// -> [1, numChannelsOut, outHeight, outWidth]
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
loc,
nhwcType,
gemmOut,
SmallVector<ReassociationIndices> {
{0, 1, 2},
{3}
});
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
});
return collectComputeOp.getResult(0);
}
@@ -487,11 +485,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
Value gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
Value biasMatrix;
DenseElementsAttr biasDenseAttr;
if (hasB) {
gemmC = b;
gemmBias = b;
biasDenseAttr = getDenseConstantAttr(b);
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
}
@@ -500,94 +498,89 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
const int64_t effectiveMaxParallelPixels =
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
Value im2col = createIm2colCompute(x,
xType,
im2colType,
rowType,
batchSize,
numChannelsIn,
xHeight,
xWidth,
wHeight,
wWidth,
padHeightBegin,
padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
outWidth,
patchSize,
numPatches,
numPatchesPerBatch,
rewriter,
loc);
// Keep the standard im2col view of convolution:
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
//
// We want to process N pixels at the same time. Instead of doing N separate operations
// of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
// containing N copies of W^T and concatenate N im2col rows into one longer row:
// A_packed: [ceil(numPatches / N), N * patchSize]
// B_packed: [N * patchSize, N * cOut]
// Y_packed: [ceil(numPatches / N), N * cOut]
auto gemmInputRowType = RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
auto gemmOutputRowType =
RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
xType,
im2colType,
rowType,
gemmInputRowType,
batchSize,
numChannelsIn,
xHeight,
xWidth,
wHeight,
wWidth,
padHeightBegin,
padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
outWidth,
patchSize,
numPatches,
numPatchesPerBatch,
effectiveMaxParallelPixels,
rewriter,
loc);
Value gemmOut;
if (effectiveMaxParallelPixels == 1) {
// Fallback to the plain im2col GEMM when a single crossbar cannot fit multiple pixels.
gemmOut = ONNXGemmOp::create(rewriter,
loc,
gemmOutType,
im2col,
wTrans,
gemmC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
}
else {
// Keep the standard im2col view of convolution:
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// but repack several old rows into one new row so we use the available crossbar size better.
//
// We want to process N spatial pixels at the exact same time. Instead of doing N separate
// operations of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
// containing N copies of W^T and concatenate N im2col rows into one longer row:
// A_packed: [ceil(numPatches / N), N * patchSize]
// B_packed: [N * patchSize, N * cOut]
// Y_packed: [ceil(numPatches / N), N * cOut]
// The downstream GemmToManyGemv pass still splits by row, but now there are fewer, longer rows.
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
auto packedOutType =
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
Value gemmB = buildPackedWeight(wDenseAttr,
wTrans,
wType,
numChannelsIn,
numChannelsOut,
wHeight,
wWidth,
patchSize,
effectiveMaxParallelPixels,
rewriter,
loc);
Value gemmC = buildPackedBias(
hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
Value packedA = createPackedIm2colRows(
im2col, im2colType, elemType, numPatches, patchSize, effectiveMaxParallelPixels, rewriter, loc);
Value packedB = buildPackedWeight(wDenseAttr,
wTrans,
wType,
numChannelsIn,
numChannelsOut,
wHeight,
wWidth,
patchSize,
effectiveMaxParallelPixels,
rewriter,
loc);
Value packedC = buildPackedBias(
hasB, gemmC, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
Value packedOut = ONNXGemmOp::create(rewriter,
loc,
packedOutType,
packedA,
packedB,
packedC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
gemmOut = createUnpackedOutput(
packedOut, gemmOutType, outType, numPatches, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
SmallVector<Value> gemmRows;
gemmRows.reserve(gemmInputRows.size());
for (Value gemmInputRow : gemmInputRows) {
Value gemmRow = ONNXGemmOp::create(rewriter,
loc,
gemmOutputRowType,
gemmInputRow,
gemmB,
gemmC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
gemmRows.push_back(gemmRow);
}
rewriter.replaceOp(convOp, createCollectedConvOutput(gemmOut, convOp.getType(), nhwcType, outType, rewriter, loc));
rewriter.replaceOp(convOp,
createCollectedConvOutput(gemmRows,
convOp.getType(),
gemmOutType,
nhwcType,
outType,
numPatches,
numChannelsOut,
effectiveMaxParallelPixels,
rewriter,
loc));
return success();
}