Verify fix
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-18 17:20:40 +02:00
parent 2836e759ab
commit aa088e2ba5
2 changed files with 19 additions and 1 deletions
+18 -1
View File
@@ -3,6 +3,7 @@
#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
#include "llvm/Support/LogicalResult.h" #include "llvm/Support/LogicalResult.h"
@@ -338,6 +339,19 @@ LogicalResult SpatConcatOp::verify() {
return success(); return success();
} }
LogicalResult verifyComputeResultsUses(Operation* op) {
if (!isa<SpatCompute, SpatComputeBatch>(op))
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation");
if (!llvm::all_of(op->getResults(), [](Value result) {
return llvm::all_of(result.getUsers(), [](Operation* op) {
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
});
})) {
return op->emitError("ComputeResult used directly inside another Compute" );
}
return success();
}
LogicalResult SpatCompute::verify() { LogicalResult SpatCompute::verify() {
auto& block = getBody().front(); auto& block = getBody().front();
if (block.mightHaveTerminator()) { if (block.mightHaveTerminator()) {
@@ -375,7 +389,8 @@ LogicalResult SpatCompute::verify() {
for (auto arg : block.getArguments()) for (auto arg : block.getArguments())
if (arg.use_empty()) if (arg.use_empty())
return emitError("ComputeOp block argument is not used"); return emitError("ComputeOp block argument is not used");
if (failed(verifyComputeResultsUses(this->getOperation())))
return failure();
return success(); return success();
} }
@@ -485,6 +500,8 @@ LogicalResult SpatComputeBatch::verify() {
return emitError("body block argument type must match input type"); return emitError("body block argument type must match input type");
} }
if (failed(verifyComputeResultsUses(this->getOperation())))
return failure();
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane); return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
} }
@@ -65,6 +65,7 @@ std::optional<ProducerValueRef> getProducerValueRef(Value value) {
if (!op) if (!op)
return std::nullopt; return std::nullopt;
//TODO Extract Slice is not the only global non compute operation. There are other legal op
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) { while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
value = extract.getSource(); value = extract.getSource();
op = value.getDefiningOp(); op = value.getDefiningOp();