This commit is contained in:
ilgeco
2026-06-24 15:52:07 +02:00
parent 2b4115699a
commit 62dd40ee89
47 changed files with 7993 additions and 1100 deletions
@@ -1,4 +1,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/LLVM.h"
#include "Common/IR/WeightUtils.hpp"
@@ -13,6 +15,8 @@ namespace onnx_mlir {
namespace {
constexpr StringLiteral kPhaseMarker = "RAPTOR_PHASE_CHECK";
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
func.walk([&](Operation* op) {
if (!hasWeightAlways(op))
@@ -23,134 +27,174 @@ void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diag
continue;
diagnostics.report(op, [&](Operation* illegalOp) {
illegalOp->emitOpError(
"weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights");
illegalOp->emitOpError()
<< kPhaseMarker
<< " weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights";
});
return;
}
});
}
Region* getParentRegion(Value value) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParent();
if (Operation* definingOp = value.getDefiningOp())
return definingOp->getParentRegion();
return nullptr;
bool isRegionOrAncestorOf(Region& region, Region* candidate) {
return candidate && (&region == candidate || region.isAncestor(candidate));
}
bool isDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value);
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion));
bool isValueDefinedInsideRegion(Value value, Region& region) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return isRegionOrAncestorOf(region, blockArg.getOwner()->getParent());
if (Operation* definingOp = value.getDefiningOp())
return isRegionOrAncestorOf(region, definingOp->getParentRegion());
return false;
}
bool isLegalExternalCapture(Value value, Region& region) {
if (isValueDefinedInsideRegion(value, region))
return true;
Operation* definingOp = value.getDefiningOp();
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
}
template <typename ComputeOpTy>
void verifyComputeBodyCaptures(ComputeOpTy compute, StringRef kind, pim::CappedDiagnosticReporter& diagnostics) {
Region& body = compute.getBody();
body.walk([&](Operation* nestedOp) {
for (OpOperand& operand : nestedOp->getOpOperands()) {
Value value = operand.get();
if (isLegalExternalCapture(value, body))
continue;
Operation* definingOp = value.getDefiningOp();
diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) {
InFlightDiagnostic diag =
illegalOp->emitOpError() << kPhaseMarker << " " << kind << " body captures non-constant external operand #"
<< operand.getOperandNumber() << " used by " << nestedOp->getName().getStringRef();
diag << " (type " << value.getType() << ")";
if (definingOp)
diag.attachNote(definingOp->getLoc()) << "defining op is " << definingOp->getName().getStringRef();
else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
if (Operation* owner = blockArg.getOwner()->getParentOp())
diag.attachNote(owner->getLoc())
<< "external block argument belongs to " << owner->getName().getStringRef();
}
});
}
});
}
bool isLegalHostBackedValue(Value value) {
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return isa<BlockArgument>(value);
if (isa<spatial::SpatChannelReceiveOp>(definingOp))
return false;
return definingOp->getDialect()->getNamespace() != "spat";
}
LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp,
ValueRange inputs,
bool allowChannelReceiveInputs,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
for (auto [inputIndex, input] : llvm::enumerate(inputs)) {
unsigned currentInputIndex = inputIndex;
template <typename ComputeOpTy>
void verifyScheduledInputs(ComputeOpTy compute,
bool allowChannelReceiveInputs,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
Operation* definingOp = input.getDefiningOp();
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
continue;
if (isLegalHostBackedValue(input))
continue;
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError()
<< kind << " input #" << currentInputIndex
<< (allowChannelReceiveInputs ? " must come from the host or an explicit "
"spat.channel_receive"
: " must come from the host");
diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) {
InFlightDiagnostic diag = illegalOp->emitOpError()
<< kPhaseMarker << " " << kind << " input #" << inputIndex
<< (allowChannelReceiveInputs ? " must come from the host or explicit spat.channel_receive"
: " must come from the host");
if (definingOp)
diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName();
diag.attachNote(definingOp->getLoc()) << "illegal producer is " << definingOp->getName().getStringRef();
});
return failure();
}
return success();
}
void verifyNoExternalTensorCaptures(Operation* ownerOp,
Region& region,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
region.walk([&](Operation* op) {
for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get();
if (!isa<TensorType>(value.getType()))
continue;
if (isDefinedInsideRegion(value, region) || isa<BlockArgument>(value))
continue;
void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
for (Operation& op : funcOp.getOps()) {
if (isa<func::ReturnOp,
spatial::SpatGraphCompute,
spatial::SpatGraphComputeBatch,
spatial::SpatConv2DPlanOp,
spatial::SpatReluPlanOp,
spatial::SpatReconciliatorOp,
spatial::SpatMaterializeLayoutOp>(&op)) {
continue;
}
if (isa<spatial::SpatScheduledCompute, spatial::SpatScheduledComputeBatch>(&op)) {
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << kPhaseMarker << " scheduled Spatial compute op is not allowed in logical graph phase";
});
continue;
}
if (isa<spatial::SpatChannelReceiveOp, spatial::SpatChannelSendOp>(&op)) {
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << kPhaseMarker
<< " explicit channel communication is not expected before merge materialization";
});
continue;
}
if (isCompileTimeOp(&op))
continue;
Operation* definingOp = value.getDefiningOp();
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError()
<< kPhaseMarker << " non-foldable top-level runtime op remains in logical Spatial graph; lower it inside spat.graph_compute";
});
}
}
diagnostics.report(ownerOp, [&](Operation* illegalOp) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor "
<< "values";
diagnostic.attachNote(op->getLoc())
<< "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by "
<< (definingOp ? definingOp->getName().getStringRef() : StringRef("<block argument>"));
void verifyScheduledTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
for (Operation& op : funcOp.getOps()) {
if (isa<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>(&op)) {
diagnostics.report(&op, [&](Operation* illegalOp) {
illegalOp->emitOpError() << kPhaseMarker << " graph Spatial compute op remained after merge materialization";
});
}
});
}
}
} // namespace
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
LogicalResult verifyNoComputeBodyCaptures(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
for (Operation& op : funcOp.getOps()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
continue;
if (isCompileTimeOp(&op))
continue;
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError(
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
});
}
checkWeightUseChains(funcOp, diagnostics);
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
for (auto compute : funcOp.getOps<spatial::SpatGraphCompute>())
verifyComputeBodyCaptures(compute, "graph_compute", diagnostics);
for (auto batch : funcOp.getOps<spatial::SpatGraphComputeBatch>())
verifyComputeBodyCaptures(batch, "graph_compute_batch", diagnostics);
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>())
verifyComputeBodyCaptures(compute, "scheduled_compute", diagnostics);
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>())
verifyComputeBodyCaptures(batch, "scheduled_compute_batch", diagnostics);
diagnostics.emitSuppressedSummary(funcOp, "compute body capture verification failed");
return success(!diagnostics.hasFailure());
}
LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) {
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { return verifyLogicalSpatialGraphInvariants(funcOp); }
LogicalResult verifyLogicalSpatialGraphInvariants(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
verifyLogicalTopLevelOps(funcOp, diagnostics);
checkWeightUseChains(funcOp, diagnostics);
if (failed(verifyNoComputeBodyCaptures(funcOp)))
return failure();
diagnostics.emitSuppressedSummary(funcOp, "logical Spatial graph verification failed");
return success(!diagnostics.hasFailure());
}
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
(void) verifyComputeLikeInputs(
computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics);
verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics);
}
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
(void) verifyComputeLikeInputs(computeBatchOp.getOperation(),
computeBatchOp.getInputs(),
/*allowChannelReceiveInputs=*/false,
"spat.compute_batch",
diagnostics);
verifyNoExternalTensorCaptures(
computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics);
}
diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed");
LogicalResult verifyScheduledSpatialInvariants(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
verifyScheduledTopLevelOps(funcOp, diagnostics);
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>())
verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics);
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>())
verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics);
if (failed(verifyNoComputeBodyCaptures(funcOp)))
return failure();
diagnostics.emitSuppressedSummary(funcOp, "scheduled Spatial verification failed");
return success(!diagnostics.hasFailure());
}