46 lines
1.6 KiB
C++
46 lines
1.6 KiB
C++
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
namespace onnx_mlir {
|
|
|
|
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp) {
|
|
if (!moduleOp)
|
|
return mlir::failure();
|
|
|
|
llvm::SmallVector<mlir::ONNXEntryPointOp> entryPoints(moduleOp.getOps<mlir::ONNXEntryPointOp>());
|
|
if (entryPoints.size() > 1) {
|
|
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
|
|
return mlir::failure();
|
|
}
|
|
if (!entryPoints.empty()) {
|
|
auto entryPointAttr =
|
|
entryPoints.front()->getAttrOfType<mlir::SymbolRefAttr>(mlir::ONNXEntryPointOp::getEntryPointFuncAttrName());
|
|
if (!entryPointAttr) {
|
|
entryPoints.front().emitOpError("is missing the entry point function attribute");
|
|
return mlir::failure();
|
|
}
|
|
auto entryFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>(entryPointAttr.getLeafReference().getValue());
|
|
if (!entryFunc) {
|
|
entryPoints.front().emitOpError("references an unknown entry function ")
|
|
<< entryPointAttr.getLeafReference().getValue();
|
|
return mlir::failure();
|
|
}
|
|
return entryFunc;
|
|
}
|
|
|
|
if (auto mainGraphFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>("main_graph"))
|
|
return mainGraphFunc;
|
|
|
|
llvm::SmallVector<mlir::func::FuncOp> nonExternalFuncs;
|
|
for (auto funcOp : moduleOp.getOps<mlir::func::FuncOp>())
|
|
if (!funcOp.isExternal())
|
|
nonExternalFuncs.push_back(funcOp);
|
|
if (nonExternalFuncs.size() == 1)
|
|
return nonExternalFuncs.front();
|
|
|
|
moduleOp.emitError("could not resolve a unique PIM entry function");
|
|
return mlir::failure();
|
|
}
|
|
|
|
} // namespace onnx_mlir
|