add better createSpatCompute helper

This commit is contained in:
NiccoloN
2026-03-30 16:14:26 +02:00
parent 39830be888
commit 3625edc80a
5 changed files with 259 additions and 239 deletions

View File

@@ -104,15 +104,34 @@ inline auto getTensorShape(mlir::Value tensor) {
namespace detail { namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
template <typename Fn, size_t... Is> template <typename Fn, size_t... Is>
void invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) { decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
std::forward<Fn>(fn)(block->getArgument(Is)...); return std::forward<Fn>(fn)(block->getArgument(Is)...);
} }
template <size_t>
using ValueArg = mlir::Value;
template <typename Fn, typename Seq>
struct InvokeWithBlockArgsResult;
template <typename Fn, size_t... Is>
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
};
template <typename Fn, typename Seq>
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
} // namespace detail } // namespace detail
template <size_t NumInputs, typename BodyFn> template <size_t NumInputs, typename RewriterT, typename BodyFn>
spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter& rewriter, auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc, mlir::Location loc,
mlir::TypeRange resultTypes, mlir::TypeRange resultTypes,
mlir::ValueRange weights, mlir::ValueRange weights,
@@ -128,11 +147,62 @@ spatial::SpatWeightedCompute createSpatCompute(mlir::ConversionPatternRewriter&
computeOp.getBody().push_back(block); computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block); rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
auto bodyResult =
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
}
else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {}); detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
return computeOp; return computeOp;
} }
}
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
}
else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
}
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice, llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis, size_t axis,

View File

@@ -6,6 +6,7 @@
#include <cassert> #include <cassert>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -138,15 +139,9 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
else else
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
auto im2colComputeOp = constexpr size_t numInputs = 1;
spatial::SpatWeightedCompute::create(rewriter, loc, im2colType, SmallVector<Value>(), ValueRange {x}); auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) {
Value paddedInput = xArg;
auto* im2colBlock = new Block();
im2colBlock->addArgument(x.getType(), loc);
im2colComputeOp.getBody().push_back(im2colBlock);
rewriter.setInsertionPointToStart(im2colBlock);
Value paddedInput = im2colBlock->getArgument(0);
// Pad input with zeros if needed: // Pad input with zeros if needed:
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth] // [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
@@ -213,8 +208,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
// Concatenate all rows: [numPatches, patchSize] // Concatenate all rows: [numPatches, patchSize]
Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows); Value im2col = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, im2colRows);
spatial::SpatYieldOp::create(rewriter, loc, im2col); spatial::SpatYieldOp::create(rewriter, loc, im2col);
});
rewriter.setInsertionPointAfter(im2colComputeOp);
// Gemm: A @ B + C = im2col @ W^T + b // Gemm: A @ B + C = im2col @ W^T + b
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut] // [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
@@ -231,15 +225,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
Value gemmOut = gemmOp.getY(); Value gemmOut = gemmOp.getY();
auto collectComputeOp = auto collectComputeOp =
spatial::SpatWeightedCompute::create(rewriter, loc, convOp.getType(), SmallVector<Value>(), ValueRange {gemmOut}); createSpatCompute<numInputs>(rewriter, loc, convOp.getType(), {}, ValueRange {gemmOut}, [&](Value gemmOutArg) {
auto* collectBlock = new Block();
collectBlock->addArgument(gemmOut.getType(), loc);
collectComputeOp.getBody().push_back(collectBlock);
rewriter.setInsertionPointToStart(collectBlock);
auto gemmOutArg = collectBlock->getArguments().front();
// Restore to NCHW layout: // Restore to NCHW layout:
// [numPatches, numChannelsOut] // [numPatches, numChannelsOut]
// -> [1, outHeight, outWidth, numChannelsOut] // -> [1, outHeight, outWidth, numChannelsOut]
@@ -255,6 +241,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2})); Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
spatial::SpatYieldOp::create(rewriter, loc, nchwOut); spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
});
rewriter.replaceOp(convOp, collectComputeOp.getResult(0)); rewriter.replaceOp(convOp, collectComputeOp.getResult(0));
return success(); return success();

View File

@@ -155,18 +155,10 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
gemvOps.push_back(gemvOp.getY()); gemvOps.push_back(gemvOp.getY());
} }
auto concatComputeOp = auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOp.getType(), SmallVector<Value>(), gemvOps); auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs);
auto* concatBlock = new Block();
for (auto gemvOp : gemvOps)
concatBlock->addArgument(gemvOp.getType(), loc);
concatComputeOp.getBody().push_back(concatBlock);
rewriter.setInsertionPointToStart(concatBlock);
auto blockArgs = concatBlock->getArguments();
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, blockArgs);
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult()); spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
});
rewriter.replaceOp(gemmOp, concatComputeOp); rewriter.replaceOp(gemmOp, concatComputeOp);
return success(); return success();
@@ -289,25 +281,17 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
weights.push_back(bTiles[outSliceId][coreId][aSliceId]); weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp = auto computeOp =
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId]); createSpatCompute(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
auto* computeBlock = new Block();
for (auto aHSlice : aHSlices[coreId])
computeBlock->addArgument(aHSlice.getType(), gemmLoc);
computeOp.getBody().push_back(computeBlock);
rewriter.setInsertionPointToStart(computeBlock);
auto computeArgs = computeBlock->getArguments();
SmallVector<Value> vmmOutputs; SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(computeArgs.size()); vmmOutputs.reserve(aHSlicesArgs.size());
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++) for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
vmmOutputs.push_back( vmmOutputs.push_back(
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId])); spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty"); assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
Value partialVmmSum = sumTensors(vmmOutputs, rewriter); Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum); spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
rewriter.setInsertionPointAfter(computeOp); });
partialResults.push_back(computeOp.getResult(0)); partialResults.push_back(computeOp.getResult(0));
} }
@@ -318,34 +302,20 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
} }
auto reduceComputeOp = auto reduceComputeOp =
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults); createSpatCompute(rewriter, gemmLoc, currOutHSliceType, {}, partialResults, [&](ValueRange blockArgs) {
SmallVector<Value> values(blockArgs.begin(), blockArgs.end());
auto* reduceBlock = new Block(); Value outHSlice = sumTensors(values, rewriter);
for (auto partialResult : partialResults)
reduceBlock->addArgument(partialResult.getType(), gemmLoc);
reduceComputeOp.getBody().push_back(reduceBlock);
rewriter.setInsertionPointToStart(reduceBlock);
auto blockArgs = reduceBlock->getArguments();
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice); spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice);
rewriter.setInsertionPointAfter(reduceComputeOp); });
outHSlices.push_back(reduceComputeOp.getResult(0)); outHSlices.push_back(reduceComputeOp.getResult(0));
} }
auto concatComputeOp = auto concatComputeOp =
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices); createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
auto* concatBlock = new Block();
for (auto outHSlice : outHSlices)
concatBlock->addArgument(outHSlice.getType(), gemmLoc);
concatComputeOp.getBody().push_back(concatBlock);
rewriter.setInsertionPointToStart(concatBlock);
auto blockArgs = concatBlock->getArguments();
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs); auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult()); spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
});
rewriter.replaceOp(gemmOp, concatComputeOp); rewriter.replaceOp(gemmOp, concatComputeOp);
return success(); return success();

View File

@@ -4,6 +4,7 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -81,19 +82,11 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
} }
} }
auto concatComputeOp = auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOutType, {}, gemmRows, [&](ValueRange gemmRowsArgs) {
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOutType, SmallVector<Value>(), gemmRows); auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowsArgs);
auto* concatBlock = new Block();
for (Value gemmRow : gemmRows)
concatBlock->addArgument(gemmRow.getType(), loc);
concatComputeOp.getBody().push_back(concatBlock);
rewriter.setInsertionPointToStart(concatBlock);
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, concatBlock->getArguments());
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult()); spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
});
rewriter.setInsertionPointAfter(concatComputeOp);
Value gemmOut = concatComputeOp.getResult(0); Value gemmOut = concatComputeOp.getResult(0);
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter, Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
loc, loc,

View File

@@ -12,6 +12,7 @@
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -154,14 +155,9 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue()); const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize; const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, outType, SmallVector<Value>(), ValueRange {x}); constexpr size_t numInputs = 1;
auto computeOp =
auto* computeBlock = new Block(); createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
computeBlock->addArgument(xType, loc);
computeOp.getBody().push_back(computeBlock);
rewriter.setInsertionPointToStart(computeBlock);
Value input = computeBlock->getArgument(0);
SmallVector<Value> batchResults; SmallVector<Value> batchResults;
batchResults.reserve(batchSize); batchResults.reserve(batchSize);
@@ -206,7 +202,7 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)}; rewriter.getIndexAttr(1)};
Value windowValue = Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, input, offsets, sizes, strides); tensor::ExtractSliceOp::create(rewriter, loc, tileType, xArg, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue); windowValue = materializeContiguousTile(rewriter, loc, windowValue);
windowValues.push_back(windowValue); windowValues.push_back(windowValue);
} }
@@ -237,8 +233,12 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults); Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput); spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
return success();
});
if (failed(computeOp))
return failure();
rewriter.replaceOp(poolOp, computeOp.getResult(0)); rewriter.replaceOp(poolOp, computeOp->getResult(0));
return success(); return success();
} }
}; };