add verification of communication invariants at the end of spatial
Validate Operations / validate-operations (push) Has been cancelled

remove dead logic
This commit is contained in:
NiccoloN
2026-05-27 19:17:48 +02:00
parent 783dffe553
commit 00414dd1d9
8 changed files with 110 additions and 58 deletions
@@ -31,6 +31,84 @@ void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diag
});
}
Region* getParentRegion(Value value) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParent();
if (Operation* definingOp = value.getDefiningOp())
return definingOp->getParentRegion();
return nullptr;
}
bool isDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value);
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion));
}
bool isLegalHostBackedValue(Value value) {
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return isa<BlockArgument>(value);
if (isa<spatial::SpatChannelReceiveOp>(definingOp))
return false;
return definingOp->getDialect()->getNamespace() != "spat";
}
LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp,
ValueRange inputs,
bool allowChannelReceiveInputs,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
for (auto [inputIndex, input] : llvm::enumerate(inputs)) {
unsigned currentInputIndex = inputIndex;
Operation* definingOp = input.getDefiningOp();
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
continue;
if (isLegalHostBackedValue(input))
continue;
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " input #" << currentInputIndex
<< (allowChannelReceiveInputs
? " must come from the host or an explicit "
"spat.channel_receive"
: " must come from the host");
if (definingOp)
diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName();
});
return failure();
}
return success();
}
void verifyNoExternalTensorCaptures(Operation* ownerOp,
Region& region,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
region.walk([&](Operation* op) {
for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get();
if (!isa<TensorType>(value.getType()))
continue;
if (isDefinedInsideRegion(value, region) || isa<BlockArgument>(value))
continue;
Operation* definingOp = value.getDefiningOp();
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
diagnostics.report(ownerOp, [&](Operation* illegalOp) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor "
<< "values";
diagnostic.attachNote(op->getLoc())
<< "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by "
<< (definingOp ? definingOp->getName().getStringRef() : StringRef("<block argument>"));
});
}
});
}
} // namespace
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
@@ -53,4 +131,27 @@ LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
return success(!diagnostics.hasFailure());
}
LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
(void)verifyComputeLikeInputs(
computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics);
verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics);
}
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
(void)verifyComputeLikeInputs(computeBatchOp.getOperation(),
computeBatchOp.getInputs(),
/*allowChannelReceiveInputs=*/false,
"spat.compute_batch",
diagnostics);
verifyNoExternalTensorCaptures(
computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics);
}
diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed");
return success(!diagnostics.hasFailure());
}
} // namespace onnx_mlir