#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/raw_ostream.h" #include #include "Common/Support/CheckedArithmetic.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp" using namespace llvm; using namespace mlir; namespace onnx_mlir { namespace {} // namespace WeightEmissionResult createAndPopulateWeightFolder(ArrayRef requests, StringRef outputDirPath) { auto coreWeightsDirPath = outputDirPath + "/weights"; auto error = sys::fs::create_directory(coreWeightsDirPath); assert(!error && "Error creating weights directory"); size_t indexFileName = 0; int64_t xbarSize = crossbarSize.getValue(); WeightEmissionResult result; llvm::SmallVector, 16> materializedWeights; auto materializeWeight = [&](const ResolvedWeightView& weightView) -> std::string { if (auto it = llvm::find_if(materializedWeights, [&](const auto& entry) { return entry.first == weightView; }); it != materializedWeights.end()) return it->second; auto globalOp = weightView.globalOp; auto denseAttr = mlir::dyn_cast(*globalOp.getInitialValue()); assert(denseAttr && "Weight global must have dense initial value"); ArrayRef shape = weightView.shape; assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional"); int64_t numRows = shape[0]; int64_t numCols = shape[1]; assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType()); std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin"; auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str(); std::error_code errorCode; raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None); if (errorCode) { errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n'; assert(errorCode); } uint64_t zero = 0; for (int64_t row = 0; row < xbarSize; row++) { for (int64_t col = 0; col < xbarSize; col++) { if (row < numRows && col < numCols) { int64_t elementIndex = weightView.offset + row * weightView.strides[0] + col * weightView.strides[1]; APInt bits = denseAttr.getValues()[elementIndex].bitcastToAPInt(); uint64_t word = bits.getZExtValue(); weightFileStream.write(reinterpret_cast(&word), elementByteWidth); } else { weightFileStream.write(reinterpret_cast(&zero), elementByteWidth); } } } weightFileStream.close(); materializedWeights.push_back({weightView, newFileName}); uint64_t weightBytes = pim::checkedMulOrCrash( pim::checkedMulOrCrash(static_cast(xbarSize), static_cast(xbarSize), "weight element count"), elementByteWidth, "weight byte size"); result.totalWeightBytes = pim::checkedAddOrCrash(result.totalWeightBytes, weightBytes, "total weight bytes"); return newFileName; }; for (const WeightFileRequest& request : requests) { auto& coreFiles = result.mapCoreWeightToFileName[request.coreId]; coreFiles.reserve(request.weights.size()); for (const ResolvedWeightView& weight : request.weights) coreFiles.push_back(materializeWeight(weight)); } return result; } } // namespace onnx_mlir