diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp index 795cd37..0ebdaeb 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp @@ -31,6 +31,84 @@ void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diag }); } +Region* getParentRegion(Value value) { + if (auto blockArg = dyn_cast(value)) + return blockArg.getOwner()->getParent(); + if (Operation* definingOp = value.getDefiningOp()) + return definingOp->getParentRegion(); + return nullptr; +} + +bool isDefinedInsideRegion(Value value, Region& region) { + Region* parentRegion = getParentRegion(value); + return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion)); +} + +bool isLegalHostBackedValue(Value value) { + Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return isa(value); + + if (isa(definingOp)) + return false; + + return definingOp->getDialect()->getNamespace() != "spat"; +} + +LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp, + ValueRange inputs, + bool allowChannelReceiveInputs, + StringRef kind, + pim::CappedDiagnosticReporter& diagnostics) { + for (auto [inputIndex, input] : llvm::enumerate(inputs)) { + unsigned currentInputIndex = inputIndex; + Operation* definingOp = input.getDefiningOp(); + if (allowChannelReceiveInputs && isa_and_nonnull(definingOp)) + continue; + if (isLegalHostBackedValue(input)) + continue; + + diagnostics.report(computeLikeOp, [&](Operation* illegalOp) { + InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " input #" << currentInputIndex + << (allowChannelReceiveInputs + ? " must come from the host or an explicit " + "spat.channel_receive" + : " must come from the host"); + if (definingOp) + diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName(); + }); + return failure(); + } + return success(); +} + +void verifyNoExternalTensorCaptures(Operation* ownerOp, + Region& region, + StringRef kind, + pim::CappedDiagnosticReporter& diagnostics) { + region.walk([&](Operation* op) { + for (OpOperand& operand : op->getOpOperands()) { + Value value = operand.get(); + if (!isa(value.getType())) + continue; + if (isDefinedInsideRegion(value, region) || isa(value)) + continue; + + Operation* definingOp = value.getDefiningOp(); + if (definingOp && definingOp->hasTrait()) + continue; + + diagnostics.report(ownerOp, [&](Operation* illegalOp) { + InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor " + << "values"; + diagnostic.attachNote(op->getLoc()) + << "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by " + << (definingOp ? definingOp->getName().getStringRef() : StringRef("")); + }); + } + }); +} + } // namespace LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { @@ -53,4 +131,27 @@ LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { return success(!diagnostics.hasFailure()); } +LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) { + pim::CappedDiagnosticReporter diagnostics; + + for (auto computeOp : funcOp.getOps()) { + (void)verifyComputeLikeInputs( + computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics); + verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics); + } + + for (auto computeBatchOp : funcOp.getOps()) { + (void)verifyComputeLikeInputs(computeBatchOp.getOperation(), + computeBatchOp.getInputs(), + /*allowChannelReceiveInputs=*/false, + "spat.compute_batch", + diagnostics); + verifyNoExternalTensorCaptures( + computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics); + } + + diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed"); + return success(!diagnostics.hasFailure()); +} + } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp index a5fd052..3ac5b9c 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp @@ -6,5 +6,6 @@ namespace onnx_mlir { mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp); +mlir::LogicalResult verifySpatialCommunicationInvariants(mlir::func::FuncOp funcOp); } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index 97d1d81..7c071bc 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -178,25 +178,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute mapper.map(*oldArg, copied); } - auto materializeCapturedTensor = [&](Value capturedTensor) -> Value { - if (auto mapped = mapper.lookupOrNull(capturedTensor)) - return mapped; - - auto capturedType = cast(capturedTensor.getType()); - auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType); - auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, - loc, - outputBuffer.getType(), - outputBuffer, - capturedTensor, - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(0), - getTensorSizeInBytesAttr(rewriter, capturedTensor)) - .getOutput(); - mapper.map(capturedTensor, copied); - return copied; - }; - SmallVector hostOutputTensors(returnOperandIndices.size()); auto getOrCreateHostOutputTensor = [&](unsigned resultIndex, Location resultLoc) -> Value { Value& hostOutputTensor = hostOutputTensors[resultIndex]; @@ -280,7 +261,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute if (definingOp && definingOp->getBlock() == &oldBlock) continue; - materializeCapturedTensor(operand); + return computeBatchOp.emitOpError( + "expected external tensor communication to be materialized in Spatial before batch lowering"); } Operation* cloned = rewriter.clone(op, mapper); diff --git a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt index 990bad7..1b56f12 100644 --- a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt +++ b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt @@ -10,7 +10,6 @@ add_pim_library(OMSpatialToPim ComputeLikeRegionUtils.cpp CoreLoweringPatterns.cpp GlobalTensorMaterialization.cpp - PhaseVerification.cpp ReturnPathNormalization.cpp TensorPackingPatterns.cpp diff --git a/src/PIM/Conversion/SpatialToPim/PhaseVerification.cpp b/src/PIM/Conversion/SpatialToPim/PhaseVerification.cpp deleted file mode 100644 index 4c18886..0000000 --- a/src/PIM/Conversion/SpatialToPim/PhaseVerification.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" - -using namespace mlir; - -namespace onnx_mlir { - -LogicalResult verifySpatialToPimBoundary(ModuleOp moduleOp) { - bool hasFailure = false; - moduleOp.walk([&](Operation* op) { - if (op->getDialect()->getNamespace() != "spat") - return; - - op->emitError("illegal Spatial operation remains after Spatial-to-PIM lowering"); - hasFailure = true; - }); - return success(!hasFailure); -} - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/PhaseVerification.hpp b/src/PIM/Conversion/SpatialToPim/PhaseVerification.hpp deleted file mode 100644 index d17da32..0000000 --- a/src/PIM/Conversion/SpatialToPim/PhaseVerification.hpp +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once - -#include "mlir/IR/BuiltinOps.h" - -namespace onnx_mlir { - -mlir::LogicalResult verifySpatialToPimBoundary(mlir::ModuleOp moduleOp); - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 1b06339..5c1ab93 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -30,7 +30,6 @@ #include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" #include "Conversion/SpatialToPim/Common.hpp" #include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp" -#include "Conversion/SpatialToPim/PhaseVerification.hpp" #include "Conversion/SpatialToPim/TensorPackingPatterns.hpp" #include "Dialect/Pim/PimOps.hpp" #include "Dialect/Spatial/SpatialOps.hpp" @@ -284,13 +283,6 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { signalPassFailure(); return; } - - if (failed(verifySpatialToPimBoundary(moduleOp))) { - moduleOp.emitError("Spatial-to-PIM boundary verification failed"); - signalPassFailure(); - return; - } - // Dump to file for debug dumpModule(moduleOp, "pim0"); } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 3c143b9..4cb3682 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -38,6 +38,7 @@ #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; @@ -651,6 +652,11 @@ public: signalPassFailure(); return; } + if (failed(verifySpatialCommunicationInvariants(func))) { + func.emitOpError("merged Spatial communication invariant verification failed"); + signalPassFailure(); + return; + } emitMergeIrCounts("final-post-merge", func); dumpModule(cast(func->getParentOp()), "spatial1_dcp_merged"); generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());