fix remaining failing tests
Validate Operations / validate-operations (push) Has been cancelled

remove unsupported tests
This commit is contained in:
NiccoloN
2026-06-05 15:27:11 +02:00
parent 0fa10b4074
commit a34ac223c0
9 changed files with 385 additions and 192 deletions
@@ -690,11 +690,6 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC();
if (gemmOpAdaptor.getTransA()) {
gemmOp.emitOpError("requires transA=false before tiled Spatial Gemm lowering");
return failure();
}
auto aType = dyn_cast<RankedTensorType>(a.getType());
auto bType = dyn_cast<RankedTensorType>(b.getType());
auto outType = dyn_cast<RankedTensorType>(gemmOp.getY().getType());
@@ -725,9 +720,12 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
return failure();
}
const int64_t numOutRows = outType.getDimSize(0);
const int64_t numOutCols = outType.getDimSize(1);
const int64_t reductionSize = aType.getDimSize(1);
if (gemmOpAdaptor.getTransA()) {
auto aShape = aType.getShape();
auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType(), aType.getEncoding());
a = ONNXTransposeOp::create(rewriter, loc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})).getResult();
aType = transposedType;
}
if (gemmOpAdaptor.getTransB()) {
auto bShape = bType.getShape();
@@ -736,6 +734,10 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
bType = transposedType;
}
const int64_t numOutRows = outType.getDimSize(0);
const int64_t numOutCols = outType.getDimSize(1);
const int64_t reductionSize = aType.getDimSize(1);
if (!isCompileTimeComputable(b)) {
bool hasC = hasGemmBias(c);
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
@@ -22,13 +22,87 @@ namespace {
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
ArrayRef<int64_t> rhsBatchShape) {
if (lhsBatchShape.empty())
return SmallVector<int64_t>(rhsBatchShape.begin(), rhsBatchShape.end());
if (rhsBatchShape.empty())
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
if (!llvm::equal(lhsBatchShape, rhsBatchShape))
return failure();
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
const int64_t resultRank = std::max<int64_t>(lhsBatchShape.size(), rhsBatchShape.size());
SmallVector<int64_t> resultShape(resultRank, 1);
for (int64_t resultIndex = resultRank - 1, lhsIndex = lhsBatchShape.size() - 1, rhsIndex = rhsBatchShape.size() - 1;
resultIndex >= 0;
--resultIndex, --lhsIndex, --rhsIndex) {
const int64_t lhsDim = lhsIndex >= 0 ? lhsBatchShape[lhsIndex] : 1;
const int64_t rhsDim = rhsIndex >= 0 ? rhsBatchShape[rhsIndex] : 1;
if (lhsDim != rhsDim && lhsDim != 1 && rhsDim != 1)
return failure();
resultShape[resultIndex] = std::max(lhsDim, rhsDim);
}
return resultShape;
}
static int64_t mapStaticBroadcastedBatchIndex(int64_t outputBatchIndex,
ArrayRef<int64_t> sourceBatchShape,
ArrayRef<int64_t> outputBatchShape) {
if (sourceBatchShape.empty() || getStaticShapeElementCount(sourceBatchShape) == 1)
return 0;
if (llvm::equal(sourceBatchShape, outputBatchShape))
return outputBatchIndex;
SmallVector<int64_t> outputStrides = computeRowMajorStrides(outputBatchShape);
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceBatchShape);
int64_t sourceFlatIndex = 0;
for (int64_t sourceDimIndex = 0; sourceDimIndex < static_cast<int64_t>(sourceBatchShape.size()); ++sourceDimIndex) {
if (sourceBatchShape[sourceDimIndex] == 1)
continue;
const int64_t outputDimIndex = outputBatchShape.size() - sourceBatchShape.size() + sourceDimIndex;
const int64_t outputDimStride = outputStrides.empty() ? 1 : outputStrides[outputDimIndex];
const int64_t outputDimIndexValue = outputDimStride == 1
? outputBatchIndex % outputBatchShape[outputDimIndex]
: (outputBatchIndex / outputDimStride) % outputBatchShape[outputDimIndex];
sourceFlatIndex += outputDimIndexValue * sourceStrides[sourceDimIndex];
}
return sourceFlatIndex;
}
static Value computeFlatBatchIndexCoordinate(
Value flatBatchIndex, ArrayRef<int64_t> batchShape, int64_t dimIndex, PatternRewriter& rewriter, Location loc) {
if (batchShape[dimIndex] == 1)
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
const int64_t dimStride = dimIndex + 1 == static_cast<int64_t>(batchShape.size())
? 1
: getStaticShapeElementCount(batchShape.drop_front(dimIndex + 1));
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value dimCoordinate = flatBatchIndex;
if (dimStride != 1)
dimCoordinate = affineFloorDivConst(rewriter, loc, dimCoordinate, dimStride, anchorOp);
return affineModConst(rewriter, loc, dimCoordinate, batchShape[dimIndex], anchorOp);
}
static Value mapOutputBatchIndexToSourceBatchIndex(Value outputBatchIndex,
ArrayRef<int64_t> sourceBatchShape,
ArrayRef<int64_t> outputBatchShape,
PatternRewriter& rewriter,
Location loc) {
if (sourceBatchShape.empty() || getStaticShapeElementCount(sourceBatchShape) == 1)
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
if (llvm::equal(sourceBatchShape, outputBatchShape))
return outputBatchIndex;
Value sourceBatchIndex = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceBatchShape);
for (int64_t sourceDimIndex = 0; sourceDimIndex < static_cast<int64_t>(sourceBatchShape.size()); ++sourceDimIndex) {
if (sourceBatchShape[sourceDimIndex] == 1)
continue;
const int64_t outputDimIndex = outputBatchShape.size() - sourceBatchShape.size() + sourceDimIndex;
Value outputCoordinate =
computeFlatBatchIndexCoordinate(outputBatchIndex, outputBatchShape, outputDimIndex, rewriter, loc);
Value contribution = sourceStrides[sourceDimIndex] == 1
? outputCoordinate
: affineMulConst(rewriter,
loc,
outputCoordinate,
sourceStrides[sourceDimIndex],
rewriter.getInsertionBlock()->getParentOp());
sourceBatchIndex = arith::AddIOp::create(rewriter, loc, sourceBatchIndex, contribution);
}
return sourceBatchIndex;
}
static Value
@@ -67,6 +141,52 @@ expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, Patt
return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded);
}
static Value createMatrixFromVector(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
auto buildExpanded = [&](Value input) -> Value {
return tensor::ExpandShapeOp::create(rewriter,
loc,
resultType,
input,
SmallVector<ReassociationIndices> {
{0, 1}
});
};
return materializeOrComputeUnary(value, resultType, rewriter, loc, buildExpanded);
}
static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> removedAxes) {
SmallVector<ReassociationIndices> reassociation;
ReassociationIndices currentGroup;
for (auto [axis, removeAxis] : llvm::enumerate(removedAxes)) {
currentGroup.push_back(axis);
if (!removeAxis) {
reassociation.push_back(currentGroup);
currentGroup.clear();
}
}
if (!currentGroup.empty()) {
if (reassociation.empty())
reassociation.push_back(std::move(currentGroup));
else
reassociation.back().append(currentGroup.begin(), currentGroup.end());
}
return reassociation;
}
static Value squeezeUnitDims(
Value value, RankedTensorType resultType, ArrayRef<bool> removedAxes, PatternRewriter& rewriter, Location loc) {
if (cast<RankedTensorType>(value.getType()) == resultType)
return value;
SmallVector<ReassociationIndices> reassociation =
resultType.getRank() == 0 ? SmallVector<ReassociationIndices> {} : buildCollapseReassociation(removedAxes);
auto buildCollapsed = [&](Value input) -> Value {
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation).getResult();
};
return materializeOrComputeUnary(value, resultType, rewriter, loc, buildCollapsed);
}
static Value ensureBatchedTensor(
Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
@@ -171,8 +291,11 @@ static Value createPaddedBatchedInputCompute(Value input,
return computeOp.getResult(0);
}
static FailureOr<Value> materializePaddedBatchedWeight(
Value value, int64_t sourceBatch, int64_t targetBatch, RankedTensorType resultType, PatternRewriter& rewriter) {
static FailureOr<Value> materializePaddedBatchedWeight(Value value,
ArrayRef<int64_t> sourceBatchShape,
ArrayRef<int64_t> targetBatchShape,
RankedTensorType resultType,
PatternRewriter& rewriter) {
auto sourceType = cast<RankedTensorType>(value.getType());
if (sourceType == resultType)
return value;
@@ -183,13 +306,15 @@ static FailureOr<Value> materializePaddedBatchedWeight(
const int64_t sourceRows = sourceType.getRank() == 2 ? sourceType.getDimSize(0) : sourceType.getDimSize(1);
const int64_t sourceCols = sourceType.getRank() == 2 ? sourceType.getDimSize(1) : sourceType.getDimSize(2);
const int64_t targetBatch = targetBatchShape.empty() ? 1 : getStaticShapeElementCount(targetBatchShape);
const int64_t targetRows = resultType.getDimSize(1);
const int64_t targetCols = resultType.getDimSize(2);
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> resultValues(resultType.getNumElements(), rewriter.getZeroAttr(resultType.getElementType()));
for (int64_t batchIdx = 0; batchIdx < targetBatch; ++batchIdx) {
const int64_t sourceBatchIdx = sourceType.getRank() == 2 ? 0 : (sourceBatch == 1 ? 0 : batchIdx);
const int64_t sourceBatchIdx =
sourceType.getRank() == 2 ? 0 : mapStaticBroadcastedBatchIndex(batchIdx, sourceBatchShape, targetBatchShape);
const int64_t sourceBatchBase = sourceType.getRank() == 2 ? 0 : sourceBatchIdx * sourceRows * sourceCols;
const int64_t targetBatchBase = batchIdx * targetRows * targetCols;
for (int64_t row = 0; row < sourceRows; ++row)
@@ -202,16 +327,18 @@ static FailureOr<Value> materializePaddedBatchedWeight(
}
static Value extractBatchedATile(Value a,
int64_t sourceBatchCount,
Value batch,
ArrayRef<int64_t> sourceBatchShape,
ArrayRef<int64_t> outputBatchShape,
Value outputBatchIndex,
Value row,
Value kOffset,
RankedTensorType aTileType,
PatternRewriter& rewriter,
Location loc) {
auto aSliceType = RankedTensorType::get({1, 1, aTileType.getDimSize(1)}, aTileType.getElementType());
SmallVector<OpFoldResult> offsets {
sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), row, kOffset};
Value sourceBatchIndex =
mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc);
SmallVector<OpFoldResult> offsets {OpFoldResult(sourceBatchIndex), row, kOffset};
SmallVector<OpFoldResult> sizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(aTileType.getDimSize(1))};
auto slice =
@@ -227,8 +354,9 @@ static Value extractBatchedATile(Value a,
}
static Value extractBatchedBTile(Value b,
int64_t sourceBatchCount,
Value batch,
ArrayRef<int64_t> sourceBatchShape,
ArrayRef<int64_t> outputBatchShape,
Value outputBatchIndex,
Value kOffset,
Value hOffset,
RankedTensorType bTileType,
@@ -236,8 +364,9 @@ static Value extractBatchedBTile(Value b,
Location loc) {
auto bSliceType =
RankedTensorType::get({1, bTileType.getDimSize(0), bTileType.getDimSize(1)}, bTileType.getElementType());
SmallVector<OpFoldResult> offsets {
sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), kOffset, hOffset};
Value sourceBatchIndex =
mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc);
SmallVector<OpFoldResult> offsets {OpFoldResult(sourceBatchIndex), kOffset, hOffset};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(bTileType.getDimSize(0)),
rewriter.getIndexAttr(bTileType.getDimSize(1))};
@@ -262,9 +391,10 @@ static Value getBatchLaneIndex(
static FailureOr<spatial::SpatComputeBatch> createBatchedVmmBatch(Value a,
Value b,
RankedTensorType aType,
int64_t aBatchCount,
ArrayRef<int64_t> aBatchShape,
RankedTensorType bType,
int64_t bBatchCount,
ArrayRef<int64_t> bBatchShape,
ArrayRef<int64_t> outputBatchShape,
RankedTensorType partialPiecesType,
int64_t numOutRows,
int64_t numKSlices,
@@ -298,10 +428,10 @@ static FailureOr<spatial::SpatComputeBatch> createBatchedVmmBatch(Value a,
auto pieceType =
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
Value aTile =
extractBatchedATile(args.inputs.front(), aBatchCount, batch, row, kOffset, aTileType, rewriter, loc);
Value bTile =
extractBatchedBTile(args.weights.front(), bBatchCount, batch, kOffset, hOffset, bTileType, rewriter, loc);
Value aTile = extractBatchedATile(
args.inputs.front(), aBatchShape, outputBatchShape, batch, row, kOffset, aTileType, rewriter, loc);
Value bTile = extractBatchedBTile(
args.weights.front(), bBatchShape, outputBatchShape, batch, kOffset, hOffset, bTileType, rewriter, loc);
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
@@ -315,17 +445,17 @@ static FailureOr<spatial::SpatComputeBatch> createBatchedVmmBatch(Value a,
}
static Value extractDynamicBatchedBColumn(Value matrix,
int64_t sourceBatchCount,
Value batch,
ArrayRef<int64_t> sourceBatchShape,
ArrayRef<int64_t> outputBatchShape,
Value outputBatchIndex,
Value column,
RankedTensorType vectorType,
PatternRewriter& rewriter,
Location loc) {
auto columnSliceType = RankedTensorType::get({1, vectorType.getDimSize(1), 1}, vectorType.getElementType());
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
: OpFoldResult(batch),
rewriter.getIndexAttr(0),
column};
Value sourceBatchIndex =
mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc);
SmallVector<OpFoldResult> offsets {OpFoldResult(sourceBatchIndex), rewriter.getIndexAttr(0), column};
SmallVector<OpFoldResult> sizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
@@ -350,17 +480,17 @@ static Value extractDynamicBatchedBColumn(Value matrix,
}
static Value extractDynamicBatchedRowVector(Value matrix,
int64_t sourceBatchCount,
Value batch,
ArrayRef<int64_t> sourceBatchShape,
ArrayRef<int64_t> outputBatchShape,
Value outputBatchIndex,
Value row,
RankedTensorType vectorType,
PatternRewriter& rewriter,
Location loc) {
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
: OpFoldResult(batch),
row,
rewriter.getIndexAttr(0)};
Value sourceBatchIndex =
mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc);
SmallVector<OpFoldResult> offsets {OpFoldResult(sourceBatchIndex), row, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
auto rowSlice =
@@ -376,9 +506,10 @@ static Value extractDynamicBatchedRowVector(Value matrix,
}
static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
int64_t aBatchCount,
ArrayRef<int64_t> aBatchShape,
Value b,
int64_t bBatchCount,
ArrayRef<int64_t> bBatchShape,
ArrayRef<int64_t> outputBatchShape,
RankedTensorType aType,
RankedTensorType bType,
RankedTensorType scalarPiecesType,
@@ -406,10 +537,10 @@ static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value aVector =
extractDynamicBatchedRowVector(args.inputs[0], aBatchCount, batch, row, vectorType, rewriter, loc);
Value bVector =
extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc);
Value aVector = extractDynamicBatchedRowVector(
args.inputs[0], aBatchShape, outputBatchShape, batch, row, vectorType, rewriter, loc);
Value bVector = extractDynamicBatchedBColumn(
args.inputs[1], bBatchShape, outputBatchShape, batch, column, vectorType, rewriter, loc);
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
@@ -629,11 +760,17 @@ static FailureOr<Value> createBatchedReductionCompute(Value partialPieces,
return computeOp->getResult(0);
}
struct MatMulShapeInfo {
struct NormalizedMatMulInfo {
RankedTensorType lhsType;
RankedTensorType rhsType;
RankedTensorType outType;
SmallVector<int64_t> batchShape;
RankedTensorType normalizedLhsType;
RankedTensorType normalizedRhsType;
SmallVector<int64_t> lhsBatchShape;
SmallVector<int64_t> rhsBatchShape;
SmallVector<int64_t> outputBatchShape;
bool lhsWasVector;
bool rhsWasVector;
int64_t lhsBatch;
int64_t rhsBatch;
int64_t batch;
@@ -642,46 +779,170 @@ struct MatMulShapeInfo {
int64_t n;
};
static FailureOr<MatMulShapeInfo> analyzeMatMulShape(ONNXMatMulOp matmulOp) {
struct MatMulLoweringPlan {
Value lhs;
Value rhs;
RankedTensorType lhsType;
RankedTensorType rhsType;
SmallVector<int64_t> lhsBatchShape;
SmallVector<int64_t> rhsBatchShape;
SmallVector<int64_t> outputBatchShape;
int64_t lhsBatch;
int64_t rhsBatch;
int64_t batch;
int64_t m;
int64_t k;
int64_t n;
bool transposedResult;
};
static SmallVector<int64_t> computeExpectedMatMulOutputShape(
ArrayRef<int64_t> batchShape, int64_t m, int64_t n, bool lhsWasVector, bool rhsWasVector) {
SmallVector<int64_t> shape(batchShape.begin(), batchShape.end());
if (lhsWasVector && rhsWasVector)
return shape;
if (lhsWasVector) {
shape.push_back(n);
return shape;
}
if (rhsWasVector) {
shape.push_back(m);
return shape;
}
shape.push_back(m);
shape.push_back(n);
return shape;
}
static FailureOr<NormalizedMatMulInfo> analyzeMatMulShape(ONNXMatMulOp matmulOp) {
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|| !outType.hasStaticShape())
return failure();
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
if (lhsType.getRank() < 1 || rhsType.getRank() < 1)
return failure();
if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType))
return failure();
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
if (failed(batchShape))
const bool lhsWasVector = lhsType.getRank() == 1;
const bool rhsWasVector = rhsType.getRank() == 1;
auto normalizedLhsType =
lhsWasVector ? RankedTensorType::get({1, lhsType.getDimSize(0)}, lhsType.getElementType(), lhsType.getEncoding())
: lhsType;
auto normalizedRhsType =
rhsWasVector ? RankedTensorType::get({rhsType.getDimSize(0), 1}, rhsType.getElementType(), rhsType.getEncoding())
: rhsType;
SmallVector<int64_t> lhsBatchShape(normalizedLhsType.getShape().begin(), normalizedLhsType.getShape().end() - 2);
SmallVector<int64_t> rhsBatchShape(normalizedRhsType.getShape().begin(), normalizedRhsType.getShape().end() - 2);
auto outputBatchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
if (failed(outputBatchShape))
return failure();
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
const int64_t batch = outputBatchShape->empty() ? 1 : getStaticShapeElementCount(*outputBatchShape);
const int64_t m = normalizedLhsType.getDimSize(normalizedLhsType.getRank() - 2);
const int64_t k = normalizedLhsType.getDimSize(normalizedLhsType.getRank() - 1);
const int64_t rhsK = normalizedRhsType.getDimSize(normalizedRhsType.getRank() - 2);
const int64_t n = normalizedRhsType.getDimSize(normalizedRhsType.getRank() - 1);
if (k != rhsK)
return failure();
if (outType.getRank() == 2) {
if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
return failure();
}
else {
SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|| outType.getDimSize(outType.getRank() - 1) != n)
return failure();
if (SmallVector<int64_t>(outType.getShape().begin(), outType.getShape().end())
!= computeExpectedMatMulOutputShape(*outputBatchShape, m, n, lhsWasVector, rhsWasVector)) {
return failure();
}
return MatMulShapeInfo {lhsType, rhsType, outType, *batchShape, lhsBatch, rhsBatch, batch, m, k, n};
return NormalizedMatMulInfo {lhsType,
rhsType,
outType,
normalizedLhsType,
normalizedRhsType,
lhsBatchShape,
rhsBatchShape,
*outputBatchShape,
lhsWasVector,
rhsWasVector,
lhsBatch,
rhsBatch,
batch,
m,
k,
n};
}
static MatMulLoweringPlan buildLoweringPlan(Value normalizedLhs,
Value normalizedRhs,
const NormalizedMatMulInfo& info,
bool useTransposedForm,
PatternRewriter& rewriter,
Location loc) {
MatMulLoweringPlan plan {normalizedLhs,
normalizedRhs,
cast<RankedTensorType>(normalizedLhs.getType()),
cast<RankedTensorType>(normalizedRhs.getType()),
info.lhsBatchShape,
info.rhsBatchShape,
info.outputBatchShape,
info.lhsBatch,
info.rhsBatch,
info.batch,
info.m,
info.k,
info.n,
false};
if (!useTransposedForm)
return plan;
plan.lhs = transposeLastTwoDims(normalizedRhs, rewriter, loc);
plan.rhs = transposeLastTwoDims(normalizedLhs, rewriter, loc);
plan.lhsType = cast<RankedTensorType>(plan.lhs.getType());
plan.rhsType = cast<RankedTensorType>(plan.rhs.getType());
std::swap(plan.lhsBatchShape, plan.rhsBatchShape);
std::swap(plan.lhsBatch, plan.rhsBatch);
plan.m = info.n;
plan.n = info.m;
plan.transposedResult = true;
return plan;
}
static Value normalizeMatMulOperand(
Value value, RankedTensorType normalizedType, bool wasVector, PatternRewriter& rewriter, Location loc) {
if (!wasVector)
return value;
return createMatrixFromVector(value, normalizedType, rewriter, loc);
}
static Value finalizeNormalizedMatMulResult(Value value,
RankedTensorType directOutType,
const NormalizedMatMulInfo& info,
PatternRewriter& rewriter,
Location loc) {
// The direct lowered result is always [flatBatch, normalizedM, normalizedN].
// Restore ONNX MatMul result rank by expanding right-aligned batch dimensions
// and removing the synthetic unit matrix axes introduced for vector operands.
Value result = value;
RankedTensorType currentType = directOutType;
if (info.outputBatchShape.size() > 1) {
SmallVector<int64_t> expandedShape(info.outputBatchShape.begin(), info.outputBatchShape.end());
expandedShape.push_back(info.m);
expandedShape.push_back(info.n);
auto expandedType = RankedTensorType::get(expandedShape, info.outType.getElementType(), info.outType.getEncoding());
result = expandBatchDims(result, expandedType, info.outputBatchShape.size(), rewriter, loc);
currentType = expandedType;
}
SmallVector<bool> removedAxes(currentType.getRank(), false);
if (info.outputBatchShape.empty())
removedAxes[0] = true;
if (info.lhsWasVector)
removedAxes[currentType.getRank() - 2] = true;
if (info.rhsWasVector)
removedAxes[currentType.getRank() - 1] = true;
return squeezeUnitDims(result, info.outType, removedAxes, rewriter, loc);
}
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
@@ -689,7 +950,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
auto shapeInfo = analyzeMatMulShape(matmulOp);
if (failed(shapeInfo) || shapeInfo->outType.getRank() != 2)
if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector || !shapeInfo->outputBatchShape.empty())
return failure();
Location loc = matmulOp.getLoc();
@@ -742,61 +1003,56 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
auto shapeInfo = analyzeMatMulShape(matmulOp);
if (failed(shapeInfo))
return failure();
if (shapeInfo->outType.getRank() == 2)
if (!shapeInfo->lhsWasVector && !shapeInfo->rhsWasVector && shapeInfo->outputBatchShape.empty())
return failure();
Location loc = matmulOp.getLoc();
bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
bool useTransposedForm = !shapeInfo->lhsWasVector && !shapeInfo->rhsWasVector
&& isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
Value lhs = collapseBatchDims(matmulOp.getA(), shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc);
Value rhs = collapseBatchDims(matmulOp.getB(), shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc);
int64_t lhsBatchForGemm = shapeInfo->lhsBatch;
int64_t rhsBatchForGemm = shapeInfo->rhsBatch;
int64_t gemmM = shapeInfo->m;
int64_t gemmK = shapeInfo->k;
int64_t gemmN = shapeInfo->n;
if (useTransposedForm) {
lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc);
lhsBatchForGemm = shapeInfo->rhsBatch;
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
rhsBatchForGemm = shapeInfo->lhsBatch;
gemmM = shapeInfo->n;
gemmN = shapeInfo->m;
}
Value lhs =
normalizeMatMulOperand(matmulOp.getA(), shapeInfo->normalizedLhsType, shapeInfo->lhsWasVector, rewriter, loc);
Value rhs =
normalizeMatMulOperand(matmulOp.getB(), shapeInfo->normalizedRhsType, shapeInfo->rhsWasVector, rewriter, loc);
lhs = collapseBatchDims(lhs, shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc);
rhs = collapseBatchDims(rhs, shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc);
MatMulLoweringPlan plan = buildLoweringPlan(lhs, rhs, *shapeInfo, useTransposedForm, rewriter, loc);
lhs = ensureBatchedTensor(lhs, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
rhs = ensureBatchedTensor(rhs, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
auto lhsBatchedType = cast<RankedTensorType>(lhs.getType());
auto rhsBatchedType = cast<RankedTensorType>(rhs.getType());
auto directOutType = RankedTensorType::get({shapeInfo->batch, gemmM, gemmN}, shapeInfo->outType.getElementType());
plan.lhs = ensureBatchedTensor(plan.lhs, plan.lhsBatch, plan.m, plan.k, rewriter, loc);
plan.rhs = ensureBatchedTensor(plan.rhs, plan.rhsBatch, plan.k, plan.n, rewriter, loc);
plan.lhsType = cast<RankedTensorType>(plan.lhs.getType());
plan.rhsType = cast<RankedTensorType>(plan.rhs.getType());
auto directOutType = RankedTensorType::get(
{plan.batch, plan.m, plan.n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding());
if (isCompileTimeComputable(rhs)) {
const int64_t numKSlices = ceilIntegerDivide(gemmK, crossbarSize.getValue());
const int64_t numOutHSlices = ceilIntegerDivide(gemmN, crossbarSize.getValue());
if (isCompileTimeComputable(plan.rhs)) {
const int64_t numKSlices = ceilIntegerDivide(plan.k, crossbarSize.getValue());
const int64_t numOutHSlices = ceilIntegerDivide(plan.n, crossbarSize.getValue());
const int64_t paddedReductionSize = numKSlices * static_cast<int64_t>(crossbarSize.getValue());
const int64_t paddedOutCols = numOutHSlices * static_cast<int64_t>(crossbarSize.getValue());
auto paddedLhsType = RankedTensorType::get(
{lhsBatchForGemm, gemmM, paddedReductionSize}, lhsBatchedType.getElementType(), lhsBatchedType.getEncoding());
auto paddedRhsType = RankedTensorType::get({shapeInfo->batch, paddedReductionSize, paddedOutCols},
rhsBatchedType.getElementType(),
rhsBatchedType.getEncoding());
{plan.lhsBatch, plan.m, paddedReductionSize}, plan.lhsType.getElementType(), plan.lhsType.getEncoding());
auto paddedRhsType = RankedTensorType::get(
{plan.batch, paddedReductionSize, paddedOutCols}, plan.rhsType.getElementType(), plan.rhsType.getEncoding());
auto paddedOutType =
RankedTensorType::get({shapeInfo->batch, gemmM, paddedOutCols}, shapeInfo->outType.getElementType());
RankedTensorType::get({plan.batch, plan.m, paddedOutCols}, shapeInfo->outType.getElementType());
auto paddedRhs = materializePaddedBatchedWeight(rhs, rhsBatchForGemm, shapeInfo->batch, paddedRhsType, rewriter);
auto paddedRhs =
materializePaddedBatchedWeight(plan.rhs, plan.rhsBatchShape, plan.outputBatchShape, paddedRhsType, rewriter);
if (succeeded(paddedRhs)) {
Value paddedLhs = createPaddedBatchedInputCompute(lhs, paddedLhsType, rewriter, loc);
const int64_t laneCount = shapeInfo->batch * gemmM * numKSlices * numOutHSlices;
Value paddedLhs = createPaddedBatchedInputCompute(plan.lhs, paddedLhsType, rewriter, loc);
const int64_t laneCount = plan.batch * plan.m * numKSlices * numOutHSlices;
auto partialPiecesType = RankedTensorType::get({laneCount, static_cast<int64_t>(crossbarSize.getValue())},
shapeInfo->outType.getElementType());
auto batchOp = createBatchedVmmBatch(paddedLhs,
*paddedRhs,
paddedLhsType,
lhsBatchForGemm,
plan.lhsBatchShape,
paddedRhsType,
rhsBatchForGemm,
plan.rhsBatchShape,
plan.outputBatchShape,
partialPiecesType,
gemmM,
plan.m,
numKSlices,
numOutHSlices,
rewriter,
@@ -807,34 +1063,35 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
partialPiecesType,
directOutType,
paddedOutType,
shapeInfo->batch,
plan.batch,
numKSlices,
rewriter,
loc);
if (failed(result))
return failure();
Value finalResult = *result;
if (useTransposedForm) {
auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n},
if (plan.transposedResult) {
auto transposedOutType = RankedTensorType::get({plan.batch, shapeInfo->m, shapeInfo->n},
shapeInfo->outType.getElementType(),
shapeInfo->outType.getEncoding());
finalResult =
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
.getResult();
}
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
finalResult = finalizeNormalizedMatMulResult(finalResult, directOutType, *shapeInfo, rewriter, loc);
rewriter.replaceOp(matmulOp, finalResult);
return success();
}
}
const int64_t laneCount = shapeInfo->batch * gemmM * gemmN;
const int64_t laneCount = plan.batch * plan.m * plan.n;
auto scalarPiecesType = RankedTensorType::get({laneCount, 1}, shapeInfo->outType.getElementType());
auto batchOp = createBatchedVvdmulBatch(lhs,
lhsBatchForGemm,
rhs,
rhsBatchForGemm,
lhsBatchedType,
rhsBatchedType,
auto batchOp = createBatchedVvdmulBatch(plan.lhs,
plan.lhsBatchShape,
plan.rhs,
plan.rhsBatchShape,
plan.outputBatchShape,
plan.lhsType,
plan.rhsType,
scalarPiecesType,
directOutType,
rewriter,
@@ -846,15 +1103,15 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
if (failed(result))
return failure();
Value finalResult = *result;
if (useTransposedForm) {
auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n},
if (plan.transposedResult) {
auto transposedOutType = RankedTensorType::get({plan.batch, shapeInfo->m, shapeInfo->n},
shapeInfo->outType.getElementType(),
shapeInfo->outType.getEncoding());
finalResult =
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
.getResult();
}
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
finalResult = finalizeNormalizedMatMulResult(finalResult, directOutType, *shapeInfo, rewriter, loc);
rewriter.replaceOp(matmulOp, finalResult);
return success();
}
@@ -238,14 +238,8 @@ static Value squeezeReducedAxes(Value keepdimsValue,
ArrayRef<bool> reducedAxes,
ConversionPatternRewriter& rewriter,
Location loc) {
if (resultType.getRank() == 0) {
SmallVector<Value> indices(cast<RankedTensorType>(keepdimsValue.getType()).getRank(),
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0));
Value element = tensor::ExtractOp::create(rewriter, loc, keepdimsValue, indices);
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
}
auto reassociation = buildCollapseReassociation(reducedAxes);
SmallVector<ReassociationIndices> reassociation =
resultType.getRank() == 0 ? SmallVector<ReassociationIndices> {} : buildCollapseReassociation(reducedAxes);
if (isCompileTimeComputable(keepdimsValue))
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
-60
View File
@@ -779,28 +779,6 @@ def matmul_matrix_vector():
save_model(model, "matmul/matrix_vector", "matmul_matrix_vector.onnx")
def matmul_vector_vector_dot():
"""Vector-vector MatMul producing a scalar output."""
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1024])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [])
B = numpy_helper.from_array(np.random.default_rng(97).uniform(-1, 1, (1024,)).astype(np.float32), name="B")
node = helper.make_node("MatMul", ["A", "B"], ["Y"])
graph = helper.make_graph([node], "matmul_vector_vector_dot", [A], [Y], initializer=[B])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "matmul/vector_vector_dot", "matmul_vector_vector_dot.onnx")
def matmul_batched_4d_broadcast():
"""Batched 4D MatMul with broadcast across leading dimensions."""
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 1, 3, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 5, 3, 6])
B = numpy_helper.from_array(np.random.default_rng(98).uniform(-1, 1, (1, 5, 4, 6)).astype(np.float32), name="B")
node = helper.make_node("MatMul", ["A", "B"], ["Y"])
graph = helper.make_graph([node], "matmul_batched_4d_broadcast", [A], [Y], initializer=[B])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "matmul/batched_4d_broadcast", "matmul_batched_4d_broadcast.onnx")
# ---------------------------------------------------------------------------
# Pooling tests
# ---------------------------------------------------------------------------
@@ -1560,17 +1538,6 @@ def add_channel_broadcast_1024():
save_model(model, "add/channel_broadcast_1024", "add_channel_broadcast_1024.onnx")
def add_scalar_runtime():
"""Elementwise Add with a runtime scalar RHS."""
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, 1024, 1, 1])
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 1, 1, 1])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 1, 1])
node = helper.make_node("Add", ["A", "B"], ["Y"])
graph = helper.make_graph([node], "add_scalar_runtime", [A, B], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "add/scalar_runtime", "add_scalar_runtime.onnx")
def add_leading_dimension_broadcast():
"""Elementwise Add with trailing-dimension broadcasting."""
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4])
@@ -1635,17 +1602,6 @@ def mul_channel_broadcast_1024():
save_model(model, "mul/channel_broadcast_1024", "mul_channel_broadcast_1024.onnx")
def mul_scalar_runtime():
"""Elementwise Mul with a runtime scalar RHS."""
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, 1024, 1, 1])
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 1, 1, 1])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 1, 1])
node = helper.make_node("Mul", ["A", "B"], ["Y"])
graph = helper.make_graph([node], "mul_scalar_runtime", [A, B], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "mul/scalar_runtime", "mul_scalar_runtime.onnx")
def mul_leading_dimension_broadcast():
"""Elementwise Mul with trailing-dimension broadcasting."""
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4])
@@ -1721,17 +1677,6 @@ def div_runtime_scalar_rhs():
save_model(model, "div/runtime_scalar_rhs", "div_runtime_scalar_rhs.onnx")
def div_runtime_scalar_lhs():
"""Elementwise Div with a scalar constant numerator."""
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 1024, 1, 1])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 1, 1])
A = numpy_helper.from_array(np.asarray([[[[2.0]]]], dtype=np.float32), name="A")
node = helper.make_node("Div", ["A", "B"], ["Y"])
graph = helper.make_graph([node], "div_runtime_scalar_lhs", [B], [Y], initializer=[A])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "div/runtime_scalar_lhs", "div_runtime_scalar_lhs.onnx")
def div_leading_dimension_broadcast():
"""Elementwise Div with trailing-dimension broadcasting."""
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4])
@@ -1812,8 +1757,6 @@ if __name__ == "__main__":
matmul_huge_1024()
matmul_vector_matrix()
matmul_matrix_vector()
matmul_vector_vector_dot()
matmul_batched_4d_broadcast()
print("\nGenerating Pooling tests:")
maxpool_basic()
@@ -1899,7 +1842,6 @@ if __name__ == "__main__":
add_broadcast_row()
add_after_gemm()
add_channel_broadcast_1024()
add_scalar_runtime()
add_leading_dimension_broadcast()
print("\nGenerating Mul tests:")
@@ -1907,7 +1849,6 @@ if __name__ == "__main__":
mul_scalar_constant()
mul_after_conv()
mul_channel_broadcast_1024()
mul_scalar_runtime()
mul_leading_dimension_broadcast()
print("\nGenerating Div tests:")
@@ -1916,7 +1857,6 @@ if __name__ == "__main__":
div_after_gemm()
div_channel_broadcast_1024()
div_runtime_scalar_rhs()
div_runtime_scalar_lhs()
div_leading_dimension_broadcast()
print("\nDone.")