From db3f52a64787b19fb42b59c516a79e204408a7a8 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Fri, 20 Mar 2026 12:55:09 +0100 Subject: [PATCH] conv now lowers correctly down to bufferized pim --- .../Conversion/ONNXToSpatial/Math/Conv.cpp | 169 ++++++++++-------- .../Conversion/ONNXToSpatial/Math/Conv.hpp | 23 --- .../SpatialBufferizableOpInterface.cpp | 18 +- src/PIM/Pass/PimHostVerificationPass.cpp | 4 +- 4 files changed, 114 insertions(+), 100 deletions(-) delete mode 100644 src/PIM/Conversion/ONNXToSpatial/Math/Conv.hpp diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp index 4c5b1ac..cf8de30 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp @@ -6,13 +6,23 @@ #include -#include "Conv.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { +namespace { + +struct ConvToGemm : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXConvOp convOp, + ONNXConvOpAdaptor convOpAdaptor, + ConversionPatternRewriter& rewriter) const override; +}; + +} // namespace LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, ONNXConvOpAdaptor convOpAdaptor, @@ -100,76 +110,15 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, const int64_t numPatches = outHeight * outWidth; auto elemType = xType.getElementType(); - - // Pad input with zeros if needed: - // [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth] - if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) { - const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd; - const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd; - auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0)); - auto paddedType = RankedTensorType::get({1, numChannelsIn, paddedHeight, paddedWidth}, elemType); - SmallVector lowPads = {rewriter.getIndexAttr(0), - rewriter.getIndexAttr(0), - rewriter.getIndexAttr(padHeightBegin), - rewriter.getIndexAttr(padWidthBegin)}; - SmallVector highPads = {rewriter.getIndexAttr(0), - rewriter.getIndexAttr(0), - rewriter.getIndexAttr(padHeightEnd), - rewriter.getIndexAttr(padWidthEnd)}; - auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, x, lowPads, highPads); - auto* padBlock = new Block(); - for (int i = 0; i < 4; i++) - padBlock->addArgument(rewriter.getIndexType(), loc); - padOp.getRegion().push_back(padBlock); - rewriter.setInsertionPointToStart(padBlock); - tensor::YieldOp::create(rewriter, loc, zero.getResult()); - rewriter.setInsertionPointAfter(padOp); - x = padOp.getResult(); - } - - // Build im2col [numPatches, patchSize]: - // For each output position (oh, ow), extract the patch from x + auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType); auto rowType = RankedTensorType::get({1, patchSize}, elemType); - SmallVector im2colRows; - im2colRows.reserve(numPatches); - - for (int64_t oh = 0; oh < outHeight; oh++) { - for (int64_t ow = 0; ow < outWidth; ow++) { - SmallVector offsets = {rewriter.getIndexAttr(0), - rewriter.getIndexAttr(0), - rewriter.getIndexAttr(oh * strideHeight), - rewriter.getIndexAttr(ow * strideWidth)}; - SmallVector sizes = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(numChannelsIn), - rewriter.getIndexAttr(wHeight), - rewriter.getIndexAttr(wWidth)}; - SmallVector strides = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(dilationHeight), - rewriter.getIndexAttr(dilationWidth)}; - auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); - Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, x, offsets, sizes, strides); - - // Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize] - Value row = tensor::CollapseShapeOp::create(rewriter, - loc, - rowType, - patch, - SmallVector { - {0}, - {1, 2, 3} - }); - im2colRows.push_back(row); - } - } - - // Concatenate all rows: [numPatches, patchSize] - Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows); + auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType()); + auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType()); + auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType()); + auto nhwcType = RankedTensorType::get({1, outHeight, outWidth, numChannelsOut}, outType.getElementType()); // Prepare weight matrix W for crossbar storage: // W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut] - auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType()); - auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType()); Value wFlat = tensor::CollapseShapeOp::create(rewriter, loc, wFlatType, @@ -196,23 +145,98 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, else gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); + auto im2colComputeOp = + spatial::SpatWeightedCompute::create(rewriter, loc, im2colType, SmallVector(), ValueRange {x}); + + auto* im2colBlock = new Block(); + im2colBlock->addArgument(x.getType(), loc); + im2colComputeOp.getBody().push_back(im2colBlock); + rewriter.setInsertionPointToStart(im2colBlock); + + Value paddedInput = im2colBlock->getArgument(0); + + // Pad input with zeros if needed: + // [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth] + if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) { + const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd; + const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd; + auto paddedType = RankedTensorType::get({1, numChannelsIn, paddedHeight, paddedWidth}, elemType); + SmallVector lowPads = {rewriter.getIndexAttr(0), + rewriter.getIndexAttr(0), + rewriter.getIndexAttr(padHeightBegin), + rewriter.getIndexAttr(padWidthBegin)}; + SmallVector highPads = {rewriter.getIndexAttr(0), + rewriter.getIndexAttr(0), + rewriter.getIndexAttr(padHeightEnd), + rewriter.getIndexAttr(padWidthEnd)}; + auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads); + auto* padBlock = new Block(); + for (int i = 0; i < 4; i++) + padBlock->addArgument(rewriter.getIndexType(), loc); + padOp.getRegion().push_back(padBlock); + rewriter.setInsertionPointToStart(padBlock); + auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0)); + tensor::YieldOp::create(rewriter, loc, zero.getResult()); + rewriter.setInsertionPointAfter(padOp); + paddedInput = padOp.getResult(); + } + + // Build im2col [numPatches, patchSize]: + // For each output position (oh, ow), extract the patch from x + SmallVector im2colRows; + im2colRows.reserve(numPatches); + for (int64_t oh = 0; oh < outHeight; oh++) { + for (int64_t ow = 0; ow < outWidth; ow++) { + SmallVector offsets = {rewriter.getIndexAttr(0), + rewriter.getIndexAttr(0), + rewriter.getIndexAttr(oh * strideHeight), + rewriter.getIndexAttr(ow * strideWidth)}; + SmallVector sizes = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(numChannelsIn), + rewriter.getIndexAttr(wHeight), + rewriter.getIndexAttr(wWidth)}; + SmallVector strides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(dilationHeight), + rewriter.getIndexAttr(dilationWidth)}; + auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); + Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides); + + // Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize] + Value row = tensor::CollapseShapeOp::create(rewriter, + loc, + rowType, + patch, + SmallVector { + {0}, + {1, 2, 3} + }); + im2colRows.push_back(row); + } + } + + // Concatenate all rows: [numPatches, patchSize] + Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows); + spatial::SpatYieldOp::create(rewriter, loc, im2col); + + rewriter.setInsertionPointAfter(im2colComputeOp); + // Gemm: A @ B + C = im2col @ W^T + b // [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut] - auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType()); auto gemmOp = ONNXGemmOp::create(rewriter, loc, gemmOutType, - im2col, + im2colComputeOp.getResult(0), wTrans, gemmC, rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f), rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)); - Value gemmOut = gemmOp.getY(); + auto collectComputeOp = - spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector(), gemmOut); + spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector(), ValueRange {gemmOut}); auto* collectBlock = new Block(); collectBlock->addArgument(gemmOut.getType(), loc); @@ -225,7 +249,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, // [numPatches, numChannelsOut] // -> [1, outHeight, outWidth, numChannelsOut] // -> [1, numChannelsOut, outHeight, outWidth] - auto nhwcType = RankedTensorType::get({1, outHeight, outWidth, numChannelsOut}, outType.getElementType()); Value nhwcOut = tensor::ExpandShapeOp::create(rewriter, loc, nhwcType, @@ -238,7 +261,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, spatial::SpatYieldOp::create(rewriter, loc, nchwOut); - rewriter.replaceOp(convOp, collectComputeOp); + rewriter.replaceOp(convOp, collectComputeOp.getResult(0)); return success(); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.hpp b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.hpp deleted file mode 100644 index b29ace7..0000000 --- a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.hpp +++ /dev/null @@ -1,23 +0,0 @@ -#pragma once - -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "llvm/Support/LogicalResult.h" - -#include "src/Dialect/ONNX/ONNXOps.hpp" - -namespace onnx_mlir { - -struct ConvToGemm : mlir::OpConversionPattern { - ConvToGemm(mlir::MLIRContext* ctx) - : OpConversionPattern(ctx) {} - - mlir::LogicalResult matchAndRewrite(mlir::ONNXConvOp convOp, - mlir::ONNXConvOpAdaptor convOpAdaptor, - mlir::ConversionPatternRewriter& rewriter) const override; -}; - -void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - -} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp index 3fcf477..e888e77 100644 --- a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp @@ -300,7 +300,8 @@ struct ChannelBroadcastReceiveOpInterface auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); - auto outputSize = cast(outputTensor.getType()).getNumElements(); + auto outputType = cast(outputTensor.getType()); + auto outputSize = outputType.getNumElements() * outputType.getElementTypeBitWidth() / 8; auto channelNewOp = op->getOperand(0).getDefiningOp(); if (!channelNewOp) { @@ -356,7 +357,8 @@ struct ChannelBroadcastSendOpInterface } /* - * Turn the channel send to pim.send + * Turn the channel send into a device-to-host copy into the shared + * broadcast buffer that receive ops load from later. */ LogicalResult bufferize(Operation* op, RewriterBase& rewriter, @@ -389,8 +391,18 @@ struct ChannelBroadcastSendOpInterface bufferAllocation = createEmptyFromType(srcTensor.getType(), op->getLoc(), rewriter); } + auto srcType = cast(srcTensor.getType()); + auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8; + rewriter.setInsertionPoint(op); - replaceOpWithBufferizedValues(rewriter, op, {bufferAllocation, srcMemRef}); + rewriter.create(op->getLoc(), + bufferAllocation.getType(), + bufferAllocation, + srcMemRef, + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(sizeInBytes)); + rewriter.eraseOp(op); return success(); } }; diff --git a/src/PIM/Pass/PimHostVerificationPass.cpp b/src/PIM/Pass/PimHostVerificationPass.cpp index 9b1dbc0..b2bedd5 100644 --- a/src/PIM/Pass/PimHostVerificationPass.cpp +++ b/src/PIM/Pass/PimHostVerificationPass.cpp @@ -1,3 +1,4 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" @@ -15,7 +16,8 @@ namespace onnx_mlir { namespace { static bool isAddressOnlyHostOp(Operation* op) { - return isa