fix MatMul pattern non-contiguous extract_slices
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m31s

This commit is contained in:
NiccoloN
2026-04-23 14:44:30 +02:00
parent cff929a083
commit 5545b0f672
6 changed files with 254 additions and 78 deletions

View File

@@ -87,8 +87,7 @@ void ONNXToSpatialPass::runOnOperation() {
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
target.addDynamicallyLegalOp<ONNXMatMulOp>(
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
target.addIllegalOp<ONNXMatMulOp>();
target.addIllegalOp<ONNXAddOp>();
target.addIllegalOp<ONNXDivOp>();
target.addIllegalOp<ONNXMulOp>();
@@ -391,11 +390,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight =
!constantOp->use_empty() && llvm::all_of(constantOp->getUses(), [](OpOperand& use) -> bool {
return isSpatialMvmVmmWeightUse(use);
});
if (isAlwaysWeight)
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
markWeightAlways(constantOp);
});
}