#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" namespace onnx_mlir { llvm::FailureOr getPimEntryFunc(mlir::ModuleOp moduleOp) { if (!moduleOp) return mlir::failure(); llvm::SmallVector entryPoints(moduleOp.getOps()); if (entryPoints.size() > 1) { moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size(); return mlir::failure(); } if (!entryPoints.empty()) { auto entryPointAttr = entryPoints.front()->getAttrOfType(mlir::ONNXEntryPointOp::getEntryPointFuncAttrName()); if (!entryPointAttr) { entryPoints.front().emitOpError("is missing the entry point function attribute"); return mlir::failure(); } auto entryFunc = moduleOp.lookupSymbol(entryPointAttr.getLeafReference().getValue()); if (!entryFunc) { entryPoints.front().emitOpError("references an unknown entry function ") << entryPointAttr.getLeafReference().getValue(); return mlir::failure(); } return entryFunc; } if (auto mainGraphFunc = moduleOp.lookupSymbol("main_graph")) return mainGraphFunc; llvm::SmallVector nonExternalFuncs; for (auto funcOp : moduleOp.getOps()) if (!funcOp.isExternal()) nonExternalFuncs.push_back(funcOp); if (nonExternalFuncs.size() == 1) return nonExternalFuncs.front(); moduleOp.emitError("could not resolve a unique PIM entry function"); return mlir::failure(); } } // namespace onnx_mlir