This commit is contained in:
@@ -23,28 +23,10 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
template <typename ArrayAttrT>
|
||||
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
|
||||
return cast<IntegerAttr>(arrayAttr[index]).getInt();
|
||||
}
|
||||
|
||||
template <typename ArrayAttrT>
|
||||
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
|
||||
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
||||
}
|
||||
|
||||
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
||||
auto tileType = cast<RankedTensorType>(tile.getType());
|
||||
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
||||
|
||||
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(tileType.getRank());
|
||||
for (int64_t dimSize : tileType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dimSize));
|
||||
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
|
||||
|
||||
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
||||
return insertStaticSlice(rewriter, loc, tile, empty, getZeroOffsets(rewriter, tileType.getRank()));
|
||||
}
|
||||
|
||||
static Value
|
||||
@@ -197,12 +179,12 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
const int64_t inputWidth = xType.getDimSize(3);
|
||||
const int64_t outputHeight = outType.getDimSize(2);
|
||||
const int64_t outputWidth = outType.getDimSize(3);
|
||||
const int64_t kernelHeight = getI64(kernelAttr, 0);
|
||||
const int64_t kernelWidth = getI64(kernelAttr, 1);
|
||||
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1);
|
||||
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1);
|
||||
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1);
|
||||
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1);
|
||||
const int64_t kernelHeight = getI64Attr(kernelAttr, 0);
|
||||
const int64_t kernelWidth = getI64Attr(kernelAttr, 1);
|
||||
const int64_t strideHeight = getOptionalI64Attr(poolOp.getStrides(), 0, 1);
|
||||
const int64_t strideWidth = getOptionalI64Attr(poolOp.getStrides(), 1, 1);
|
||||
const int64_t dilationHeight = getOptionalI64Attr(poolOp.getDilations(), 0, 1);
|
||||
const int64_t dilationWidth = getOptionalI64Attr(poolOp.getDilations(), 1, 1);
|
||||
|
||||
int64_t padTop = 0;
|
||||
int64_t padLeft = 0;
|
||||
@@ -212,10 +194,10 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
if (auto padsAttr = poolOp.getPads()) {
|
||||
if (padsAttr->size() != 4)
|
||||
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
|
||||
padTop = getI64(*padsAttr, 0);
|
||||
padLeft = getI64(*padsAttr, 1);
|
||||
padBottom = getI64(*padsAttr, 2);
|
||||
padRight = getI64(*padsAttr, 3);
|
||||
padTop = getI64Attr(*padsAttr, 0);
|
||||
padLeft = getI64Attr(*padsAttr, 1);
|
||||
padBottom = getI64Attr(*padsAttr, 2);
|
||||
padRight = getI64Attr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
StringRef autoPad = poolOp.getAutoPad();
|
||||
|
||||
@@ -13,16 +13,6 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
||||
|
||||
static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> permutedShape;
|
||||
permutedShape.reserve(permutation.size());
|
||||
for (int64_t axis : permutation)
|
||||
permutedShape.push_back(shape[axis]);
|
||||
return permutedShape;
|
||||
}
|
||||
|
||||
static Value buildLoopSoftmaxSlice(Value input,
|
||||
Value accumulator,
|
||||
RankedTensorType inputType,
|
||||
@@ -36,7 +26,7 @@ static Value buildLoopSoftmaxSlice(Value input,
|
||||
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, rank);
|
||||
offsets.reserve(rank);
|
||||
sizes.reserve(rank);
|
||||
|
||||
@@ -110,44 +100,31 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
if (!inputType || !inputType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
int64_t axis = normalizeAxis(softmaxOp.getAxis(), inputType.getRank());
|
||||
if (axis < 0 || axis >= inputType.getRank())
|
||||
auto axis = normalizeAxisChecked(softmaxOp.getAxis(), inputType.getRank());
|
||||
if (failed(axis))
|
||||
return failure();
|
||||
|
||||
Value input = adaptor.getInput();
|
||||
Value result;
|
||||
if (axis == inputType.getRank() - 1) {
|
||||
if (*axis == inputType.getRank() - 1) {
|
||||
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
|
||||
}
|
||||
else {
|
||||
SmallVector<int64_t> permutation;
|
||||
permutation.reserve(inputType.getRank());
|
||||
for (int64_t dim = 0; dim < inputType.getRank(); ++dim)
|
||||
if (dim != axis)
|
||||
if (dim != *axis)
|
||||
permutation.push_back(dim);
|
||||
permutation.push_back(axis);
|
||||
|
||||
SmallVector<int64_t> inversePermutation(inputType.getRank());
|
||||
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
|
||||
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
|
||||
permutation.push_back(*axis);
|
||||
SmallVector<int64_t> inversePermutation = invertPermutation(permutation);
|
||||
|
||||
auto transposedType = RankedTensorType::get(
|
||||
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
|
||||
auto preTransposeCompute =
|
||||
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) {
|
||||
Value transposed = ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation));
|
||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||
});
|
||||
Value transposedInput = preTransposeCompute.getResult(0);
|
||||
Value transposedInput =
|
||||
transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc());
|
||||
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
||||
auto postTransposeCompute =
|
||||
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
|
||||
Value transposed = ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation));
|
||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||
});
|
||||
result = postTransposeCompute.getResult(0);
|
||||
result = transposeMaybeInCompute(
|
||||
transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc());
|
||||
}
|
||||
|
||||
rewriter.replaceOp(softmaxOp, result);
|
||||
|
||||
Reference in New Issue
Block a user