#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" #include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static Value buildLoopSoftmaxSlice(Value input, Value accumulator, RankedTensorType inputType, ArrayRef outerIndices, ConversionPatternRewriter& rewriter, Location loc) { int64_t rank = inputType.getRank(); SmallVector sliceShape(static_cast(rank - 1), 1); sliceShape.push_back(inputType.getDimSize(rank - 1)); auto sliceType = RankedTensorType::get(sliceShape, inputType.getElementType(), inputType.getEncoding()); SmallVector offsets; SmallVector sizes; SmallVector strides = getUnitStrides(rewriter, rank); offsets.reserve(rank); sizes.reserve(rank); for (Value outerIndex : outerIndices) { offsets.push_back(outerIndex); sizes.push_back(rewriter.getIndexAttr(1)); } offsets.push_back(rewriter.getIndexAttr(0)); sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(rank - 1))); Value inputSlice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides); Value softmaxSlice = spatial::SpatSoftmaxOp::create(rewriter, loc, sliceType, inputSlice).getResult(); return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides); } static FailureOr buildLoopSoftmaxNest(Value input, Value accumulator, RankedTensorType inputType, int64_t axis, SmallVectorImpl& outerIndices, ConversionPatternRewriter& rewriter, Location loc) { if (axis == inputType.getRank() - 1) return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); Value cUpper = getOrCreateIndexConstant(rewriter, anchorOp, inputType.getDimSize(axis)); auto loop = buildNormalizedScfFor( rewriter, loc, c0, cUpper, c1, ValueRange {accumulator}, [&](OpBuilder& builder, Location nestedLoc, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { outerIndices.push_back(loopIndex); auto updatedAccumulator = buildLoopSoftmaxNest(input, iterArgs.front(), inputType, axis + 1, outerIndices, rewriter, nestedLoc); outerIndices.pop_back(); if (failed(updatedAccumulator)) return failure(); yielded.push_back(*updatedAccumulator); return success(); }); if (failed(loop)) return failure(); return loop->results.front(); } static FailureOr createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) { auto inputType = cast(input.getType()); constexpr size_t numInputs = 1; auto computeOp = createSpatCompute( rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) -> LogicalResult { if (inputType.getRank() == 1) { Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult(); spatial::SpatYieldOp::create(rewriter, loc, softmax); return success(); } Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType()); SmallVector outerIndices; auto result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc); if (failed(result)) return failure(); spatial::SpatYieldOp::create(rewriter, loc, *result); return success(); }); if (failed(computeOp)) return failure(); return computeOp->getResult(0); } struct SoftmaxToSpatialCompute : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXSoftmaxOp softmaxOp, ONNXSoftmaxOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto inputType = dyn_cast(adaptor.getInput().getType()); if (!inputType || !inputType.hasStaticShape()) return failure(); auto axis = normalizeAxisChecked(softmaxOp.getAxis(), inputType.getRank()); if (failed(axis)) return failure(); Value input = adaptor.getInput(); Value result; if (*axis == inputType.getRank() - 1) { auto computed = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc()); if (failed(computed)) return failure(); result = *computed; } else { SmallVector permutation; permutation.reserve(inputType.getRank()); for (int64_t dim = 0; dim < inputType.getRank(); ++dim) if (dim != *axis) permutation.push_back(dim); permutation.push_back(*axis); SmallVector inversePermutation = invertPermutation(permutation); auto transposedType = RankedTensorType::get( permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding()); Value transposedInput = ONNXTransposeOp::create( rewriter, softmaxOp.getLoc(), transposedType, input, rewriter.getI64ArrayAttr(permutation)) .getResult(); auto transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc()); if (failed(transposedResult)) return failure(); result = ONNXTransposeOp::create( rewriter, softmaxOp.getLoc(), inputType, *transposedResult, rewriter.getI64ArrayAttr(inversePermutation)) .getResult(); } rewriter.replaceOp(softmaxOp, result); return success(); } }; } // namespace void populateSoftmaxPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } } // namespace onnx_mlir