This commit is contained in:
@@ -1,6 +1,4 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "Common/IR/WeightUtils.hpp"
|
||||
@@ -13,17 +11,28 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void checkWeightsDirectlyExtracted(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
for (auto extractSlice : func.getOps<tensor::ExtractSliceOp>()) {
|
||||
auto source = getCompileTimeSource(extractSlice.getOperation());
|
||||
if (source && hasWeightAlways(source->source) && source->chainLength > 1) {
|
||||
namespace {
|
||||
|
||||
diagnostics.report(extractSlice.getOperation(),
|
||||
[](Operation* illegalOp) { illegalOp->emitOpError("Weight not directly extracted"); });
|
||||
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
func.walk([&](Operation* op) {
|
||||
if (!hasWeightAlways(op))
|
||||
return;
|
||||
|
||||
for (Value result : op->getResults()) {
|
||||
if (hasOnlySpatialMvmVmmWeightUses(result))
|
||||
continue;
|
||||
|
||||
diagnostics.report(op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError(
|
||||
"weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights");
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
@@ -38,9 +47,7 @@ LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
||||
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
||||
});
|
||||
}
|
||||
|
||||
checkWeightsDirectlyExtracted(funcOp, diagnostics);
|
||||
|
||||
checkWeightUseChains(funcOp, diagnostics);
|
||||
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
|
||||
|
||||
return success(!diagnostics.hasFailure());
|
||||
|
||||
Reference in New Issue
Block a user