reduce spatial compile-times in convolutions using a scf.for instead of materializing a huge number of instructions
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -26,6 +26,7 @@ add_pim_library(OMONNXToSpatial
|
||||
ONNXToSpatialIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRSCFDialect
|
||||
MLIRTosaDialect
|
||||
OMCompilerOptions
|
||||
OMPimCompilerOptions
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
@@ -75,7 +76,11 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
}
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
|
||||
target.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
target.addDynamicallyLegalOp<ONNXMatMulOp>(
|
||||
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
|
||||
target.addIllegalOp<ONNXAddOp>();
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
@@ -169,44 +170,60 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
paddedInput = padOp.getResult();
|
||||
}
|
||||
|
||||
// Build im2col [numPatches, patchSize]:
|
||||
// For each batch/output position (n, oh, ow), extract the patch from x
|
||||
SmallVector<Value> im2colRows;
|
||||
im2colRows.reserve(numPatches);
|
||||
for (int64_t n = 0; n < batchSize; n++) {
|
||||
for (int64_t oh = 0; oh < outHeight; oh++) {
|
||||
for (int64_t ow = 0; ow < outWidth; ow++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(n),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(oh * strideHeight),
|
||||
rewriter.getIndexAttr(ow * strideWidth)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(numChannelsIn),
|
||||
rewriter.getIndexAttr(wHeight),
|
||||
rewriter.getIndexAttr(wWidth)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(dilationHeight),
|
||||
rewriter.getIndexAttr(dilationWidth)};
|
||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
|
||||
// until the late PIM unrolling step.
|
||||
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
|
||||
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
|
||||
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
|
||||
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
|
||||
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
||||
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
||||
|
||||
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
rowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
im2colRows.push_back(row);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
|
||||
rewriter.setInsertionPointToStart(im2colLoop.getBody());
|
||||
|
||||
// Concatenate all rows: [numPatches, patchSize]
|
||||
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
|
||||
Value patchIndex = im2colLoop.getInductionVar();
|
||||
Value im2colAcc = im2colLoop.getRegionIterArgs().front();
|
||||
|
||||
Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
||||
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
||||
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
||||
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
||||
Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
|
||||
Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
|
||||
|
||||
SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(numChannelsIn),
|
||||
rewriter.getIndexAttr(wHeight),
|
||||
rewriter.getIndexAttr(wWidth)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(dilationHeight),
|
||||
rewriter.getIndexAttr(dilationWidth)};
|
||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
rowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
|
||||
SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
|
||||
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value updatedIm2col =
|
||||
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
|
||||
scf::YieldOp::create(rewriter, loc, updatedIm2col);
|
||||
|
||||
rewriter.setInsertionPointAfter(im2colLoop);
|
||||
Value im2col = im2colLoop.getResult(0);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
||||
});
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ add_pim_library(OMSpatialToPim
|
||||
SpatialToPimIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRSCFDialect
|
||||
MLIRTosaDialect
|
||||
OMCompilerOptions
|
||||
OMPimCommon
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
@@ -134,7 +135,12 @@ void SpatialToPimPass::runOnOperation() {
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<PimDialect, tensor::TensorDialect, arith::ArithDialect, func::FuncDialect, BuiltinDialect>();
|
||||
target.addLegalDialect<PimDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
func::FuncDialect,
|
||||
scf::SCFDialect,
|
||||
BuiltinDialect>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateWithGenerated(patterns);
|
||||
|
||||
Reference in New Issue
Block a user