This commit is contained in:
@@ -110,6 +110,14 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
|
|||||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs))
|
||||||
|
return mlir::failure();
|
||||||
|
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
|
||||||
|
}
|
||||||
|
|
||||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
|
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
|
||||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ bool isCoreStaticAddressOp(mlir::Operation* op) {
|
|||||||
mlir::arith::SubIOp,
|
mlir::arith::SubIOp,
|
||||||
mlir::arith::MulIOp,
|
mlir::arith::MulIOp,
|
||||||
mlir::arith::DivUIOp,
|
mlir::arith::DivUIOp,
|
||||||
|
mlir::arith::MinUIOp,
|
||||||
mlir::arith::RemUIOp,
|
mlir::arith::RemUIOp,
|
||||||
mlir::arith::IndexCastOp,
|
mlir::arith::IndexCastOp,
|
||||||
mlir::memref::AllocOp,
|
mlir::memref::AllocOp,
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
@@ -15,42 +15,88 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static Value
|
static Value buildNearestAsymmetricIndex(
|
||||||
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
|
Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
auto inputType = cast<RankedTensorType>(input.getType());
|
Value cInputDim = arith::ConstantIndexOp::create(rewriter, loc, inputDim);
|
||||||
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
|
Value cOutputDim = arith::ConstantIndexOp::create(rewriter, loc, outputDim);
|
||||||
SmallVector<OpFoldResult> sizes;
|
Value cInputDimLast = arith::ConstantIndexOp::create(rewriter, loc, inputDim - 1);
|
||||||
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
|
Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim);
|
||||||
sizes.reserve(inputType.getRank());
|
Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim);
|
||||||
for (int64_t dim : inputType.getShape())
|
return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast);
|
||||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
|
||||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
|
||||||
sizes[axis] = rewriter.getIndexAttr(1);
|
|
||||||
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static int64_t nearestAsymmetricIndex(int64_t outputIndex, int64_t inputDim, int64_t outputDim) {
|
static Value buildNearestResizeLoop(Value input,
|
||||||
return std::min<int64_t>((outputIndex * inputDim) / outputDim, inputDim - 1);
|
RankedTensorType inputType,
|
||||||
}
|
RankedTensorType resultType,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto elemType = resultType.getElementType();
|
||||||
|
SmallVector<int64_t> unitShape(resultType.getRank(), 1);
|
||||||
|
auto unitTensorType = RankedTensorType::get(unitShape, elemType);
|
||||||
|
|
||||||
static Value buildNearestResize(Value input,
|
SmallVector<OpFoldResult> unitSizes(resultType.getRank(), rewriter.getIndexAttr(1));
|
||||||
ArrayRef<int64_t> inputShape,
|
SmallVector<OpFoldResult> unitStrides(resultType.getRank(), rewriter.getIndexAttr(1));
|
||||||
ArrayRef<int64_t> outputShape,
|
|
||||||
int64_t axis,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
if (axis == static_cast<int64_t>(outputShape.size()))
|
|
||||||
return input;
|
|
||||||
|
|
||||||
SmallVector<Value> slices;
|
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||||
slices.reserve(outputShape[axis]);
|
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||||
for (int64_t outputIndex = 0; outputIndex < outputShape[axis]; ++outputIndex) {
|
Value cOutputN = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(0));
|
||||||
int64_t inputIndex = nearestAsymmetricIndex(outputIndex, inputShape[axis], outputShape[axis]);
|
Value cOutputC = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(1));
|
||||||
Value slice = extractSliceAt(input, axis, inputIndex, rewriter, loc);
|
Value cOutputH = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(2));
|
||||||
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
|
Value cOutputW = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(3));
|
||||||
}
|
|
||||||
|
|
||||||
return createSpatConcat(rewriter, loc, axis, slices);
|
Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType);
|
||||||
|
|
||||||
|
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cOutputN, c1, ValueRange {outputInit});
|
||||||
|
rewriter.setInsertionPointToStart(batchLoop.getBody());
|
||||||
|
|
||||||
|
Value outputN = batchLoop.getInductionVar();
|
||||||
|
Value outputBatchAcc = batchLoop.getRegionIterArgs().front();
|
||||||
|
Value inputN = buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, loc);
|
||||||
|
|
||||||
|
auto channelLoop = scf::ForOp::create(rewriter, loc, c0, cOutputC, c1, ValueRange {outputBatchAcc});
|
||||||
|
rewriter.setInsertionPointToStart(channelLoop.getBody());
|
||||||
|
|
||||||
|
Value outputC = channelLoop.getInductionVar();
|
||||||
|
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
|
||||||
|
Value inputC =
|
||||||
|
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
|
||||||
|
|
||||||
|
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
|
||||||
|
rewriter.setInsertionPointToStart(heightLoop.getBody());
|
||||||
|
|
||||||
|
Value outputH = heightLoop.getInductionVar();
|
||||||
|
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
|
||||||
|
Value inputH =
|
||||||
|
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
|
||||||
|
|
||||||
|
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
|
||||||
|
rewriter.setInsertionPointToStart(widthLoop.getBody());
|
||||||
|
|
||||||
|
Value outputW = widthLoop.getInductionVar();
|
||||||
|
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
|
||||||
|
Value inputW =
|
||||||
|
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
|
||||||
|
Value inputSlice =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, unitTensorType, input, inputOffsets, unitSizes, unitStrides);
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> outputOffsets = {outputN, outputC, outputH, outputW};
|
||||||
|
Value updatedOutput =
|
||||||
|
tensor::InsertSliceOp::create(rewriter, loc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides);
|
||||||
|
scf::YieldOp::create(rewriter, loc, updatedOutput);
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(widthLoop);
|
||||||
|
scf::YieldOp::create(rewriter, loc, widthLoop.getResult(0));
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(heightLoop);
|
||||||
|
scf::YieldOp::create(rewriter, loc, heightLoop.getResult(0));
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(channelLoop);
|
||||||
|
scf::YieldOp::create(rewriter, loc, channelLoop.getResult(0));
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(batchLoop);
|
||||||
|
return batchLoop.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Resize : OpConversionPattern<ONNXResizeOp> {
|
struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||||
@@ -62,20 +108,22 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
|
|||||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getX().getType());
|
auto inputType = dyn_cast<RankedTensorType>(adaptor.getX().getType());
|
||||||
auto resultType = dyn_cast<RankedTensorType>(resizeOp.getY().getType());
|
auto resultType = dyn_cast<RankedTensorType>(resizeOp.getY().getType());
|
||||||
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
|
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
return failure();
|
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires static ranked tensor types.");
|
||||||
|
if (inputType.getRank() != 4 || resultType.getRank() != 4)
|
||||||
|
return rewriter.notifyMatchFailure(resizeOp, "resize lowering currently supports only rank-4 NCHW tensors.");
|
||||||
|
|
||||||
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
||||||
|| resizeOp.getNearestMode() != "floor")
|
|| resizeOp.getNearestMode() != "floor")
|
||||||
return failure();
|
return rewriter.notifyMatchFailure(
|
||||||
|
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor.");
|
||||||
|
|
||||||
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|
||||||
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
||||||
return failure();
|
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires positive static dimensions.");
|
||||||
|
|
||||||
auto computeOp =
|
auto computeOp =
|
||||||
createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
|
createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
|
||||||
Value result =
|
Value result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc());
|
||||||
buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc());
|
|
||||||
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result);
|
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result);
|
||||||
});
|
});
|
||||||
rewriter.replaceOp(resizeOp, computeOp.getResults());
|
rewriter.replaceOp(resizeOp, computeOp.getResults());
|
||||||
|
|||||||
Reference in New Issue
Block a user