#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/SmallVector.h" #include #include #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { template static int64_t getI64(ArrayAttrT arrayAttr, size_t index) { return cast(arrayAttr[index]).getInt(); } template static int64_t getOptionalI64(std::optional arrayAttr, size_t index, int64_t defaultValue) { return arrayAttr ? getI64(*arrayAttr, index) : defaultValue; } static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) { auto tileType = cast(tile.getType()); Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType()); SmallVector offsets(tileType.getRank(), rewriter.getIndexAttr(0)); SmallVector sizes; sizes.reserve(tileType.getRank()); for (int64_t dimSize : tileType.getShape()) sizes.push_back(rewriter.getIndexAttr(dimSize)); SmallVector strides(tileType.getRank(), rewriter.getIndexAttr(1)); return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides); } static Value createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) { if (!useMinimumValue) return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType)); if (auto floatType = dyn_cast(elementType)) { auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true); return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue)); } if (auto integerType = dyn_cast(elementType)) { auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth()); return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue)); } llvm_unreachable("unsupported pool element type"); } static Value createPoolFillTensor(ConversionPatternRewriter& rewriter, Location loc, RankedTensorType tensorType, bool useMinimumValue) { auto fillElement = createPoolFillElement(rewriter, loc, tensorType.getElementType(), useMinimumValue); return tensor::SplatOp::create(rewriter, loc, tensorType, fillElement); } template static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, Value input, RankedTensorType inputType, int64_t padTop, int64_t padLeft, int64_t padBottom, int64_t padRight) { if (padTop == 0 && padLeft == 0 && padBottom == 0 && padRight == 0) return input; auto paddedType = RankedTensorType::get({inputType.getDimSize(0), inputType.getDimSize(1), inputType.getDimSize(2) + padTop + padBottom, inputType.getDimSize(3) + padLeft + padRight}, inputType.getElementType(), inputType.getEncoding()); SmallVector lowPads = { rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(padTop), rewriter.getIndexAttr(padLeft)}; SmallVector highPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(padBottom), rewriter.getIndexAttr(padRight)}; auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, input, lowPads, highPads); auto* padBlock = new Block(); for (int index = 0; index < paddedType.getRank(); ++index) padBlock->addArgument(rewriter.getIndexType(), loc); padOp.getRegion().push_back(padBlock); rewriter.setInsertionPointToStart(padBlock); Value padValue = createPoolFillElement(rewriter, loc, inputType.getElementType(), std::is_same_v); tensor::YieldOp::create(rewriter, loc, padValue); rewriter.setInsertionPointAfter(padOp); return padOp.getResult(); } static FailureOr createAverageScaleTensor(ConversionPatternRewriter& rewriter, Location loc, Operation* op, RankedTensorType outType, int64_t channels, int64_t inputHeight, int64_t inputWidth, int64_t outputHeight, int64_t outputWidth, int64_t kernelHeight, int64_t kernelWidth, int64_t strideHeight, int64_t strideWidth, int64_t dilationHeight, int64_t dilationWidth, int64_t padTop, int64_t padLeft, bool countIncludePad) { auto elemType = dyn_cast(outType.getElementType()); if (!elemType) { op->emitOpError("AveragePool lowering requires a floating-point element type"); return failure(); } auto scaleType = RankedTensorType::get({1, channels, outputHeight, outputWidth}, elemType, outType.getEncoding()); SmallVector scaleValues; scaleValues.reserve(static_cast(channels * outputHeight * outputWidth)); for (int64_t channel = 0; channel < channels; ++channel) { (void) channel; for (int64_t outH = 0; outH < outputHeight; ++outH) { for (int64_t outW = 0; outW < outputWidth; ++outW) { int64_t validCount = 0; for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop; if (inH < 0 || inH >= inputHeight) continue; for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft; if (inW < 0 || inW >= inputWidth) continue; ++validCount; } } const int64_t divisor = countIncludePad ? kernelHeight * kernelWidth : validCount; if (divisor <= 0) { op->emitOpError("AveragePool divisor must be positive"); return failure(); } scaleValues.push_back(rewriter.getFloatAttr(elemType, 1.0 / static_cast(divisor))); } } } auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues); return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult(); } template struct PoolToSpatialCompute; template struct PoolToSpatialComputeBase : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { Location loc = poolOp.getLoc(); Value x = adaptor.getX(); auto xType = dyn_cast(x.getType()); auto outType = dyn_cast(poolOp.getResult().getType()); if (!xType || !outType || !xType.hasStaticShape() || !outType.hasStaticShape()) return rewriter.notifyMatchFailure(poolOp, "pool lowering requires static ranked tensor types."); if (xType.getRank() != 4 || outType.getRank() != 4) return rewriter.notifyMatchFailure(poolOp, "only 2D NCHW pool is supported."); ArrayAttr kernelAttr = poolOp.getKernelShape(); if (!kernelAttr || kernelAttr.size() != 2) return rewriter.notifyMatchFailure(poolOp, "pool lowering expects a 2D kernel."); const int64_t batchSize = xType.getDimSize(0); const int64_t channels = xType.getDimSize(1); const int64_t inputHeight = xType.getDimSize(2); const int64_t inputWidth = xType.getDimSize(3); const int64_t outputHeight = outType.getDimSize(2); const int64_t outputWidth = outType.getDimSize(3); const int64_t kernelHeight = getI64(kernelAttr, 0); const int64_t kernelWidth = getI64(kernelAttr, 1); const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1); const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1); const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1); const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1); int64_t padTop = 0; int64_t padLeft = 0; int64_t padBottom = 0; int64_t padRight = 0; if (auto padsAttr = poolOp.getPads()) { if (padsAttr->size() != 4) return rewriter.notifyMatchFailure(poolOp, "pads must have four elements."); padTop = getI64(*padsAttr, 0); padLeft = getI64(*padsAttr, 1); padBottom = getI64(*padsAttr, 2); padRight = getI64(*padsAttr, 3); } else { StringRef autoPad = poolOp.getAutoPad(); if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { const int64_t effectiveKernelH = (kernelHeight - 1) * dilationHeight + 1; const int64_t effectiveKernelW = (kernelWidth - 1) * dilationWidth + 1; const int64_t totalPadH = std::max(0, (outputHeight - 1) * strideHeight + effectiveKernelH - inputHeight); const int64_t totalPadW = std::max(0, (outputWidth - 1) * strideWidth + effectiveKernelW - inputWidth); if (autoPad == "SAME_UPPER") { padTop = totalPadH / 2; padBottom = totalPadH - padTop; padLeft = totalPadW / 2; padRight = totalPadW - padLeft; } else { padBottom = totalPadH / 2; padTop = totalPadH - padBottom; padRight = totalPadW / 2; padLeft = totalPadW - padRight; } } else if (autoPad != "NOTSET" && autoPad != "VALID") { return rewriter.notifyMatchFailure(poolOp, "unsupported auto_pad value."); } } const int64_t xbarSize = static_cast(crossbarSize.getValue()); const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize; const int64_t outputPatchCount = batchSize * outputHeight * outputWidth; const bool countIncludePad = [&]() { if constexpr (std::is_same_v) return poolOp.getCountIncludePad() == 1; return true; }(); Value averageScaleTensor; if constexpr (std::is_same_v) { auto maybeAverageScaleTensor = createAverageScaleTensor(rewriter, loc, poolOp, outType, channels, inputHeight, inputWidth, outputHeight, outputWidth, kernelHeight, kernelWidth, strideHeight, strideWidth, dilationHeight, dilationWidth, padTop, padLeft, countIncludePad); if (failed(maybeAverageScaleTensor)) return failure(); averageScaleTensor = *maybeAverageScaleTensor; } constexpr size_t numInputs = 1; auto computeOp = createSpatCompute(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult { Value paddedInput = createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight); Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()); Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount); Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth); Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth); Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight); Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth); auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit}); rewriter.setInsertionPointToStart(outputLoop.getBody()); Value outputPatchIndex = outputLoop.getInductionVar(); Value pooledOutputAcc = outputLoop.getRegionIterArgs().front(); Value batchIndex = arith::DivUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch); Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch); Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth); Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth); Value windowBaseH = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight); Value windowBaseW = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth); Value updatedOutput = pooledOutputAcc; for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) { const int64_t tileChannels = std::min(xbarSize, channels - channelTile * xbarSize); auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType()); Value reducedWindow = createPoolFillTensor(rewriter, loc, tileType, std::is_same_v); for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { Value paddedInH = windowBaseH; if (kernelH * dilationHeight != 0) { Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight); paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset); } for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { Value paddedInW = windowBaseW; if (kernelW * dilationWidth != 0) { Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth); paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset); } SmallVector offsets = { batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), paddedInH, paddedInW}; SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tileChannels), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector strides = { rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value windowValue = tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides); windowValue = materializeContiguousTile(rewriter, loc, windowValue); reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue); } } if constexpr (std::is_same_v) { SmallVector scaleOffsets = { rewriter.getIndexAttr(0), rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex}; SmallVector scaleSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tileChannels), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector scaleStrides = { rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value scaleSlice = tensor::ExtractSliceOp::create( rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides); scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice); reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice); } SmallVector outputOffsets = { batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex}; SmallVector outputSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(tileChannels), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector outputStrides = { rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; updatedOutput = tensor::InsertSliceOp::create( rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides); } scf::YieldOp::create(rewriter, loc, updatedOutput); rewriter.setInsertionPointAfter(outputLoop); spatial::SpatYieldOp::create(rewriter, loc, outputLoop.getResult(0)); return success(); }); if (failed(computeOp)) return failure(); rewriter.replaceOp(poolOp, computeOp->getResult(0)); return success(); } }; template <> struct PoolToSpatialCompute : public PoolToSpatialComputeBase { using PoolToSpatialComputeBase::PoolToSpatialComputeBase; }; template <> struct PoolToSpatialCompute : public PoolToSpatialComputeBase { using PoolToSpatialComputeBase::PoolToSpatialComputeBase; }; } // namespace void populatePoolPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert>(ctx); patterns.insert>(ctx); } } // namespace onnx_mlir