From 1a0192d1f96865c084f8147815b8fa43618ecf1d Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Thu, 9 Apr 2026 14:25:00 +0200 Subject: [PATCH] add support for softmax, resize, split, gather --- src/PIM/Compiler/PimCodeGen.cpp | 16 ++ src/PIM/Compiler/PimCodeGen.hpp | 1 + .../Conversion/ONNXToSpatial/CMakeLists.txt | 4 + .../ONNXToSpatial/ONNXToSpatialPass.cpp | 7 + src/PIM/Conversion/ONNXToSpatial/Patterns.hpp | 8 + .../ONNXToSpatial/Patterns/NN/Softmax.cpp | 111 +++++++++++++ .../ONNXToSpatial/Patterns/Tensor/Gather.cpp | 152 ++++++++++++++++++ .../ONNXToSpatial/Patterns/Tensor/Resize.cpp | 93 +++++++++++ .../ONNXToSpatial/Patterns/Tensor/Split.cpp | 75 +++++++++ .../Conversion/SpatialToPim/SpatialToPim.td | 6 + .../SpatialToPim/SpatialToPimPass.cpp | 13 +- src/PIM/Dialect/Pim/Pim.td | 23 +++ .../OpBufferizationInterfaces.cpp | 1 + src/PIM/Dialect/Spatial/Spatial.td | 16 ++ .../MergeComputeNode/MergeComputeNodePass.cpp | 6 +- .../Pim/ConstantFolding/Patterns/Subview.cpp | 36 ++++- 16 files changed, 560 insertions(+), 8 deletions(-) create mode 100644 src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index d799c46..c48de47 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -386,6 +386,20 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const { emitInstruction(std::move(json)); } +void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const { + auto outputBufferAddr = memory.getValueAddress(vsoftmaxOp.getOutputBuffer()); + auto inputAddr = memory.getValueAddress(vsoftmaxOp.getInput()); + setupRdRs1(outputBufferAddr, 0, inputAddr, 0); + + json::Object json; + json["op"] = "vsoftmax"; + json["rd"] = 0; + json["rs1"] = 1; + json["offset"] = createEmptyOffset(); + json["len"] = getValueSizeInBytes(vsoftmaxOp.getInput()); + emitInstruction(std::move(json)); +} + void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const { auto srcAddr = memory.getValueAddress(transposeOp.getInput()); auto dstAddr = memory.getValueAddress(transposeOp.getOutputBuffer()); @@ -537,6 +551,8 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenVTanhOp(vtanhOp); else if (auto vsigmOp = dyn_cast(op)) coreCodeGen.codeGenVSigmOp(vsigmOp); + else if (auto vsoftmaxOp = dyn_cast(op)) + coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp); else { op.emitError("Unsupported codegen for this operation"); op.dump(); diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index ab0f1a0..81f340d 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -99,6 +99,7 @@ public: void codeGenVReluOp(pim::PimVReluOp vreluOp) const; void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const; void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const; + void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const; void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const; }; diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index 1e3b3ae..d26673f 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -11,8 +11,12 @@ add_pim_library(OMONNXToSpatial Patterns/NN/Pool.cpp Patterns/NN/Relu.cpp Patterns/NN/Sigmoid.cpp + Patterns/NN/Softmax.cpp Patterns/Tensor/Concat.cpp + Patterns/Tensor/Gather.cpp + Patterns/Tensor/Resize.cpp Patterns/Tensor/Reshape.cpp + Patterns/Tensor/Split.cpp ONNXToSpatialPass.cpp Common.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 1649c34..b495f13 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -89,9 +89,12 @@ void ONNXToSpatialPass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); RewritePatternSet patterns(ctx); patterns.add(ctx); @@ -103,8 +106,12 @@ void ONNXToSpatialPass::runOnOperation() { populateReduceMeanPatterns(patterns, ctx); populateReluPatterns(patterns, ctx); populateSigmoidPatterns(patterns, ctx); + populateSoftmaxPatterns(patterns, ctx); populateConcatPatterns(patterns, ctx); + populateGatherPatterns(patterns, ctx); + populateResizePatterns(patterns, ctx); populateReshapePatterns(patterns, ctx); + populateSplitPatterns(patterns, ctx); if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp index 38232ba..7c44286 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp @@ -21,8 +21,16 @@ void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* void populateSigmoidPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateSoftmaxPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + +void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp new file mode 100644 index 0000000..34ea12d --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp @@ -0,0 +1,111 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; } + +static SmallVector permuteShape(ArrayRef shape, ArrayRef permutation) { + SmallVector permutedShape; + permutedShape.reserve(permutation.size()); + for (int64_t axis : permutation) + permutedShape.push_back(shape[axis]); + return permutedShape; +} + +static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) { + auto inputType = cast(input.getType()); + constexpr size_t numInputs = 1; + auto computeOp = createSpatCompute(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) { + auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x); + spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult()); + }); + return computeOp.getResult(0); +} + +static Value buildSoftmax(Value input, + int64_t softmaxAxis, + int64_t axis, + ConversionPatternRewriter& rewriter, + Location loc) { + auto inputType = cast(input.getType()); + if (axis == inputType.getRank()) + return createSoftmaxCompute(input, rewriter, loc); + + if (axis == softmaxAxis) + return buildSoftmax(input, softmaxAxis, axis + 1, rewriter, loc); + + SmallVector slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc); + SmallVector rebuiltSlices; + rebuiltSlices.reserve(slices.size()); + for (Value slice : slices) + rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc)); + + return rebuiltSlices.size() == 1 ? rebuiltSlices.front() + : tensor::ConcatOp::create(rewriter, loc, axis, rebuiltSlices).getResult(); +} + +struct SoftmaxToSpatialCompute : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXSoftmaxOp softmaxOp, + ONNXSoftmaxOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto inputType = dyn_cast(adaptor.getInput().getType()); + if (!inputType || !inputType.hasStaticShape()) + return failure(); + + int64_t axis = normalizeAxis(softmaxOp.getAxis(), inputType.getRank()); + if (axis < 0 || axis >= inputType.getRank()) + return failure(); + + Value input = adaptor.getInput(); + Value result; + if (axis == inputType.getRank() - 1) { + result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc()); + } else { + SmallVector permutation; + permutation.reserve(inputType.getRank()); + for (int64_t dim = 0; dim < inputType.getRank(); ++dim) + if (dim != axis) + permutation.push_back(dim); + permutation.push_back(axis); + + SmallVector inversePermutation(inputType.getRank()); + for (auto [newIndex, oldIndex] : llvm::enumerate(permutation)) + inversePermutation[oldIndex] = static_cast(newIndex); + + auto transposedType = RankedTensorType::get( + permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding()); + auto preTransposeCompute = createSpatCompute<1>( + rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) { + Value transposed = + ONNXTransposeOp::create(rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation)); + spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed); + }); + Value transposedInput = preTransposeCompute.getResult(0); + Value transposedResult = buildSoftmax(transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc()); + result = ONNXTransposeOp::create( + rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation)); + } + + rewriter.replaceOp(softmaxOp, result); + return success(); + } +}; + +} // namespace + +void populateSoftmaxPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { + patterns.add(ctx); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp new file mode 100644 index 0000000..ef295aa --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp @@ -0,0 +1,152 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/SmallVector.h" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; } + +static int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; } + +static Value extractSliceAt(Value input, + int64_t axis, + int64_t offset, + ConversionPatternRewriter& rewriter, + Location loc) { + auto inputType = cast(input.getType()); + SmallVector offsets(inputType.getRank(), rewriter.getIndexAttr(0)); + SmallVector sizes; + SmallVector strides(inputType.getRank(), rewriter.getIndexAttr(1)); + sizes.reserve(inputType.getRank()); + for (int64_t dim : inputType.getShape()) + sizes.push_back(rewriter.getIndexAttr(dim)); + offsets[axis] = rewriter.getIndexAttr(offset); + sizes[axis] = rewriter.getIndexAttr(1); + return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides); +} + +static Value concatGatherSlices(Value data, + int64_t axis, + ArrayRef indices, + int64_t axisDim, + ConversionPatternRewriter& rewriter, + Location loc) { + SmallVector slices; + slices.reserve(indices.size()); + for (int64_t index : indices) { + int64_t normalizedIndex = normalizeIndex(index, axisDim); + if (normalizedIndex < 0 || normalizedIndex >= axisDim) + return {}; + slices.push_back(extractSliceAt(data, axis, normalizedIndex, rewriter, loc)); + } + if (slices.empty()) + return {}; + return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult(); +} + +static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) { + auto valueType = cast(value.getType()); + SmallVector resultShape; + SmallVector reassociation; + resultShape.reserve(valueType.getRank() + 1); + reassociation.reserve(valueType.getRank()); + + int64_t resultDim = 0; + for (int64_t dim = 0; dim < valueType.getRank(); ++dim) { + if (dim == axis) { + resultShape.push_back(1); + resultShape.push_back(valueType.getShape()[dim]); + reassociation.push_back({static_cast(resultDim), static_cast(resultDim + 1)}); + resultDim += 2; + continue; + } + resultShape.push_back(valueType.getShape()[dim]); + reassociation.push_back({static_cast(resultDim)}); + resultDim++; + } + + auto resultType = RankedTensorType::get(resultShape, valueType.getElementType(), valueType.getEncoding()); + return tensor::ExpandShapeOp::create(rewriter, loc, resultType, value, reassociation); +} + +struct Gather : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXGatherOp gatherOp, + ONNXGatherOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto dataType = dyn_cast(adaptor.getData().getType()); + auto indicesType = dyn_cast(adaptor.getIndices().getType()); + if (!dataType || !indicesType || !dataType.hasStaticShape() || !indicesType.hasStaticShape()) + return failure(); + + auto indicesConst = adaptor.getIndices().getDefiningOp(); + if (!indicesConst) + return failure(); + auto indicesAttr = dyn_cast(indicesConst.getValue()); + if (!indicesAttr) + return failure(); + + int64_t rank = dataType.getRank(); + int64_t axis = normalizeAxis(gatherOp.getAxis(), rank); + if (axis < 0 || axis >= rank) + return failure(); + + int64_t axisDim = dataType.getShape()[axis]; + if (axisDim <= 0) + return failure(); + + SmallVector flatIndices(indicesAttr.getValues().begin(), indicesAttr.getValues().end()); + Location loc = gatherOp.getLoc(); + + auto computeOp = createSpatCompute<1>( + rewriter, loc, TypeRange {gatherOp.getResult().getType()}, {}, adaptor.getData(), [&](Value data) -> LogicalResult { + Value result; + if (indicesType.getRank() == 1) { + result = concatGatherSlices(data, axis, flatIndices, axisDim, rewriter, loc); + } else if (indicesType.getRank() == 2) { + int64_t rowCount = indicesType.getShape()[0]; + int64_t rowWidth = indicesType.getShape()[1]; + SmallVector rows; + rows.reserve(rowCount); + for (int64_t row = 0; row < rowCount; ++row) { + ArrayRef rowIndices(flatIndices.data() + row * rowWidth, rowWidth); + Value gatheredRow = concatGatherSlices(data, axis, rowIndices, axisDim, rewriter, loc); + if (!gatheredRow) + return failure(); + rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc)); + } + result = + rows.size() == 1 ? rows.front() : tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult(); + } else { + return failure(); + } + + if (!result) + return failure(); + spatial::SpatYieldOp::create(rewriter, loc, result); + return success(); + }); + if (failed(computeOp)) + return failure(); + rewriter.replaceOp(gatherOp, computeOp->getResults()); + return success(); + } +}; + +} // namespace + +void populateGatherPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp new file mode 100644 index 0000000..53bb2d5 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp @@ -0,0 +1,93 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/STLExtras.h" + +#include + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static Value extractSliceAt(Value input, + int64_t axis, + int64_t offset, + ConversionPatternRewriter& rewriter, + Location loc) { + auto inputType = cast(input.getType()); + SmallVector offsets(inputType.getRank(), rewriter.getIndexAttr(0)); + SmallVector sizes; + SmallVector strides(inputType.getRank(), rewriter.getIndexAttr(1)); + sizes.reserve(inputType.getRank()); + for (int64_t dim : inputType.getShape()) + sizes.push_back(rewriter.getIndexAttr(dim)); + offsets[axis] = rewriter.getIndexAttr(offset); + sizes[axis] = rewriter.getIndexAttr(1); + return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides); +} + +static int64_t nearestAsymmetricIndex(int64_t outputIndex, int64_t inputDim, int64_t outputDim) { + return std::min((outputIndex * inputDim) / outputDim, inputDim - 1); +} + +static Value buildNearestResize(Value input, + ArrayRef inputShape, + ArrayRef outputShape, + int64_t axis, + ConversionPatternRewriter& rewriter, + Location loc) { + if (axis == static_cast(outputShape.size())) + return input; + + SmallVector slices; + slices.reserve(outputShape[axis]); + for (int64_t outputIndex = 0; outputIndex < outputShape[axis]; ++outputIndex) { + int64_t inputIndex = nearestAsymmetricIndex(outputIndex, inputShape[axis], outputShape[axis]); + Value slice = extractSliceAt(input, axis, inputIndex, rewriter, loc); + slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc)); + } + + return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult(); +} + +struct Resize : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXResizeOp resizeOp, + ONNXResizeOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto inputType = dyn_cast(adaptor.getX().getType()); + auto resultType = dyn_cast(resizeOp.getY().getType()); + if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape()) + return failure(); + + if (resizeOp.getMode() != "nearest" + || resizeOp.getCoordinateTransformationMode() != "asymmetric" + || resizeOp.getNearestMode() != "floor") + return failure(); + + if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; }) + || llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; })) + return failure(); + + auto computeOp = createSpatCompute<1>( + rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) { + Value result = buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc()); + spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result); + }); + rewriter.replaceOp(resizeOp, computeOp.getResults()); + return success(); + } +}; + +} // namespace + +void populateResizePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp new file mode 100644 index 0000000..1196657 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp @@ -0,0 +1,75 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; } + +static Value extractSliceAt(Value input, + int64_t axis, + int64_t offset, + int64_t size, + ConversionPatternRewriter& rewriter, + Location loc) { + auto inputType = cast(input.getType()); + SmallVector offsets(inputType.getRank(), rewriter.getIndexAttr(0)); + SmallVector sizes; + SmallVector strides(inputType.getRank(), rewriter.getIndexAttr(1)); + sizes.reserve(inputType.getRank()); + for (int64_t dim : inputType.getShape()) + sizes.push_back(rewriter.getIndexAttr(dim)); + offsets[axis] = rewriter.getIndexAttr(offset); + sizes[axis] = rewriter.getIndexAttr(size); + return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides); +} + +struct Split : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXSplitOp splitOp, + ONNXSplitOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto inputType = dyn_cast(adaptor.getInput().getType()); + if (!inputType || !inputType.hasStaticShape()) + return failure(); + + int64_t rank = inputType.getRank(); + int64_t axis = normalizeAxis(splitOp.getAxis(), rank); + if (axis < 0 || axis >= rank) + return failure(); + + SmallVector outputs; + outputs.reserve(splitOp.getNumResults()); + + int64_t offset = 0; + for (Value result : splitOp.getResults()) { + auto resultType = dyn_cast(result.getType()); + if (!resultType || !resultType.hasStaticShape()) + return failure(); + int64_t sliceSize = resultType.getShape()[axis]; + auto computeOp = + createSpatCompute<1>(rewriter, splitOp.getLoc(), TypeRange {resultType}, {}, adaptor.getInput(), [&](Value x) { + Value output = extractSliceAt(x, axis, offset, sliceSize, rewriter, splitOp.getLoc()); + spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), output); + }); + outputs.push_back(computeOp.getResult(0)); + offset += sliceSize; + } + + rewriter.replaceOp(splitOp, outputs); + return success(); + } +}; + +} // namespace + +void populateSplitPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td index ce94a90..a0fbce5 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td @@ -63,4 +63,10 @@ def spatToPimVSigm : Pat< (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; +def spatToPimVSoftmax : Pat< + (SpatSoftmaxOp:$srcOpRes $input), + (PimVSoftmaxOp $input, + (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) +>; + #endif // SPATIAL_TO_PIM diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index b527f6d..07148d3 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -618,17 +618,22 @@ void SpatialToPimPass::markOpToRemove(Operation* op) { } void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) { - for (auto it : llvm::enumerate(returnOp.getOperands())) { - Operation* returnOperand = it.value().getDefiningOp(); - + SmallVector originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end()); + for (auto it : llvm::enumerate(originalOperands)) { size_t orderWithinReturn = it.index(); + Operation* returnOperand = it.value().getDefiningOp(); rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); }); Operation* opToErase = returnOperand; while (opToErase) { - bool isExclusivelyOwnedByReturnChain = opToErase->use_empty() || opToErase->hasOneUse(); + bool isExclusivelyOwnedByReturnChain = opToErase->use_empty(); + if (!isExclusivelyOwnedByReturnChain && opToErase->hasOneUse()) { + Operation* onlyUser = *opToErase->getUsers().begin(); + isExclusivelyOwnedByReturnChain = + isa(onlyUser) || isChannelUseChainOp(onlyUser); + } if (!isExclusivelyOwnedByReturnChain) break; diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index 300ddd1..2a6d8e9 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -455,4 +455,27 @@ def PimVSigmOp : PimOp<"vsigm", [DestinationStyleOpInterface]> { }]; } +def PimVSoftmaxOp : PimOp<"vsoftmax", [DestinationStyleOpInterface]> { + let summary = "Softmax over the full input vector"; + + let arguments = (ins + PimTensor:$input, + PimTensor:$outputBuffer + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output) + }]; +} + #endif // PIM_DIALECT_H diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index 5d29208..eae40c2 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -273,6 +273,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) { PimVReluOp::attachInterface>(*ctx); PimVTanhOp::attachInterface>(*ctx); PimVSigmOp::attachInterface>(*ctx); + PimVSoftmaxOp::attachInterface>(*ctx); }); } diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 12fdd87..c8f419e 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -272,6 +272,22 @@ def SpatSigmoidOp : SpatOp<"sigmoid", []> { }]; } +def SpatSoftmaxOp : SpatOp<"softmax", []> { + let summary = "Softmax over the full input tensor slice"; + + let arguments = (ins + SpatTensor:$input + ); + + let results = (outs + SpatTensor:$output + ); + + let assemblyFormat = [{ + `(` $input `)` attr-dict `:` type($input) `->` type($output) + }]; +} + def SpatReluOp : SpatOp<"relu", []> { let summary = "Element-wise ReLU activation"; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp index c64b6b7..550a8a5 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp @@ -202,9 +202,9 @@ private: rewriter.clone(op, mapper); } - for (auto users : oldWeightedCompute->getUsers()) - if (auto funcRet = dyn_cast(users)) - funcRet.setOperand(0, newWeightedCompute.getResult(0)); + for (auto& use : llvm::make_early_inc_range(oldWeightedCompute->getUses())) + if (isa(use.getOwner())) + use.assign(newWeightedCompute.getResult(0)); oldToNewComputeMap.insert({oldWeightedCompute, newWeightedCompute}); return {cast(newWeightedCompute), computeValueResults}; diff --git a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp b/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp index 32e5f43..68034a8 100644 --- a/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp +++ b/src/PIM/Pass/Pim/ConstantFolding/Patterns/Subview.cpp @@ -146,6 +146,37 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override { + auto status = rewriteSubviewCopyLikeOp( + copyOp, + copyOp.getHostTarget(), + copyOp.getDeviceSource(), + copyOp.getHostTargetOffset(), + copyOp.getDeviceSourceOffset(), + copyOp.getSize(), + rewriter, + [&]( + MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) { + pim::PimMemCopyDevToHostOp::create(rewriter, + copyOp.getLoc(), + resultType, + dst, + src, + rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), + rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), + rewriter.getI32IntegerAttr(static_cast(sliceBytes))); + }); + if (failed(status)) + return failure(); + + rewriter.replaceOp(copyOp, copyOp.getHostTarget()); + return success(); + } +}; + struct FoldConstantCoreSubviewPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -209,7 +240,10 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern( + patterns.add( patterns.getContext()); }