#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/Support/LLVM.h" #include "Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { constexpr StringLiteral kPhaseMarker = "phase-check"; void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) { func.walk([&](Operation* op) { if (!hasWeightAlways(op)) return; for (Value result : op->getResults()) { if (hasOnlySpatialMvmVmmWeightUses(result)) continue; diagnostics.report(op, [&](Operation* illegalOp) { illegalOp->emitOpError() << kPhaseMarker << " weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights"; }); return; } }); } bool isRegionOrAncestorOf(Region& region, Region* candidate) { return candidate && (®ion == candidate || region.isAncestor(candidate)); } bool isValueDefinedInsideRegion(Value value, Region& region) { if (auto blockArg = dyn_cast(value)) return isRegionOrAncestorOf(region, blockArg.getOwner()->getParent()); if (Operation* definingOp = value.getDefiningOp()) return isRegionOrAncestorOf(region, definingOp->getParentRegion()); return false; } bool isLegalExternalCapture(Value value, Region& region) { if (isValueDefinedInsideRegion(value, region)) return true; Operation* definingOp = value.getDefiningOp(); return definingOp && definingOp->hasTrait(); } template void verifyComputeBodyCaptures(ComputeOpTy compute, StringRef kind, pim::CappedDiagnosticReporter& diagnostics) { Region& body = compute.getBody(); body.walk([&](Operation* nestedOp) { for (OpOperand& operand : nestedOp->getOpOperands()) { Value value = operand.get(); if (isLegalExternalCapture(value, body)) continue; Operation* definingOp = value.getDefiningOp(); diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) { InFlightDiagnostic diag = illegalOp->emitOpError() << kPhaseMarker << " " << kind << " body captures non-constant external operand #" << operand.getOperandNumber() << " used by " << nestedOp->getName().getStringRef(); diag << " (type " << value.getType() << ")"; if (definingOp) diag.attachNote(definingOp->getLoc()) << "defining op is " << definingOp->getName().getStringRef(); else if (auto blockArg = dyn_cast(value)) { if (Operation* owner = blockArg.getOwner()->getParentOp()) diag.attachNote(owner->getLoc()) << "external block argument belongs to " << owner->getName().getStringRef(); } }); } }); } bool isLegalHostBackedValue(Value value) { Operation* definingOp = value.getDefiningOp(); if (!definingOp) return isa(value); return definingOp->getDialect()->getNamespace() != "spat"; } template void verifyScheduledInputs(ComputeOpTy compute, bool allowChannelReceiveInputs, StringRef kind, pim::CappedDiagnosticReporter& diagnostics) { for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) { Operation* definingOp = input.getDefiningOp(); if (allowChannelReceiveInputs && isa_and_nonnull(definingOp)) continue; if (isLegalHostBackedValue(input)) continue; diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) { InFlightDiagnostic diag = illegalOp->emitOpError() << kPhaseMarker << " " << kind << " input #" << inputIndex << (allowChannelReceiveInputs ? " must come from the host or explicit spat.channel_receive" : " must come from the host"); if (definingOp) diag.attachNote(definingOp->getLoc()) << "illegal producer is " << definingOp->getName().getStringRef(); }); } } template void verifyNoNestedFragmentAssemblyBlueprints(ComputeOpTy compute, pim::CappedDiagnosticReporter& diagnostics) { compute.getBody().walk([&](spatial::SpatBlueprintOp blueprint) { std::optional mode = blueprint.getMode(); if (!mode || *mode != "fragment_assembly") return; diagnostics.report(blueprint.getOperation(), [&](Operation* illegalOp) { illegalOp->emitOpError("fragment assembly blueprint must be host-level after merge materialization"); }); }); } void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) { for (Operation& op : funcOp.getOps()) { if (isa(&op)) { continue; } if (isa(&op)) { diagnostics.report(&op, [&](Operation* illegalOp) { illegalOp->emitOpError() << kPhaseMarker << " scheduled Spatial compute op is not allowed in logical graph phase"; }); continue; } if (isa(&op)) { diagnostics.report(&op, [&](Operation* illegalOp) { illegalOp->emitOpError() << kPhaseMarker << " explicit channel communication is not expected before merge materialization"; }); continue; } if (isCompileTimeOp(&op)) continue; diagnostics.report(&op, [&](Operation* illegalOp) { illegalOp->emitOpError() << kPhaseMarker << " non-foldable top-level runtime op remains in logical Spatial graph; lower it inside spat.graph_compute"; }); } } void verifyScheduledTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) { for (Operation& op : funcOp.getOps()) { if (isa(&op)) { diagnostics.report(&op, [&](Operation* illegalOp) { illegalOp->emitOpError() << kPhaseMarker << " graph Spatial compute op remained after merge materialization"; }); } } } } // namespace LogicalResult verifyNoComputeBodyCaptures(func::FuncOp funcOp) { pim::CappedDiagnosticReporter diagnostics; for (auto compute : funcOp.getOps()) verifyComputeBodyCaptures(compute, "graph_compute", diagnostics); for (auto batch : funcOp.getOps()) verifyComputeBodyCaptures(batch, "graph_compute_batch", diagnostics); for (auto compute : funcOp.getOps()) verifyComputeBodyCaptures(compute, "scheduled_compute", diagnostics); for (auto batch : funcOp.getOps()) verifyComputeBodyCaptures(batch, "scheduled_compute_batch", diagnostics); diagnostics.emitSuppressedSummary(funcOp, "compute body capture verification failed"); return success(!diagnostics.hasFailure()); } LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { return verifyLogicalSpatialGraphInvariants(funcOp); } LogicalResult verifyLogicalSpatialGraphInvariants(func::FuncOp funcOp) { pim::CappedDiagnosticReporter diagnostics; verifyLogicalTopLevelOps(funcOp, diagnostics); checkWeightUseChains(funcOp, diagnostics); if (failed(verifyNoComputeBodyCaptures(funcOp))) return failure(); diagnostics.emitSuppressedSummary(funcOp, "logical Spatial graph verification failed"); return success(!diagnostics.hasFailure()); } LogicalResult verifyScheduledSpatialInvariants(func::FuncOp funcOp) { pim::CappedDiagnosticReporter diagnostics; verifyScheduledTopLevelOps(funcOp, diagnostics); for (auto compute : funcOp.getOps()) { verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics); verifyNoNestedFragmentAssemblyBlueprints(compute, diagnostics); } for (auto batch : funcOp.getOps()) { verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics); verifyNoNestedFragmentAssemblyBlueprints(batch, diagnostics); } if (failed(verifyNoComputeBodyCaptures(funcOp))) return failure(); diagnostics.emitSuppressedSummary(funcOp, "scheduled Spatial verification failed"); return success(!diagnostics.hasFailure()); } } // namespace onnx_mlir