#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/SmallVector.h" #include #include #include #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { template static int64_t getI64(ArrayAttrT arrayAttr, size_t index) { return cast(arrayAttr[index]).getInt(); } template static int64_t getOptionalI64(std::optional arrayAttr, size_t index, int64_t defaultValue) { return arrayAttr ? getI64(*arrayAttr, index) : defaultValue; } static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef values) { assert(!values.empty() && "Expected at least one value to concatenate."); if (values.size() == 1) return values.front(); return tensor::ConcatOp::create(rewriter, loc, axis, values); } static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) { auto tileType = cast(tile.getType()); Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType()); SmallVector offsets(tileType.getRank(), rewriter.getIndexAttr(0)); SmallVector sizes; sizes.reserve(tileType.getRank()); for (int64_t dimSize : tileType.getShape()) sizes.push_back(rewriter.getIndexAttr(dimSize)); SmallVector strides(tileType.getRank(), rewriter.getIndexAttr(1)); return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides); } template static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef windowValues) { assert(!windowValues.empty() && "Expected at least one pool window value."); Value reduced = windowValues.front(); for (Value value : windowValues.drop_front()) reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value); return reduced; } static Value scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) { assert(divisor > 0 && "AveragePool divisor must be positive."); if (divisor == 1) return reducedWindow; auto tileType = cast(reducedWindow.getType()); double scale = 1.0 / static_cast(divisor); auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale)); Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr); return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor); } template struct PoolToSpatialCompute; template struct PoolToSpatialComputeBase : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { Location loc = poolOp.getLoc(); Value x = adaptor.getX(); auto xType = dyn_cast(x.getType()); auto outType = dyn_cast(poolOp.getResult().getType()); if (!xType || !outType || !xType.hasStaticShape() || !outType.hasStaticShape()) return rewriter.notifyMatchFailure(poolOp, "pool lowering requires static ranked tensor types."); if (xType.getRank() != 4 || outType.getRank() != 4) return rewriter.notifyMatchFailure(poolOp, "only 2D NCHW pool is supported."); ArrayAttr kernelAttr = poolOp.getKernelShape(); if (!kernelAttr || kernelAttr.size() != 2) return rewriter.notifyMatchFailure(poolOp, "pool lowering expects a 2D kernel."); const int64_t batchSize = xType.getDimSize(0); const int64_t channels = xType.getDimSize(1); const int64_t inputHeight = xType.getDimSize(2); const int64_t inputWidth = xType.getDimSize(3); const int64_t outputHeight = outType.getDimSize(2); const int64_t outputWidth = outType.getDimSize(3); const int64_t kernelHeight = getI64(kernelAttr, 0); const int64_t kernelWidth = getI64(kernelAttr, 1); const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1); const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1); const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1); const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1); int64_t padTop = 0; int64_t padLeft = 0; int64_t padBottom = 0; int64_t padRight = 0; if (auto padsAttr = poolOp.getPads()) { if (padsAttr->size() != 4) return rewriter.notifyMatchFailure(poolOp, "pads must have four elements."); padTop = getI64(*padsAttr, 0); padLeft = getI64(*padsAttr, 1); padBottom = getI64(*padsAttr, 2); padRight = getI64(*padsAttr, 3); } else { StringRef autoPad = poolOp.getAutoPad(); if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { const int64_t effectiveKernelH = (kernelHeight - 1) * dilationHeight + 1; const int64_t effectiveKernelW = (kernelWidth - 1) * dilationWidth + 1; const int64_t totalPadH = std::max(0, (outputHeight - 1) * strideHeight + effectiveKernelH - inputHeight); const int64_t totalPadW = std::max(0, (outputWidth - 1) * strideWidth + effectiveKernelW - inputWidth); if (autoPad == "SAME_UPPER") { padTop = totalPadH / 2; padBottom = totalPadH - padTop; padLeft = totalPadW / 2; padRight = totalPadW - padLeft; } else { padBottom = totalPadH / 2; padTop = totalPadH - padBottom; padRight = totalPadW / 2; padLeft = totalPadW - padRight; } } else if (autoPad != "NOTSET" && autoPad != "VALID") { return rewriter.notifyMatchFailure(poolOp, "unsupported auto_pad value."); } } (void) padBottom; (void) padRight; const int64_t xbarSize = static_cast(crossbarSize.getValue()); const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize; auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, outType, SmallVector(), ValueRange {x}); auto* computeBlock = new Block(); computeBlock->addArgument(xType, loc); computeOp.getBody().push_back(computeBlock); rewriter.setInsertionPointToStart(computeBlock); Value input = computeBlock->getArgument(0); SmallVector batchResults; batchResults.reserve(batchSize); for (int64_t batch = 0; batch < batchSize; ++batch) { SmallVector rows; rows.reserve(outputHeight); for (int64_t outH = 0; outH < outputHeight; ++outH) { SmallVector rowPixels; rowPixels.reserve(outputWidth); for (int64_t outW = 0; outW < outputWidth; ++outW) { SmallVector outputChannelTiles; outputChannelTiles.reserve(channelTileCount); for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) { const int64_t tileChannels = std::min(xbarSize, channels - channelTile * xbarSize); auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType()); SmallVector windowValues; windowValues.reserve(kernelHeight * kernelWidth); for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop; if (inH < 0 || inH >= inputHeight) continue; for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft; if (inW < 0 || inW >= inputWidth) continue; SmallVector offsets = {rewriter.getIndexAttr(batch), rewriter.getIndexAttr(channelTile * xbarSize), rewriter.getIndexAttr(inH), rewriter.getIndexAttr(inW)}; SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tileChannels), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value windowValue = tensor::ExtractSliceOp::create(rewriter, loc, tileType, input, offsets, sizes, strides); windowValue = materializeContiguousTile(rewriter, loc, windowValue); windowValues.push_back(windowValue); } } if (windowValues.empty()) return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements."); Value reducedWindow = reduceWindowValues(rewriter, loc, windowValues); if constexpr (std::is_same_v) { const bool countIncludePad = poolOp.getCountIncludePad() == 1; const int64_t divisor = countIncludePad ? kernelHeight * kernelWidth : static_cast(windowValues.size()); reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor); } outputChannelTiles.push_back(reducedWindow); } rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles)); } rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels)); } batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows)); } Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults); spatial::SpatYieldOp::create(rewriter, loc, pooledOutput); rewriter.replaceOp(poolOp, computeOp.getResult(0)); return success(); } }; template <> struct PoolToSpatialCompute : public PoolToSpatialComputeBase { using PoolToSpatialComputeBase::PoolToSpatialComputeBase; }; template <> struct PoolToSpatialCompute : public PoolToSpatialComputeBase { using PoolToSpatialComputeBase::PoolToSpatialComputeBase; }; } // namespace void populatePoolPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert>(ctx); patterns.insert>(ctx); } } // namespace onnx_mlir