compact resize op lowering
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-15 17:36:12 +02:00
parent fe244d5aa1
commit 78242e2887
3 changed files with 95 additions and 38 deletions
+8
View File
@@ -110,6 +110,14 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
}
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
}
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
+1
View File
@@ -12,6 +12,7 @@ bool isCoreStaticAddressOp(mlir::Operation* op) {
mlir::arith::SubIOp,
mlir::arith::MulIOp,
mlir::arith::DivUIOp,
mlir::arith::MinUIOp,
mlir::arith::RemUIOp,
mlir::arith::IndexCastOp,
mlir::memref::AllocOp,
@@ -1,10 +1,10 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
#include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -15,42 +15,88 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static Value
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(inputType.getRank());
for (int64_t dim : inputType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(1);
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
static Value buildNearestAsymmetricIndex(
Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) {
Value cInputDim = arith::ConstantIndexOp::create(rewriter, loc, inputDim);
Value cOutputDim = arith::ConstantIndexOp::create(rewriter, loc, outputDim);
Value cInputDimLast = arith::ConstantIndexOp::create(rewriter, loc, inputDim - 1);
Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim);
Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim);
return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast);
}
static int64_t nearestAsymmetricIndex(int64_t outputIndex, int64_t inputDim, int64_t outputDim) {
return std::min<int64_t>((outputIndex * inputDim) / outputDim, inputDim - 1);
}
static Value buildNearestResizeLoop(Value input,
RankedTensorType inputType,
RankedTensorType resultType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto elemType = resultType.getElementType();
SmallVector<int64_t> unitShape(resultType.getRank(), 1);
auto unitTensorType = RankedTensorType::get(unitShape, elemType);
static Value buildNearestResize(Value input,
ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> outputShape,
int64_t axis,
ConversionPatternRewriter& rewriter,
Location loc) {
if (axis == static_cast<int64_t>(outputShape.size()))
return input;
SmallVector<OpFoldResult> unitSizes(resultType.getRank(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> unitStrides(resultType.getRank(), rewriter.getIndexAttr(1));
SmallVector<Value> slices;
slices.reserve(outputShape[axis]);
for (int64_t outputIndex = 0; outputIndex < outputShape[axis]; ++outputIndex) {
int64_t inputIndex = nearestAsymmetricIndex(outputIndex, inputShape[axis], outputShape[axis]);
Value slice = extractSliceAt(input, axis, inputIndex, rewriter, loc);
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
}
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
Value cOutputN = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(0));
Value cOutputC = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(1));
Value cOutputH = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(2));
Value cOutputW = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(3));
return createSpatConcat(rewriter, loc, axis, slices);
Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType);
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cOutputN, c1, ValueRange {outputInit});
rewriter.setInsertionPointToStart(batchLoop.getBody());
Value outputN = batchLoop.getInductionVar();
Value outputBatchAcc = batchLoop.getRegionIterArgs().front();
Value inputN = buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, loc);
auto channelLoop = scf::ForOp::create(rewriter, loc, c0, cOutputC, c1, ValueRange {outputBatchAcc});
rewriter.setInsertionPointToStart(channelLoop.getBody());
Value outputC = channelLoop.getInductionVar();
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
Value inputC =
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
rewriter.setInsertionPointToStart(heightLoop.getBody());
Value outputH = heightLoop.getInductionVar();
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
Value inputH =
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
rewriter.setInsertionPointToStart(widthLoop.getBody());
Value outputW = widthLoop.getInductionVar();
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
Value inputW =
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
Value inputSlice =
tensor::ExtractSliceOp::create(rewriter, loc, unitTensorType, input, inputOffsets, unitSizes, unitStrides);
SmallVector<OpFoldResult> outputOffsets = {outputN, outputC, outputH, outputW};
Value updatedOutput =
tensor::InsertSliceOp::create(rewriter, loc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides);
scf::YieldOp::create(rewriter, loc, updatedOutput);
rewriter.setInsertionPointAfter(widthLoop);
scf::YieldOp::create(rewriter, loc, widthLoop.getResult(0));
rewriter.setInsertionPointAfter(heightLoop);
scf::YieldOp::create(rewriter, loc, heightLoop.getResult(0));
rewriter.setInsertionPointAfter(channelLoop);
scf::YieldOp::create(rewriter, loc, channelLoop.getResult(0));
rewriter.setInsertionPointAfter(batchLoop);
return batchLoop.getResult(0);
}
struct Resize : OpConversionPattern<ONNXResizeOp> {
@@ -62,20 +108,22 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
auto inputType = dyn_cast<RankedTensorType>(adaptor.getX().getType());
auto resultType = dyn_cast<RankedTensorType>(resizeOp.getY().getType());
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires static ranked tensor types.");
if (inputType.getRank() != 4 || resultType.getRank() != 4)
return rewriter.notifyMatchFailure(resizeOp, "resize lowering currently supports only rank-4 NCHW tensors.");
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|| resizeOp.getNearestMode() != "floor")
return failure();
return rewriter.notifyMatchFailure(
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor.");
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
return failure();
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires positive static dimensions.");
auto computeOp =
createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
Value result =
buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc());
Value result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc());
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result);
});
rewriter.replaceOp(resizeOp, computeOp.getResults());