219 lines
8.9 KiB
C++
219 lines
8.9 KiB
C++
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
|
|
#include "Common/IR/WeightUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
namespace {
|
|
|
|
constexpr StringLiteral kPhaseMarker = "phase-check";
|
|
|
|
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
|
func.walk([&](Operation* op) {
|
|
if (!hasWeightAlways(op))
|
|
return;
|
|
|
|
for (Value result : op->getResults()) {
|
|
if (hasOnlySpatialMvmVmmWeightUses(result))
|
|
continue;
|
|
|
|
diagnostics.report(op, [&](Operation* illegalOp) {
|
|
illegalOp->emitOpError()
|
|
<< kPhaseMarker
|
|
<< " weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights";
|
|
});
|
|
return;
|
|
}
|
|
});
|
|
}
|
|
|
|
bool isRegionOrAncestorOf(Region& region, Region* candidate) {
|
|
return candidate && (®ion == candidate || region.isAncestor(candidate));
|
|
}
|
|
|
|
bool isValueDefinedInsideRegion(Value value, Region& region) {
|
|
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
|
return isRegionOrAncestorOf(region, blockArg.getOwner()->getParent());
|
|
if (Operation* definingOp = value.getDefiningOp())
|
|
return isRegionOrAncestorOf(region, definingOp->getParentRegion());
|
|
return false;
|
|
}
|
|
|
|
bool isLegalExternalCapture(Value value, Region& region) {
|
|
if (isValueDefinedInsideRegion(value, region))
|
|
return true;
|
|
|
|
Operation* definingOp = value.getDefiningOp();
|
|
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
|
|
}
|
|
|
|
template <typename ComputeOpTy>
|
|
void verifyComputeBodyCaptures(ComputeOpTy compute, StringRef kind, pim::CappedDiagnosticReporter& diagnostics) {
|
|
Region& body = compute.getBody();
|
|
body.walk([&](Operation* nestedOp) {
|
|
for (OpOperand& operand : nestedOp->getOpOperands()) {
|
|
Value value = operand.get();
|
|
if (isLegalExternalCapture(value, body))
|
|
continue;
|
|
|
|
Operation* definingOp = value.getDefiningOp();
|
|
diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) {
|
|
InFlightDiagnostic diag =
|
|
illegalOp->emitOpError() << kPhaseMarker << " " << kind << " body captures non-constant external operand #"
|
|
<< operand.getOperandNumber() << " used by " << nestedOp->getName().getStringRef();
|
|
diag << " (type " << value.getType() << ")";
|
|
if (definingOp)
|
|
diag.attachNote(definingOp->getLoc()) << "defining op is " << definingOp->getName().getStringRef();
|
|
else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
|
|
if (Operation* owner = blockArg.getOwner()->getParentOp())
|
|
diag.attachNote(owner->getLoc())
|
|
<< "external block argument belongs to " << owner->getName().getStringRef();
|
|
}
|
|
});
|
|
}
|
|
});
|
|
}
|
|
|
|
bool isLegalHostBackedValue(Value value) {
|
|
Operation* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return isa<BlockArgument>(value);
|
|
return definingOp->getDialect()->getNamespace() != "spat";
|
|
}
|
|
|
|
template <typename ComputeOpTy>
|
|
void verifyScheduledInputs(ComputeOpTy compute,
|
|
bool allowChannelReceiveInputs,
|
|
StringRef kind,
|
|
pim::CappedDiagnosticReporter& diagnostics) {
|
|
for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) {
|
|
Operation* definingOp = input.getDefiningOp();
|
|
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
|
|
continue;
|
|
if (isLegalHostBackedValue(input))
|
|
continue;
|
|
|
|
diagnostics.report(compute.getOperation(), [&](Operation* illegalOp) {
|
|
InFlightDiagnostic diag = illegalOp->emitOpError()
|
|
<< kPhaseMarker << " " << kind << " input #" << inputIndex
|
|
<< (allowChannelReceiveInputs ? " must come from the host or explicit spat.channel_receive"
|
|
: " must come from the host");
|
|
if (definingOp)
|
|
diag.attachNote(definingOp->getLoc()) << "illegal producer is " << definingOp->getName().getStringRef();
|
|
});
|
|
}
|
|
}
|
|
|
|
template <typename ComputeOpTy>
|
|
void verifyNoNestedFragmentAssemblyBlueprints(ComputeOpTy compute,
|
|
pim::CappedDiagnosticReporter& diagnostics) {
|
|
compute.getBody().walk([&](spatial::SpatBlueprintOp blueprint) {
|
|
std::optional<StringRef> mode = blueprint.getMode();
|
|
if (!mode || *mode != "fragment_assembly")
|
|
return;
|
|
diagnostics.report(blueprint.getOperation(), [&](Operation* illegalOp) {
|
|
illegalOp->emitOpError("fragment assembly blueprint must be host-level after merge materialization");
|
|
});
|
|
});
|
|
}
|
|
|
|
void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
|
|
for (Operation& op : funcOp.getOps()) {
|
|
if (isa<func::ReturnOp,
|
|
spatial::SpatGraphCompute,
|
|
spatial::SpatGraphComputeBatch,
|
|
spatial::SpatConv2DPlanOp,
|
|
spatial::SpatReluPlanOp,
|
|
spatial::SpatBlueprintOp,
|
|
spatial::SpatMaterializeLayoutOp>(&op)) {
|
|
continue;
|
|
}
|
|
if (isa<spatial::SpatScheduledCompute, spatial::SpatScheduledComputeBatch>(&op)) {
|
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
|
illegalOp->emitOpError() << kPhaseMarker << " scheduled Spatial compute op is not allowed in logical graph phase";
|
|
});
|
|
continue;
|
|
}
|
|
if (isa<spatial::SpatChannelReceiveOp, spatial::SpatChannelSendOp>(&op)) {
|
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
|
illegalOp->emitOpError() << kPhaseMarker
|
|
<< " explicit channel communication is not expected before merge materialization";
|
|
});
|
|
continue;
|
|
}
|
|
if (isCompileTimeOp(&op))
|
|
continue;
|
|
|
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
|
illegalOp->emitOpError()
|
|
<< kPhaseMarker << " non-foldable top-level runtime op remains in logical Spatial graph; lower it inside spat.graph_compute";
|
|
});
|
|
}
|
|
}
|
|
|
|
void verifyScheduledTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter& diagnostics) {
|
|
for (Operation& op : funcOp.getOps()) {
|
|
if (isa<spatial::SpatGraphCompute, spatial::SpatGraphComputeBatch>(&op)) {
|
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
|
illegalOp->emitOpError() << kPhaseMarker << " graph Spatial compute op remained after merge materialization";
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
LogicalResult verifyNoComputeBodyCaptures(func::FuncOp funcOp) {
|
|
pim::CappedDiagnosticReporter diagnostics;
|
|
for (auto compute : funcOp.getOps<spatial::SpatGraphCompute>())
|
|
verifyComputeBodyCaptures(compute, "graph_compute", diagnostics);
|
|
for (auto batch : funcOp.getOps<spatial::SpatGraphComputeBatch>())
|
|
verifyComputeBodyCaptures(batch, "graph_compute_batch", diagnostics);
|
|
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>())
|
|
verifyComputeBodyCaptures(compute, "scheduled_compute", diagnostics);
|
|
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>())
|
|
verifyComputeBodyCaptures(batch, "scheduled_compute_batch", diagnostics);
|
|
diagnostics.emitSuppressedSummary(funcOp, "compute body capture verification failed");
|
|
return success(!diagnostics.hasFailure());
|
|
}
|
|
|
|
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { return verifyLogicalSpatialGraphInvariants(funcOp); }
|
|
|
|
LogicalResult verifyLogicalSpatialGraphInvariants(func::FuncOp funcOp) {
|
|
pim::CappedDiagnosticReporter diagnostics;
|
|
verifyLogicalTopLevelOps(funcOp, diagnostics);
|
|
checkWeightUseChains(funcOp, diagnostics);
|
|
if (failed(verifyNoComputeBodyCaptures(funcOp)))
|
|
return failure();
|
|
diagnostics.emitSuppressedSummary(funcOp, "logical Spatial graph verification failed");
|
|
return success(!diagnostics.hasFailure());
|
|
}
|
|
|
|
LogicalResult verifyScheduledSpatialInvariants(func::FuncOp funcOp) {
|
|
pim::CappedDiagnosticReporter diagnostics;
|
|
verifyScheduledTopLevelOps(funcOp, diagnostics);
|
|
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>()) {
|
|
verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics);
|
|
verifyNoNestedFragmentAssemblyBlueprints(compute, diagnostics);
|
|
}
|
|
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>()) {
|
|
verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics);
|
|
verifyNoNestedFragmentAssemblyBlueprints(batch, diagnostics);
|
|
}
|
|
if (failed(verifyNoComputeBodyCaptures(funcOp)))
|
|
return failure();
|
|
diagnostics.emitSuppressedSummary(funcOp, "scheduled Spatial verification failed");
|
|
return success(!diagnostics.hasFailure());
|
|
}
|
|
|
|
} // namespace onnx_mlir
|