add better createSpatCompute helper

This commit is contained in:
NiccoloN
2026-03-30 16:14:26 +02:00
parent 39830be888
commit 3625edc80a
5 changed files with 259 additions and 239 deletions

View File

@@ -12,6 +12,7 @@
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -154,91 +155,90 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, outType, SmallVector<Value>(), ValueRange {x});
constexpr size_t numInputs = 1;
auto computeOp =
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
SmallVector<Value> batchResults;
batchResults.reserve(batchSize);
auto* computeBlock = new Block();
computeBlock->addArgument(xType, loc);
computeOp.getBody().push_back(computeBlock);
rewriter.setInsertionPointToStart(computeBlock);
for (int64_t batch = 0; batch < batchSize; ++batch) {
SmallVector<Value> rows;
rows.reserve(outputHeight);
Value input = computeBlock->getArgument(0);
SmallVector<Value> batchResults;
batchResults.reserve(batchSize);
for (int64_t outH = 0; outH < outputHeight; ++outH) {
SmallVector<Value> rowPixels;
rowPixels.reserve(outputWidth);
for (int64_t batch = 0; batch < batchSize; ++batch) {
SmallVector<Value> rows;
rows.reserve(outputHeight);
for (int64_t outW = 0; outW < outputWidth; ++outW) {
SmallVector<Value> outputChannelTiles;
outputChannelTiles.reserve(channelTileCount);
for (int64_t outH = 0; outH < outputHeight; ++outH) {
SmallVector<Value> rowPixels;
rowPixels.reserve(outputWidth);
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
for (int64_t outW = 0; outW < outputWidth; ++outW) {
SmallVector<Value> outputChannelTiles;
outputChannelTiles.reserve(channelTileCount);
SmallVector<Value> windowValues;
windowValues.reserve(kernelHeight * kernelWidth);
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
if (inH < 0 || inH >= inputHeight)
continue;
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
if (inW < 0 || inW >= inputWidth)
continue;
SmallVector<Value> windowValues;
windowValues.reserve(kernelHeight * kernelWidth);
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
if (inH < 0 || inH >= inputHeight)
continue;
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
rewriter.getIndexAttr(channelTile * xbarSize),
rewriter.getIndexAttr(inH),
rewriter.getIndexAttr(inW)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, xArg, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
windowValues.push_back(windowValue);
}
}
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
if (inW < 0 || inW >= inputWidth)
continue;
if (windowValues.empty())
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
rewriter.getIndexAttr(channelTile * xbarSize),
rewriter.getIndexAttr(inH),
rewriter.getIndexAttr(inW)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, input, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
windowValues.push_back(windowValue);
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
const int64_t divisor =
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
}
outputChannelTiles.push_back(reducedWindow);
}
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
}
if (windowValues.empty())
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
const int64_t divisor =
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
}
outputChannelTiles.push_back(reducedWindow);
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
}
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
}
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
}
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
return success();
});
if (failed(computeOp))
return failure();
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
}
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
rewriter.replaceOp(poolOp, computeOp.getResult(0));
rewriter.replaceOp(poolOp, computeOp->getResult(0));
return success();
}
};