automatic code-reformat
Validate Operations / validate-operations (push) Successful in 18m22s

This commit is contained in:
NiccoloN
2026-04-09 14:27:23 +02:00
parent 1a0192d1f9
commit 9e0d31af50
16 changed files with 88 additions and 103 deletions
@@ -24,18 +24,16 @@ static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
});
auto computeOp =
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
});
return computeOp.getResult(0);
}
static Value buildSoftmax(Value input,
int64_t softmaxAxis,
int64_t axis,
ConversionPatternRewriter& rewriter,
Location loc) {
static Value
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
if (axis == inputType.getRank())
return createSoftmaxCompute(input, rewriter, loc);
@@ -71,7 +69,8 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
Value result;
if (axis == inputType.getRank() - 1) {
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc());
} else {
}
else {
SmallVector<int64_t> permutation;
permutation.reserve(inputType.getRank());
for (int64_t dim = 0; dim < inputType.getRank(); ++dim)
@@ -85,14 +84,15 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
auto transposedType = RankedTensorType::get(
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
auto preTransposeCompute = createSpatCompute<1>(
rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) {
Value transposed =
ONNXTransposeOp::create(rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation));
auto preTransposeCompute =
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) {
Value transposed = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation));
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
});
Value transposedInput = preTransposeCompute.getResult(0);
Value transposedResult = buildSoftmax(transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
Value transposedResult = buildSoftmax(
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
result = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation));
}