Files
Raptor/src/PIM/Compiler/PimCodeGen.cpp
T
NiccoloN ff36729140 centralize logic for materializing contiguous memory into bufferization
fix codegen symlinks overwrite
remove deprecated pim memcp_hd_batch op
2026-05-30 16:09:58 +02:00

1386 lines
55 KiB
C++

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Threading.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h"
#include <absl/types/compare.h>
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <fstream>
#include <memory>
#include <string>
#include <utility>
#include "Common/IR/CompactAsmUtils.hpp"
#include "Common/PimCommon.hpp"
#include "Common/Support/ReportUtils.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm;
using namespace mlir;
using namespace onnx_mlir;
using namespace onnx_mlir::compact_asm;
namespace {
static std::optional<unsigned> getLaneForMemoryValue(mlir::Value value, std::optional<unsigned> lane) {
if (!lane)
return std::nullopt;
auto allocOp = value.getDefiningOp<memref::AllocOp>();
if (!allocOp || !allocOp->getParentOfType<pim::PimCoreBatchOp>())
return std::nullopt;
return lane;
}
static mlir::Value resolveCachedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) {
auto iter = knowledge.aliases.find(value);
while (iter != knowledge.aliases.end()) {
value = iter->second;
iter = knowledge.aliases.find(value);
}
return value;
}
static MemoryValueKey getMemoryValueKey(mlir::Value value, std::optional<unsigned> lane = std::nullopt) {
return {value, getLaneForMemoryValue(value, lane)};
}
} // namespace
MemEntry* PimMemory::gatherMemEntry(mlir::Value value, std::optional<unsigned> lane) {
auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape());
size_t allocSize = getShapedTypeSizeInBytes(type);
MemEntry memEntry = {0, allocSize};
return &memEntries.emplace_back(memEntry, getMemoryValueKey(value, lane)).first;
}
void PimMemory::allocateGatheredMemory() {
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; });
for (auto& [memEntry, key] : memEntries)
allocateMemoryForValue(key, memEntry);
memEntries.clear();
}
void PimMemory::allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry) {
memEntry.address = firstAvailableAddress;
firstAvailableAddress += memEntry.size;
// Alignment
if (size_t remainder = firstAvailableAddress % minAlignment)
firstAvailableAddress += minAlignment - remainder;
ownedMemEntriesMap[key] = memEntry;
globalMemEntriesMap[key] = memEntry;
}
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
SmallVector<mlir::Value> args;
for (mlir::Value arg : funcOp.getArguments()) {
gatherMemEntry(arg);
args.push_back(arg);
}
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (!hasWeightAlways(getGlobalOp)) {
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (globalMemrefOp.getName().starts_with("arg")) {
StringRef indexStr = globalMemrefOp.getName().substr(4);
int index = 0;
llvm::to_integer(indexStr, index, 10);
globalAliases.push_back({getGlobalOp.getResult(), args[index]});
}
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
if (inserted)
gatherMemEntry(getGlobalOp.getResult());
else
globalAliases.push_back({getGlobalOp.getResult(), iter->second});
}
});
funcOp.walk([&](memref::AllocOp allocOp) {
if (!allocOp->getParentOfType<pim::PimCoreOp>())
gatherMemEntry(allocOp.getResult());
});
allocateGatheredMemory();
for (auto [alias, original] : globalAliases)
globalMemEntriesMap[getMemoryValueKey(alias)] = getMemEntry(getMemoryValueKey(original));
}
void PimMemory::allocateCore(Operation* op, std::optional<unsigned> lane) {
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp, lane); });
allocateGatheredMemory();
}
static void printHostMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
llvm::SmallVector<ReportField, 2> fields = {
{"Number of globals", std::to_string(row.numGlobal) },
{"Global memory", formatReportMemory(row.sizeGlobal)}
};
printReportFlatFields(os, fields);
}
static void printCoreMemoryReportRow(raw_ostream& os, const MemoryReportEntry& entry) {
llvm::SmallVector<ReportField, 2> fields = {
{"Number of allocas", std::to_string(entry.row.numAlloca) },
{"Allocated memory", formatReportMemory(entry.row.sizeAlloca)}
};
printReportFlatFields(os, fields);
}
static void printBatchMemoryReportRow(raw_ostream& os, const MemoryReportEntry& entry) {
llvm::SmallVector<ReportField, 2> perCoreFields = {
{"Number of allocas", std::to_string(entry.row.numAlloca) },
{"Allocated memory", formatReportMemory(entry.row.sizeAlloca)}
};
llvm::SmallVector<ReportField, 2> totalFields = {
{"Number of allocas", std::to_string(entry.totalAllocaCount) },
{"Batch memory", formatReportMemory(entry.totalAllocaBytes)}
};
printReportPerCoreAndTotalFields(os, perCoreFields, totalFields);
}
static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const MemoryReportRow& rhs) {
MemoryReportRow result = lhs;
result.numAlloca += rhs.numAlloca;
result.sizeAlloca += rhs.sizeAlloca;
result.numGlobal += rhs.numGlobal;
result.sizeGlobal += rhs.sizeGlobal;
return result;
}
MemoryReportRow PimMemory::getReportRow() const {
MemoryReportRow row;
for (auto& [key, memEntry] : ownedMemEntriesMap) {
if (auto op = key.value.getDefiningOp()) {
if (isa<memref::AllocOp>(op)) {
row.numAlloca++;
row.sizeAlloca += memEntry.size;
}
if (isa<memref::GetGlobalOp>(op)) {
row.numGlobal++;
row.sizeGlobal += memEntry.size;
}
}
}
return row;
}
void PimMemory::remove(mlir::Value val) {
for (auto it = ownedMemEntriesMap.begin(); it != ownedMemEntriesMap.end();)
if (it->first.value == val) {
auto eraseIt = it++;
ownedMemEntriesMap.erase(eraseIt);
}
else
++it;
for (auto it = globalMemEntriesMap.begin(); it != globalMemEntriesMap.end();)
if (it->first.value == val) {
auto eraseIt = it++;
globalMemEntriesMap.erase(eraseIt);
}
else
++it;
}
MemEntry PimMemory::getMemEntry(const MemoryValueKey& key) const {
auto iter = globalMemEntriesMap.find(key);
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
return iter->second;
}
PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
return deviceMem.try_emplace(id, memEntriesMap).first->second;
}
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value,
const StaticValueKnowledge& knowledge,
std::optional<unsigned> lane) const {
value = resolveCachedAlias(value, knowledge);
auto compiledIt = compiledAddressExprs.find(value);
if (compiledIt == compiledAddressExprs.end()) {
auto compiledExpr = compileContiguousAddressExpr(value);
if (failed(compiledExpr)) {
errs() << "Failed to compile contiguous address for value: ";
value.print(errs());
errs() << "\n";
llvm_unreachable("Failed to compile contiguous address");
}
compiledIt = compiledAddressExprs.try_emplace(value, *compiledExpr).first;
}
auto resolvedAddress = compiledIt->second.evaluate(knowledge, lane);
if (failed(resolvedAddress)) {
errs() << "Failed to evaluate contiguous address for value: ";
value.print(errs());
errs() << "\n";
if (auto* definingOp = value.getDefiningOp()) {
errs() << "Defining op:\n";
definingOp->print(errs());
errs() << "\n";
}
llvm_unreachable("Failed to resolve contiguous address");
}
MemoryValueKey key = getMemoryValueKey(resolvedAddress->base, lane);
auto iter = memEntriesMap.find(key);
if (iter == memEntriesMap.end()) {
errs() << "Missing mem entry for value: ";
resolvedAddress->base.print(errs());
errs() << "\n";
if (key.lane)
errs() << "Lane: " << *key.lane << "\n";
if (auto* definingOp = resolvedAddress->base.getDefiningOp()) {
errs() << "Defining op:\n";
definingOp->print(errs());
errs() << "\n";
}
llvm_unreachable("Missing mem entry");
}
return iter->second.address + resolvedAddress->byteOffset;
}
llvm::FailureOr<int64_t> PimAcceleratorMemory::getIndexValue(mlir::Value value,
const StaticValueKnowledge& knowledge) const {
value = resolveCachedAlias(value, knowledge);
auto compiledIt = compiledIndexExprs.find(value);
if (compiledIt == compiledIndexExprs.end()) {
auto compiledExpr = compileIndexExpr(value);
if (failed(compiledExpr))
return mlir::failure();
compiledIt = compiledIndexExprs.try_emplace(value, *compiledExpr).first;
}
return compiledIt->second.evaluate(knowledge);
}
void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); }
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
reportEntries.push_back(
{MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row, row.numAlloca, row.sizeAlloca});
}
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId,
ArrayRef<int32_t> coreIds,
const MemoryReportRow& perCoreRow,
uint64_t totalAllocaCount,
uint64_t totalAllocaBytes) {
MemoryReportEntry entry;
entry.kind = MemoryReportEntry::Kind::Batch;
entry.id = batchId;
llvm::append_range(entry.coreIds, coreIds);
entry.row = perCoreRow;
entry.totalAllocaCount = totalAllocaCount;
entry.totalAllocaBytes = totalAllocaBytes;
reportEntries.push_back(std::move(entry));
}
void PimAcceleratorMemory::flushReport() {
if (!fileReport.is_open())
return;
llvm::raw_os_ostream os(fileReport);
uint64_t totalGlobalMemory = hostReportRow.has_value() ? hostReportRow->sizeGlobal : 0;
uint64_t totalCoresMemory = 0;
for (const MemoryReportEntry& entry : reportEntries)
totalCoresMemory += entry.totalAllocaBytes;
llvm::SmallVector<ReportField, 2> totalFields = {
{"Global memory", formatReportMemory(totalGlobalMemory)},
{"Cores memory", formatReportMemory(totalCoresMemory) }
};
printReportTotalsBlock(os, totalFields);
if (hostReportRow.has_value()) {
os << "\nHost:\n";
printHostMemoryReportRow(os, *hostReportRow);
}
if (!reportEntries.empty()) {
if (hostReportRow.has_value())
os << "\n";
sortReportEntriesByFirstCore(reportEntries);
for (size_t index = 0; index < reportEntries.size();) {
size_t runEnd = index + 1;
while (runEnd < reportEntries.size() && reportEntries[runEnd].kind == reportEntries[index].kind
&& reportEntries[runEnd].row == reportEntries[index].row
&& reportEntries[runEnd].totalAllocaCount == reportEntries[index].totalAllocaCount
&& reportEntries[runEnd].totalAllocaBytes == reportEntries[index].totalAllocaBytes) {
++runEnd;
}
if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch) {
os << "Batch ";
for (size_t batchIndex = index; batchIndex < runEnd; ++batchIndex) {
if (batchIndex != index)
os << ",\n ";
os << reportEntries[batchIndex].id << " (cores ";
printCompressedIntegerEntries(os, ArrayRef<int32_t>(reportEntries[batchIndex].coreIds));
os << ")";
}
}
else {
llvm::SmallVector<int32_t, 8> coreIds;
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
coreIds.push_back(reportEntries[coreIndex].coreIds.front());
os << "Core ";
printCompressedIntegerEntries(os, ArrayRef<int32_t>(coreIds));
}
os << ":\n";
if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch)
printBatchMemoryReportRow(os, reportEntries[index]);
else
printCoreMemoryReportRow(os, reportEntries[index]);
printReportEntrySeparator(os, runEnd < reportEntries.size());
index = runEnd;
}
}
os.flush();
fileReport.close();
}
void PimAcceleratorMemory::clean(mlir::Operation* op) {
for (auto value : op->getResults()) {
hostMem.remove(value);
for (auto& device : deviceMem)
device.second.remove(value);
}
}
size_t PimCodeGen::remapCoreId(size_t coreId) const {
auto it = emittedCoreIds.find(coreId);
assert(it != emittedCoreIds.end() && "Missing emitted core id remapping");
return it->second;
}
void PimCodeGen::emitInstruction(const pim_binary::InstructionRecord& instruction) const {
pim_binary::writeInstructionRecord(coreBinaryStream, instruction);
++emittedInstructionCount;
if (coreJsonStream)
*coreJsonStream << json::Value(pim_binary::makeInstructionJson(instruction)) << ',';
}
void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const {
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::sldi;
instruction.rd = static_cast<uint8_t>(registerNumber);
instruction.r2OrImm = static_cast<int32_t>(immediate);
emitInstruction(instruction);
}
void PimCodeGen::setupRd(size_t rdAddress, size_t rdOffset) const {
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
}
void PimCodeGen::setupRdRs1(size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset) const {
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset);
}
void PimCodeGen::setupRdRs1Rs2(
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const {
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset);
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
}
void PimCodeGen::emitMemCopyOp(StringRef opName,
size_t rdAddr,
size_t rdOffset,
size_t rs1Addr,
size_t rs1Offset,
size_t size,
StringRef sizeFieldName) const {
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::opcodeFromString(opName);
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic1 = 0;
instruction.generic2 = 0;
instruction.generic3 = static_cast<int32_t>(size);
(void) sizeFieldName;
emitInstruction(instruction);
}
void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const {
setupRd(bufferAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::opcodeFromString(opName);
instruction.rd = 0;
instruction.r2OrImm = static_cast<int32_t>(remapCoreId(coreId));
instruction.generic1 = 0;
instruction.generic2 = 0;
instruction.generic3 = static_cast<int32_t>(size);
emitInstruction(instruction);
}
void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const {
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::mvmul;
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 8;
instruction.generic1 = 0;
instruction.generic2 = static_cast<int32_t>(groupId);
emitInstruction(instruction);
}
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const {
auto deviceTargetOffset = indexOf(loadOp.getDeviceTargetOffset(), knowledge);
auto hostSourceOffset = indexOf(loadOp.getHostSourceOffset(), knowledge);
assert(succeeded(deviceTargetOffset) && succeeded(hostSourceOffset)
&& "pim.memcp_hd offsets must be statically resolvable during codegen");
emitMemCopyOp("ld",
addressOf(loadOp.getDeviceTarget(), knowledge),
*deviceTargetOffset,
addressOf(loadOp.getHostSource(), knowledge),
*hostSourceOffset,
loadOp.getSize());
}
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
auto hostTargetOffset = indexOf(storeOp.getHostTargetOffset(), knowledge);
auto deviceSourceOffset = indexOf(storeOp.getDeviceSourceOffset(), knowledge);
assert(succeeded(hostTargetOffset) && succeeded(deviceSourceOffset)
&& "pim.memcp_dh offsets must be statically resolvable during codegen");
emitMemCopyOp("st",
addressOf(storeOp.getHostTarget(), knowledge),
*hostTargetOffset,
addressOf(storeOp.getDeviceSource(), knowledge),
*deviceSourceOffset,
storeOp.getSize());
}
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const {
emitMemCopyOp("lmv",
addressOf(lmvOp.getTarget(), knowledge),
lmvOp.getTargetOffset(),
addressOf(lmvOp.getSource(), knowledge),
lmvOp.getSourceOffset(),
lmvOp.getSize(),
"len");
}
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const {
auto sourceCoreId = indexOf(receiveOp.getSourceCoreId(), knowledge);
assert(succeeded(sourceCoreId) && "pim.receive source core id must be statically resolvable during codegen");
emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize());
}
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
auto targetCoreId = indexOf(sendOp.getTargetCoreId(), knowledge);
assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen");
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize());
}
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
auto outputType = cast<ShapedType>(concatOp.getOutputBuffer().getType());
assert(outputType.hasStaticShape() && "concat codegen requires static output shape");
int64_t axis = concatOp.getAxis();
ArrayRef<int64_t> outputShape = outputType.getShape();
size_t elementSize = getElementTypeSizeInBytes(outputType.getElementType());
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
size_t outerCount = 1;
for (int64_t dim = 0; dim < axis; ++dim)
outerCount *= static_cast<size_t>(outputShape[dim]);
size_t innerCount = 1;
for (size_t dim = static_cast<size_t>(axis) + 1; dim < outputShape.size(); ++dim)
innerCount *= static_cast<size_t>(outputShape[dim]);
size_t outputConcatDim = static_cast<size_t>(outputShape[axis]);
size_t concatOffset = 0;
for (mlir::Value input : concatOp.getInputs()) {
auto inputType = cast<ShapedType>(input.getType());
assert(inputType.hasStaticShape() && "concat codegen requires static input shapes");
size_t inputConcatDim = static_cast<size_t>(inputType.getDimSize(axis));
size_t blockSizeInBytes = inputConcatDim * innerCount * elementSize;
size_t inputAddr = addressOf(input, knowledge);
for (size_t outerIndex = 0; outerIndex < outerCount; ++outerIndex) {
size_t dstOffset = (outerIndex * outputConcatDim + concatOffset) * innerCount * elementSize;
size_t srcOffset = outerIndex * inputConcatDim * innerCount * elementSize;
emitMemCopyOp("lmv", outputAddr, dstOffset, inputAddr, srcOffset, blockSizeInBytes, "len");
}
concatOffset += inputConcatDim;
}
}
template <typename MVMTy>
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
MVMTy mvmLikeOp,
bool transposeMatrix,
const StaticValueKnowledge& knowledge) {
emitMvmOp(mvmId, addressOf(mvmLikeOp.getOutputBuffer(), knowledge), 0, addressOf(mvmLikeOp.getInput(), knowledge), 0);
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
}
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvaddOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvaddOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvaddOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vvadd;
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvaddOp.getLhs().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvsubOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvsubOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvsubOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vvsub;
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvsubOp.getLhs().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvmulOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvmulOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvmulOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vvmul;
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmulOp.getLhs().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvmaxOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvmaxOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvmaxOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vvmax;
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmaxOp.getLhs().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvdmulOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvdmulOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvdmulOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vvdmul;
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvdmulOp.getLhs().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vavgOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vavgOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vavg;
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 1;
instruction.generic1 = 1;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vavgOp.getInput().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vreluOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vreluOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vrelu;
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vreluOp.getInput().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vtanhOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vtanhOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vtanh;
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vtanhOp.getInput().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vsigmOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vsigmOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vsigm;
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsigmOp.getInput().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vsoftmaxOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vsoftmaxOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vsoftmax;
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic3 =
static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsoftmaxOp.getInput().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGetGlobalOp(memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const {}
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
auto srcShape = srcType.getShape();
size_t rank = srcShape.size();
size_t elementSize = getElementTypeSizeInBytes(srcType.getElementType());
size_t totalElements = srcType.getNumElements();
// Read permutation. Destination dim i corresponds to source dim perm[i].
SmallVector<int64_t> perm = map_to_vector(transposeOp.getPermutation().getAsRange<IntegerAttr>(),
[](auto attr) -> int64_t { return attr.getInt(); });
// Destination shape: dstShape[i] = srcShape[perm[i]]
SmallVector<int64_t> dstShape(rank);
for (size_t i = 0; i < rank; i++)
dstShape[i] = srcShape[perm[i]];
// Row-major strides for source and destination
SmallVector<size_t> srcStrides(rank, 1);
SmallVector<size_t> dstStrides(rank, 1);
for (int64_t i = rank - 2; i >= 0; i--) {
srcStrides[i] = srcStrides[i + 1] * srcShape[i + 1];
dstStrides[i] = dstStrides[i + 1] * dstShape[i + 1];
}
bool storagePreserving = true;
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
SmallVector<size_t> srcIdx(rank);
size_t remaining = srcFlat;
for (size_t d = 0; d < rank; d++) {
srcIdx[d] = remaining / srcStrides[d];
remaining %= srcStrides[d];
}
size_t dstFlat = 0;
for (size_t d = 0; d < rank; d++)
dstFlat += srcIdx[perm[d]] * dstStrides[d];
if (dstFlat != srcFlat) {
storagePreserving = false;
break;
}
}
if (storagePreserving) {
emitMemCopyOp("lmv", dstAddr, 0, srcAddr, 0, totalElements * elementSize, "len");
return;
}
// Emit element-by-element copy with transposed addressing
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
// Decompose flat source index into multi-dimensional index
SmallVector<size_t> srcIdx(rank);
size_t remaining = srcFlat;
for (size_t d = 0; d < rank; d++) {
srcIdx[d] = remaining / srcStrides[d];
remaining %= srcStrides[d];
}
// Compute flat destination index: dstIdx[d] = srcIdx[perm[d]]
size_t dstFlat = 0;
for (size_t d = 0; d < rank; d++)
dstFlat += srcIdx[perm[d]] * dstStrides[d];
emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len");
}
}
size_t getMatrixSize(ShapedType matrixShape) {
if (matrixShape.getRank() != 2 && matrixShape.getRank() != 4)
assert(false && "Unsupported matrix shape");
return std::max(matrixShape.getDimSize(0), matrixShape.getDimSize(1));
}
std::string getMemorySizeAsString(size_t size) {
if (size > 1024 * 1024 * 1024)
return std::to_string(size / 1024 / 1024 / 1024) + " GB";
if (size > 1024 * 1024)
return std::to_string(size / 1024 / 1024) + " MB";
if (size > 1024)
return std::to_string(size / 1024) + " KB";
return std::to_string(size) + " Bytes";
}
static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
SmallVector<Operation*> coreLikeOps;
for (Operation& op : funcOp.getBody().front())
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
coreLikeOps.push_back(&op);
return coreLikeOps;
}
struct CoreEmissionResult {
OnnxMlirCompilerErrorCodes status = CompilerSuccess;
MemoryReportRow reportRow;
llvm::SmallVector<ResolvedWeightView, 8> usedWeights;
};
template <typename MapTy>
class ScopedMapBindings {
using KeyTy = typename MapTy::key_type;
using ValueTy = typename MapTy::mapped_type;
MapTy& map;
llvm::SmallVector<std::pair<KeyTy, std::optional<ValueTy>>, 8> savedEntries;
public:
explicit ScopedMapBindings(MapTy& map)
: map(map) {}
void bind(const KeyTy& key, const ValueTy& value) {
auto it = map.find(key);
if (it == map.end())
savedEntries.emplace_back(key, std::nullopt);
else
savedEntries.emplace_back(key, it->second);
map[key] = value;
}
~ScopedMapBindings() {
for (auto it = savedEntries.rbegin(); it != savedEntries.rend(); ++it)
if (it->second)
map[it->first] = *it->second;
else
map.erase(it->first);
}
};
enum class CompiledCoreOpKind : uint8_t {
Load,
Store,
Lmv,
Receive,
Send,
Concat,
Vmm,
Transpose,
VVAdd,
VVSub,
VVMul,
VVMax,
VVDMul,
VAvg,
VRelu,
VTanh,
VSigm,
VSoftmax,
GetGlobal
};
struct CompiledCoreNode {
enum class Kind : uint8_t {
Op,
Loop
};
Kind kind = Kind::Op;
Operation* op = nullptr;
CompiledCoreOpKind opKind = CompiledCoreOpKind::Load;
CompiledIndexExpr lowerBound;
CompiledIndexExpr upperBound;
CompiledIndexExpr step;
std::unique_ptr<llvm::SmallVector<CompiledCoreNode, 8>> loopBody;
};
static FailureOr<CompiledCoreOpKind> classifyCompiledCoreOpKind(Operation& op) {
if (isa<pim::PimMemCopyHostToDevOp>(op))
return CompiledCoreOpKind::Load;
if (isa<pim::PimMemCopyDevToHostOp>(op))
return CompiledCoreOpKind::Store;
if (isa<pim::PimMemCopyOp>(op))
return CompiledCoreOpKind::Lmv;
if (isa<pim::PimReceiveOp>(op))
return CompiledCoreOpKind::Receive;
if (isa<pim::PimSendOp>(op))
return CompiledCoreOpKind::Send;
if (isa<pim::PimConcatOp>(op))
return CompiledCoreOpKind::Concat;
if (isa<pim::PimVMMOp>(op))
return CompiledCoreOpKind::Vmm;
if (isa<pim::PimTransposeOp>(op))
return CompiledCoreOpKind::Transpose;
if (isa<pim::PimVVAddOp>(op))
return CompiledCoreOpKind::VVAdd;
if (isa<pim::PimVVSubOp>(op))
return CompiledCoreOpKind::VVSub;
if (isa<pim::PimVVMulOp>(op))
return CompiledCoreOpKind::VVMul;
if (isa<pim::PimVVMaxOp>(op))
return CompiledCoreOpKind::VVMax;
if (isa<pim::PimVVDMulOp>(op))
return CompiledCoreOpKind::VVDMul;
if (isa<pim::PimVAvgOp>(op))
return CompiledCoreOpKind::VAvg;
if (isa<pim::PimVReluOp>(op))
return CompiledCoreOpKind::VRelu;
if (isa<pim::PimVTanhOp>(op))
return CompiledCoreOpKind::VTanh;
if (isa<pim::PimVSigmOp>(op))
return CompiledCoreOpKind::VSigm;
if (isa<pim::PimVSoftmaxOp>(op))
return CompiledCoreOpKind::VSoftmax;
if (isa<memref::GetGlobalOp>(op))
return CompiledCoreOpKind::GetGlobal;
return failure();
}
static LogicalResult
compileCoreEmissionPlan(Block& block, Operation* weightOwner, llvm::SmallVectorImpl<CompiledCoreNode>& plan) {
for (Operation& op : block) {
if (isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
if (succeeded(compileIndexExpr(loadOp.getResult())))
continue;
}
if (auto forOp = dyn_cast<mlir::scf::ForOp>(op)) {
auto lowerBound = compileIndexExpr(forOp.getLowerBound());
auto upperBound = compileIndexExpr(forOp.getUpperBound());
auto step = compileIndexExpr(forOp.getStep());
if (failed(lowerBound) || failed(upperBound) || failed(step)) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
return failure();
}
CompiledCoreNode loopNode;
loopNode.kind = CompiledCoreNode::Kind::Loop;
loopNode.op = forOp.getOperation();
loopNode.lowerBound = *lowerBound;
loopNode.upperBound = *upperBound;
loopNode.step = *step;
loopNode.loopBody = std::make_unique<llvm::SmallVector<CompiledCoreNode, 8>>();
if (failed(compileCoreEmissionPlan(forOp.getRegion().front(), weightOwner, *loopNode.loopBody)))
return failure();
plan.push_back(std::move(loopNode));
continue;
}
auto opKind = classifyCompiledCoreOpKind(op);
if (failed(opKind)) {
InFlightDiagnostic diag = op.emitError() << "unsupported codegen for op '" << op.getName().getStringRef() << "'";
if (auto coreOp = op.getParentOfType<pim::PimCoreOp>())
diag << " inside pim.core " << coreOp.getCoreId();
else if (auto coreBatchOp = op.getParentOfType<pim::PimCoreBatchOp>())
diag << " inside pim.core_batch with laneCount " << coreBatchOp.getLaneCount();
return failure();
}
CompiledCoreNode opNode;
opNode.kind = CompiledCoreNode::Kind::Op;
opNode.op = &op;
opNode.opKind = *opKind;
plan.push_back(std::move(opNode));
}
return success();
}
static LogicalResult executeCompiledCorePlan(
const llvm::SmallVectorImpl<CompiledCoreNode>& plan,
PimCodeGen& coreCodeGen,
StaticValueKnowledge& knowledge,
llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp, const StaticValueKnowledge&)> resolveWeightSlot,
size_t& processedOperations,
std::optional<unsigned> batchLane = std::nullopt,
std::optional<unsigned> batchLaneCount = std::nullopt) {
for (const CompiledCoreNode& node : plan) {
if (node.kind == CompiledCoreNode::Kind::Loop) {
auto lowerBound = node.lowerBound.evaluate(knowledge);
auto upperBound = node.upperBound.evaluate(knowledge);
auto step = node.step.evaluate(knowledge);
auto forOp = cast<mlir::scf::ForOp>(node.op);
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
return failure();
}
llvm::SmallVector<mlir::Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
ScopedMapBindings<decltype(knowledge.indexValues)> indexBindings(knowledge.indexValues);
ScopedMapBindings<decltype(knowledge.aliases)> aliasBindings(knowledge.aliases);
indexBindings.bind(forOp.getInductionVar(), inductionValue);
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
aliasBindings.bind(iterArg, iterValue);
if (failed(executeCompiledCorePlan(*node.loopBody,
coreCodeGen,
knowledge,
resolveWeightSlot,
processedOperations,
batchLane,
batchLaneCount)))
return failure();
auto yieldOp = cast<mlir::scf::YieldOp>(forOp.getRegion().front().getTerminator());
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, knowledge);
}
continue;
}
switch (node.opKind) {
case CompiledCoreOpKind::Load:
coreCodeGen.codeGenLoadOp(cast<pim::PimMemCopyHostToDevOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::Store:
coreCodeGen.codeGenStoreOp(cast<pim::PimMemCopyDevToHostOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::Lmv: coreCodeGen.codeGenLmvOp(cast<pim::PimMemCopyOp>(node.op), knowledge); break;
case CompiledCoreOpKind::Receive: coreCodeGen.codeGenReceiveOp(cast<pim::PimReceiveOp>(node.op), knowledge); break;
case CompiledCoreOpKind::Send: coreCodeGen.codeGenSendOp(cast<pim::PimSendOp>(node.op), knowledge); break;
case CompiledCoreOpKind::Concat: coreCodeGen.codeGenConcatOp(cast<pim::PimConcatOp>(node.op), knowledge); break;
case CompiledCoreOpKind::Vmm:
if (auto weightSlot = resolveWeightSlot(cast<pim::PimVMMOp>(node.op), knowledge); succeeded(weightSlot))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightSlot, cast<pim::PimVMMOp>(node.op), true, knowledge);
else
return failure();
break;
case CompiledCoreOpKind::Transpose:
coreCodeGen.codeGenTransposeOp(cast<pim::PimTransposeOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::VVAdd: coreCodeGen.codeGenVVAddOp(cast<pim::PimVVAddOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VVSub: coreCodeGen.codeGenVVSubOp(cast<pim::PimVVSubOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VVMul: coreCodeGen.codeGenVVMulOp(cast<pim::PimVVMulOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VVMax: coreCodeGen.codeGenVVMaxOp(cast<pim::PimVVMaxOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VVDMul: coreCodeGen.codeGenVVDMulOp(cast<pim::PimVVDMulOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VAvg: coreCodeGen.codeGenVAvgOp(cast<pim::PimVAvgOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VRelu: coreCodeGen.codeGenVReluOp(cast<pim::PimVReluOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VTanh: coreCodeGen.codeGenVTanhOp(cast<pim::PimVTanhOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VSigm: coreCodeGen.codeGenVSigmOp(cast<pim::PimVSigmOp>(node.op), knowledge); break;
case CompiledCoreOpKind::VSoftmax:
coreCodeGen.codeGenVSoftmaxOp(cast<pim::PimVSoftmaxOp>(node.op), knowledge);
break;
case CompiledCoreOpKind::GetGlobal:
coreCodeGen.codeGetGlobalOp(cast<memref::GetGlobalOp>(node.op), knowledge);
break;
}
processedOperations++;
}
return success();
}
static SmallDenseMap<memref::GlobalOp, MemEntry, 16>
collectMaterializedHostGlobals(ModuleOp moduleOp, func::FuncOp funcOp, const PimAcceleratorMemory& memory) {
SmallDenseMap<memref::GlobalOp, MemEntry, 16> materializedHostGlobals;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (hasWeightAlways(getGlobalOp))
return;
auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!targetGlobal || materializedHostGlobals.contains(targetGlobal))
return;
auto it = memory.memEntriesMap.find(getMemoryValueKey(getGlobalOp.getResult()));
if (it != memory.memEntriesMap.end())
materializedHostGlobals[targetGlobal] = it->second;
});
return materializedHostGlobals;
}
template <typename CoreLikeOpTy>
static void aliasMaterializedHostGlobals(CoreLikeOpTy coreLikeOp,
ModuleOp moduleOp,
const SmallDenseMap<memref::GlobalOp, MemEntry, 16>& materializedHostGlobals,
PimAcceleratorMemory& memory) {
coreLikeOp.walk([&](memref::GetGlobalOp getGlobalOp) {
MemoryValueKey key = getMemoryValueKey(getGlobalOp.getResult());
if (hasWeightAlways(getGlobalOp) || memory.memEntriesMap.contains(key))
return;
auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!targetGlobal)
return;
auto it = materializedHostGlobals.find(targetGlobal);
if (it != materializedHostGlobals.end())
memory.memEntriesMap[key] = it->second;
});
}
/// Dispatch all operations in a core region to the appropriate code generator.
/// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is
/// fully resolved before the JSON instructions are emitted.
/// Returns the number of emitted instructions, or -1 on failure.
static int64_t codeGenCoreOps(
Block& block,
PimCodeGen& coreCodeGen,
const StaticValueKnowledge& initialKnowledge,
Operation* weightOwner,
llvm::function_ref<llvm::FailureOr<unsigned>(pim::PimVMMOp, const StaticValueKnowledge&)> resolveWeightSlot,
std::optional<unsigned> batchLane = std::nullopt,
std::optional<unsigned> batchLaneCount = std::nullopt) {
llvm::SmallVector<CompiledCoreNode, 32> plan;
if (failed(compileCoreEmissionPlan(block, weightOwner, plan)))
return -1;
size_t processedOperations = 0;
StaticValueKnowledge knowledge = initialKnowledge;
auto result = executeCompiledCorePlan(
plan, coreCodeGen, knowledge, resolveWeightSlot, processedOperations, batchLane, batchLaneCount);
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
}
OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::string& outputDirPath) {
if (!outputDirPath.empty()) {
if (auto error = sys::fs::create_directory(outputDirPath)) {
errs() << "Error creating output directory: " << outputDirPath << ": " << error.message() << '\n';
return InvalidOutputFileAccess;
}
}
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc))
return CompilerFailure;
auto funcOp = *entryFunc;
PimAcceleratorMemory memory;
memory.hostMem.allocateHost(moduleOp, funcOp);
memory.reportHost();
if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath))
return err;
json::Object xbarsPerArrayGroup;
size_t maxCoreId = 0;
uint64_t nextBatchReportId = 0;
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
SmallDenseMap<memref::GlobalOp, MemEntry, 16> materializedHostGlobals =
collectMaterializedHostGlobals(moduleOp, funcOp, memory);
llvm::DenseMap<size_t, size_t> emittedCoreIds;
size_t nextEmittedCoreId = 0;
for (Operation* op : coreLikeOps) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
if (!emittedCoreIds.contains(originalCoreId))
emittedCoreIds[originalCoreId] = nextEmittedCoreId++;
continue;
}
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
if (!emittedCoreIds.contains(originalCoreId))
emittedCoreIds[originalCoreId] = nextEmittedCoreId++;
}
}
SmallVector<CoreEmissionJob> jobs;
SmallVector<SmallVector<size_t>> batchJobIndices;
for (Operation* op : coreLikeOps) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
CoreEmissionJob job;
job.coreLikeOp = coreOp;
job.originalCoreId = originalCoreId;
job.emittedCoreId = emittedCoreIds.lookup(originalCoreId);
jobs.push_back(std::move(job));
continue;
}
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
llvm::DenseMap<size_t, SmallVector<unsigned>> lanesByCoreId;
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
lanesByCoreId[static_cast<size_t>(batchCoreIds[lane])].push_back(lane);
SmallVector<size_t> jobIndices;
SmallVector<size_t> orderedOriginalCoreIds = llvm::to_vector(lanesByCoreId.keys());
llvm::sort(orderedOriginalCoreIds,
[&](size_t lhs, size_t rhs) { return emittedCoreIds.lookup(lhs) < emittedCoreIds.lookup(rhs); });
for (size_t originalCoreId : orderedOriginalCoreIds) {
CoreEmissionJob job;
job.coreLikeOp = coreBatchOp;
job.originalCoreId = originalCoreId;
job.emittedCoreId = emittedCoreIds.lookup(originalCoreId);
job.lanes = lanesByCoreId.lookup(originalCoreId);
job.batchReportId = nextBatchReportId;
jobIndices.push_back(jobs.size());
jobs.push_back(std::move(job));
}
batchJobIndices.push_back(std::move(jobIndices));
++nextBatchReportId;
}
auto linkCoreWeights =
[&](size_t coreId, ArrayRef<std::string> weightFiles, json::Array& xbarsPerGroup) -> OnnxMlirCompilerErrorCodes {
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
if (auto error = sys::fs::create_directory(coreWeightsDirPath); error && error != std::errc::file_exists) {
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
return InvalidOutputFileAccess;
}
for (auto [slot, fileName] : llvm::enumerate(weightFiles)) {
xbarsPerGroup.push_back(static_cast<int64_t>(slot));
std::string sourcePath = outputDirPath + "/weights/" + fileName;
std::string targetPath = coreWeightsDirPath + "/crossbar_" + std::to_string(slot) + ".bin";
sys::fs::remove(targetPath);
if (auto error = sys::fs::create_link(sourcePath, targetPath)) {
errs() << "Error creating link file: " << sourcePath << " to " << targetPath << "\nError:" << error.message()
<< '\n';
return InvalidOutputFileAccess;
}
}
return CompilerSuccess;
};
auto emitJob = [&](const CoreEmissionJob& job) -> CoreEmissionResult {
CoreEmissionResult result;
PimAcceleratorMemory jobMemory(memory.memEntriesMap, false);
llvm::SmallVector<ResolvedWeightView, 8> usedWeights;
auto resolveWeightSlot = [&](pim::PimVMMOp vmmOp,
const StaticValueKnowledge& knowledge) -> llvm::FailureOr<unsigned> {
auto weightView = onnx_mlir::resolveWeightView(job.coreLikeOp, vmmOp.getWeight(), knowledge);
if (failed(weightView)) {
vmmOp.emitOpError("requires a statically resolvable dense global weight view during PIM codegen");
return failure();
}
if (auto it = llvm::find(usedWeights, *weightView); it != usedWeights.end())
return static_cast<unsigned>(std::distance(usedWeights.begin(), it));
usedWeights.push_back(*weightView);
return static_cast<unsigned>(usedWeights.size() - 1);
};
std::error_code errorCode;
auto outputCorePath = outputDirPath + "/core_" + std::to_string(job.emittedCoreId) + ".pim";
raw_fd_ostream coreBinaryStream(outputCorePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n';
result.status = InvalidOutputFileAccess;
return result;
}
std::unique_ptr<raw_fd_ostream> coreJsonStream;
if (pimEmitJson.getValue()) {
std::string outputCoreJsonPath = outputDirPath + "/core_" + std::to_string(job.emittedCoreId) + ".json";
errorCode = std::error_code();
coreJsonStream = std::make_unique<raw_fd_ostream>(outputCoreJsonPath, errorCode);
if (errorCode) {
errs() << "Error while opening core json file `" << outputCoreJsonPath << "`: " << errorCode.message() << '\n';
result.status = InvalidOutputFileAccess;
return result;
}
*coreJsonStream << '[';
}
pim_binary::writeHeader(coreBinaryStream);
PimCodeGen coreCodeGen(jobMemory, coreBinaryStream, coreJsonStream.get(), emittedCoreIds);
if (auto coreOp = dyn_cast<pim::PimCoreOp>(job.coreLikeOp)) {
aliasMaterializedHostGlobals(coreOp, moduleOp, materializedHostGlobals, jobMemory);
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
deviceMemory.allocateCore(coreOp);
int64_t processedOperations = codeGenCoreOps(
coreOp.getBody().front(), coreCodeGen, StaticValueKnowledge {}, coreOp.getOperation(), resolveWeightSlot);
if (processedOperations < 0) {
result.status = CompilerFailure;
return result;
}
assert(processedOperations > 0);
result.reportRow = deviceMemory.getReportRow();
result.usedWeights = std::move(usedWeights);
}
else {
auto coreBatchOp = cast<pim::PimCoreBatchOp>(job.coreLikeOp);
aliasMaterializedHostGlobals(coreBatchOp, moduleOp, materializedHostGlobals, jobMemory);
auto& deviceMemory = jobMemory.getOrCreateDeviceMem(job.emittedCoreId);
for (unsigned lane : job.lanes) {
StaticValueKnowledge knowledge;
knowledge.indexValues[coreBatchOp.getLaneArgument()] = lane;
for (unsigned i = 0; i < coreBatchOp.getInputs().size(); ++i)
knowledge.aliases[coreBatchOp.getInputArgument(i)] = coreBatchOp.getInputs()[i];
deviceMemory.allocateCore(coreBatchOp, lane);
coreCodeGen.setBatchLane(lane);
int64_t processedOperations = codeGenCoreOps(coreBatchOp.getBody().front(),
coreCodeGen,
knowledge,
coreBatchOp.getOperation(),
resolveWeightSlot,
lane,
static_cast<unsigned>(coreBatchOp.getLaneCount()));
if (processedOperations < 0) {
result.status = CompilerFailure;
return result;
}
assert(processedOperations > 0);
}
result.reportRow = deviceMemory.getReportRow();
result.usedWeights = std::move(usedWeights);
}
pim_binary::patchInstructionCount(coreBinaryStream, coreCodeGen.getEmittedInstructionCount());
coreBinaryStream.close();
if (coreJsonStream) {
coreJsonStream->seek(coreJsonStream->tell() - 1);
*coreJsonStream << ']';
coreJsonStream->close();
}
return result;
};
std::vector<CoreEmissionResult> jobResults(jobs.size());
mlir::parallelFor(
moduleOp.getContext(), 0, jobs.size(), [&](size_t index) { jobResults[index] = emitJob(jobs[index]); });
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex)
if (jobResults[jobIndex].status != CompilerSuccess)
return jobResults[jobIndex].status;
llvm::SmallVector<WeightFileRequest, 8> weightRequests;
weightRequests.reserve(jobs.size());
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) {
WeightFileRequest request;
request.coreId = jobs[jobIndex].emittedCoreId;
request.weights = jobResults[jobIndex].usedWeights;
weightRequests.push_back(std::move(request));
}
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(weightRequests, outputDirPath);
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) {
const CoreEmissionJob& job = jobs[jobIndex];
const CoreEmissionResult& result = jobResults[jobIndex];
json::Array xbarsPerGroup;
if (auto coreOp = dyn_cast<pim::PimCoreOp>(job.coreLikeOp)) {
if (auto err = linkCoreWeights(job.emittedCoreId, mapCoreWeightToFileName[job.emittedCoreId], xbarsPerGroup))
return err;
xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup);
memory.recordCoreReport(job.emittedCoreId, result.reportRow);
continue;
}
}
for (const SmallVector<size_t>& group : batchJobIndices) {
SmallVector<int32_t> reportedCoreIds;
MemoryReportRow batchRow;
std::optional<MemoryReportRow> batchPerCoreRow;
for (size_t jobIndex : group) {
const CoreEmissionJob& job = jobs[jobIndex];
const CoreEmissionResult& result = jobResults[jobIndex];
json::Array xbarsPerGroup;
if (auto err = linkCoreWeights(job.emittedCoreId, mapCoreWeightToFileName[job.emittedCoreId], xbarsPerGroup))
return err;
xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup);
reportedCoreIds.push_back(static_cast<int32_t>(job.emittedCoreId));
if (!batchPerCoreRow)
batchPerCoreRow = result.reportRow;
batchRow = addMemoryReportRows(batchRow, result.reportRow);
}
uint64_t batchReportId = jobs[group.front()].batchReportId.value_or(0);
memory.recordBatchReport(batchReportId,
reportedCoreIds,
batchPerCoreRow.value_or(MemoryReportRow {}),
batchRow.numAlloca,
batchRow.sizeAlloca);
}
maxCoreId = nextEmittedCoreId == 0 ? 0 : nextEmittedCoreId - 1;
memory.flushReport();
return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath);
}