From 3625edc80a896e035d4ca9942e477d64773d147a Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Mon, 30 Mar 2026 16:14:26 +0200 Subject: [PATCH] add better createSpatCompute helper --- src/PIM/Conversion/ONNXToSpatial/Common.hpp | 94 ++++++++-- .../ONNXToSpatial/Patterns/Math/Conv.cpp | 175 ++++++++---------- .../ONNXToSpatial/Patterns/Math/Gemm.cpp | 76 +++----- .../ONNXToSpatial/Patterns/Math/MatMul.cpp | 17 +- .../ONNXToSpatial/Patterns/NN/Pool.cpp | 136 +++++++------- 5 files changed, 259 insertions(+), 239 deletions(-) diff --git a/src/PIM/Conversion/ONNXToSpatial/Common.hpp b/src/PIM/Conversion/ONNXToSpatial/Common.hpp index 56a5920..351a54b 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common.hpp @@ -104,20 +104,39 @@ inline auto getTensorShape(mlir::Value tensor) { namespace detail { +inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); } + template -void invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence) { - std::forward(fn)(block->getArgument(Is)...); +decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence) { + return std::forward(fn)(block->getArgument(Is)...); } +template +using ValueArg = mlir::Value; + +template +struct InvokeWithBlockArgsResult; + +template +struct InvokeWithBlockArgsResult> { + using type = std::invoke_result_t...>; +}; + +template +using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult::type; + +template +using InvokeWithValueRangeResultT = std::invoke_result_t; + } // namespace detail -template -spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter& rewriter, - mlir::Location loc, - mlir::TypeRange resultTypes, - mlir::ValueRange weights, - mlir::ValueRange inputs, - BodyFn&& body) { +template +auto createSpatCompute(RewriterT& rewriter, + mlir::Location loc, + mlir::TypeRange resultTypes, + mlir::ValueRange weights, + mlir::ValueRange inputs, + BodyFn&& body) { assert(inputs.size() == NumInputs && "NumInputs must match the number of input values"); auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs); @@ -128,10 +147,61 @@ spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter& computeOp.getBody().push_back(block); rewriter.setInsertionPointToStart(block); - detail::invokeWithBlockArgs(std::forward(body), block, std::make_index_sequence {}); + using BodyResult = detail::InvokeWithBlockArgsResultT, std::make_index_sequence>; + if constexpr (std::is_same_v) { + auto bodyResult = + detail::invokeWithBlockArgs(std::forward(body), block, std::make_index_sequence {}); + if (mlir::failed(bodyResult)) { + rewriter.setInsertionPointAfter(computeOp); + rewriter.eraseOp(computeOp); + return mlir::FailureOr(mlir::failure()); + } + rewriter.setInsertionPointAfter(computeOp); + return mlir::FailureOr(computeOp); + } + else { + static_assert(std::is_same_v, "createSpatCompute body must return void or mlir::LogicalResult"); + detail::invokeWithBlockArgs(std::forward(body), block, std::make_index_sequence {}); - rewriter.setInsertionPointAfter(computeOp); - return computeOp; + rewriter.setInsertionPointAfter(computeOp); + return computeOp; + } +} + +template +auto createSpatCompute(RewriterT& rewriter, + mlir::Location loc, + mlir::TypeRange resultTypes, + mlir::ValueRange weights, + mlir::ValueRange inputs, + BodyFn&& body) { + auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs); + + auto* block = new mlir::Block(); + for (mlir::Value input : inputs) + block->addArgument(input.getType(), loc); + + computeOp.getBody().push_back(block); + rewriter.setInsertionPointToStart(block); + + using BodyResult = detail::InvokeWithValueRangeResultT>; + if constexpr (std::is_same_v) { + auto bodyResult = std::forward(body)(detail::getBlockArgs(block)); + if (mlir::failed(bodyResult)) { + rewriter.setInsertionPointAfter(computeOp); + rewriter.eraseOp(computeOp); + return mlir::FailureOr(mlir::failure()); + } + rewriter.setInsertionPointAfter(computeOp); + return mlir::FailureOr(computeOp); + } + else { + static_assert(std::is_same_v, "createSpatCompute body must return void or mlir::LogicalResult"); + std::forward(body)(detail::getBlockArgs(block)); + + rewriter.setInsertionPointAfter(computeOp); + return computeOp; + } } llvm::SmallVector sliceTensor(const mlir::Value& tensorToSlice, diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 09eb23e..1c4f38b 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -6,6 +6,7 @@ #include +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -138,83 +139,76 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, else gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); - auto im2colComputeOp = - spatial::SpatWeightedCompute::create(rewriter, loc, im2colType, SmallVector(), ValueRange {x}); + constexpr size_t numInputs = 1; + auto im2colComputeOp = createSpatCompute(rewriter, loc, im2colType, {}, x, [&](Value xArg) { + Value paddedInput = xArg; - auto* im2colBlock = new Block(); - im2colBlock->addArgument(x.getType(), loc); - im2colComputeOp.getBody().push_back(im2colBlock); - rewriter.setInsertionPointToStart(im2colBlock); + // Pad input with zeros if needed: + // [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth] + if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) { + const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd; + const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd; + auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType); + SmallVector lowPads = {rewriter.getIndexAttr(0), + rewriter.getIndexAttr(0), + rewriter.getIndexAttr(padHeightBegin), + rewriter.getIndexAttr(padWidthBegin)}; + SmallVector highPads = {rewriter.getIndexAttr(0), + rewriter.getIndexAttr(0), + rewriter.getIndexAttr(padHeightEnd), + rewriter.getIndexAttr(padWidthEnd)}; + auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads); + auto* padBlock = new Block(); + for (int i = 0; i < 4; i++) + padBlock->addArgument(rewriter.getIndexType(), loc); + padOp.getRegion().push_back(padBlock); + rewriter.setInsertionPointToStart(padBlock); + auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0)); + tensor::YieldOp::create(rewriter, loc, zero.getResult()); + rewriter.setInsertionPointAfter(padOp); + paddedInput = padOp.getResult(); + } - Value paddedInput = im2colBlock->getArgument(0); + // Build im2col [numPatches, patchSize]: + // For each batch/output position (n, oh, ow), extract the patch from x + SmallVector im2colRows; + im2colRows.reserve(numPatches); + for (int64_t n = 0; n < batchSize; n++) { + for (int64_t oh = 0; oh < outHeight; oh++) { + for (int64_t ow = 0; ow < outWidth; ow++) { + SmallVector offsets = {rewriter.getIndexAttr(n), + rewriter.getIndexAttr(0), + rewriter.getIndexAttr(oh * strideHeight), + rewriter.getIndexAttr(ow * strideWidth)}; + SmallVector sizes = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(numChannelsIn), + rewriter.getIndexAttr(wHeight), + rewriter.getIndexAttr(wWidth)}; + SmallVector strides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(dilationHeight), + rewriter.getIndexAttr(dilationWidth)}; + auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); + Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides); - // Pad input with zeros if needed: - // [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth] - if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) { - const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd; - const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd; - auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType); - SmallVector lowPads = {rewriter.getIndexAttr(0), - rewriter.getIndexAttr(0), - rewriter.getIndexAttr(padHeightBegin), - rewriter.getIndexAttr(padWidthBegin)}; - SmallVector highPads = {rewriter.getIndexAttr(0), - rewriter.getIndexAttr(0), - rewriter.getIndexAttr(padHeightEnd), - rewriter.getIndexAttr(padWidthEnd)}; - auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads); - auto* padBlock = new Block(); - for (int i = 0; i < 4; i++) - padBlock->addArgument(rewriter.getIndexType(), loc); - padOp.getRegion().push_back(padBlock); - rewriter.setInsertionPointToStart(padBlock); - auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0)); - tensor::YieldOp::create(rewriter, loc, zero.getResult()); - rewriter.setInsertionPointAfter(padOp); - paddedInput = padOp.getResult(); - } - - // Build im2col [numPatches, patchSize]: - // For each batch/output position (n, oh, ow), extract the patch from x - SmallVector im2colRows; - im2colRows.reserve(numPatches); - for (int64_t n = 0; n < batchSize; n++) { - for (int64_t oh = 0; oh < outHeight; oh++) { - for (int64_t ow = 0; ow < outWidth; ow++) { - SmallVector offsets = {rewriter.getIndexAttr(n), - rewriter.getIndexAttr(0), - rewriter.getIndexAttr(oh * strideHeight), - rewriter.getIndexAttr(ow * strideWidth)}; - SmallVector sizes = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(numChannelsIn), - rewriter.getIndexAttr(wHeight), - rewriter.getIndexAttr(wWidth)}; - SmallVector strides = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(dilationHeight), - rewriter.getIndexAttr(dilationWidth)}; - auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); - Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides); - - // Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize] - Value row = tensor::CollapseShapeOp::create(rewriter, - loc, - rowType, - patch, - SmallVector { - {0}, - {1, 2, 3} - }); - im2colRows.push_back(row); + // Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize] + Value row = tensor::CollapseShapeOp::create(rewriter, + loc, + rowType, + patch, + SmallVector { + {0}, + {1, 2, 3} + }); + im2colRows.push_back(row); + } } } - } - // Concatenate all rows: [numPatches, patchSize] - Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows); - spatial::SpatYieldOp::create(rewriter, loc, im2col); - - rewriter.setInsertionPointAfter(im2colComputeOp); + // Concatenate all rows: [numPatches, patchSize] + Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows); + spatial::SpatYieldOp::create(rewriter, loc, im2col); + }); // Gemm: A @ B + C = im2col @ W^T + b // [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut] @@ -231,30 +225,23 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, Value gemmOut = gemmOp.getY(); auto collectComputeOp = - spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector(), ValueRange {gemmOut}); + createSpatCompute(rewriter, loc, convOp.getType(), {}, ValueRange {gemmOut}, [&](Value gemmOutArg) { + // Restore to NCHW layout: + // [numPatches, numChannelsOut] + // -> [1, outHeight, outWidth, numChannelsOut] + // -> [1, numChannelsOut, outHeight, outWidth] + Value nhwcOut = tensor::ExpandShapeOp::create(rewriter, + loc, + nhwcType, + gemmOutArg, + SmallVector { + {0, 1, 2}, + {3} + }); + Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2})); - auto* collectBlock = new Block(); - collectBlock->addArgument(gemmOut.getType(), loc); - collectComputeOp.getBody().push_back(collectBlock); - rewriter.setInsertionPointToStart(collectBlock); - - auto gemmOutArg = collectBlock->getArguments().front(); - - // Restore to NCHW layout: - // [numPatches, numChannelsOut] - // -> [1, outHeight, outWidth, numChannelsOut] - // -> [1, numChannelsOut, outHeight, outWidth] - Value nhwcOut = tensor::ExpandShapeOp::create(rewriter, - loc, - nhwcType, - gemmOutArg, - SmallVector { - {0, 1, 2}, - {3} - }); - Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2})); - - spatial::SpatYieldOp::create(rewriter, loc, nchwOut); + spatial::SpatYieldOp::create(rewriter, loc, nchwOut); + }); rewriter.replaceOp(convOp, collectComputeOp.getResult(0)); return success(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index 7ac5bfb..0e87f7a 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -155,18 +155,10 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, gemvOps.push_back(gemvOp.getY()); } - auto concatComputeOp = - spatial::SpatWeightedCompute::create(rewriter, loc, gemmOp.getType(), SmallVector(), gemvOps); - - auto* concatBlock = new Block(); - for (auto gemvOp : gemvOps) - concatBlock->addArgument(gemvOp.getType(), loc); - concatComputeOp.getBody().push_back(concatBlock); - rewriter.setInsertionPointToStart(concatBlock); - - auto blockArgs = concatBlock->getArguments(); - auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, blockArgs); - spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult()); + auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) { + auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs); + spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult()); + }); rewriter.replaceOp(gemmOp, concatComputeOp); return success(); @@ -289,25 +281,17 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, weights.push_back(bTiles[outSliceId][coreId][aSliceId]); auto computeOp = - spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId]); + createSpatCompute(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) { + SmallVector vmmOutputs; + vmmOutputs.reserve(aHSlicesArgs.size()); + for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs)) + vmmOutputs.push_back( + spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg)); + assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty"); - auto* computeBlock = new Block(); - for (auto aHSlice : aHSlices[coreId]) - computeBlock->addArgument(aHSlice.getType(), gemmLoc); - computeOp.getBody().push_back(computeBlock); - rewriter.setInsertionPointToStart(computeBlock); - - auto computeArgs = computeBlock->getArguments(); - SmallVector vmmOutputs; - vmmOutputs.reserve(computeArgs.size()); - for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++) - vmmOutputs.push_back( - spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId])); - assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty"); - - Value partialVmmSum = sumTensors(vmmOutputs, rewriter); - spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum); - rewriter.setInsertionPointAfter(computeOp); + Value partialVmmSum = sumTensors(vmmOutputs, rewriter); + spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum); + }); partialResults.push_back(computeOp.getResult(0)); } @@ -318,34 +302,20 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, } auto reduceComputeOp = - spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, SmallVector(), partialResults); - - auto* reduceBlock = new Block(); - for (auto partialResult : partialResults) - reduceBlock->addArgument(partialResult.getType(), gemmLoc); - reduceComputeOp.getBody().push_back(reduceBlock); - rewriter.setInsertionPointToStart(reduceBlock); - - auto blockArgs = reduceBlock->getArguments(); - Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter); - spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice); - rewriter.setInsertionPointAfter(reduceComputeOp); + createSpatCompute(rewriter, gemmLoc, currOutHSliceType, {}, partialResults, [&](ValueRange blockArgs) { + SmallVector values(blockArgs.begin(), blockArgs.end()); + Value outHSlice = sumTensors(values, rewriter); + spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice); + }); outHSlices.push_back(reduceComputeOp.getResult(0)); } auto concatComputeOp = - spatial::SpatWeightedCompute::create(rewriter, gemmLoc, gemmOp.getType(), SmallVector(), outHSlices); - - auto* concatBlock = new Block(); - for (auto outHSlice : outHSlices) - concatBlock->addArgument(outHSlice.getType(), gemmLoc); - concatComputeOp.getBody().push_back(concatBlock); - rewriter.setInsertionPointToStart(concatBlock); - - auto blockArgs = concatBlock->getArguments(); - auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs); - spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult()); + createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) { + auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs); + spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult()); + }); rewriter.replaceOp(gemmOp, concatComputeOp); return success(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index 2ec810e..fa5bb20 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -4,6 +4,7 @@ #include "llvm/ADT/SmallVector.h" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -81,19 +82,11 @@ struct MatMulRank3ToGemm : OpRewritePattern { } } - auto concatComputeOp = - spatial::SpatWeightedCompute::create(rewriter, loc, gemmOutType, SmallVector(), gemmRows); + auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOutType, {}, gemmRows, [&](ValueRange gemmRowsArgs) { + auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowsArgs); + spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult()); + }); - auto* concatBlock = new Block(); - for (Value gemmRow : gemmRows) - concatBlock->addArgument(gemmRow.getType(), loc); - concatComputeOp.getBody().push_back(concatBlock); - rewriter.setInsertionPointToStart(concatBlock); - - auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, concatBlock->getArguments()); - spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult()); - - rewriter.setInsertionPointAfter(concatComputeOp); Value gemmOut = concatComputeOp.getResult(0); Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter, loc, diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp index 281eab0..8d6000a 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp @@ -12,6 +12,7 @@ #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -154,91 +155,90 @@ struct PoolToSpatialComputeBase : public OpConversionPattern { const int64_t xbarSize = static_cast(crossbarSize.getValue()); const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize; - auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, outType, SmallVector(), ValueRange {x}); + constexpr size_t numInputs = 1; + auto computeOp = + createSpatCompute(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult { + SmallVector batchResults; + batchResults.reserve(batchSize); - auto* computeBlock = new Block(); - computeBlock->addArgument(xType, loc); - computeOp.getBody().push_back(computeBlock); - rewriter.setInsertionPointToStart(computeBlock); + for (int64_t batch = 0; batch < batchSize; ++batch) { + SmallVector rows; + rows.reserve(outputHeight); - Value input = computeBlock->getArgument(0); - SmallVector batchResults; - batchResults.reserve(batchSize); + for (int64_t outH = 0; outH < outputHeight; ++outH) { + SmallVector rowPixels; + rowPixels.reserve(outputWidth); - for (int64_t batch = 0; batch < batchSize; ++batch) { - SmallVector rows; - rows.reserve(outputHeight); + for (int64_t outW = 0; outW < outputWidth; ++outW) { + SmallVector outputChannelTiles; + outputChannelTiles.reserve(channelTileCount); - for (int64_t outH = 0; outH < outputHeight; ++outH) { - SmallVector rowPixels; - rowPixels.reserve(outputWidth); + for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) { + const int64_t tileChannels = std::min(xbarSize, channels - channelTile * xbarSize); + auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType()); - for (int64_t outW = 0; outW < outputWidth; ++outW) { - SmallVector outputChannelTiles; - outputChannelTiles.reserve(channelTileCount); + SmallVector windowValues; + windowValues.reserve(kernelHeight * kernelWidth); + for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { + const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop; + if (inH < 0 || inH >= inputHeight) + continue; - for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) { - const int64_t tileChannels = std::min(xbarSize, channels - channelTile * xbarSize); - auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType()); + for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { + const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft; + if (inW < 0 || inW >= inputWidth) + continue; - SmallVector windowValues; - windowValues.reserve(kernelHeight * kernelWidth); - for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { - const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop; - if (inH < 0 || inH >= inputHeight) - continue; + SmallVector offsets = {rewriter.getIndexAttr(batch), + rewriter.getIndexAttr(channelTile * xbarSize), + rewriter.getIndexAttr(inH), + rewriter.getIndexAttr(inW)}; + SmallVector sizes = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(tileChannels), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + SmallVector strides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; + Value windowValue = + tensor::ExtractSliceOp::create(rewriter, loc, tileType, xArg, offsets, sizes, strides); + windowValue = materializeContiguousTile(rewriter, loc, windowValue); + windowValues.push_back(windowValue); + } + } - for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { - const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft; - if (inW < 0 || inW >= inputWidth) - continue; + if (windowValues.empty()) + return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements."); - SmallVector offsets = {rewriter.getIndexAttr(batch), - rewriter.getIndexAttr(channelTile * xbarSize), - rewriter.getIndexAttr(inH), - rewriter.getIndexAttr(inW)}; - SmallVector sizes = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(tileChannels), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1)}; - SmallVector strides = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1)}; - Value windowValue = - tensor::ExtractSliceOp::create(rewriter, loc, tileType, input, offsets, sizes, strides); - windowValue = materializeContiguousTile(rewriter, loc, windowValue); - windowValues.push_back(windowValue); + Value reducedWindow = reduceWindowValues(rewriter, loc, windowValues); + if constexpr (std::is_same_v) { + const bool countIncludePad = poolOp.getCountIncludePad() == 1; + const int64_t divisor = + countIncludePad ? kernelHeight * kernelWidth : static_cast(windowValues.size()); + reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor); + } + + outputChannelTiles.push_back(reducedWindow); } + + rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles)); } - if (windowValues.empty()) - return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements."); - - Value reducedWindow = reduceWindowValues(rewriter, loc, windowValues); - if constexpr (std::is_same_v) { - const bool countIncludePad = poolOp.getCountIncludePad() == 1; - const int64_t divisor = - countIncludePad ? kernelHeight * kernelWidth : static_cast(windowValues.size()); - reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor); - } - - outputChannelTiles.push_back(reducedWindow); + rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels)); } - rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles)); + batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows)); } - rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels)); - } + Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults); + spatial::SpatYieldOp::create(rewriter, loc, pooledOutput); + return success(); + }); + if (failed(computeOp)) + return failure(); - batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows)); - } - - Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults); - spatial::SpatYieldOp::create(rewriter, loc, pooledOutput); - - rewriter.replaceOp(poolOp, computeOp.getResult(0)); + rewriter.replaceOp(poolOp, computeOp->getResult(0)); return success(); } };