diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp index 278ddc1..ad88c52 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp @@ -1,9 +1,10 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -22,53 +23,83 @@ static SmallVector permuteShape(ArrayRef shape, ArrayRef outerIndices, + ConversionPatternRewriter& rewriter, + Location loc) { + int64_t rank = inputType.getRank(); + SmallVector sliceShape(static_cast(rank - 1), 1); + sliceShape.push_back(inputType.getDimSize(rank - 1)); + auto sliceType = RankedTensorType::get(sliceShape, inputType.getElementType(), inputType.getEncoding()); + + SmallVector offsets; + SmallVector sizes; + SmallVector strides(rank, rewriter.getIndexAttr(1)); + offsets.reserve(rank); + sizes.reserve(rank); + + for (Value outerIndex : outerIndices) { + offsets.push_back(outerIndex); + sizes.push_back(rewriter.getIndexAttr(1)); + } + offsets.push_back(rewriter.getIndexAttr(0)); + sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(rank - 1))); + + Value inputSlice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides); + Value softmaxSlice = spatial::SpatSoftmaxOp::create(rewriter, loc, sliceType, inputSlice).getResult(); + return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides); +} + +static Value buildLoopSoftmaxNest(Value input, + Value accumulator, + RankedTensorType inputType, + int64_t axis, + SmallVectorImpl& outerIndices, + ConversionPatternRewriter& rewriter, + Location loc) { + if (axis == inputType.getRank() - 1) + return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc); + + Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); + Value cUpper = arith::ConstantIndexOp::create(rewriter, loc, inputType.getDimSize(axis)); + + auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator}); + rewriter.setInsertionPointToStart(loop.getBody()); + + Value loopIndex = loop.getInductionVar(); + Value loopAccumulator = loop.getRegionIterArgs().front(); + outerIndices.push_back(loopIndex); + Value updatedAccumulator = + buildLoopSoftmaxNest(input, loopAccumulator, inputType, axis + 1, outerIndices, rewriter, loc); + outerIndices.pop_back(); + + scf::YieldOp::create(rewriter, loc, updatedAccumulator); + rewriter.setInsertionPointAfter(loop); + return loop.getResult(0); +} + +static Value createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) { auto inputType = cast(input.getType()); constexpr size_t numInputs = 1; auto computeOp = createSpatCompute(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) { - auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x); - spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult()); + if (inputType.getRank() == 1) { + Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult(); + spatial::SpatYieldOp::create(rewriter, loc, softmax); + return; + } + + Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType()); + SmallVector outerIndices; + Value result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc); + spatial::SpatYieldOp::create(rewriter, loc, result); }); return computeOp.getResult(0); } -static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) { - auto firstType = cast(inputs.front().getType()); - SmallVector outputShape(firstType.getShape().begin(), firstType.getShape().end()); - int64_t concatDimSize = 0; - for (Value input : inputs) - concatDimSize += cast(input.getType()).getDimSize(axis); - outputShape[axis] = concatDimSize; - auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding()); - - if (llvm::all_of(inputs, isHostFoldableValue)) - return createSpatConcat(rewriter, loc, axis, inputs); - - auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) { - spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args)); - }); - return concatCompute.getResult(0); -} - -static Value -buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) { - auto inputType = cast(input.getType()); - if (axis == inputType.getRank()) - return createSoftmaxCompute(input, rewriter, loc); - - if (axis == softmaxAxis) - return buildSoftmax(input, softmaxAxis, axis + 1, rewriter, loc); - - SmallVector slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc); - SmallVector rebuiltSlices; - rebuiltSlices.reserve(slices.size()); - for (Value slice : slices) - rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc)); - - return concatValues(rebuiltSlices, axis, rewriter, loc); -} - struct SoftmaxToSpatialCompute : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -86,7 +117,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern { Value input = adaptor.getInput(); Value result; if (axis == inputType.getRank() - 1) { - result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc()); + result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc()); } else { SmallVector permutation; @@ -109,8 +140,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern { spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed); }); Value transposedInput = preTransposeCompute.getResult(0); - Value transposedResult = buildSoftmax( - transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc()); + Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc()); auto postTransposeCompute = createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) { Value transposed = ONNXTransposeOp::create(