From 78242e28874cdf0925db0733eb81ecac0fcff387 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Fri, 15 May 2026 17:36:12 +0200 Subject: [PATCH] compact resize op lowering --- src/PIM/Common/IR/AddressAnalysis.cpp | 8 ++ src/PIM/Common/IR/CoreBlockUtils.cpp | 1 + .../ONNXToSpatial/Patterns/Tensor/Resize.cpp | 124 ++++++++++++------ 3 files changed, 95 insertions(+), 38 deletions(-) diff --git a/src/PIM/Common/IR/AddressAnalysis.cpp b/src/PIM/Common/IR/AddressAnalysis.cpp index 029f50f..69a4ca2 100644 --- a/src/PIM/Common/IR/AddressAnalysis.cpp +++ b/src/PIM/Common/IR/AddressAnalysis.cpp @@ -110,6 +110,14 @@ llvm::FailureOr resolveIndexValueImpl(mlir::Value value, const StaticVa return static_cast(static_cast(*lhs) / static_cast(*rhs)); } + if (auto minOp = mlir::dyn_cast(definingOp)) { + auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge); + auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge); + if (failed(lhs) || failed(rhs)) + return mlir::failure(); + return static_cast(std::min(static_cast(*lhs), static_cast(*rhs))); + } + if (auto remOp = mlir::dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge); diff --git a/src/PIM/Common/IR/CoreBlockUtils.cpp b/src/PIM/Common/IR/CoreBlockUtils.cpp index 09327bb..a5cc241 100644 --- a/src/PIM/Common/IR/CoreBlockUtils.cpp +++ b/src/PIM/Common/IR/CoreBlockUtils.cpp @@ -12,6 +12,7 @@ bool isCoreStaticAddressOp(mlir::Operation* op) { mlir::arith::SubIOp, mlir::arith::MulIOp, mlir::arith::DivUIOp, + mlir::arith::MinUIOp, mlir::arith::RemUIOp, mlir::arith::IndexCastOp, mlir::memref::AllocOp, diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp index d3977d8..0749c20 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp @@ -1,10 +1,10 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/STLExtras.h" -#include - #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -15,42 +15,88 @@ using namespace mlir; namespace onnx_mlir { namespace { -static Value -extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) { - auto inputType = cast(input.getType()); - SmallVector offsets(inputType.getRank(), rewriter.getIndexAttr(0)); - SmallVector sizes; - SmallVector strides(inputType.getRank(), rewriter.getIndexAttr(1)); - sizes.reserve(inputType.getRank()); - for (int64_t dim : inputType.getShape()) - sizes.push_back(rewriter.getIndexAttr(dim)); - offsets[axis] = rewriter.getIndexAttr(offset); - sizes[axis] = rewriter.getIndexAttr(1); - return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides); +static Value buildNearestAsymmetricIndex( + Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) { + Value cInputDim = arith::ConstantIndexOp::create(rewriter, loc, inputDim); + Value cOutputDim = arith::ConstantIndexOp::create(rewriter, loc, outputDim); + Value cInputDimLast = arith::ConstantIndexOp::create(rewriter, loc, inputDim - 1); + Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim); + Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim); + return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast); } -static int64_t nearestAsymmetricIndex(int64_t outputIndex, int64_t inputDim, int64_t outputDim) { - return std::min((outputIndex * inputDim) / outputDim, inputDim - 1); -} +static Value buildNearestResizeLoop(Value input, + RankedTensorType inputType, + RankedTensorType resultType, + ConversionPatternRewriter& rewriter, + Location loc) { + auto elemType = resultType.getElementType(); + SmallVector unitShape(resultType.getRank(), 1); + auto unitTensorType = RankedTensorType::get(unitShape, elemType); -static Value buildNearestResize(Value input, - ArrayRef inputShape, - ArrayRef outputShape, - int64_t axis, - ConversionPatternRewriter& rewriter, - Location loc) { - if (axis == static_cast(outputShape.size())) - return input; + SmallVector unitSizes(resultType.getRank(), rewriter.getIndexAttr(1)); + SmallVector unitStrides(resultType.getRank(), rewriter.getIndexAttr(1)); - SmallVector slices; - slices.reserve(outputShape[axis]); - for (int64_t outputIndex = 0; outputIndex < outputShape[axis]; ++outputIndex) { - int64_t inputIndex = nearestAsymmetricIndex(outputIndex, inputShape[axis], outputShape[axis]); - Value slice = extractSliceAt(input, axis, inputIndex, rewriter, loc); - slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc)); - } + Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); + Value cOutputN = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(0)); + Value cOutputC = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(1)); + Value cOutputH = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(2)); + Value cOutputW = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(3)); - return createSpatConcat(rewriter, loc, axis, slices); + Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType); + + auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cOutputN, c1, ValueRange {outputInit}); + rewriter.setInsertionPointToStart(batchLoop.getBody()); + + Value outputN = batchLoop.getInductionVar(); + Value outputBatchAcc = batchLoop.getRegionIterArgs().front(); + Value inputN = buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, loc); + + auto channelLoop = scf::ForOp::create(rewriter, loc, c0, cOutputC, c1, ValueRange {outputBatchAcc}); + rewriter.setInsertionPointToStart(channelLoop.getBody()); + + Value outputC = channelLoop.getInductionVar(); + Value outputChannelAcc = channelLoop.getRegionIterArgs().front(); + Value inputC = + buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc); + + auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc}); + rewriter.setInsertionPointToStart(heightLoop.getBody()); + + Value outputH = heightLoop.getInductionVar(); + Value outputHeightAcc = heightLoop.getRegionIterArgs().front(); + Value inputH = + buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc); + + auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc}); + rewriter.setInsertionPointToStart(widthLoop.getBody()); + + Value outputW = widthLoop.getInductionVar(); + Value outputWidthAcc = widthLoop.getRegionIterArgs().front(); + Value inputW = + buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc); + + SmallVector inputOffsets = {inputN, inputC, inputH, inputW}; + Value inputSlice = + tensor::ExtractSliceOp::create(rewriter, loc, unitTensorType, input, inputOffsets, unitSizes, unitStrides); + + SmallVector outputOffsets = {outputN, outputC, outputH, outputW}; + Value updatedOutput = + tensor::InsertSliceOp::create(rewriter, loc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides); + scf::YieldOp::create(rewriter, loc, updatedOutput); + + rewriter.setInsertionPointAfter(widthLoop); + scf::YieldOp::create(rewriter, loc, widthLoop.getResult(0)); + + rewriter.setInsertionPointAfter(heightLoop); + scf::YieldOp::create(rewriter, loc, heightLoop.getResult(0)); + + rewriter.setInsertionPointAfter(channelLoop); + scf::YieldOp::create(rewriter, loc, channelLoop.getResult(0)); + + rewriter.setInsertionPointAfter(batchLoop); + return batchLoop.getResult(0); } struct Resize : OpConversionPattern { @@ -62,20 +108,22 @@ struct Resize : OpConversionPattern { auto inputType = dyn_cast(adaptor.getX().getType()); auto resultType = dyn_cast(resizeOp.getY().getType()); if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape()) - return failure(); + return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires static ranked tensor types."); + if (inputType.getRank() != 4 || resultType.getRank() != 4) + return rewriter.notifyMatchFailure(resizeOp, "resize lowering currently supports only rank-4 NCHW tensors."); if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric" || resizeOp.getNearestMode() != "floor") - return failure(); + return rewriter.notifyMatchFailure( + resizeOp, "resize lowering currently supports only nearest + asymmetric + floor."); if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; }) || llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; })) - return failure(); + return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires positive static dimensions."); auto computeOp = createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) { - Value result = - buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc()); + Value result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc()); spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result); }); rewriter.replaceOp(resizeOp, computeOp.getResults());