240 lines
7.9 KiB
C++
240 lines
7.9 KiB
C++
#include "llvm/Support/raw_os_ostream.h"
|
|
|
|
#include <filesystem>
|
|
#include <fstream>
|
|
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Compiler/CompilerOptions.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
std::string getOutputDir() {
|
|
if (outputBaseName.empty() || outputBaseName == "-")
|
|
return {};
|
|
|
|
size_t lastSlash = outputBaseName.find_last_of('/');
|
|
if (lastSlash == std::string::npos)
|
|
return ".";
|
|
return outputBaseName.substr(0, lastSlash);
|
|
}
|
|
|
|
void createDirectory(const std::string& directory) {
|
|
std::error_code errorCode;
|
|
std::filesystem::create_directories(directory, errorCode);
|
|
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
|
}
|
|
|
|
void dumpModule(ModuleOp moduleOp, const std::string& name) {
|
|
std::string outputDir = getOutputDir();
|
|
if (outputDir.empty())
|
|
return;
|
|
|
|
std::string dialectsDir = outputDir + "/dialects";
|
|
createDirectory(dialectsDir);
|
|
|
|
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
|
llvm::raw_os_ostream os(file);
|
|
os << *moduleOp;
|
|
os.flush();
|
|
file.close();
|
|
}
|
|
|
|
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
|
|
if (!moduleOp)
|
|
return failure();
|
|
|
|
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
|
|
if (entryPoints.size() > 1) {
|
|
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
|
|
return failure();
|
|
}
|
|
if (!entryPoints.empty()) {
|
|
auto entryPointAttr =
|
|
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
|
|
if (!entryPointAttr) {
|
|
entryPoints.front().emitOpError("is missing the entry point function attribute");
|
|
return failure();
|
|
}
|
|
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
|
|
if (!entryFunc) {
|
|
entryPoints.front().emitOpError("references an unknown entry function ")
|
|
<< entryPointAttr.getLeafReference().getValue();
|
|
return failure();
|
|
}
|
|
return entryFunc;
|
|
}
|
|
|
|
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
|
|
return mainGraphFunc;
|
|
|
|
SmallVector<func::FuncOp> nonExternalFuncs;
|
|
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
|
|
if (!funcOp.isExternal())
|
|
nonExternalFuncs.push_back(funcOp);
|
|
if (nonExternalFuncs.size() == 1)
|
|
return nonExternalFuncs.front();
|
|
|
|
moduleOp.emitError("could not resolve a unique PIM entry function");
|
|
return failure();
|
|
}
|
|
|
|
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
|
|
|
void markWeightAlways(Operation* op) {
|
|
assert(op && "expected valid op");
|
|
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
|
|
}
|
|
|
|
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
|
if (!moduleOp || !getGlobalOp)
|
|
return {};
|
|
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
|
}
|
|
|
|
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
|
|
|
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
|
if (!channelNewOp) {
|
|
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
|
return failure();
|
|
}
|
|
// channelNewOp should have two users: `op` and a
|
|
// `ChannelSendOp`/`ChannelReceiveOp`
|
|
auto channelUsers = channelNewOp->getUsers();
|
|
auto usersIterator = channelUsers.begin();
|
|
auto firstUser = *usersIterator;
|
|
usersIterator++;
|
|
if (usersIterator == channelUsers.end()) {
|
|
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
|
"only one found.");
|
|
channelNewOp->dump();
|
|
op->dump();
|
|
channelNewOp->getParentOp()->dump();
|
|
return failure();
|
|
}
|
|
auto secondUser = *usersIterator;
|
|
usersIterator++;
|
|
if (usersIterator != channelUsers.end()) {
|
|
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
|
"more than two found.");
|
|
return failure();
|
|
}
|
|
Operation* notOpUser;
|
|
if (firstUser == op) {
|
|
notOpUser = secondUser;
|
|
}
|
|
else if (secondUser == op) {
|
|
notOpUser = firstUser;
|
|
}
|
|
else {
|
|
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
|
"and one of them must be me, but"
|
|
"none of them is actually me.");
|
|
return failure();
|
|
}
|
|
|
|
if (opIsReceive) {
|
|
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
|
|
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
|
"me, the other is not a ChannelSendOp.");
|
|
return failure();
|
|
}
|
|
return notOpUser;
|
|
}
|
|
else {
|
|
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
|
|
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
|
"me, the other is not a ChannelReceiveOp.");
|
|
return failure();
|
|
}
|
|
return notOpUser;
|
|
}
|
|
}
|
|
|
|
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
|
SmallVector<int64_t> strides(shape.size(), 1);
|
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
|
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
|
return strides;
|
|
}
|
|
|
|
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
|
|
SmallVector<int64_t> indices(shape.size(), 0);
|
|
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
|
indices[dim] = linearIndex / stride;
|
|
linearIndex %= stride;
|
|
}
|
|
return indices;
|
|
}
|
|
|
|
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
|
|
int64_t linearIndex = 0;
|
|
for (auto [index, stride] : llvm::zip_equal(indices, strides))
|
|
linearIndex += index * stride;
|
|
return linearIndex;
|
|
}
|
|
|
|
int64_t getNumElements(ArrayRef<int64_t> shape) {
|
|
int64_t numElements = 1;
|
|
for (int64_t dim : shape)
|
|
numElements *= dim;
|
|
return numElements;
|
|
}
|
|
|
|
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
|
|
ArrayRef<int64_t> offsets,
|
|
ArrayRef<int64_t> sizes,
|
|
ArrayRef<int64_t> strides) {
|
|
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
|
return false;
|
|
|
|
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
|
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
|
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
|
|
|
auto firstNonZeroOffset = std::find_if(
|
|
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
|
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
|
return offset != 0;
|
|
});
|
|
|
|
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
|
auto [offset, size, dimension] = *firstNonZeroOffset;
|
|
if (size > dimension - offset)
|
|
return false;
|
|
++firstNonZeroOffset;
|
|
|
|
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
|
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
|
return size != 1;
|
|
}))
|
|
return false;
|
|
}
|
|
|
|
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
|
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
|
|
|
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
|
auto [size, dimension] = sizeAndShape;
|
|
return size != dimension;
|
|
});
|
|
|
|
if (firstDifferentSize != sizesAndShape.end()) {
|
|
++firstDifferentSize;
|
|
|
|
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
|
auto [size, _dimension] = sizeAndShape;
|
|
return size != 1;
|
|
}))
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
} // namespace onnx_mlir
|