112 lines
4.5 KiB
C++
112 lines
4.5 KiB
C++
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
|
|
|
static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
|
|
SmallVector<int64_t> permutedShape;
|
|
permutedShape.reserve(permutation.size());
|
|
for (int64_t axis : permutation)
|
|
permutedShape.push_back(shape[axis]);
|
|
return permutedShape;
|
|
}
|
|
|
|
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
|
auto inputType = cast<RankedTensorType>(input.getType());
|
|
constexpr size_t numInputs = 1;
|
|
auto computeOp =
|
|
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
|
|
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
|
|
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
|
|
});
|
|
return computeOp.getResult(0);
|
|
}
|
|
|
|
static Value
|
|
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
|
auto inputType = cast<RankedTensorType>(input.getType());
|
|
if (axis == inputType.getRank())
|
|
return createSoftmaxCompute(input, rewriter, loc);
|
|
|
|
if (axis == softmaxAxis)
|
|
return buildSoftmax(input, softmaxAxis, axis + 1, rewriter, loc);
|
|
|
|
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
|
|
SmallVector<Value> rebuiltSlices;
|
|
rebuiltSlices.reserve(slices.size());
|
|
for (Value slice : slices)
|
|
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
|
|
|
return rebuiltSlices.size() == 1 ? rebuiltSlices.front()
|
|
: tensor::ConcatOp::create(rewriter, loc, axis, rebuiltSlices).getResult();
|
|
}
|
|
|
|
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(ONNXSoftmaxOp softmaxOp,
|
|
ONNXSoftmaxOpAdaptor adaptor,
|
|
ConversionPatternRewriter& rewriter) const override {
|
|
auto inputType = dyn_cast<RankedTensorType>(adaptor.getInput().getType());
|
|
if (!inputType || !inputType.hasStaticShape())
|
|
return failure();
|
|
|
|
int64_t axis = normalizeAxis(softmaxOp.getAxis(), inputType.getRank());
|
|
if (axis < 0 || axis >= inputType.getRank())
|
|
return failure();
|
|
|
|
Value input = adaptor.getInput();
|
|
Value result;
|
|
if (axis == inputType.getRank() - 1) {
|
|
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
|
}
|
|
else {
|
|
SmallVector<int64_t> permutation;
|
|
permutation.reserve(inputType.getRank());
|
|
for (int64_t dim = 0; dim < inputType.getRank(); ++dim)
|
|
if (dim != axis)
|
|
permutation.push_back(dim);
|
|
permutation.push_back(axis);
|
|
|
|
SmallVector<int64_t> inversePermutation(inputType.getRank());
|
|
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
|
|
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
|
|
|
|
auto transposedType = RankedTensorType::get(
|
|
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
|
|
auto preTransposeCompute =
|
|
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) {
|
|
Value transposed = ONNXTransposeOp::create(
|
|
rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation));
|
|
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
|
});
|
|
Value transposedInput = preTransposeCompute.getResult(0);
|
|
Value transposedResult = buildSoftmax(
|
|
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
|
result = ONNXTransposeOp::create(
|
|
rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation));
|
|
}
|
|
|
|
rewriter.replaceOp(softmaxOp, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void populateSoftmaxPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
patterns.add<SoftmaxToSpatialCompute>(ctx);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|