#include "llvm/Support/raw_os_ostream.h" #include #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { std::string getOutputDir() { if (outputBaseName.empty() || outputBaseName == "-") return {}; size_t lastSlash = outputBaseName.find_last_of('/'); if (lastSlash == std::string::npos) return "."; return outputBaseName.substr(0, lastSlash); } void createDirectory(const std::string& directory) { std::error_code errorCode; std::filesystem::create_directories(directory, errorCode); assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data()); } void dumpModule(ModuleOp moduleOp, const std::string& name) { std::string outputDir = getOutputDir(); if (outputDir.empty()) return; std::string dialectsDir = outputDir + "/dialects"; createDirectory(dialectsDir); std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out); llvm::raw_os_ostream os(file); os << *moduleOp; os.flush(); file.close(); } FailureOr getPimEntryFunc(ModuleOp moduleOp) { if (!moduleOp) return failure(); SmallVector entryPoints(moduleOp.getOps()); if (entryPoints.size() > 1) { moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size(); return failure(); } if (!entryPoints.empty()) { auto entryPointAttr = entryPoints.front()->getAttrOfType(ONNXEntryPointOp::getEntryPointFuncAttrName()); if (!entryPointAttr) { entryPoints.front().emitOpError("is missing the entry point function attribute"); return failure(); } auto entryFunc = moduleOp.lookupSymbol(entryPointAttr.getLeafReference().getValue()); if (!entryFunc) { entryPoints.front().emitOpError("references an unknown entry function ") << entryPointAttr.getLeafReference().getValue(); return failure(); } return entryFunc; } if (auto mainGraphFunc = moduleOp.lookupSymbol("main_graph")) return mainGraphFunc; SmallVector nonExternalFuncs; for (auto funcOp : moduleOp.getOps()) if (!funcOp.isExternal()) nonExternalFuncs.push_back(funcOp); if (nonExternalFuncs.size() == 1) return nonExternalFuncs.front(); moduleOp.emitError("could not resolve a unique PIM entry function"); return failure(); } bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; } void markWeightAlways(Operation* op) { assert(op && "expected valid op"); op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext())); } memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) { if (!moduleOp || !getGlobalOp) return {}; return moduleOp.lookupSymbol(getGlobalOp.getName()); } FailureOr getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) { auto channelNewOp = op->getOperand(0).getDefiningOp(); if (!channelNewOp) { op->emitError("User of Channel must have the first operand created by ChannelNewOp."); return failure(); } // channelNewOp should have two users: `op` and a // `ChannelSendOp`/`ChannelReceiveOp` auto channelUsers = channelNewOp->getUsers(); auto usersIterator = channelUsers.begin(); auto firstUser = *usersIterator; usersIterator++; if (usersIterator == channelUsers.end()) { op->emitError("Operand generated by ChannelNewOp must have two users, " "only one found."); channelNewOp->dump(); op->dump(); channelNewOp->getParentOp()->dump(); return failure(); } auto secondUser = *usersIterator; usersIterator++; if (usersIterator != channelUsers.end()) { op->emitError("Operand generated by ChannelNewOp must have two users, " "more than two found."); return failure(); } Operation* notOpUser; if (firstUser == op) { notOpUser = secondUser; } else if (secondUser == op) { notOpUser = firstUser; } else { op->emitError("Operand generated by ChannelNewOp must have two users, " "and one of them must be me, but" "none of them is actually me."); return failure(); } if (opIsReceive) { if (!isa(notOpUser)) { op->emitError("Operand generated by ChannelNewOp has two user, one is " "me, the other is not a ChannelSendOp."); return failure(); } return notOpUser; } else { if (!isa(notOpUser)) { op->emitError("Operand generated by ChannelNewOp has two user, one is " "me, the other is not a ChannelReceiveOp."); return failure(); } return notOpUser; } } SmallVector computeRowMajorStrides(ArrayRef shape) { SmallVector strides(shape.size(), 1); for (int64_t dim = static_cast(shape.size()) - 2; dim >= 0; --dim) strides[dim] = strides[dim + 1] * shape[dim + 1]; return strides; } SmallVector delinearizeIndex(int64_t linearIndex, ArrayRef shape, ArrayRef strides) { SmallVector indices(shape.size(), 0); for (auto [dim, stride] : llvm::enumerate(strides)) { indices[dim] = linearIndex / stride; linearIndex %= stride; } return indices; } int64_t linearizeIndex(ArrayRef indices, ArrayRef strides) { int64_t linearIndex = 0; for (auto [index, stride] : llvm::zip_equal(indices, strides)) linearIndex += index * stride; return linearIndex; } int64_t getNumElements(ArrayRef shape) { int64_t numElements = 1; for (int64_t dim : shape) numElements *= dim; return numElements; } bool isMemoryContiguous(ArrayRef srcShape, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; })) return false; auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()), llvm::make_range(sizes.rbegin(), sizes.rend()), llvm::make_range(srcShape.rbegin(), srcShape.rend())); auto firstNonZeroOffset = std::find_if( offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool { auto [offset, _size, _dimension] = offsetAndSizeAndShape; return offset != 0; }); if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) { auto [offset, size, dimension] = *firstNonZeroOffset; if (size > dimension - offset) return false; ++firstNonZeroOffset; if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool { auto [_offset, size, _dimension] = offsetAndSizeAndShape; return size != 1; })) return false; } auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()), llvm::make_range(srcShape.rbegin(), srcShape.rend())); auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool { auto [size, dimension] = sizeAndShape; return size != dimension; }); if (firstDifferentSize != sizesAndShape.end()) { ++firstDifferentSize; if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool { auto [size, _dimension] = sizeAndShape; return size != 1; })) return false; } return true; } } // namespace onnx_mlir