158 lines
6.0 KiB
C++
158 lines
6.0 KiB
C++
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
|
|
#include "Common/IR/WeightUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
namespace {
|
|
|
|
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
|
func.walk([&](Operation* op) {
|
|
if (!hasWeightAlways(op))
|
|
return;
|
|
|
|
for (Value result : op->getResults()) {
|
|
if (hasOnlySpatialMvmVmmWeightUses(result))
|
|
continue;
|
|
|
|
diagnostics.report(op, [&](Operation* illegalOp) {
|
|
illegalOp->emitOpError(
|
|
"weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights");
|
|
});
|
|
return;
|
|
}
|
|
});
|
|
}
|
|
|
|
Region* getParentRegion(Value value) {
|
|
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
|
return blockArg.getOwner()->getParent();
|
|
if (Operation* definingOp = value.getDefiningOp())
|
|
return definingOp->getParentRegion();
|
|
return nullptr;
|
|
}
|
|
|
|
bool isDefinedInsideRegion(Value value, Region& region) {
|
|
Region* parentRegion = getParentRegion(value);
|
|
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
|
}
|
|
|
|
bool isLegalHostBackedValue(Value value) {
|
|
Operation* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return isa<BlockArgument>(value);
|
|
|
|
if (isa<spatial::SpatChannelReceiveOp>(definingOp))
|
|
return false;
|
|
|
|
return definingOp->getDialect()->getNamespace() != "spat";
|
|
}
|
|
|
|
LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp,
|
|
ValueRange inputs,
|
|
bool allowChannelReceiveInputs,
|
|
StringRef kind,
|
|
pim::CappedDiagnosticReporter& diagnostics) {
|
|
for (auto [inputIndex, input] : llvm::enumerate(inputs)) {
|
|
unsigned currentInputIndex = inputIndex;
|
|
Operation* definingOp = input.getDefiningOp();
|
|
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
|
|
continue;
|
|
if (isLegalHostBackedValue(input))
|
|
continue;
|
|
|
|
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) {
|
|
InFlightDiagnostic diagnostic = illegalOp->emitOpError()
|
|
<< kind << " input #" << currentInputIndex
|
|
<< (allowChannelReceiveInputs ? " must come from the host or an explicit "
|
|
"spat.channel_receive"
|
|
: " must come from the host");
|
|
if (definingOp)
|
|
diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName();
|
|
});
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void verifyNoExternalTensorCaptures(Operation* ownerOp,
|
|
Region& region,
|
|
StringRef kind,
|
|
pim::CappedDiagnosticReporter& diagnostics) {
|
|
region.walk([&](Operation* op) {
|
|
for (OpOperand& operand : op->getOpOperands()) {
|
|
Value value = operand.get();
|
|
if (!isa<TensorType>(value.getType()))
|
|
continue;
|
|
if (isDefinedInsideRegion(value, region) || isa<BlockArgument>(value))
|
|
continue;
|
|
|
|
Operation* definingOp = value.getDefiningOp();
|
|
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
|
|
continue;
|
|
|
|
diagnostics.report(ownerOp, [&](Operation* illegalOp) {
|
|
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor "
|
|
<< "values";
|
|
diagnostic.attachNote(op->getLoc())
|
|
<< "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by "
|
|
<< (definingOp ? definingOp->getName().getStringRef() : StringRef("<block argument>"));
|
|
});
|
|
}
|
|
});
|
|
}
|
|
|
|
} // namespace
|
|
|
|
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
|
pim::CappedDiagnosticReporter diagnostics;
|
|
|
|
for (Operation& op : funcOp.getOps()) {
|
|
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
|
continue;
|
|
if (isCompileTimeOp(&op))
|
|
continue;
|
|
|
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
|
illegalOp->emitOpError(
|
|
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
|
});
|
|
}
|
|
checkWeightUseChains(funcOp, diagnostics);
|
|
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
|
|
|
|
return success(!diagnostics.hasFailure());
|
|
}
|
|
|
|
LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) {
|
|
pim::CappedDiagnosticReporter diagnostics;
|
|
|
|
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
|
(void) verifyComputeLikeInputs(
|
|
computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics);
|
|
verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics);
|
|
}
|
|
|
|
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
|
(void) verifyComputeLikeInputs(computeBatchOp.getOperation(),
|
|
computeBatchOp.getInputs(),
|
|
/*allowChannelReceiveInputs=*/false,
|
|
"spat.compute_batch",
|
|
diagnostics);
|
|
verifyNoExternalTensorCaptures(
|
|
computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics);
|
|
}
|
|
|
|
diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed");
|
|
return success(!diagnostics.hasFailure());
|
|
}
|
|
|
|
} // namespace onnx_mlir
|