add verification of communication invariants at the end of spatial
Validate Operations / validate-operations (push) Has been cancelled

remove dead logic
This commit is contained in:
NiccoloN
2026-05-27 19:17:48 +02:00
parent 783dffe553
commit 00414dd1d9
8 changed files with 110 additions and 58 deletions
@@ -31,6 +31,84 @@ void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diag
}); });
} }
Region* getParentRegion(Value value) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParent();
if (Operation* definingOp = value.getDefiningOp())
return definingOp->getParentRegion();
return nullptr;
}
bool isDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value);
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion));
}
bool isLegalHostBackedValue(Value value) {
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return isa<BlockArgument>(value);
if (isa<spatial::SpatChannelReceiveOp>(definingOp))
return false;
return definingOp->getDialect()->getNamespace() != "spat";
}
LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp,
ValueRange inputs,
bool allowChannelReceiveInputs,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
for (auto [inputIndex, input] : llvm::enumerate(inputs)) {
unsigned currentInputIndex = inputIndex;
Operation* definingOp = input.getDefiningOp();
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
continue;
if (isLegalHostBackedValue(input))
continue;
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " input #" << currentInputIndex
<< (allowChannelReceiveInputs
? " must come from the host or an explicit "
"spat.channel_receive"
: " must come from the host");
if (definingOp)
diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName();
});
return failure();
}
return success();
}
void verifyNoExternalTensorCaptures(Operation* ownerOp,
Region& region,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
region.walk([&](Operation* op) {
for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get();
if (!isa<TensorType>(value.getType()))
continue;
if (isDefinedInsideRegion(value, region) || isa<BlockArgument>(value))
continue;
Operation* definingOp = value.getDefiningOp();
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
diagnostics.report(ownerOp, [&](Operation* illegalOp) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor "
<< "values";
diagnostic.attachNote(op->getLoc())
<< "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by "
<< (definingOp ? definingOp->getName().getStringRef() : StringRef("<block argument>"));
});
}
});
}
} // namespace } // namespace
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
@@ -53,4 +131,27 @@ LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
return success(!diagnostics.hasFailure()); return success(!diagnostics.hasFailure());
} }
LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
(void)verifyComputeLikeInputs(
computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics);
verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics);
}
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
(void)verifyComputeLikeInputs(computeBatchOp.getOperation(),
computeBatchOp.getInputs(),
/*allowChannelReceiveInputs=*/false,
"spat.compute_batch",
diagnostics);
verifyNoExternalTensorCaptures(
computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics);
}
diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed");
return success(!diagnostics.hasFailure());
}
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -6,5 +6,6 @@
namespace onnx_mlir { namespace onnx_mlir {
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp); mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
mlir::LogicalResult verifySpatialCommunicationInvariants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -178,25 +178,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
mapper.map(*oldArg, copied); mapper.map(*oldArg, copied);
} }
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
if (auto mapped = mapper.lookupOrNull(capturedTensor))
return mapped;
auto capturedType = cast<ShapedType>(capturedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
capturedTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, capturedTensor))
.getOutput();
mapper.map(capturedTensor, copied);
return copied;
};
SmallVector<Value> hostOutputTensors(returnOperandIndices.size()); SmallVector<Value> hostOutputTensors(returnOperandIndices.size());
auto getOrCreateHostOutputTensor = [&](unsigned resultIndex, Location resultLoc) -> Value { auto getOrCreateHostOutputTensor = [&](unsigned resultIndex, Location resultLoc) -> Value {
Value& hostOutputTensor = hostOutputTensors[resultIndex]; Value& hostOutputTensor = hostOutputTensors[resultIndex];
@@ -280,7 +261,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
if (definingOp && definingOp->getBlock() == &oldBlock) if (definingOp && definingOp->getBlock() == &oldBlock)
continue; continue;
materializeCapturedTensor(operand); return computeBatchOp.emitOpError(
"expected external tensor communication to be materialized in Spatial before batch lowering");
} }
Operation* cloned = rewriter.clone(op, mapper); Operation* cloned = rewriter.clone(op, mapper);
@@ -10,7 +10,6 @@ add_pim_library(OMSpatialToPim
ComputeLikeRegionUtils.cpp ComputeLikeRegionUtils.cpp
CoreLoweringPatterns.cpp CoreLoweringPatterns.cpp
GlobalTensorMaterialization.cpp GlobalTensorMaterialization.cpp
PhaseVerification.cpp
ReturnPathNormalization.cpp ReturnPathNormalization.cpp
TensorPackingPatterns.cpp TensorPackingPatterns.cpp
@@ -1,20 +0,0 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult verifySpatialToPimBoundary(ModuleOp moduleOp) {
bool hasFailure = false;
moduleOp.walk([&](Operation* op) {
if (op->getDialect()->getNamespace() != "spat")
return;
op->emitError("illegal Spatial operation remains after Spatial-to-PIM lowering");
hasFailure = true;
});
return success(!hasFailure);
}
} // namespace onnx_mlir
@@ -1,9 +0,0 @@
#pragma once
#include "mlir/IR/BuiltinOps.h"
namespace onnx_mlir {
mlir::LogicalResult verifySpatialToPimBoundary(mlir::ModuleOp moduleOp);
} // namespace onnx_mlir
@@ -30,7 +30,6 @@
#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" #include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "Conversion/SpatialToPim/Common.hpp" #include "Conversion/SpatialToPim/Common.hpp"
#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp" #include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
#include "Conversion/SpatialToPim/PhaseVerification.hpp"
#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp" #include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "Dialect/Pim/PimOps.hpp" #include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Spatial/SpatialOps.hpp" #include "Dialect/Spatial/SpatialOps.hpp"
@@ -284,13 +283,6 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
signalPassFailure(); signalPassFailure();
return; return;
} }
if (failed(verifySpatialToPimBoundary(moduleOp))) {
moduleOp.emitError("Spatial-to-PIM boundary verification failed");
signalPassFailure();
return;
}
// Dump to file for debug // Dump to file for debug
dumpModule(moduleOp, "pim0"); dumpModule(moduleOp, "pim0");
} }
@@ -38,6 +38,7 @@
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp" #include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir; using namespace mlir;
@@ -651,6 +652,11 @@ public:
signalPassFailure(); signalPassFailure();
return; return;
} }
if (failed(verifySpatialCommunicationInvariants(func))) {
func.emitOpError("merged Spatial communication invariant verification failed");
signalPassFailure();
return;
}
emitMergeIrCounts("final-post-merge", func); emitMergeIrCounts("final-post-merge", func);
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged"); dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size()); generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());