huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
@@ -5,9 +5,9 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -49,6 +50,45 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr
|
||||
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
||||
}
|
||||
|
||||
static Value transposeForSpatial(Value value,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<int64_t> permutation,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (isHostFoldableValue(value))
|
||||
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value
|
||||
expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (isHostFoldableValue(value))
|
||||
return tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
resultType,
|
||||
value,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
resultType,
|
||||
input,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -81,6 +121,11 @@ static SmallVector<Value> materializeBatchRowSlices(Value matrix,
|
||||
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
|
||||
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
|
||||
|
||||
if (isHostFoldableValue(matrix)) {
|
||||
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix);
|
||||
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
||||
}
|
||||
|
||||
auto buildRowSlices = [&](Value matrixArg) {
|
||||
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg);
|
||||
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
||||
@@ -122,7 +167,8 @@ static SmallVector<Value> materializeBatchRowSlices(Value matrix,
|
||||
rootValue = definingOp->getOperand(0);
|
||||
}
|
||||
|
||||
return buildRowSlices(matrix);
|
||||
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
|
||||
return cloneBatchInputChainIntoSliceCompute(rootValue, reversedChainOps, rootValue);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -175,13 +221,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
c,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
c = expandRankOneBias(c, expandedType, rewriter, loc);
|
||||
cType = expandedType;
|
||||
}
|
||||
if (!cType.hasStaticShape()) {
|
||||
@@ -196,25 +236,18 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
}
|
||||
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
||||
SmallVector<Value> cSlices;
|
||||
if (hasC && cHasNumOutRows)
|
||||
cSlices = materializeBatchRowSlices(c, cType, rewriter, loc);
|
||||
|
||||
SmallVector<Value> gemvOps;
|
||||
gemvOps.reserve(numOutRows);
|
||||
gemvOps.reserve(static_cast<size_t>(numOutRows));
|
||||
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||
auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
|
||||
|
||||
Value cSlice = c;
|
||||
if (hasC) {
|
||||
if (cHasNumOutRows) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
||||
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
|
||||
}
|
||||
if (cHasNumOutRows)
|
||||
cSlice = cSlices[static_cast<size_t>(rowIdx)];
|
||||
else if (!isVectorShape(getTensorShape(c))) {
|
||||
gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows");
|
||||
return failure();
|
||||
@@ -224,7 +257,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto gemvOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
outRowType,
|
||||
aSlice,
|
||||
aSlices[static_cast<size_t>(rowIdx)],
|
||||
b,
|
||||
cSlice,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
@@ -267,13 +300,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter,
|
||||
gemmLoc,
|
||||
expandedType,
|
||||
c,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
c = expandRankOneBias(c, expandedType, rewriter, gemmLoc);
|
||||
cType = expandedType;
|
||||
}
|
||||
if (!cType.hasStaticShape()) {
|
||||
@@ -305,13 +332,14 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
|
||||
if (transA) {
|
||||
auto aShape = aType.getShape();
|
||||
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
||||
a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
|
||||
auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType());
|
||||
a = transposeForSpatial(a, transposedType, {1, 0}, rewriter, gemmLoc);
|
||||
aType = cast<RankedTensorType>(a.getType());
|
||||
}
|
||||
if (transB) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
||||
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, gemmLoc);
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
|
||||
@@ -335,7 +363,6 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
|
||||
auto bNumVSlices = aNumHSlices;
|
||||
auto bLastVSliceSize = aLastHSliceSize;
|
||||
auto cNumHSlices = bNumHSlices;
|
||||
auto cLastHSliceSize = bLastHSliceSize;
|
||||
auto outNumHSlices = cNumHSlices;
|
||||
@@ -469,12 +496,15 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
|
||||
if (gemmOpAdaptor.getTransB()) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
||||
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc);
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
(void) bType;
|
||||
|
||||
if (!isHostFoldableValue(b))
|
||||
return failure();
|
||||
|
||||
Value sharedBias;
|
||||
if (hasC) {
|
||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||
@@ -484,13 +514,7 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto cType = cast<RankedTensorType>(c.getType());
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
c,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
c = expandRankOneBias(c, expandedType, rewriter, loc);
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
}
|
||||
if (!cType.hasStaticShape()) {
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -36,49 +36,27 @@ static Value extractBatchMatrix(Value value,
|
||||
SmallVector<OpFoldResult> sizes = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, value, offsets, sizes, strides);
|
||||
|
||||
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
matrixType,
|
||||
slice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
}
|
||||
auto buildMatrix = [&](Value input) -> Value {
|
||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
matrixType,
|
||||
slice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
};
|
||||
|
||||
static bool isConstantLikeOperand(Value value) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
if (isHostFoldableValue(value))
|
||||
return buildMatrix(value);
|
||||
|
||||
while (auto* definingOp = value.getDefiningOp()) {
|
||||
if (!visited.insert(definingOp).second)
|
||||
return false;
|
||||
if (definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
return true;
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
value = extractSliceOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
value = expandShapeOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseShapeOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||
value = transposeOp.getData();
|
||||
continue;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
auto batchMatrixCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
|
||||
});
|
||||
return batchMatrixCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -107,15 +85,31 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@@ -157,7 +151,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
}
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
bool useTransposedForm = isConstantLikeOperand(matmulOp.getA()) && !isConstantLikeOperand(matmulOp.getB());
|
||||
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
|
||||
|
||||
Value lhs = matmulOp.getA();
|
||||
Value rhs = matmulOp.getB();
|
||||
@@ -193,8 +187,14 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
if (useTransposedForm)
|
||||
gemmResult = ONNXTransposeOp::create(rewriter, loc, outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}));
|
||||
if (useTransposedForm) {
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0}));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
gemmResult = transposeCompute.getResult(0);
|
||||
}
|
||||
rewriter.replaceOp(matmulOp, gemmResult);
|
||||
return success();
|
||||
}
|
||||
@@ -215,24 +215,30 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
if (useTransposedForm)
|
||||
gemmResult = ONNXTransposeOp::create(
|
||||
rewriter,
|
||||
loc,
|
||||
RankedTensorType::get({m, n}, outType.getElementType()),
|
||||
gemmResult,
|
||||
rewriter.getI64ArrayAttr({1, 0}));
|
||||
batchResults.push_back(tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
batchedOutType,
|
||||
gemmResult,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
}));
|
||||
auto batchResultCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) {
|
||||
Value resultMatrix = input;
|
||||
if (useTransposedForm) {
|
||||
resultMatrix = ONNXTransposeOp::create(rewriter,
|
||||
loc,
|
||||
RankedTensorType::get({m, n}, outType.getElementType()),
|
||||
input,
|
||||
rewriter.getI64ArrayAttr({1, 0}));
|
||||
}
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
batchedOutType,
|
||||
resultMatrix,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
});
|
||||
batchResults.push_back(batchResultCompute.getResult(0));
|
||||
}
|
||||
|
||||
Value result = createSpatConcat(rewriter, loc, /*axis=*/0, batchResults);
|
||||
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -81,6 +82,24 @@ createAverageCompute(Value input, RankedTensorType resultType, ConversionPattern
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value buildReduceMeanKeepdims(Value input,
|
||||
ArrayRef<bool> reducedAxes,
|
||||
int64_t axis,
|
||||
@@ -100,7 +119,7 @@ static Value buildReduceMeanKeepdims(Value input,
|
||||
for (Value slice : slices)
|
||||
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
|
||||
|
||||
return createSpatConcat(rewriter, loc, axis, reducedSlices);
|
||||
return concatValues(reducedSlices, axis, rewriter, loc);
|
||||
}
|
||||
|
||||
static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
@@ -115,9 +134,16 @@ static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
|
||||
}
|
||||
|
||||
return tensor::CollapseShapeOp::create(
|
||||
rewriter, loc, resultType, keepdimsValue, buildCollapseReassociation(reducedAxes))
|
||||
.getResult();
|
||||
auto reassociation = buildCollapseReassociation(reducedAxes);
|
||||
if (isHostFoldableValue(keepdimsValue))
|
||||
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
|
||||
|
||||
auto squeezeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, ValueRange {keepdimsValue}, [&](Value input) {
|
||||
Value collapsed = tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, collapsed);
|
||||
});
|
||||
return squeezeCompute.getResult(0);
|
||||
}
|
||||
|
||||
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||
|
||||
Reference in New Issue
Block a user