This commit is contained in:
@@ -24,18 +24,16 @@ static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64
|
||||
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
|
||||
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
|
||||
});
|
||||
auto computeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
|
||||
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value buildSoftmax(Value input,
|
||||
int64_t softmaxAxis,
|
||||
int64_t axis,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static Value
|
||||
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (axis == inputType.getRank())
|
||||
return createSoftmaxCompute(input, rewriter, loc);
|
||||
@@ -71,7 +69,8 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
Value result;
|
||||
if (axis == inputType.getRank() - 1) {
|
||||
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
SmallVector<int64_t> permutation;
|
||||
permutation.reserve(inputType.getRank());
|
||||
for (int64_t dim = 0; dim < inputType.getRank(); ++dim)
|
||||
@@ -85,14 +84,15 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
|
||||
auto transposedType = RankedTensorType::get(
|
||||
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
|
||||
auto preTransposeCompute = createSpatCompute<1>(
|
||||
rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation));
|
||||
auto preTransposeCompute =
|
||||
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) {
|
||||
Value transposed = ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation));
|
||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||
});
|
||||
Value transposedInput = preTransposeCompute.getResult(0);
|
||||
Value transposedResult = buildSoftmax(transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
Value transposedResult = buildSoftmax(
|
||||
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
result = ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user