conv now lowers correctly down to bufferized pim

This commit is contained in:
NiccoloN
2026-03-20 12:55:09 +01:00
parent 6e1de865bb
commit db3f52a647
4 changed files with 114 additions and 100 deletions

View File

@@ -6,13 +6,23 @@
#include <cassert> #include <cassert>
#include "Conv.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace {
struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXConvOp convOp,
ONNXConvOpAdaptor convOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
};
} // namespace
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
ONNXConvOpAdaptor convOpAdaptor, ONNXConvOpAdaptor convOpAdaptor,
@@ -100,76 +110,15 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
const int64_t numPatches = outHeight * outWidth; const int64_t numPatches = outHeight * outWidth;
auto elemType = xType.getElementType(); auto elemType = xType.getElementType();
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
// Pad input with zeros if needed:
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
auto paddedType = RankedTensorType::get({1, numChannelsIn, paddedHeight, paddedWidth}, elemType);
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padHeightBegin),
rewriter.getIndexAttr(padWidthBegin)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padHeightEnd),
rewriter.getIndexAttr(padWidthEnd)};
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, x, lowPads, highPads);
auto* padBlock = new Block();
for (int i = 0; i < 4; i++)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
tensor::YieldOp::create(rewriter, loc, zero.getResult());
rewriter.setInsertionPointAfter(padOp);
x = padOp.getResult();
}
// Build im2col [numPatches, patchSize]:
// For each output position (oh, ow), extract the patch from x
auto rowType = RankedTensorType::get({1, patchSize}, elemType); auto rowType = RankedTensorType::get({1, patchSize}, elemType);
SmallVector<Value> im2colRows; auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
im2colRows.reserve(numPatches); auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
for (int64_t oh = 0; oh < outHeight; oh++) { auto nhwcType = RankedTensorType::get({1, outHeight, outWidth, numChannelsOut}, outType.getElementType());
for (int64_t ow = 0; ow < outWidth; ow++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(oh * strideHeight),
rewriter.getIndexAttr(ow * strideWidth)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(numChannelsIn),
rewriter.getIndexAttr(wHeight),
rewriter.getIndexAttr(wWidth)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(dilationHeight),
rewriter.getIndexAttr(dilationWidth)};
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, x, offsets, sizes, strides);
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
Value row = tensor::CollapseShapeOp::create(rewriter,
loc,
rowType,
patch,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
im2colRows.push_back(row);
}
}
// Concatenate all rows: [numPatches, patchSize]
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
// Prepare weight matrix W for crossbar storage: // Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut] // W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
Value wFlat = tensor::CollapseShapeOp::create(rewriter, Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc, loc,
wFlatType, wFlatType,
@@ -196,23 +145,98 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
else else
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
auto im2colComputeOp =
spatial::SpatWeightedCompute::create(rewriter, loc, im2colType, SmallVector<Value>(), ValueRange {x});
auto* im2colBlock = new Block();
im2colBlock->addArgument(x.getType(), loc);
im2colComputeOp.getBody().push_back(im2colBlock);
rewriter.setInsertionPointToStart(im2colBlock);
Value paddedInput = im2colBlock->getArgument(0);
// Pad input with zeros if needed:
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
auto paddedType = RankedTensorType::get({1, numChannelsIn, paddedHeight, paddedWidth}, elemType);
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padHeightBegin),
rewriter.getIndexAttr(padWidthBegin)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padHeightEnd),
rewriter.getIndexAttr(padWidthEnd)};
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
auto* padBlock = new Block();
for (int i = 0; i < 4; i++)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
tensor::YieldOp::create(rewriter, loc, zero.getResult());
rewriter.setInsertionPointAfter(padOp);
paddedInput = padOp.getResult();
}
// Build im2col [numPatches, patchSize]:
// For each output position (oh, ow), extract the patch from x
SmallVector<Value> im2colRows;
im2colRows.reserve(numPatches);
for (int64_t oh = 0; oh < outHeight; oh++) {
for (int64_t ow = 0; ow < outWidth; ow++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(oh * strideHeight),
rewriter.getIndexAttr(ow * strideWidth)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(numChannelsIn),
rewriter.getIndexAttr(wHeight),
rewriter.getIndexAttr(wWidth)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(dilationHeight),
rewriter.getIndexAttr(dilationWidth)};
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
Value row = tensor::CollapseShapeOp::create(rewriter,
loc,
rowType,
patch,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
im2colRows.push_back(row);
}
}
// Concatenate all rows: [numPatches, patchSize]
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
spatial::SpatYieldOp::create(rewriter, loc, im2col);
rewriter.setInsertionPointAfter(im2colComputeOp);
// Gemm: A @ B + C = im2col @ W^T + b // Gemm: A @ B + C = im2col @ W^T + b
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut] // [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
auto gemmOp = ONNXGemmOp::create(rewriter, auto gemmOp = ONNXGemmOp::create(rewriter,
loc, loc,
gemmOutType, gemmOutType,
im2col, im2colComputeOp.getResult(0),
wTrans, wTrans,
gemmC, gemmC,
rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false), rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false)); rewriter.getBoolAttr(false));
Value gemmOut = gemmOp.getY(); Value gemmOut = gemmOp.getY();
auto collectComputeOp = auto collectComputeOp =
spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector<Value>(), gemmOut); spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector<Value>(), ValueRange {gemmOut});
auto* collectBlock = new Block(); auto* collectBlock = new Block();
collectBlock->addArgument(gemmOut.getType(), loc); collectBlock->addArgument(gemmOut.getType(), loc);
@@ -225,7 +249,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
// [numPatches, numChannelsOut] // [numPatches, numChannelsOut]
// -> [1, outHeight, outWidth, numChannelsOut] // -> [1, outHeight, outWidth, numChannelsOut]
// -> [1, numChannelsOut, outHeight, outWidth] // -> [1, numChannelsOut, outHeight, outWidth]
auto nhwcType = RankedTensorType::get({1, outHeight, outWidth, numChannelsOut}, outType.getElementType());
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter, Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
loc, loc,
nhwcType, nhwcType,
@@ -238,7 +261,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
spatial::SpatYieldOp::create(rewriter, loc, nchwOut); spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
rewriter.replaceOp(convOp, collectComputeOp); rewriter.replaceOp(convOp, collectComputeOp.getResult(0));
return success(); return success();
} }

View File

@@ -1,23 +0,0 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/LogicalResult.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
struct ConvToGemm : mlir::OpConversionPattern<mlir::ONNXConvOp> {
ConvToGemm(mlir::MLIRContext* ctx)
: OpConversionPattern(ctx) {}
mlir::LogicalResult matchAndRewrite(mlir::ONNXConvOp convOp,
mlir::ONNXConvOpAdaptor convOpAdaptor,
mlir::ConversionPatternRewriter& rewriter) const override;
};
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir

View File

@@ -300,7 +300,8 @@ struct ChannelBroadcastReceiveOpInterface
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
auto outputSize = cast<ShapedType>(outputTensor.getType()).getNumElements(); auto outputType = cast<ShapedType>(outputTensor.getType());
auto outputSize = outputType.getNumElements() * outputType.getElementTypeBitWidth() / 8;
auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>(); auto channelNewOp = op->getOperand(0).getDefiningOp<SpatChannelNewOp>();
if (!channelNewOp) { if (!channelNewOp) {
@@ -356,7 +357,8 @@ struct ChannelBroadcastSendOpInterface
} }
/* /*
* Turn the channel send to pim.send * Turn the channel send into a device-to-host copy into the shared
* broadcast buffer that receive ops load from later.
*/ */
LogicalResult bufferize(Operation* op, LogicalResult bufferize(Operation* op,
RewriterBase& rewriter, RewriterBase& rewriter,
@@ -389,8 +391,18 @@ struct ChannelBroadcastSendOpInterface
bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter); bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter);
} }
auto srcType = cast<ShapedType>(srcTensor.getType());
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
rewriter.setInsertionPoint(op); rewriter.setInsertionPoint(op);
replaceOpWithBufferizedValues(rewriter, op, {bufferAllocation, srcMemRef}); rewriter.create<pim::PimMemCopyDevToHostOp>(op->getLoc(),
bufferAllocation.getType(),
bufferAllocation,
srcMemRef,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(sizeInBytes));
rewriter.eraseOp(op);
return success(); return success();
} }
}; };

View File

@@ -1,3 +1,4 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
@@ -15,7 +16,8 @@ namespace onnx_mlir {
namespace { namespace {
static bool isAddressOnlyHostOp(Operation* op) { static bool isAddressOnlyHostOp(Operation* op) {
return isa<memref::AllocOp, return isa<arith::ConstantOp,
memref::AllocOp,
memref::GetGlobalOp, memref::GetGlobalOp,
memref::SubViewOp, memref::SubViewOp,
memref::CastOp, memref::CastOp,