huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -49,6 +50,45 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr
|
||||
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
||||
}
|
||||
|
||||
static Value transposeForSpatial(Value value,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<int64_t> permutation,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (isHostFoldableValue(value))
|
||||
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value
|
||||
expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (isHostFoldableValue(value))
|
||||
return tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
resultType,
|
||||
value,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
resultType,
|
||||
input,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -81,6 +121,11 @@ static SmallVector<Value> materializeBatchRowSlices(Value matrix,
|
||||
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
|
||||
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
|
||||
|
||||
if (isHostFoldableValue(matrix)) {
|
||||
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix);
|
||||
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
||||
}
|
||||
|
||||
auto buildRowSlices = [&](Value matrixArg) {
|
||||
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg);
|
||||
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
||||
@@ -122,7 +167,8 @@ static SmallVector<Value> materializeBatchRowSlices(Value matrix,
|
||||
rootValue = definingOp->getOperand(0);
|
||||
}
|
||||
|
||||
return buildRowSlices(matrix);
|
||||
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
|
||||
return cloneBatchInputChainIntoSliceCompute(rootValue, reversedChainOps, rootValue);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -175,13 +221,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
c,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
c = expandRankOneBias(c, expandedType, rewriter, loc);
|
||||
cType = expandedType;
|
||||
}
|
||||
if (!cType.hasStaticShape()) {
|
||||
@@ -196,25 +236,18 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
}
|
||||
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
||||
SmallVector<Value> cSlices;
|
||||
if (hasC && cHasNumOutRows)
|
||||
cSlices = materializeBatchRowSlices(c, cType, rewriter, loc);
|
||||
|
||||
SmallVector<Value> gemvOps;
|
||||
gemvOps.reserve(numOutRows);
|
||||
gemvOps.reserve(static_cast<size_t>(numOutRows));
|
||||
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||
auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
|
||||
|
||||
Value cSlice = c;
|
||||
if (hasC) {
|
||||
if (cHasNumOutRows) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
||||
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
|
||||
}
|
||||
if (cHasNumOutRows)
|
||||
cSlice = cSlices[static_cast<size_t>(rowIdx)];
|
||||
else if (!isVectorShape(getTensorShape(c))) {
|
||||
gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows");
|
||||
return failure();
|
||||
@@ -224,7 +257,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto gemvOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
outRowType,
|
||||
aSlice,
|
||||
aSlices[static_cast<size_t>(rowIdx)],
|
||||
b,
|
||||
cSlice,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
@@ -267,13 +300,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter,
|
||||
gemmLoc,
|
||||
expandedType,
|
||||
c,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
c = expandRankOneBias(c, expandedType, rewriter, gemmLoc);
|
||||
cType = expandedType;
|
||||
}
|
||||
if (!cType.hasStaticShape()) {
|
||||
@@ -305,13 +332,14 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
|
||||
if (transA) {
|
||||
auto aShape = aType.getShape();
|
||||
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
||||
a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
|
||||
auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType());
|
||||
a = transposeForSpatial(a, transposedType, {1, 0}, rewriter, gemmLoc);
|
||||
aType = cast<RankedTensorType>(a.getType());
|
||||
}
|
||||
if (transB) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
||||
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, gemmLoc);
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
|
||||
@@ -335,7 +363,6 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
|
||||
auto bNumVSlices = aNumHSlices;
|
||||
auto bLastVSliceSize = aLastHSliceSize;
|
||||
auto cNumHSlices = bNumHSlices;
|
||||
auto cLastHSliceSize = bLastHSliceSize;
|
||||
auto outNumHSlices = cNumHSlices;
|
||||
@@ -469,12 +496,15 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
|
||||
if (gemmOpAdaptor.getTransB()) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
||||
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc);
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
(void) bType;
|
||||
|
||||
if (!isHostFoldableValue(b))
|
||||
return failure();
|
||||
|
||||
Value sharedBias;
|
||||
if (hasC) {
|
||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||
@@ -484,13 +514,7 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto cType = cast<RankedTensorType>(c.getType());
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
c,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
c = expandRankOneBias(c, expandedType, rewriter, loc);
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
}
|
||||
if (!cType.hasStaticShape()) {
|
||||
|
||||
Reference in New Issue
Block a user