diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 3e88a69..0d744da 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -1,10 +1,10 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" -#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/JSON.h" #include "llvm/Support/raw_ostream.h" @@ -12,13 +12,14 @@ #include #include #include +#include +#include #include "Common/PimCommon.hpp" #include "Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" -#include "src/Compiler/CompilerPasses.hpp" using namespace llvm; using namespace mlir; @@ -644,6 +645,93 @@ static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp, return CompilerSuccess; } +llvm::DenseMap> +createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) { + ModuleOp moduleOp = funcOp->getParentOfType(); + auto coreWeightsDirPath = outputDirPath + "/weights"; + auto error = sys::fs::create_directory(coreWeightsDirPath); + assert(!error && "Error creating weights directory"); + size_t indexFileName = 0; + + int64_t xbarSize = crossbarSize.getValue(); + llvm::DenseMap> mapCoreWeightToFileName; + llvm::DenseMap mapGlobalOpToFileName; + + for (pim::PimCoreOp coreOp : funcOp.getOps()) { + for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) { + + auto getGlobalOp = weight.getDefiningOp(); + if (!getGlobalOp) { + coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index)); + assert(!getGlobalOp && "Weight is not from a memref.get_global"); + } + + auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); + if (!globalOp) { + coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(index)); + assert(!globalOp && "Could not find memref.global"); + } + + auto initialValue = globalOp.getInitialValue(); + if (!initialValue) { + coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(index)); + assert(!initialValue && "memref.global has no initial value"); + } + + auto denseAttr = dyn_cast(*initialValue); + if (!denseAttr) { + coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(index)); + assert(!denseAttr && "memref.global initial value is not dense"); + } + + if (mapGlobalOpToFileName.contains(globalOp)) { + auto& fileName = mapGlobalOpToFileName[globalOp]; + std::pair weightToFile = {weight, fileName}; + mapCoreWeightToFileName[coreOp].insert(weightToFile); + continue; + } + + auto type = denseAttr.getType(); + auto shape = type.getShape(); + assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional"); + int64_t numRows = shape[0]; + int64_t numCols = shape[1]; + assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); + + size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8; + + std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin"; + auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str(); + std::error_code errorCode; + raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None); + if (errorCode) { + errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n'; + assert(errorCode); + } + + uint64_t zero = 0; + for (int64_t row = 0; row < xbarSize; row++) { + for (int64_t col = 0; col < xbarSize; col++) { + if (row < numRows && col < numCols) { + int64_t index = row * numCols + col; + APInt bits = denseAttr.getValues()[index].bitcastToAPInt(); + uint64_t word = bits.getZExtValue(); + weightFileStream.write(reinterpret_cast(&word), elementByteWidth); + } + else { + weightFileStream.write(reinterpret_cast(&zero), elementByteWidth); + } + } + } + + weightFileStream.close(); + mapGlobalOpToFileName.insert({globalOp, newFileName}); + mapCoreWeightToFileName[coreOp].insert({weight, newFileName}); + } + } + return mapCoreWeightToFileName; +} + /// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses). static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp, PimAcceleratorMemory& memory, @@ -729,6 +817,9 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: json::Object xbarsPerArrayGroup; size_t coreCount = 0; + // Create Weight Folder + auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath); + for (auto coreOp : funcOp.getOps()) { auto coreId = coreOp.getCoreId(); coreCount++; @@ -762,9 +853,21 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: return InvalidOutputFileAccess; } + auto& mapWeightToFile = mapCoreWeightToFileName[coreOp]; json::Array xbarsPerGroup; - if (auto err = writeCrossbarWeights(moduleOp, coreOp, coreWeightsDirPath, xbarsPerGroup)) - return err; + for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) { + xbarsPerGroup.push_back(index); + assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!"); + auto& fileName = mapWeightToFile[weight]; + if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName, + coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) { + errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to " + << (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") << "\nError:" << error.message() + << '\n'; + return InvalidOutputFileAccess; + } + } + xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup); }