#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Support/LLVM.h" #include "Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) { func.walk([&](Operation* op) { if (!hasWeightAlways(op)) return; for (Value result : op->getResults()) { if (hasOnlySpatialMvmVmmWeightUses(result)) continue; diagnostics.report(op, [&](Operation* illegalOp) { illegalOp->emitOpError( "weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights"); }); return; } }); } Region* getParentRegion(Value value) { if (auto blockArg = dyn_cast(value)) return blockArg.getOwner()->getParent(); if (Operation* definingOp = value.getDefiningOp()) return definingOp->getParentRegion(); return nullptr; } bool isDefinedInsideRegion(Value value, Region& region) { Region* parentRegion = getParentRegion(value); return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion)); } bool isLegalHostBackedValue(Value value) { Operation* definingOp = value.getDefiningOp(); if (!definingOp) return isa(value); if (isa(definingOp)) return false; return definingOp->getDialect()->getNamespace() != "spat"; } LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp, ValueRange inputs, bool allowChannelReceiveInputs, StringRef kind, pim::CappedDiagnosticReporter& diagnostics) { for (auto [inputIndex, input] : llvm::enumerate(inputs)) { unsigned currentInputIndex = inputIndex; Operation* definingOp = input.getDefiningOp(); if (allowChannelReceiveInputs && isa_and_nonnull(definingOp)) continue; if (isLegalHostBackedValue(input)) continue; diagnostics.report(computeLikeOp, [&](Operation* illegalOp) { InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " input #" << currentInputIndex << (allowChannelReceiveInputs ? " must come from the host or an explicit " "spat.channel_receive" : " must come from the host"); if (definingOp) diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName(); }); return failure(); } return success(); } void verifyNoExternalTensorCaptures(Operation* ownerOp, Region& region, StringRef kind, pim::CappedDiagnosticReporter& diagnostics) { region.walk([&](Operation* op) { for (OpOperand& operand : op->getOpOperands()) { Value value = operand.get(); if (!isa(value.getType())) continue; if (isDefinedInsideRegion(value, region) || isa(value)) continue; Operation* definingOp = value.getDefiningOp(); if (definingOp && definingOp->hasTrait()) continue; diagnostics.report(ownerOp, [&](Operation* illegalOp) { InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor " << "values"; diagnostic.attachNote(op->getLoc()) << "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by " << (definingOp ? definingOp->getName().getStringRef() : StringRef("")); }); } }); } } // namespace LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { pim::CappedDiagnosticReporter diagnostics; for (Operation& op : funcOp.getOps()) { if (isa(&op)) continue; if (isCompileTimeOp(&op)) continue; diagnostics.report(&op, [](Operation* illegalOp) { illegalOp->emitOpError( "non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute"); }); } checkWeightUseChains(funcOp, diagnostics); diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed"); return success(!diagnostics.hasFailure()); } LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) { pim::CappedDiagnosticReporter diagnostics; for (auto computeOp : funcOp.getOps()) { (void) verifyComputeLikeInputs( computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics); verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics); } for (auto computeBatchOp : funcOp.getOps()) { (void) verifyComputeLikeInputs(computeBatchOp.getOperation(), computeBatchOp.getInputs(), /*allowChannelReceiveInputs=*/false, "spat.compute_batch", diagnostics); verifyNoExternalTensorCaptures( computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics); } diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed"); return success(!diagnostics.hasFailure()); } } // namespace onnx_mlir