add verification of communication invariants at the end of spatial
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
remove dead logic
This commit is contained in:
@@ -31,6 +31,84 @@ void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diag
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Region* getParentRegion(Value value) {
|
||||||
|
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||||
|
return blockArg.getOwner()->getParent();
|
||||||
|
if (Operation* definingOp = value.getDefiningOp())
|
||||||
|
return definingOp->getParentRegion();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isDefinedInsideRegion(Value value, Region& region) {
|
||||||
|
Region* parentRegion = getParentRegion(value);
|
||||||
|
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isLegalHostBackedValue(Value value) {
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return isa<BlockArgument>(value);
|
||||||
|
|
||||||
|
if (isa<spatial::SpatChannelReceiveOp>(definingOp))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return definingOp->getDialect()->getNamespace() != "spat";
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp,
|
||||||
|
ValueRange inputs,
|
||||||
|
bool allowChannelReceiveInputs,
|
||||||
|
StringRef kind,
|
||||||
|
pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
|
for (auto [inputIndex, input] : llvm::enumerate(inputs)) {
|
||||||
|
unsigned currentInputIndex = inputIndex;
|
||||||
|
Operation* definingOp = input.getDefiningOp();
|
||||||
|
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
|
||||||
|
continue;
|
||||||
|
if (isLegalHostBackedValue(input))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) {
|
||||||
|
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " input #" << currentInputIndex
|
||||||
|
<< (allowChannelReceiveInputs
|
||||||
|
? " must come from the host or an explicit "
|
||||||
|
"spat.channel_receive"
|
||||||
|
: " must come from the host");
|
||||||
|
if (definingOp)
|
||||||
|
diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName();
|
||||||
|
});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void verifyNoExternalTensorCaptures(Operation* ownerOp,
|
||||||
|
Region& region,
|
||||||
|
StringRef kind,
|
||||||
|
pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
|
region.walk([&](Operation* op) {
|
||||||
|
for (OpOperand& operand : op->getOpOperands()) {
|
||||||
|
Value value = operand.get();
|
||||||
|
if (!isa<TensorType>(value.getType()))
|
||||||
|
continue;
|
||||||
|
if (isDefinedInsideRegion(value, region) || isa<BlockArgument>(value))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
diagnostics.report(ownerOp, [&](Operation* illegalOp) {
|
||||||
|
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor "
|
||||||
|
<< "values";
|
||||||
|
diagnostic.attachNote(op->getLoc())
|
||||||
|
<< "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by "
|
||||||
|
<< (definingOp ? definingOp->getName().getStringRef() : StringRef("<block argument>"));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
||||||
@@ -53,4 +131,27 @@ LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
|||||||
return success(!diagnostics.hasFailure());
|
return success(!diagnostics.hasFailure());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) {
|
||||||
|
pim::CappedDiagnosticReporter diagnostics;
|
||||||
|
|
||||||
|
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||||
|
(void)verifyComputeLikeInputs(
|
||||||
|
computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics);
|
||||||
|
verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
||||||
|
(void)verifyComputeLikeInputs(computeBatchOp.getOperation(),
|
||||||
|
computeBatchOp.getInputs(),
|
||||||
|
/*allowChannelReceiveInputs=*/false,
|
||||||
|
"spat.compute_batch",
|
||||||
|
diagnostics);
|
||||||
|
verifyNoExternalTensorCaptures(
|
||||||
|
computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics);
|
||||||
|
}
|
||||||
|
|
||||||
|
diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed");
|
||||||
|
return success(!diagnostics.hasFailure());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -6,5 +6,6 @@
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
|
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
|
||||||
|
mlir::LogicalResult verifySpatialCommunicationInvariants(mlir::func::FuncOp funcOp);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -178,25 +178,6 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
|||||||
mapper.map(*oldArg, copied);
|
mapper.map(*oldArg, copied);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
|
|
||||||
if (auto mapped = mapper.lookupOrNull(capturedTensor))
|
|
||||||
return mapped;
|
|
||||||
|
|
||||||
auto capturedType = cast<ShapedType>(capturedTensor.getType());
|
|
||||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType);
|
|
||||||
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
outputBuffer.getType(),
|
|
||||||
outputBuffer,
|
|
||||||
capturedTensor,
|
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
rewriter.getI32IntegerAttr(0),
|
|
||||||
getTensorSizeInBytesAttr(rewriter, capturedTensor))
|
|
||||||
.getOutput();
|
|
||||||
mapper.map(capturedTensor, copied);
|
|
||||||
return copied;
|
|
||||||
};
|
|
||||||
|
|
||||||
SmallVector<Value> hostOutputTensors(returnOperandIndices.size());
|
SmallVector<Value> hostOutputTensors(returnOperandIndices.size());
|
||||||
auto getOrCreateHostOutputTensor = [&](unsigned resultIndex, Location resultLoc) -> Value {
|
auto getOrCreateHostOutputTensor = [&](unsigned resultIndex, Location resultLoc) -> Value {
|
||||||
Value& hostOutputTensor = hostOutputTensors[resultIndex];
|
Value& hostOutputTensor = hostOutputTensors[resultIndex];
|
||||||
@@ -280,7 +261,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
|||||||
if (definingOp && definingOp->getBlock() == &oldBlock)
|
if (definingOp && definingOp->getBlock() == &oldBlock)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
materializeCapturedTensor(operand);
|
return computeBatchOp.emitOpError(
|
||||||
|
"expected external tensor communication to be materialized in Spatial before batch lowering");
|
||||||
}
|
}
|
||||||
|
|
||||||
Operation* cloned = rewriter.clone(op, mapper);
|
Operation* cloned = rewriter.clone(op, mapper);
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ add_pim_library(OMSpatialToPim
|
|||||||
ComputeLikeRegionUtils.cpp
|
ComputeLikeRegionUtils.cpp
|
||||||
CoreLoweringPatterns.cpp
|
CoreLoweringPatterns.cpp
|
||||||
GlobalTensorMaterialization.cpp
|
GlobalTensorMaterialization.cpp
|
||||||
PhaseVerification.cpp
|
|
||||||
ReturnPathNormalization.cpp
|
ReturnPathNormalization.cpp
|
||||||
TensorPackingPatterns.cpp
|
TensorPackingPatterns.cpp
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
LogicalResult verifySpatialToPimBoundary(ModuleOp moduleOp) {
|
|
||||||
bool hasFailure = false;
|
|
||||||
moduleOp.walk([&](Operation* op) {
|
|
||||||
if (op->getDialect()->getNamespace() != "spat")
|
|
||||||
return;
|
|
||||||
|
|
||||||
op->emitError("illegal Spatial operation remains after Spatial-to-PIM lowering");
|
|
||||||
hasFailure = true;
|
|
||||||
});
|
|
||||||
return success(!hasFailure);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
mlir::LogicalResult verifySpatialToPimBoundary(mlir::ModuleOp moduleOp);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -30,7 +30,6 @@
|
|||||||
#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||||
#include "Conversion/SpatialToPim/Common.hpp"
|
#include "Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
||||||
#include "Conversion/SpatialToPim/PhaseVerification.hpp"
|
|
||||||
#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||||
#include "Dialect/Pim/PimOps.hpp"
|
#include "Dialect/Pim/PimOps.hpp"
|
||||||
#include "Dialect/Spatial/SpatialOps.hpp"
|
#include "Dialect/Spatial/SpatialOps.hpp"
|
||||||
@@ -284,13 +283,6 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (failed(verifySpatialToPimBoundary(moduleOp))) {
|
|
||||||
moduleOp.emitError("Spatial-to-PIM boundary verification failed");
|
|
||||||
signalPassFailure();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
dumpModule(moduleOp, "pim0");
|
dumpModule(moduleOp, "pim0");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,6 +38,7 @@
|
|||||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -651,6 +652,11 @@ public:
|
|||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (failed(verifySpatialCommunicationInvariants(func))) {
|
||||||
|
func.emitOpError("merged Spatial communication invariant verification failed");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
emitMergeIrCounts("final-post-merge", func);
|
emitMergeIrCounts("final-post-merge", func);
|
||||||
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
|
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
|
||||||
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());
|
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());
|
||||||
|
|||||||
Reference in New Issue
Block a user