Files
Raptor/src/PIM/Compiler/PimCodeGen.cpp
2026-03-23 21:25:51 +01:00

743 lines
27 KiB
C++

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
#include <cmath>
#include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Compiler/CompilerPasses.hpp"
using namespace llvm;
using namespace mlir;
using namespace onnx_mlir;
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape());
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
MemEntry memEntry = {0, allocSize};
return &memEntries.emplace_back(memEntry, value).first;
}
void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
memEntry.address = firstAvailableAddress;
firstAvailableAddress += memEntry.size;
// Alignment
if (size_t remainder = firstAvailableAddress % minAlignment)
firstAvailableAddress += minAlignment - remainder;
globalMemEntriesMap[value] = memEntry;
}
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
// More than one SSA value per single global constant:
// Cannot call gatherMemEntry for each of them, otherwise memory will be allocated multiple times
// Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others
SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (!hasWeightAlways(getGlobalOp)) {
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
auto iter = globalConstants.find(globalMemrefOp);
if (iter == globalConstants.end())
globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp);
else {
MemEntry memEntry = *iter->second;
globalMemEntriesMap[getGlobalOp] = memEntry;
}
}
});
for (mlir::Value arg : funcOp.getArguments())
gatherMemEntry(arg);
allocateCore(funcOp);
}
void PimMemory::allocateCore(Operation* op) {
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; });
for (auto& [memEntry, value] : memEntries)
allocateMemoryForValue(value, memEntry);
}
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
auto iter = globalMemEntriesMap.find(value);
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
return iter->second;
}
PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
return deviceMem.try_emplace(id, memEntriesMap).first->second;
}
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
auto resolvedAddress = resolveContiguousAddress(value);
if (failed(resolvedAddress)) {
errs() << "Failed to resolve contiguous address for value: ";
value.print(errs());
errs() << "\n";
if (auto* definingOp = value.getDefiningOp()) {
errs() << "Defining op:\n";
definingOp->print(errs());
errs() << "\n";
}
llvm_unreachable("Failed to resolve contiguous address");
}
auto iter = memEntriesMap.find(resolvedAddress->base);
if (iter == memEntriesMap.end()) {
errs() << "Missing mem entry for value: ";
resolvedAddress->base.print(errs());
errs() << "\n";
if (auto* definingOp = resolvedAddress->base.getDefiningOp()) {
errs() << "Defining op:\n";
definingOp->print(errs());
errs() << "\n";
}
llvm_unreachable("Missing mem entry");
}
return iter->second.address + resolvedAddress->byteOffset;
}
json::Object PimCodeGen::createEmptyOffset() {
json::Object offset;
offset["offset_select"] = 0;
offset["offset_value"] = 0;
return offset;
}
void PimCodeGen::emitInstruction(json::Object instruction) const {
coreFileStream << json::Value(std::move(instruction)) << ',';
}
void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const {
json::Object json;
json["op"] = "sldi";
json["rd"] = registerNumber;
json["imm"] = immediate;
emitInstruction(std::move(json));
}
void PimCodeGen::setupRd(size_t rdAddress, size_t rdOffset) const {
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
}
void PimCodeGen::setupRdRs1(size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset) const {
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset);
}
void PimCodeGen::setupRdRs1Rs2(
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const {
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset);
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
}
void PimCodeGen::emitMemCopyOp(StringRef opName,
size_t rdAddr,
size_t rdOffset,
size_t rs1Addr,
size_t rs1Offset,
size_t size,
StringRef sizeFieldName) const {
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
json::Object json;
json["op"] = opName;
json["rd"] = 0;
json["rs1"] = 1;
json[sizeFieldName] = size;
json["offset"] = createEmptyOffset();
emitInstruction(std::move(json));
}
void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const {
setupRd(bufferAddr, 0);
json::Object json;
json["op"] = opName;
json["rd"] = 0;
json["core"] = coreId;
json["size"] = size;
json["offset"] = createEmptyOffset();
emitInstruction(std::move(json));
}
void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const {
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
json::Object json;
json["op"] = "mvmul";
json["rd"] = 0;
json["rs1"] = 1;
json["group"] = groupId;
json["relu"] = 0;
json["mbiw"] = 8;
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const {
emitMemCopyOp("ld",
memory.getValueAddress(loadOp.getDeviceTarget()),
loadOp.getDeviceTargetOffset(),
memory.getValueAddress(loadOp.getHostSource()),
loadOp.getHostSourceOffset(),
loadOp.getSize());
}
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const {
emitMemCopyOp("st",
memory.getValueAddress(storeOp.getHostTarget()),
storeOp.getHostTargetOffset(),
memory.getValueAddress(storeOp.getDeviceSource()),
storeOp.getDeviceSourceOffset(),
storeOp.getSize());
}
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const {
emitMemCopyOp("lmv",
memory.getValueAddress(lmvOp.getTarget()),
lmvOp.getTargetOffset(),
memory.getValueAddress(lmvOp.getSource()),
lmvOp.getSourceOffset(),
lmvOp.getSize(),
"len");
}
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const {
emitCommunicationOp(
"recv", memory.getValueAddress(receiveOp.getOutputBuffer()), receiveOp.getSourceCoreId(), receiveOp.getSize());
}
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp) const {
emitCommunicationOp("send", memory.getValueAddress(sendOp.getInput()), sendOp.getTargetCoreId(), sendOp.getSize());
}
template <typename MVMTy>
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix) {
emitMvmOp(
mvmId, memory.getValueAddress(mvmLikeOp.getOutputBuffer()), 0, memory.getValueAddress(mvmLikeOp.getInput()), 0);
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
}
static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp) const {
auto outputBufferAddr = memory.getValueAddress(vvaddOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvaddOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvaddOp.getRhs());
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
json["op"] = "vvadd";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvaddOp.getLhs());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp) const {
auto outputBufferAddr = memory.getValueAddress(vvsubOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvsubOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvsubOp.getRhs());
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
json["op"] = "vvsub";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvsubOp.getLhs());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp) const {
auto outputBufferAddr = memory.getValueAddress(vvmulOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvmulOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvmulOp.getRhs());
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
json["op"] = "vvmul";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvmulOp.getLhs());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp) const {
auto outputBufferAddr = memory.getValueAddress(vvmaxOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvmaxOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvmaxOp.getRhs());
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
json["op"] = "vvmax";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvmaxOp.getLhs());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp) const {
auto outputBufferAddr = memory.getValueAddress(vvdmulOp.getOutputBuffer());
auto lhsAddr = memory.getValueAddress(vvdmulOp.getLhs());
auto rhsAddr = memory.getValueAddress(vvdmulOp.getRhs());
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
json::Object json;
json["op"] = "vvdmul";
json["rd"] = 0;
json["rs1"] = 1;
json["rs2"] = 2;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vvdmulOp.getLhs());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp) const {
auto outputBufferAddr = memory.getValueAddress(vavgOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vavgOp.getInput());
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
json["op"] = "vavg";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vavgOp.getInput());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp) const {
auto outputBufferAddr = memory.getValueAddress(vreluOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vreluOp.getInput());
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
json["op"] = "vrelu";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vreluOp.getInput());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp) const {
auto outputBufferAddr = memory.getValueAddress(vtanhOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vtanhOp.getInput());
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
json["op"] = "vtanh";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vtanhOp.getInput());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp) const {
auto outputBufferAddr = memory.getValueAddress(vsigmOp.getOutputBuffer());
auto inputAddr = memory.getValueAddress(vsigmOp.getInput());
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
json::Object json;
json["op"] = "vsigm";
json["rd"] = 0;
json["rs1"] = 1;
json["offset"] = createEmptyOffset();
json["len"] = getValueSizeInBytes(vsigmOp.getInput());
emitInstruction(std::move(json));
}
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp) const {
auto srcAddr = memory.getValueAddress(transposeOp.getInput());
auto dstAddr = memory.getValueAddress(transposeOp.getOutputBuffer());
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
auto srcShape = srcType.getShape();
size_t rank = srcShape.size();
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
size_t totalElements = srcType.getNumElements();
// Read permutation. Destination dim i corresponds to source dim perm[i].
SmallVector<int64_t> perm = map_to_vector(transposeOp.getPermutation().getAsRange<IntegerAttr>(),
[](auto attr) -> int64_t { return attr.getInt(); });
// Destination shape: dstShape[i] = srcShape[perm[i]]
SmallVector<int64_t> dstShape(rank);
for (size_t i = 0; i < rank; i++)
dstShape[i] = srcShape[perm[i]];
// Row-major strides for source and destination
SmallVector<size_t> srcStrides(rank, 1);
SmallVector<size_t> dstStrides(rank, 1);
for (int64_t i = rank - 2; i >= 0; i--) {
srcStrides[i] = srcStrides[i + 1] * srcShape[i + 1];
dstStrides[i] = dstStrides[i + 1] * dstShape[i + 1];
}
// Emit element-by-element copy with transposed addressing
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
// Decompose flat source index into multi-dimensional index
SmallVector<size_t> srcIdx(rank);
size_t remaining = srcFlat;
for (size_t d = 0; d < rank; d++) {
srcIdx[d] = remaining / srcStrides[d];
remaining %= srcStrides[d];
}
// Compute flat destination index: dstIdx[d] = srcIdx[perm[d]]
size_t dstFlat = 0;
for (size_t d = 0; d < rank; d++)
dstFlat += srcIdx[perm[d]] * dstStrides[d];
emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len");
}
}
size_t getMatrixSize(ShapedType matrixShape) {
if (matrixShape.getRank() != 2 && matrixShape.getRank() != 4)
assert(false && "Unsupported matrix shape");
return std::max(matrixShape.getDimSize(0), matrixShape.getDimSize(1));
}
std::string getMemorySizeAsString(size_t size) {
if (size > 1024 * 1024 * 1024)
return std::to_string(size / 1024 / 1024 / 1024) + " GB";
if (size > 1024 * 1024)
return std::to_string(size / 1024 / 1024) + " MB";
if (size > 1024)
return std::to_string(size / 1024) + " KB";
return std::to_string(size) + " Bytes";
}
/// Write global constant data into a binary memory image at their allocated addresses.
static OnnxMlirCompilerErrorCodes
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
std::error_code errorCode;
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (hasWeightAlways(getGlobalOp))
return;
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp)
return;
auto initialValue = globalOp.getInitialValue();
if (!initialValue)
return;
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr)
return;
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
ArrayRef<char> rawData = denseAttr.getRawData();
char* dst = memoryBuffer.data() + memEntry.address;
if (denseAttr.isSplat()) {
size_t elementSize = rawData.size();
assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch");
for (size_t offset = 0; offset < memEntry.size; offset += elementSize)
std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset));
}
else {
assert(rawData.size() == memEntry.size && "Data size mismatch");
std::memcpy(dst, rawData.data(), rawData.size());
}
});
memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size());
memoryFileStream.close();
return CompilerSuccess;
}
/// Dispatch all operations in a core region to the appropriate code generator.
/// Returns the number of emitted instructions, or -1 on failure.
static int64_t codeGenCoreOps(pim::PimCoreOp coreOp, PimCodeGen& coreCodeGen) {
size_t processedOperations = 0;
for (auto& op : coreOp.getBody().front()) {
if (isa<memref::AllocOp, pim::PimHaltOp, memref::SubViewOp, memref::ExpandShapeOp, memref::CollapseShapeOp>(op))
continue;
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
coreCodeGen.codeGenLoadOp(loadOp);
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
coreCodeGen.codeGenStoreOp(storeOp);
else if (auto lmvOp = dyn_cast<pim::PimMemCopyOp>(op))
coreCodeGen.codeGenLmvOp(lmvOp);
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
coreCodeGen.codeGenReceiveOp(receiveOp);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true);
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp);
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
coreCodeGen.codeGenVVAddOp(vvaddOp);
else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
coreCodeGen.codeGenVVSubOp(vvsubOp);
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
coreCodeGen.codeGenVVMulOp(vvmulOp);
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
coreCodeGen.codeGenVVMaxOp(vvmaxOp);
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
coreCodeGen.codeGenVVDMulOp(vvdmulOp);
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
coreCodeGen.codeGenVAvgOp(vavgOp);
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
coreCodeGen.codeGenVReluOp(vreluOp);
else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
coreCodeGen.codeGenVTanhOp(vtanhOp);
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
coreCodeGen.codeGenVSigmOp(vsigmOp);
else {
op.emitError("Unsupported codegen for this operation");
op.dump();
return -1;
}
processedOperations++;
}
return processedOperations;
}
/// Write crossbar weight matrices as padded binary files for a single core.
static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
pim::PimCoreOp coreOp,
StringRef coreWeightsDirPath,
json::Array& xbarsPerGroup) {
int64_t xbarSize = crossbarSize.getValue();
std::error_code errorCode;
size_t weightIndex = 0;
for (auto weight : coreOp.getWeights()) {
xbarsPerGroup.push_back(weightIndex);
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto initialValue = globalOp.getInitialValue();
if (!initialValue) {
coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr) {
coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto type = denseAttr.getType();
auto shape = type.getShape();
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8;
auto weightFilePath = (coreWeightsDirPath + "/crossbar_" + std::to_string(weightIndex) + ".bin").str();
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t index = row * numCols + col;
APInt bits = denseAttr.getValues<APFloat>()[index].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
weightIndex++;
}
return CompilerSuccess;
}
/// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses).
static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
PimAcceleratorMemory& memory,
size_t coreCount,
json::Object xbarsPerArrayGroup,
StringRef outputDirPath) {
json::Object configJson;
configJson["core_cnt"] = coreCount;
// TODO: Should this be based on the floating point type used in the model?
// The 2 following values determine the bitwidth of the vectors' elements: bitwidth = adc_count * cell_precision
// Number of ADC for MVM units
configJson["adc_count"] = 16;
// The bit precision of each ADC
configJson["cell_precision"] = 2;
// Crossbar configuration
configJson["xbar_array_count"] = crossbarCountInCore.getValue();
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
// Memory layout of inputs and outputs
json::Array inputsAddresses;
for (BlockArgument input : funcOp.getArguments())
inputsAddresses.push_back(memory.getValueAddress(input));
configJson["inputs_addresses"] = std::move(inputsAddresses);
json::Array outputsAddresses;
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
for (mlir::Value output : returnOp.getOperands())
outputsAddresses.push_back(memory.getValueAddress(output));
configJson["outputs_addresses"] = std::move(outputsAddresses);
auto configPath = (outputDirPath + "/config.json").str();
std::error_code errorCode;
raw_fd_ostream jsonOS(configPath, errorCode);
if (errorCode) {
errs() << "Error while opening config file: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
jsonOS << json::Value(std::move(configJson)) << '\n';
jsonOS.close();
return CompilerSuccess;
}
OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::string& outputDirPath) {
if (!outputDirPath.empty()) {
if (auto error = sys::fs::create_directory(outputDirPath)) {
errs() << "Error creating output directory: " << outputDirPath << ": " << error.message() << '\n';
return InvalidOutputFileAccess;
}
}
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc))
return CompilerFailure;
auto funcOp = *entryFunc;
PimAcceleratorMemory memory;
memory.hostMem.allocateHost(moduleOp, funcOp);
if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath))
return err;
// Write empty host core file
std::error_code errorCode;
auto outputHostCorePath = outputDirPath + "/core_0.json";
raw_fd_ostream hostFileStream(outputHostCorePath, errorCode);
if (errorCode) {
errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
hostFileStream << "[]";
hostFileStream.close();
// For each core, specify the number of crossbar per array group.
// This implementation always assigns one crossbar per group.
json::Object xbarsPerArrayGroup;
size_t coreCount = 0;
for (auto coreOp : funcOp.getOps<pim::PimCoreOp>()) {
auto coreId = coreOp.getCoreId();
coreCount++;
std::error_code errorCode;
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
raw_fd_ostream coreFileStream(outputCorePath, errorCode);
if (errorCode) {
errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
coreFileStream << '[';
PimCodeGen coreCodeGen(memory, coreFileStream);
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
int64_t processedOperations = codeGenCoreOps(coreOp, coreCodeGen);
if (processedOperations < 0)
return CompilerFailure;
assert(processedOperations > 0);
// Remove trailing comma, close JSON array
coreFileStream.seek(coreFileStream.tell() - 1);
coreFileStream << ']';
coreFileStream.close();
// Write crossbar weights for this core
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
return InvalidOutputFileAccess;
}
json::Array xbarsPerGroup;
if (auto err = writeCrossbarWeights(moduleOp, coreOp, coreWeightsDirPath, xbarsPerGroup))
return err;
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
}
return writeConfigJson(funcOp, memory, coreCount, std::move(xbarsPerArrayGroup), outputDirPath);
}