reduce spatial compile-times in convolutions using a scf.for instead of materializing a huge number of instructions
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-04-10 18:50:25 +02:00
parent f3a36e9d43
commit f054e66ed0
18 changed files with 623 additions and 241 deletions

View File

@@ -26,6 +26,7 @@ add_pim_library(OMONNXToSpatial
ONNXToSpatialIncGen
LINK_LIBS PUBLIC
MLIRSCFDialect
MLIRTosaDialect
OMCompilerOptions
OMPimCompilerOptions

View File

@@ -1,4 +1,5 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@@ -75,7 +76,11 @@ void ONNXToSpatialPass::runOnOperation() {
}
ConversionTarget target(*ctx);
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
target.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
target.addDynamicallyLegalOp<ONNXMatMulOp>(
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
target.addIllegalOp<ONNXAddOp>();

View File

@@ -1,4 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -169,44 +170,60 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
paddedInput = padOp.getResult();
}
// Build im2col [numPatches, patchSize]:
// For each batch/output position (n, oh, ow), extract the patch from x
SmallVector<Value> im2colRows;
im2colRows.reserve(numPatches);
for (int64_t n = 0; n < batchSize; n++) {
for (int64_t oh = 0; oh < outHeight; oh++) {
for (int64_t ow = 0; ow < outWidth; ow++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(n),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(oh * strideHeight),
rewriter.getIndexAttr(ow * strideWidth)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(numChannelsIn),
rewriter.getIndexAttr(wHeight),
rewriter.getIndexAttr(wWidth)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(dilationHeight),
rewriter.getIndexAttr(dilationWidth)};
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
// until the late PIM unrolling step.
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
Value row = tensor::CollapseShapeOp::create(rewriter,
loc,
rowType,
patch,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
im2colRows.push_back(row);
}
}
}
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
rewriter.setInsertionPointToStart(im2colLoop.getBody());
// Concatenate all rows: [numPatches, patchSize]
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
Value patchIndex = im2colLoop.getInductionVar();
Value im2colAcc = im2colLoop.getRegionIterArgs().front();
Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(numChannelsIn),
rewriter.getIndexAttr(wHeight),
rewriter.getIndexAttr(wWidth)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(dilationHeight),
rewriter.getIndexAttr(dilationWidth)};
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
Value row = tensor::CollapseShapeOp::create(rewriter,
loc,
rowType,
patch,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value updatedIm2col =
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
scf::YieldOp::create(rewriter, loc, updatedIm2col);
rewriter.setInsertionPointAfter(im2colLoop);
Value im2col = im2colLoop.getResult(0);
spatial::SpatYieldOp::create(rewriter, loc, im2col);
});

View File

@@ -12,6 +12,7 @@ add_pim_library(OMSpatialToPim
SpatialToPimIncGen
LINK_LIBS PUBLIC
MLIRSCFDialect
MLIRTosaDialect
OMCompilerOptions
OMPimCommon

View File

@@ -1,5 +1,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinDialect.h"
@@ -134,7 +135,12 @@ void SpatialToPimPass::runOnOperation() {
MLIRContext* ctx = moduleOp.getContext();
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect, func::FuncDialect, BuiltinDialect>();
target.addLegalDialect<PimDialect,
tensor::TensorDialect,
arith::ArithDialect,
func::FuncDialect,
scf::SCFDialect,
BuiltinDialect>();
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);