165 lines
6.6 KiB
C++
165 lines
6.6 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
static Value buildLoopSoftmaxSlice(Value input,
|
|
Value accumulator,
|
|
RankedTensorType inputType,
|
|
ArrayRef<Value> outerIndices,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
int64_t rank = inputType.getRank();
|
|
SmallVector<int64_t> sliceShape(static_cast<size_t>(rank - 1), 1);
|
|
sliceShape.push_back(inputType.getDimSize(rank - 1));
|
|
auto sliceType = RankedTensorType::get(sliceShape, inputType.getElementType(), inputType.getEncoding());
|
|
|
|
SmallVector<OpFoldResult> offsets;
|
|
SmallVector<OpFoldResult> sizes;
|
|
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, rank);
|
|
offsets.reserve(rank);
|
|
sizes.reserve(rank);
|
|
|
|
for (Value outerIndex : outerIndices) {
|
|
offsets.push_back(outerIndex);
|
|
sizes.push_back(rewriter.getIndexAttr(1));
|
|
}
|
|
offsets.push_back(rewriter.getIndexAttr(0));
|
|
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(rank - 1)));
|
|
|
|
Value inputSlice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
|
Value softmaxSlice = spatial::SpatSoftmaxOp::create(rewriter, loc, sliceType, inputSlice).getResult();
|
|
return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides);
|
|
}
|
|
|
|
static FailureOr<Value> buildLoopSoftmaxNest(Value input,
|
|
Value accumulator,
|
|
RankedTensorType inputType,
|
|
int64_t axis,
|
|
SmallVectorImpl<Value>& outerIndices,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
if (axis == inputType.getRank() - 1)
|
|
return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc);
|
|
|
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
|
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
|
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
|
Value cUpper = getOrCreateIndexConstant(rewriter, anchorOp, inputType.getDimSize(axis));
|
|
|
|
auto loop = buildNormalizedScfFor(
|
|
rewriter,
|
|
loc,
|
|
c0,
|
|
cUpper,
|
|
c1,
|
|
ValueRange {accumulator},
|
|
[&](OpBuilder& builder, Location nestedLoc, Value loopIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
|
|
outerIndices.push_back(loopIndex);
|
|
auto updatedAccumulator =
|
|
buildLoopSoftmaxNest(input, iterArgs.front(), inputType, axis + 1, outerIndices, rewriter, nestedLoc);
|
|
outerIndices.pop_back();
|
|
if (failed(updatedAccumulator))
|
|
return failure();
|
|
yielded.push_back(*updatedAccumulator);
|
|
return success();
|
|
});
|
|
if (failed(loop))
|
|
return failure();
|
|
return loop->results.front();
|
|
}
|
|
|
|
static FailureOr<Value> createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
|
auto inputType = cast<RankedTensorType>(input.getType());
|
|
constexpr size_t numInputs = 1;
|
|
auto computeOp = createSpatCompute<numInputs>(
|
|
rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) -> LogicalResult {
|
|
if (inputType.getRank() == 1) {
|
|
Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult();
|
|
spatial::SpatYieldOp::create(rewriter, loc, softmax);
|
|
return success();
|
|
}
|
|
|
|
Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType());
|
|
SmallVector<Value> outerIndices;
|
|
auto result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc);
|
|
if (failed(result))
|
|
return failure();
|
|
spatial::SpatYieldOp::create(rewriter, loc, *result);
|
|
return success();
|
|
});
|
|
if (failed(computeOp))
|
|
return failure();
|
|
return computeOp->getResult(0);
|
|
}
|
|
|
|
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(ONNXSoftmaxOp softmaxOp,
|
|
ONNXSoftmaxOpAdaptor adaptor,
|
|
ConversionPatternRewriter& rewriter) const override {
|
|
auto inputType = dyn_cast<RankedTensorType>(adaptor.getInput().getType());
|
|
if (!inputType || !inputType.hasStaticShape())
|
|
return failure();
|
|
|
|
auto axis = normalizeAxisChecked(softmaxOp.getAxis(), inputType.getRank());
|
|
if (failed(axis))
|
|
return failure();
|
|
|
|
Value input = adaptor.getInput();
|
|
Value result;
|
|
if (*axis == inputType.getRank() - 1) {
|
|
auto computed = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
|
|
if (failed(computed))
|
|
return failure();
|
|
result = *computed;
|
|
}
|
|
else {
|
|
SmallVector<int64_t> permutation;
|
|
permutation.reserve(inputType.getRank());
|
|
for (int64_t dim = 0; dim < inputType.getRank(); ++dim)
|
|
if (dim != *axis)
|
|
permutation.push_back(dim);
|
|
permutation.push_back(*axis);
|
|
SmallVector<int64_t> inversePermutation = invertPermutation(permutation);
|
|
|
|
auto transposedType = RankedTensorType::get(
|
|
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
|
|
Value transposedInput =
|
|
ONNXTransposeOp::create(
|
|
rewriter, softmaxOp.getLoc(), transposedType, input, rewriter.getI64ArrayAttr(permutation))
|
|
.getResult();
|
|
auto transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
|
if (failed(transposedResult))
|
|
return failure();
|
|
result =
|
|
ONNXTransposeOp::create(
|
|
rewriter, softmaxOp.getLoc(), inputType, *transposedResult, rewriter.getI64ArrayAttr(inversePermutation))
|
|
.getResult();
|
|
}
|
|
|
|
rewriter.replaceOp(softmaxOp, result);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void populateSoftmaxPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
patterns.add<SoftmaxToSpatialCompute>(ctx);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|