add support for softmax, resize, split, gather
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -386,6 +386,20 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const {
|
||||
auto outputBufferAddr = memory.getValueAddress(vsoftmaxOp.getOutputBuffer());
|
||||
auto inputAddr = memory.getValueAddress(vsoftmaxOp.getInput());
|
||||
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vsoftmax";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vsoftmaxOp.getInput());
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
|
||||
auto srcAddr = memory.getValueAddress(transposeOp.getInput());
|
||||
auto dstAddr = memory.getValueAddress(transposeOp.getOutputBuffer());
|
||||
@@ -537,6 +551,8 @@ static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
|
||||
coreCodeGen.codeGenVTanhOp(vtanhOp);
|
||||
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
|
||||
coreCodeGen.codeGenVSigmOp(vsigmOp);
|
||||
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
|
||||
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp);
|
||||
else {
|
||||
op.emitError("Unsupported codegen for this operation");
|
||||
op.dump();
|
||||
|
||||
@@ -99,6 +99,7 @@ public:
|
||||
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
|
||||
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const;
|
||||
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const;
|
||||
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp) const;
|
||||
void codeGenTransposeOp(pim::PimTransposeOp transposeOp) const;
|
||||
};
|
||||
|
||||
|
||||
@@ -11,8 +11,12 @@ add_pim_library(OMONNXToSpatial
|
||||
Patterns/NN/Pool.cpp
|
||||
Patterns/NN/Relu.cpp
|
||||
Patterns/NN/Sigmoid.cpp
|
||||
Patterns/NN/Softmax.cpp
|
||||
Patterns/Tensor/Concat.cpp
|
||||
Patterns/Tensor/Gather.cpp
|
||||
Patterns/Tensor/Resize.cpp
|
||||
Patterns/Tensor/Reshape.cpp
|
||||
Patterns/Tensor/Split.cpp
|
||||
ONNXToSpatialPass.cpp
|
||||
Common.cpp
|
||||
|
||||
|
||||
@@ -89,9 +89,12 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
target.addIllegalOp<ONNXSigmoidOp>();
|
||||
target.addIllegalOp<ONNXSoftmaxOp>();
|
||||
target.addIllegalOp<ONNXConcatOp>();
|
||||
target.addIllegalOp<ONNXGatherOp>();
|
||||
target.addIllegalOp<ONNXReshapeOp>();
|
||||
target.addIllegalOp<ONNXResizeOp>();
|
||||
target.addIllegalOp<ONNXLRNOp>();
|
||||
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
||||
target.addIllegalOp<ONNXSplitOp>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<removeLRN>(ctx);
|
||||
@@ -103,8 +106,12 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
populateReduceMeanPatterns(patterns, ctx);
|
||||
populateReluPatterns(patterns, ctx);
|
||||
populateSigmoidPatterns(patterns, ctx);
|
||||
populateSoftmaxPatterns(patterns, ctx);
|
||||
populateConcatPatterns(patterns, ctx);
|
||||
populateGatherPatterns(patterns, ctx);
|
||||
populateResizePatterns(patterns, ctx);
|
||||
populateReshapePatterns(patterns, ctx);
|
||||
populateSplitPatterns(patterns, ctx);
|
||||
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
|
||||
@@ -21,8 +21,16 @@ void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext*
|
||||
|
||||
void populateSigmoidPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateSoftmaxPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
111
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp
Normal file
111
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp
Normal file
@@ -0,0 +1,111 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
||||
|
||||
static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> permutedShape;
|
||||
permutedShape.reserve(permutation.size());
|
||||
for (int64_t axis : permutation)
|
||||
permutedShape.push_back(shape[axis]);
|
||||
return permutedShape;
|
||||
}
|
||||
|
||||
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
|
||||
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value buildSoftmax(Value input,
|
||||
int64_t softmaxAxis,
|
||||
int64_t axis,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (axis == inputType.getRank())
|
||||
return createSoftmaxCompute(input, rewriter, loc);
|
||||
|
||||
if (axis == softmaxAxis)
|
||||
return buildSoftmax(input, softmaxAxis, axis + 1, rewriter, loc);
|
||||
|
||||
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
|
||||
SmallVector<Value> rebuiltSlices;
|
||||
rebuiltSlices.reserve(slices.size());
|
||||
for (Value slice : slices)
|
||||
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
||||
|
||||
return rebuiltSlices.size() == 1 ? rebuiltSlices.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, axis, rebuiltSlices).getResult();
|
||||
}
|
||||
|
||||
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXSoftmaxOp softmaxOp,
|
||||
ONNXSoftmaxOpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getInput().getType());
|
||||
if (!inputType || !inputType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
int64_t axis = normalizeAxis(softmaxOp.getAxis(), inputType.getRank());
|
||||
if (axis < 0 || axis >= inputType.getRank())
|
||||
return failure();
|
||||
|
||||
Value input = adaptor.getInput();
|
||||
Value result;
|
||||
if (axis == inputType.getRank() - 1) {
|
||||
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
} else {
|
||||
SmallVector<int64_t> permutation;
|
||||
permutation.reserve(inputType.getRank());
|
||||
for (int64_t dim = 0; dim < inputType.getRank(); ++dim)
|
||||
if (dim != axis)
|
||||
permutation.push_back(dim);
|
||||
permutation.push_back(axis);
|
||||
|
||||
SmallVector<int64_t> inversePermutation(inputType.getRank());
|
||||
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
|
||||
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
|
||||
|
||||
auto transposedType = RankedTensorType::get(
|
||||
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
|
||||
auto preTransposeCompute = createSpatCompute<1>(
|
||||
rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation));
|
||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||
});
|
||||
Value transposedInput = preTransposeCompute.getResult(0);
|
||||
Value transposedResult = buildSoftmax(transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
result = ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(softmaxOp, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateSoftmaxPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<SoftmaxToSpatialCompute>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
152
src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp
Normal file
152
src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp
Normal file
@@ -0,0 +1,152 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
||||
|
||||
static int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
|
||||
|
||||
static Value extractSliceAt(Value input,
|
||||
int64_t axis,
|
||||
int64_t offset,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
|
||||
sizes.reserve(inputType.getRank());
|
||||
for (int64_t dim : inputType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(1);
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
static Value concatGatherSlices(Value data,
|
||||
int64_t axis,
|
||||
ArrayRef<int64_t> indices,
|
||||
int64_t axisDim,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
SmallVector<Value> slices;
|
||||
slices.reserve(indices.size());
|
||||
for (int64_t index : indices) {
|
||||
int64_t normalizedIndex = normalizeIndex(index, axisDim);
|
||||
if (normalizedIndex < 0 || normalizedIndex >= axisDim)
|
||||
return {};
|
||||
slices.push_back(extractSliceAt(data, axis, normalizedIndex, rewriter, loc));
|
||||
}
|
||||
if (slices.empty())
|
||||
return {};
|
||||
return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult();
|
||||
}
|
||||
|
||||
static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto valueType = cast<RankedTensorType>(value.getType());
|
||||
SmallVector<int64_t> resultShape;
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
resultShape.reserve(valueType.getRank() + 1);
|
||||
reassociation.reserve(valueType.getRank());
|
||||
|
||||
int64_t resultDim = 0;
|
||||
for (int64_t dim = 0; dim < valueType.getRank(); ++dim) {
|
||||
if (dim == axis) {
|
||||
resultShape.push_back(1);
|
||||
resultShape.push_back(valueType.getShape()[dim]);
|
||||
reassociation.push_back({static_cast<int64_t>(resultDim), static_cast<int64_t>(resultDim + 1)});
|
||||
resultDim += 2;
|
||||
continue;
|
||||
}
|
||||
resultShape.push_back(valueType.getShape()[dim]);
|
||||
reassociation.push_back({static_cast<int64_t>(resultDim)});
|
||||
resultDim++;
|
||||
}
|
||||
|
||||
auto resultType = RankedTensorType::get(resultShape, valueType.getElementType(), valueType.getEncoding());
|
||||
return tensor::ExpandShapeOp::create(rewriter, loc, resultType, value, reassociation);
|
||||
}
|
||||
|
||||
struct Gather : OpConversionPattern<ONNXGatherOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXGatherOp gatherOp,
|
||||
ONNXGatherOpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
auto dataType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
|
||||
auto indicesType = dyn_cast<RankedTensorType>(adaptor.getIndices().getType());
|
||||
if (!dataType || !indicesType || !dataType.hasStaticShape() || !indicesType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto indicesConst = adaptor.getIndices().getDefiningOp<arith::ConstantOp>();
|
||||
if (!indicesConst)
|
||||
return failure();
|
||||
auto indicesAttr = dyn_cast<DenseIntElementsAttr>(indicesConst.getValue());
|
||||
if (!indicesAttr)
|
||||
return failure();
|
||||
|
||||
int64_t rank = dataType.getRank();
|
||||
int64_t axis = normalizeAxis(gatherOp.getAxis(), rank);
|
||||
if (axis < 0 || axis >= rank)
|
||||
return failure();
|
||||
|
||||
int64_t axisDim = dataType.getShape()[axis];
|
||||
if (axisDim <= 0)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> flatIndices(indicesAttr.getValues<int64_t>().begin(), indicesAttr.getValues<int64_t>().end());
|
||||
Location loc = gatherOp.getLoc();
|
||||
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, loc, TypeRange {gatherOp.getResult().getType()}, {}, adaptor.getData(), [&](Value data) -> LogicalResult {
|
||||
Value result;
|
||||
if (indicesType.getRank() == 1) {
|
||||
result = concatGatherSlices(data, axis, flatIndices, axisDim, rewriter, loc);
|
||||
} else if (indicesType.getRank() == 2) {
|
||||
int64_t rowCount = indicesType.getShape()[0];
|
||||
int64_t rowWidth = indicesType.getShape()[1];
|
||||
SmallVector<Value> rows;
|
||||
rows.reserve(rowCount);
|
||||
for (int64_t row = 0; row < rowCount; ++row) {
|
||||
ArrayRef<int64_t> rowIndices(flatIndices.data() + row * rowWidth, rowWidth);
|
||||
Value gatheredRow = concatGatherSlices(data, axis, rowIndices, axisDim, rewriter, loc);
|
||||
if (!gatheredRow)
|
||||
return failure();
|
||||
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
|
||||
}
|
||||
result =
|
||||
rows.size() == 1 ? rows.front() : tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult();
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!result)
|
||||
return failure();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||
return success();
|
||||
});
|
||||
if (failed(computeOp))
|
||||
return failure();
|
||||
rewriter.replaceOp(gatherOp, computeOp->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateGatherPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<Gather>(ctx); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
93
src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp
Normal file
93
src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp
Normal file
@@ -0,0 +1,93 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static Value extractSliceAt(Value input,
|
||||
int64_t axis,
|
||||
int64_t offset,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
|
||||
sizes.reserve(inputType.getRank());
|
||||
for (int64_t dim : inputType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(1);
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
static int64_t nearestAsymmetricIndex(int64_t outputIndex, int64_t inputDim, int64_t outputDim) {
|
||||
return std::min<int64_t>((outputIndex * inputDim) / outputDim, inputDim - 1);
|
||||
}
|
||||
|
||||
static Value buildNearestResize(Value input,
|
||||
ArrayRef<int64_t> inputShape,
|
||||
ArrayRef<int64_t> outputShape,
|
||||
int64_t axis,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (axis == static_cast<int64_t>(outputShape.size()))
|
||||
return input;
|
||||
|
||||
SmallVector<Value> slices;
|
||||
slices.reserve(outputShape[axis]);
|
||||
for (int64_t outputIndex = 0; outputIndex < outputShape[axis]; ++outputIndex) {
|
||||
int64_t inputIndex = nearestAsymmetricIndex(outputIndex, inputShape[axis], outputShape[axis]);
|
||||
Value slice = extractSliceAt(input, axis, inputIndex, rewriter, loc);
|
||||
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
|
||||
}
|
||||
|
||||
return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult();
|
||||
}
|
||||
|
||||
struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXResizeOp resizeOp,
|
||||
ONNXResizeOpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getX().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(resizeOp.getY().getType());
|
||||
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
if (resizeOp.getMode() != "nearest"
|
||||
|| resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
||||
|| resizeOp.getNearestMode() != "floor")
|
||||
return failure();
|
||||
|
||||
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|
||||
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
||||
return failure();
|
||||
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
|
||||
Value result = buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc());
|
||||
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result);
|
||||
});
|
||||
rewriter.replaceOp(resizeOp, computeOp.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateResizePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<Resize>(ctx); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
75
src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp
Normal file
75
src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp
Normal file
@@ -0,0 +1,75 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
||||
|
||||
static Value extractSliceAt(Value input,
|
||||
int64_t axis,
|
||||
int64_t offset,
|
||||
int64_t size,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
|
||||
sizes.reserve(inputType.getRank());
|
||||
for (int64_t dim : inputType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(size);
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
struct Split : OpConversionPattern<ONNXSplitOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXSplitOp splitOp,
|
||||
ONNXSplitOpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getInput().getType());
|
||||
if (!inputType || !inputType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
int64_t rank = inputType.getRank();
|
||||
int64_t axis = normalizeAxis(splitOp.getAxis(), rank);
|
||||
if (axis < 0 || axis >= rank)
|
||||
return failure();
|
||||
|
||||
SmallVector<Value> outputs;
|
||||
outputs.reserve(splitOp.getNumResults());
|
||||
|
||||
int64_t offset = 0;
|
||||
for (Value result : splitOp.getResults()) {
|
||||
auto resultType = dyn_cast<RankedTensorType>(result.getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
int64_t sliceSize = resultType.getShape()[axis];
|
||||
auto computeOp =
|
||||
createSpatCompute<1>(rewriter, splitOp.getLoc(), TypeRange {resultType}, {}, adaptor.getInput(), [&](Value x) {
|
||||
Value output = extractSliceAt(x, axis, offset, sliceSize, rewriter, splitOp.getLoc());
|
||||
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), output);
|
||||
});
|
||||
outputs.push_back(computeOp.getResult(0));
|
||||
offset += sliceSize;
|
||||
}
|
||||
|
||||
rewriter.replaceOp(splitOp, outputs);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateSplitPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<Split>(ctx); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -63,4 +63,10 @@ def spatToPimVSigm : Pat<
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVSoftmax : Pat<
|
||||
(SpatSoftmaxOp:$srcOpRes $input),
|
||||
(PimVSoftmaxOp $input,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
#endif // SPATIAL_TO_PIM
|
||||
|
||||
@@ -618,17 +618,22 @@ void SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||
}
|
||||
|
||||
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||
for (auto it : llvm::enumerate(returnOp.getOperands())) {
|
||||
Operation* returnOperand = it.value().getDefiningOp();
|
||||
|
||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||
for (auto it : llvm::enumerate(originalOperands)) {
|
||||
size_t orderWithinReturn = it.index();
|
||||
Operation* returnOperand = it.value().getDefiningOp();
|
||||
|
||||
rewriter.modifyOpInPlace(returnOp,
|
||||
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
|
||||
|
||||
Operation* opToErase = returnOperand;
|
||||
while (opToErase) {
|
||||
bool isExclusivelyOwnedByReturnChain = opToErase->use_empty() || opToErase->hasOneUse();
|
||||
bool isExclusivelyOwnedByReturnChain = opToErase->use_empty();
|
||||
if (!isExclusivelyOwnedByReturnChain && opToErase->hasOneUse()) {
|
||||
Operation* onlyUser = *opToErase->getUsers().begin();
|
||||
isExclusivelyOwnedByReturnChain =
|
||||
isa<func::ReturnOp, tensor::ConcatOp>(onlyUser) || isChannelUseChainOp(onlyUser);
|
||||
}
|
||||
if (!isExclusivelyOwnedByReturnChain)
|
||||
break;
|
||||
|
||||
|
||||
@@ -455,4 +455,27 @@ def PimVSigmOp : PimOp<"vsigm", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
}
|
||||
|
||||
def PimVSoftmaxOp : PimOp<"vsoftmax", [DestinationStyleOpInterface]> {
|
||||
let summary = "Softmax over the full input vector";
|
||||
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
PimTensor:$output
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
||||
return getOutputBufferMutable();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // PIM_DIALECT_H
|
||||
|
||||
@@ -273,6 +273,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
PimVReluOp::attachInterface<UnaryDstOpInterface<PimVReluOp>>(*ctx);
|
||||
PimVTanhOp::attachInterface<UnaryDstOpInterface<PimVTanhOp>>(*ctx);
|
||||
PimVSigmOp::attachInterface<UnaryDstOpInterface<PimVSigmOp>>(*ctx);
|
||||
PimVSoftmaxOp::attachInterface<UnaryDstOpInterface<PimVSoftmaxOp>>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -272,6 +272,22 @@ def SpatSigmoidOp : SpatOp<"sigmoid", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatSoftmaxOp : SpatOp<"softmax", []> {
|
||||
let summary = "Softmax over the full input tensor slice";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatReluOp : SpatOp<"relu", []> {
|
||||
let summary = "Element-wise ReLU activation";
|
||||
|
||||
|
||||
@@ -202,9 +202,9 @@ private:
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
|
||||
for (auto users : oldWeightedCompute->getUsers())
|
||||
if (auto funcRet = dyn_cast<func::ReturnOp>(users))
|
||||
funcRet.setOperand(0, newWeightedCompute.getResult(0));
|
||||
for (auto& use : llvm::make_early_inc_range(oldWeightedCompute->getUses()))
|
||||
if (isa<func::ReturnOp>(use.getOwner()))
|
||||
use.assign(newWeightedCompute.getResult(0));
|
||||
|
||||
oldToNewComputeMap.insert({oldWeightedCompute, newWeightedCompute});
|
||||
return {cast<SpatWeightedCompute>(newWeightedCompute), computeValueResults};
|
||||
|
||||
@@ -146,6 +146,37 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
|
||||
}
|
||||
};
|
||||
|
||||
struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDevToHostOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyDevToHostOp copyOp, PatternRewriter& rewriter) const override {
|
||||
auto status = rewriteSubviewCopyLikeOp(
|
||||
copyOp,
|
||||
copyOp.getHostTarget(),
|
||||
copyOp.getDeviceSource(),
|
||||
copyOp.getHostTargetOffset(),
|
||||
copyOp.getDeviceSourceOffset(),
|
||||
copyOp.getSize(),
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
copyOp.getLoc(),
|
||||
resultType,
|
||||
dst,
|
||||
src,
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
});
|
||||
if (failed(status))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getHostTarget());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -209,7 +240,10 @@ struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp
|
||||
} // namespace
|
||||
|
||||
void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<RewriteCoreSubviewCopyPattern, RewriteHostSubviewLoadPattern, FoldConstantCoreSubviewPattern>(
|
||||
patterns.add<RewriteCoreSubviewCopyPattern,
|
||||
RewriteHostSubviewLoadPattern,
|
||||
RewriteHostSubviewStorePattern,
|
||||
FoldConstantCoreSubviewPattern>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user