replace deprecated "rewriter.create()" pattern
refactor PIM to Pim everywhere except for the accelerator name
This commit is contained in:
146
src/PIM/Common/PimCommon.cpp
Normal file
146
src/PIM/Common/PimCommon.cpp
Normal file
@@ -0,0 +1,146 @@
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::string getOutputDir() { return outputBaseName.substr(0, outputBaseName.find_last_of('/')); }
|
||||
|
||||
void createDirectory(const std::string& directory) {
|
||||
std::error_code errorCode;
|
||||
std::filesystem::create_directories(directory, errorCode);
|
||||
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
||||
}
|
||||
|
||||
void dumpModule(ModuleOp moduleOp, const std::string& name) {
|
||||
std::string dialectsDir = getOutputDir() + "/dialects";
|
||||
createDirectory(dialectsDir);
|
||||
|
||||
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
||||
llvm::raw_os_ostream os(file);
|
||||
os << *moduleOp;
|
||||
os.flush();
|
||||
file.close();
|
||||
}
|
||||
|
||||
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
|
||||
if (entryPoints.size() > 1) {
|
||||
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
|
||||
return failure();
|
||||
}
|
||||
if (!entryPoints.empty()) {
|
||||
auto entryPointAttr =
|
||||
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
|
||||
if (!entryPointAttr) {
|
||||
entryPoints.front().emitOpError("is missing the entry point function attribute");
|
||||
return failure();
|
||||
}
|
||||
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
|
||||
if (!entryFunc) {
|
||||
entryPoints.front().emitOpError("references an unknown entry function ")
|
||||
<< entryPointAttr.getLeafReference().getValue();
|
||||
return failure();
|
||||
}
|
||||
return entryFunc;
|
||||
}
|
||||
|
||||
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
|
||||
return mainGraphFunc;
|
||||
|
||||
SmallVector<func::FuncOp> nonExternalFuncs;
|
||||
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
|
||||
if (!funcOp.isExternal())
|
||||
nonExternalFuncs.push_back(funcOp);
|
||||
if (nonExternalFuncs.size() == 1)
|
||||
return nonExternalFuncs.front();
|
||||
|
||||
moduleOp.emitError("could not resolve a unique PIM entry function");
|
||||
return failure();
|
||||
}
|
||||
|
||||
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
||||
|
||||
void markWeightAlways(Operation* op) {
|
||||
assert(op && "expected valid op");
|
||||
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
|
||||
}
|
||||
|
||||
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
||||
if (!moduleOp || !getGlobalOp)
|
||||
return {};
|
||||
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
||||
}
|
||||
|
||||
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
||||
return failure();
|
||||
}
|
||||
// channelNewOp should have two users: `op` and a
|
||||
// `ChannelSendOp`/`ChannelReceiveOp`
|
||||
auto channelUsers = channelNewOp->getUsers();
|
||||
auto usersIterator = channelUsers.begin();
|
||||
auto firstUser = *usersIterator;
|
||||
usersIterator++;
|
||||
if (usersIterator == channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||
"only one found.");
|
||||
channelNewOp->dump();
|
||||
op->dump();
|
||||
channelNewOp->getParentOp()->dump();
|
||||
return failure();
|
||||
}
|
||||
auto secondUser = *usersIterator;
|
||||
usersIterator++;
|
||||
if (usersIterator != channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||
"more than two found.");
|
||||
return failure();
|
||||
}
|
||||
Operation* notOpUser;
|
||||
if (firstUser == op) {
|
||||
notOpUser = secondUser;
|
||||
}
|
||||
else if (secondUser == op) {
|
||||
notOpUser = firstUser;
|
||||
}
|
||||
else {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||
"and one of them must be me, but"
|
||||
"none of them is actually me.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (opIsReceive) {
|
||||
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
||||
"me, the other is not a ChannelSendOp.");
|
||||
return failure();
|
||||
}
|
||||
return notOpUser;
|
||||
}
|
||||
else {
|
||||
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
||||
"me, the other is not a ChannelReceiveOp.");
|
||||
return failure();
|
||||
}
|
||||
return notOpUser;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user