huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
@@ -2,7 +2,8 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -32,6 +33,24 @@ static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewrit
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value
|
||||
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
@@ -47,7 +66,7 @@ buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRe
|
||||
for (Value slice : slices)
|
||||
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
||||
|
||||
return createSpatConcat(rewriter, loc, axis, rebuiltSlices);
|
||||
return concatValues(rebuiltSlices, axis, rewriter, loc);
|
||||
}
|
||||
|
||||
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
@@ -92,8 +111,13 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
Value transposedInput = preTransposeCompute.getResult(0);
|
||||
Value transposedResult = buildSoftmax(
|
||||
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
result = ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation));
|
||||
auto postTransposeCompute =
|
||||
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
|
||||
Value transposed = ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation));
|
||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||
});
|
||||
result = postTransposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(softmaxOp, result);
|
||||
|
||||
Reference in New Issue
Block a user