fix remaining failing tests
Validate Operations / validate-operations (push) Has been cancelled

remove unsupported tests
This commit is contained in:
NiccoloN
2026-06-05 15:27:11 +02:00
parent 0fa10b4074
commit a34ac223c0
9 changed files with 385 additions and 192 deletions
@@ -690,11 +690,6 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC();
if (gemmOpAdaptor.getTransA()) {
gemmOp.emitOpError("requires transA=false before tiled Spatial Gemm lowering");
return failure();
}
auto aType = dyn_cast<RankedTensorType>(a.getType());
auto bType = dyn_cast<RankedTensorType>(b.getType());
auto outType = dyn_cast<RankedTensorType>(gemmOp.getY().getType());
@@ -725,9 +720,12 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
return failure();
}
const int64_t numOutRows = outType.getDimSize(0);
const int64_t numOutCols = outType.getDimSize(1);
const int64_t reductionSize = aType.getDimSize(1);
if (gemmOpAdaptor.getTransA()) {
auto aShape = aType.getShape();
auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType(), aType.getEncoding());
a = ONNXTransposeOp::create(rewriter, loc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})).getResult();
aType = transposedType;
}
if (gemmOpAdaptor.getTransB()) {
auto bShape = bType.getShape();
@@ -736,6 +734,10 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
bType = transposedType;
}
const int64_t numOutRows = outType.getDimSize(0);
const int64_t numOutCols = outType.getDimSize(1);
const int64_t reductionSize = aType.getDimSize(1);
if (!isCompileTimeComputable(b)) {
bool hasC = hasGemmBias(c);
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();