Pim backend
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-04-21 11:12:44 +02:00
parent fbf898e11c
commit f4c6da8f10

View File

@@ -1,10 +1,10 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/Support/FileSystem.h" #include "llvm/Support/FileSystem.h"
#include "llvm/Support/JSON.h" #include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
@@ -12,13 +12,14 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include <string>
#include <utility>
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Common.hpp" #include "Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Compiler/CompilerPasses.hpp"
using namespace llvm; using namespace llvm;
using namespace mlir; using namespace mlir;
@@ -644,6 +645,93 @@ static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
return CompilerSuccess; return CompilerSuccess;
} }
llvm::DenseMap<pim::PimCoreOp, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
auto coreWeightsDirPath = outputDirPath + "/weights";
auto error = sys::fs::create_directory(coreWeightsDirPath);
assert(!error && "Error creating weights directory");
size_t indexFileName = 0;
int64_t xbarSize = crossbarSize.getValue();
llvm::DenseMap<pim::PimCoreOp, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) {
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
assert(!getGlobalOp && "Weight is not from a memref.get_global");
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(index));
assert(!globalOp && "Could not find memref.global");
}
auto initialValue = globalOp.getInitialValue();
if (!initialValue) {
coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(index));
assert(!initialValue && "memref.global has no initial value");
}
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr) {
coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(index));
assert(!denseAttr && "memref.global initial value is not dense");
}
if (mapGlobalOpToFileName.contains(globalOp)) {
auto& fileName = mapGlobalOpToFileName[globalOp];
std::pair<mlir::Value, std::string> weightToFile = {weight, fileName};
mapCoreWeightToFileName[coreOp].insert(weightToFile);
continue;
}
auto type = denseAttr.getType();
auto shape = type.getShape();
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8;
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
std::error_code errorCode;
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
assert(errorCode);
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t index = row * numCols + col;
APInt bits = denseAttr.getValues<APFloat>()[index].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
mapGlobalOpToFileName.insert({globalOp, newFileName});
mapCoreWeightToFileName[coreOp].insert({weight, newFileName});
}
}
return mapCoreWeightToFileName;
}
/// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses). /// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses).
static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp, static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
PimAcceleratorMemory& memory, PimAcceleratorMemory& memory,
@@ -729,6 +817,9 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
json::Object xbarsPerArrayGroup; json::Object xbarsPerArrayGroup;
size_t coreCount = 0; size_t coreCount = 0;
// Create Weight Folder
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
for (auto coreOp : funcOp.getOps<pim::PimCoreOp>()) { for (auto coreOp : funcOp.getOps<pim::PimCoreOp>()) {
auto coreId = coreOp.getCoreId(); auto coreId = coreOp.getCoreId();
coreCount++; coreCount++;
@@ -762,9 +853,21 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
return InvalidOutputFileAccess; return InvalidOutputFileAccess;
} }
auto& mapWeightToFile = mapCoreWeightToFileName[coreOp];
json::Array xbarsPerGroup; json::Array xbarsPerGroup;
if (auto err = writeCrossbarWeights(moduleOp, coreOp, coreWeightsDirPath, xbarsPerGroup)) for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) {
return err; xbarsPerGroup.push_back(index);
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
auto& fileName = mapWeightToFile[weight];
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) {
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") << "\nError:" << error.message()
<< '\n';
return InvalidOutputFileAccess;
}
}
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup); xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
} }