This commit is contained in:
@@ -1,10 +1,10 @@
|
|||||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||||
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
|
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallSet.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/Support/FileSystem.h"
|
#include "llvm/Support/FileSystem.h"
|
||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
@@ -12,13 +12,14 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Compiler/CompilerPasses.hpp"
|
|
||||||
|
|
||||||
using namespace llvm;
|
using namespace llvm;
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -644,6 +645,93 @@ static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
|
|||||||
return CompilerSuccess;
|
return CompilerSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llvm::DenseMap<pim::PimCoreOp, llvm::DenseMap<mlir::Value, std::string>>
|
||||||
|
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||||
|
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
|
||||||
|
auto coreWeightsDirPath = outputDirPath + "/weights";
|
||||||
|
auto error = sys::fs::create_directory(coreWeightsDirPath);
|
||||||
|
assert(!error && "Error creating weights directory");
|
||||||
|
size_t indexFileName = 0;
|
||||||
|
|
||||||
|
int64_t xbarSize = crossbarSize.getValue();
|
||||||
|
llvm::DenseMap<pim::PimCoreOp, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
|
||||||
|
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
||||||
|
|
||||||
|
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
||||||
|
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) {
|
||||||
|
|
||||||
|
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||||
|
if (!getGlobalOp) {
|
||||||
|
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
|
||||||
|
assert(!getGlobalOp && "Weight is not from a memref.get_global");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
|
if (!globalOp) {
|
||||||
|
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(index));
|
||||||
|
assert(!globalOp && "Could not find memref.global");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto initialValue = globalOp.getInitialValue();
|
||||||
|
if (!initialValue) {
|
||||||
|
coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(index));
|
||||||
|
assert(!initialValue && "memref.global has no initial value");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
|
||||||
|
if (!denseAttr) {
|
||||||
|
coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(index));
|
||||||
|
assert(!denseAttr && "memref.global initial value is not dense");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mapGlobalOpToFileName.contains(globalOp)) {
|
||||||
|
auto& fileName = mapGlobalOpToFileName[globalOp];
|
||||||
|
std::pair<mlir::Value, std::string> weightToFile = {weight, fileName};
|
||||||
|
mapCoreWeightToFileName[coreOp].insert(weightToFile);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto type = denseAttr.getType();
|
||||||
|
auto shape = type.getShape();
|
||||||
|
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
|
||||||
|
int64_t numRows = shape[0];
|
||||||
|
int64_t numCols = shape[1];
|
||||||
|
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
||||||
|
|
||||||
|
size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8;
|
||||||
|
|
||||||
|
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
||||||
|
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
||||||
|
std::error_code errorCode;
|
||||||
|
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
|
||||||
|
if (errorCode) {
|
||||||
|
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
|
||||||
|
assert(errorCode);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t zero = 0;
|
||||||
|
for (int64_t row = 0; row < xbarSize; row++) {
|
||||||
|
for (int64_t col = 0; col < xbarSize; col++) {
|
||||||
|
if (row < numRows && col < numCols) {
|
||||||
|
int64_t index = row * numCols + col;
|
||||||
|
APInt bits = denseAttr.getValues<APFloat>()[index].bitcastToAPInt();
|
||||||
|
uint64_t word = bits.getZExtValue();
|
||||||
|
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
weightFileStream.close();
|
||||||
|
mapGlobalOpToFileName.insert({globalOp, newFileName});
|
||||||
|
mapCoreWeightToFileName[coreOp].insert({weight, newFileName});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return mapCoreWeightToFileName;
|
||||||
|
}
|
||||||
|
|
||||||
/// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses).
|
/// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses).
|
||||||
static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
||||||
PimAcceleratorMemory& memory,
|
PimAcceleratorMemory& memory,
|
||||||
@@ -729,6 +817,9 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
json::Object xbarsPerArrayGroup;
|
json::Object xbarsPerArrayGroup;
|
||||||
size_t coreCount = 0;
|
size_t coreCount = 0;
|
||||||
|
|
||||||
|
// Create Weight Folder
|
||||||
|
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
|
||||||
|
|
||||||
for (auto coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
for (auto coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
||||||
auto coreId = coreOp.getCoreId();
|
auto coreId = coreOp.getCoreId();
|
||||||
coreCount++;
|
coreCount++;
|
||||||
@@ -762,9 +853,21 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
return InvalidOutputFileAccess;
|
return InvalidOutputFileAccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto& mapWeightToFile = mapCoreWeightToFileName[coreOp];
|
||||||
json::Array xbarsPerGroup;
|
json::Array xbarsPerGroup;
|
||||||
if (auto err = writeCrossbarWeights(moduleOp, coreOp, coreWeightsDirPath, xbarsPerGroup))
|
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) {
|
||||||
return err;
|
xbarsPerGroup.push_back(index);
|
||||||
|
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
|
||||||
|
auto& fileName = mapWeightToFile[weight];
|
||||||
|
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
|
||||||
|
coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) {
|
||||||
|
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
|
||||||
|
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") << "\nError:" << error.message()
|
||||||
|
<< '\n';
|
||||||
|
return InvalidOutputFileAccess;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
|
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user