#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; } static SmallVector permuteShape(ArrayRef shape, ArrayRef permutation) { SmallVector permutedShape; permutedShape.reserve(permutation.size()); for (int64_t axis : permutation) permutedShape.push_back(shape[axis]); return permutedShape; } static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) { auto inputType = cast(input.getType()); constexpr size_t numInputs = 1; auto computeOp = createSpatCompute(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) { auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x); spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult()); }); return computeOp.getResult(0); } static Value buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) { auto inputType = cast(input.getType()); if (axis == inputType.getRank()) return createSoftmaxCompute(input, rewriter, loc); if (axis == softmaxAxis) return buildSoftmax(input, softmaxAxis, axis + 1, rewriter, loc); SmallVector slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc); SmallVector rebuiltSlices; rebuiltSlices.reserve(slices.size()); for (Value slice : slices) rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc)); return rebuiltSlices.size() == 1 ? rebuiltSlices.front() : tensor::ConcatOp::create(rewriter, loc, axis, rebuiltSlices).getResult(); } struct SoftmaxToSpatialCompute : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXSoftmaxOp softmaxOp, ONNXSoftmaxOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto inputType = dyn_cast(adaptor.getInput().getType()); if (!inputType || !inputType.hasStaticShape()) return failure(); int64_t axis = normalizeAxis(softmaxOp.getAxis(), inputType.getRank()); if (axis < 0 || axis >= inputType.getRank()) return failure(); Value input = adaptor.getInput(); Value result; if (axis == inputType.getRank() - 1) { result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc()); } else { SmallVector permutation; permutation.reserve(inputType.getRank()); for (int64_t dim = 0; dim < inputType.getRank(); ++dim) if (dim != axis) permutation.push_back(dim); permutation.push_back(axis); SmallVector inversePermutation(inputType.getRank()); for (auto [newIndex, oldIndex] : llvm::enumerate(permutation)) inversePermutation[oldIndex] = static_cast(newIndex); auto transposedType = RankedTensorType::get( permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding()); auto preTransposeCompute = createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) { Value transposed = ONNXTransposeOp::create( rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation)); spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed); }); Value transposedInput = preTransposeCompute.getResult(0); Value transposedResult = buildSoftmax( transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc()); result = ONNXTransposeOp::create( rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation)); } rewriter.replaceOp(softmaxOp, result); return success(); } }; } // namespace void populateSoftmaxPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } } // namespace onnx_mlir