Files
Raptor/src/PIM/Compiler/PimCodeGen.cpp
T
2026-05-22 21:52:28 +02:00

1072 lines
44 KiB
C++

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h"
#include <absl/types/compare.h>
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <fstream>
#include <string>
#include <utility>
#include "Common/IR/CompactAsmUtils.hpp"
#include "Common/PimCommon.hpp"
#include "Common/Support/ReportUtils.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm;
using namespace mlir;
using namespace onnx_mlir;
using namespace onnx_mlir::compact_asm;
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape());
size_t allocSize = getShapedTypeSizeInBytes(type);
MemEntry memEntry = {0, allocSize};
return &memEntries.emplace_back(memEntry, value).first;
}
void PimMemory::allocateGatheredMemory() {
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; });
for (auto& [memEntry, value] : memEntries)
allocateMemoryForValue(value, memEntry);
}
void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
memEntry.address = firstAvailableAddress;
firstAvailableAddress += memEntry.size;
// Alignment
if (size_t remainder = firstAvailableAddress % minAlignment)
firstAvailableAddress += minAlignment - remainder;
ownedMemEntriesMap[value] = memEntry;
globalMemEntriesMap[value] = memEntry;
}
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
SmallVector<mlir::Value> args;
for (mlir::Value arg : funcOp.getArguments()) {
gatherMemEntry(arg);
args.push_back(arg);
}
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (!hasWeightAlways(getGlobalOp)) {
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (globalMemrefOp.getName().starts_with("arg")) {
StringRef indexStr = globalMemrefOp.getName().substr(4);
int index = 0;
llvm::to_integer(indexStr, index, 10);
globalAliases.push_back({getGlobalOp.getResult(), args[index]});
}
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
if (inserted)
gatherMemEntry(getGlobalOp.getResult());
else
globalAliases.push_back({getGlobalOp.getResult(), iter->second});
}
});
funcOp.walk([&](memref::AllocOp allocOp) {
if (!allocOp->getParentOfType<pim::PimCoreOp>())
gatherMemEntry(allocOp.getResult());
});
allocateGatheredMemory();
for (auto [alias, original] : globalAliases)
globalMemEntriesMap[alias] = getMemEntry(original);
}
void PimMemory::allocateCore(Operation* op) {
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
allocateGatheredMemory();
}
static void printHostMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
llvm::SmallVector<ReportField, 2> fields = {
{"Number of globals", std::to_string(row.numGlobal) },
{"Global memory", formatReportMemory(row.sizeGlobal)}
};
printReportFlatFields(os, fields);
}
static void printCoreMemoryReportRow(raw_ostream& os, const MemoryReportEntry& entry) {
llvm::SmallVector<ReportField, 2> fields = {
{"Number of allocas", std::to_string(entry.row.numAlloca) },
{"Allocated memory", formatReportMemory(entry.row.sizeAlloca)}
};
printReportFlatFields(os, fields);
}
static void printBatchMemoryReportRow(raw_ostream& os, const MemoryReportEntry& entry) {
llvm::SmallVector<ReportField, 2> perCoreFields = {
{"Number of allocas", std::to_string(entry.row.numAlloca) },
{"Allocated memory", formatReportMemory(entry.row.sizeAlloca)}
};
llvm::SmallVector<ReportField, 2> totalFields = {
{"Number of allocas", std::to_string(entry.totalAllocaCount) },
{"Batch memory", formatReportMemory(entry.totalAllocaBytes)}
};
printReportPerCoreAndTotalFields(os, perCoreFields, totalFields);
}
static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const MemoryReportRow& rhs) {
MemoryReportRow result = lhs;
result.numAlloca += rhs.numAlloca;
result.sizeAlloca += rhs.sizeAlloca;
result.numGlobal += rhs.numGlobal;
result.sizeGlobal += rhs.sizeGlobal;
return result;
}
MemoryReportRow PimMemory::getReportRow() const {
MemoryReportRow row;
for (auto& [val, memEntry] : ownedMemEntriesMap) {
if (auto op = val.getDefiningOp()) {
if (isa<memref::AllocOp>(op)) {
row.numAlloca++;
row.sizeAlloca += memEntry.size;
}
if (isa<memref::GetGlobalOp>(op)) {
row.numGlobal++;
row.sizeGlobal += memEntry.size;
}
}
}
return row;
}
void PimMemory::remove(mlir::Value val) {
if (auto removeIter = ownedMemEntriesMap.find(val); removeIter != ownedMemEntriesMap.end())
ownedMemEntriesMap.erase(removeIter);
if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end())
globalMemEntriesMap.erase(removeIter);
}
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
auto iter = globalMemEntriesMap.find(value);
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
return iter->second;
}
PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
return deviceMem.try_emplace(id, memEntriesMap).first->second;
}
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge) const {
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
if (failed(resolvedAddress)) {
errs() << "Failed to resolve contiguous address for value: ";
value.print(errs());
errs() << "\n";
if (auto* definingOp = value.getDefiningOp()) {
errs() << "Defining op:\n";
definingOp->print(errs());
errs() << "\n";
}
llvm_unreachable("Failed to resolve contiguous address");
}
auto iter = memEntriesMap.find(resolvedAddress->base);
if (iter == memEntriesMap.end()) {
errs() << "Missing mem entry for value: ";
resolvedAddress->base.print(errs());
errs() << "\n";
if (auto* definingOp = resolvedAddress->base.getDefiningOp()) {
errs() << "Defining op:\n";
definingOp->print(errs());
errs() << "\n";
}
llvm_unreachable("Missing mem entry");
}
return iter->second.address + resolvedAddress->byteOffset;
}
void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); }
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
reportEntries.push_back(
{MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row, row.numAlloca, row.sizeAlloca});
}
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId,
ArrayRef<int32_t> coreIds,
const MemoryReportRow& perCoreRow,
uint64_t totalAllocaCount,
uint64_t totalAllocaBytes) {
MemoryReportEntry entry;
entry.kind = MemoryReportEntry::Kind::Batch;
entry.id = batchId;
llvm::append_range(entry.coreIds, coreIds);
entry.row = perCoreRow;
entry.totalAllocaCount = totalAllocaCount;
entry.totalAllocaBytes = totalAllocaBytes;
reportEntries.push_back(std::move(entry));
}
void PimAcceleratorMemory::flushReport() {
if (!fileReport.is_open())
return;
llvm::raw_os_ostream os(fileReport);
uint64_t totalGlobalMemory = hostReportRow.has_value() ? hostReportRow->sizeGlobal : 0;
uint64_t totalCoresMemory = 0;
for (const MemoryReportEntry& entry : reportEntries)
totalCoresMemory += entry.totalAllocaBytes;
llvm::SmallVector<ReportField, 2> totalFields = {
{"Global memory", formatReportMemory(totalGlobalMemory)},
{"Cores memory", formatReportMemory(totalCoresMemory) }
};
printReportTotalsBlock(os, totalFields);
if (hostReportRow.has_value()) {
os << "\nHost:\n";
printHostMemoryReportRow(os, *hostReportRow);
}
if (!reportEntries.empty()) {
if (hostReportRow.has_value())
os << "\n";
sortReportEntriesByFirstCore(reportEntries);
for (size_t index = 0; index < reportEntries.size();) {
size_t runEnd = index + 1;
while (runEnd < reportEntries.size() && reportEntries[runEnd].kind == reportEntries[index].kind
&& reportEntries[runEnd].row == reportEntries[index].row
&& reportEntries[runEnd].totalAllocaCount == reportEntries[index].totalAllocaCount
&& reportEntries[runEnd].totalAllocaBytes == reportEntries[index].totalAllocaBytes) {
++runEnd;
}
if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch) {
os << "Batch ";
for (size_t batchIndex = index; batchIndex < runEnd; ++batchIndex) {
if (batchIndex != index)
os << ",\n ";
os << reportEntries[batchIndex].id << " (cores ";
printCompressedIntegerEntries(os, ArrayRef<int32_t>(reportEntries[batchIndex].coreIds));
os << ")";
}
}
else {
llvm::SmallVector<int32_t, 8> coreIds;
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
coreIds.push_back(reportEntries[coreIndex].coreIds.front());
os << "Core ";
printCompressedIntegerEntries(os, ArrayRef<int32_t>(coreIds));
}
os << ":\n";
if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch)
printBatchMemoryReportRow(os, reportEntries[index]);
else
printCoreMemoryReportRow(os, reportEntries[index]);
printReportEntrySeparator(os, runEnd < reportEntries.size());
index = runEnd;
}
}
os.flush();
fileReport.close();
}
void PimAcceleratorMemory::clean(mlir::Operation* op) {
for (auto value : op->getResults()) {
hostMem.remove(value);
for (auto& device : deviceMem)
device.second.remove(value);
}
}
size_t PimCodeGen::remapCoreId(size_t coreId) const {
auto it = emittedCoreIds.find(coreId);
assert(it != emittedCoreIds.end() && "Missing emitted core id remapping");
return it->second;
}
void PimCodeGen::emitInstruction(const pim_binary::InstructionRecord& instruction) const {
pim_binary::writeInstructionRecord(coreBinaryStream, instruction);
++emittedInstructionCount;
if (coreJsonStream)
*coreJsonStream << json::Value(pim_binary::makeInstructionJson(instruction)) << ',';
}
void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const {
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::sldi;
instruction.rd = static_cast<uint8_t>(registerNumber);
instruction.r2OrImm = static_cast<int32_t>(immediate);
emitInstruction(instruction);
}
void PimCodeGen::setupRd(size_t rdAddress, size_t rdOffset) const {
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
}
void PimCodeGen::setupRdRs1(size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset) const {
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset);
}
void PimCodeGen::setupRdRs1Rs2(
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const {
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset);
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
}
void PimCodeGen::emitMemCopyOp(StringRef opName,
size_t rdAddr,
size_t rdOffset,
size_t rs1Addr,
size_t rs1Offset,
size_t size,
StringRef sizeFieldName) const {
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::opcodeFromString(opName);
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic1 = 0;
instruction.generic2 = 0;
instruction.generic3 = static_cast<int32_t>(size);
(void) sizeFieldName;
emitInstruction(instruction);
}
void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const {
setupRd(bufferAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::opcodeFromString(opName);
instruction.rd = 0;
instruction.r2OrImm = static_cast<int32_t>(remapCoreId(coreId));
instruction.generic1 = 0;
instruction.generic2 = 0;
instruction.generic3 = static_cast<int32_t>(size);
emitInstruction(instruction);
}
void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const {
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::mvmul;
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 8;
instruction.generic1 = 0;
instruction.generic2 = static_cast<int32_t>(groupId);
emitInstruction(instruction);
}
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const {
auto deviceTargetOffset = resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge);
auto hostSourceOffset = resolveIndexValue(loadOp.getHostSourceOffset(), knowledge);
assert(succeeded(deviceTargetOffset) && succeeded(hostSourceOffset)
&& "pim.memcp_hd offsets must be statically resolvable during codegen");
emitMemCopyOp("ld",
addressOf(loadOp.getDeviceTarget(), knowledge),
*deviceTargetOffset,
addressOf(loadOp.getHostSource(), knowledge),
*hostSourceOffset,
loadOp.getSize());
}
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
auto hostTargetOffset = resolveIndexValue(storeOp.getHostTargetOffset(), knowledge);
auto deviceSourceOffset = resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge);
assert(succeeded(hostTargetOffset) && succeeded(deviceSourceOffset)
&& "pim.memcp_dh offsets must be statically resolvable during codegen");
emitMemCopyOp("st",
addressOf(storeOp.getHostTarget(), knowledge),
*hostTargetOffset,
addressOf(storeOp.getDeviceSource(), knowledge),
*deviceSourceOffset,
storeOp.getSize());
}
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const {
emitMemCopyOp("lmv",
addressOf(lmvOp.getTarget(), knowledge),
lmvOp.getTargetOffset(),
addressOf(lmvOp.getSource(), knowledge),
lmvOp.getSourceOffset(),
lmvOp.getSize(),
"len");
}
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const {
auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge);
assert(succeeded(sourceCoreId) && "pim.receive source core id must be statically resolvable during codegen");
emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize());
}
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
const StaticValueKnowledge& knowledge) const {
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(receiveTensorOp.getOutputBuffer().getType()))
/ receiveTensorOp.getSourceCoreIds().size();
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
}
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge);
assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen");
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize());
}
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(sendTensorOp.getInput().getType()))
/ sendTensorOp.getTargetCoreIds().size();
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
}
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
auto outputType = cast<ShapedType>(concatOp.getOutputBuffer().getType());
assert(outputType.hasStaticShape() && "concat codegen requires static output shape");
int64_t axis = concatOp.getAxis();
ArrayRef<int64_t> outputShape = outputType.getShape();
size_t elementSize = getElementTypeSizeInBytes(outputType.getElementType());
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
size_t outerCount = 1;
for (int64_t dim = 0; dim < axis; ++dim)
outerCount *= static_cast<size_t>(outputShape[dim]);
size_t innerCount = 1;
for (size_t dim = static_cast<size_t>(axis) + 1; dim < outputShape.size(); ++dim)
innerCount *= static_cast<size_t>(outputShape[dim]);
size_t outputConcatDim = static_cast<size_t>(outputShape[axis]);
size_t concatOffset = 0;
for (mlir::Value input : concatOp.getInputs()) {
auto inputType = cast<ShapedType>(input.getType());
assert(inputType.hasStaticShape() && "concat codegen requires static input shapes");
size_t inputConcatDim = static_cast<size_t>(inputType.getDimSize(axis));
size_t blockSizeInBytes = inputConcatDim * innerCount * elementSize;
size_t inputAddr = addressOf(input, knowledge);
for (size_t outerIndex = 0; outerIndex < outerCount; ++outerIndex) {
size_t dstOffset = (outerIndex * outputConcatDim + concatOffset) * innerCount * elementSize;
size_t srcOffset = outerIndex * inputConcatDim * innerCount * elementSize;
emitMemCopyOp("lmv", outputAddr, dstOffset, inputAddr, srcOffset, blockSizeInBytes, "len");
}
concatOffset += inputConcatDim;
}
}
template <typename MVMTy>
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
MVMTy mvmLikeOp,
bool transposeMatrix,
const StaticValueKnowledge& knowledge) {
emitMvmOp(mvmId, addressOf(mvmLikeOp.getOutputBuffer(), knowledge), 0, addressOf(mvmLikeOp.getInput(), knowledge), 0);
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
}
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvaddOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvaddOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvaddOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vvadd;
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvaddOp.getLhs().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvsubOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvsubOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvsubOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vvsub;
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvsubOp.getLhs().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvmulOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvmulOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvmulOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vvmul;
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmulOp.getLhs().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvmaxOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvmaxOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvmaxOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vvmax;
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmaxOp.getLhs().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvdmulOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvdmulOp.getLhs(), knowledge);
auto rhsAddr = addressOf(vvdmulOp.getRhs(), knowledge);
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vvdmul;
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvdmulOp.getLhs().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vavgOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vavgOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vavg;
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 1;
instruction.generic1 = 1;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vavgOp.getInput().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vreluOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vreluOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vrelu;
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vreluOp.getInput().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vtanhOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vtanhOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vtanh;
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vtanhOp.getInput().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vsigmOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vsigmOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vsigm;
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsigmOp.getInput().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vsoftmaxOp.getOutputBuffer(), knowledge);
auto inputAddr = addressOf(vsoftmaxOp.getInput(), knowledge);
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::vsoftmax;
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic3 =
static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsoftmaxOp.getInput().getType())));
emitInstruction(instruction);
}
void PimCodeGen::codeGetGlobalOp(memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const {}
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
auto srcShape = srcType.getShape();
size_t rank = srcShape.size();
size_t elementSize = getElementTypeSizeInBytes(srcType.getElementType());
size_t totalElements = srcType.getNumElements();
// Read permutation. Destination dim i corresponds to source dim perm[i].
SmallVector<int64_t> perm = map_to_vector(transposeOp.getPermutation().getAsRange<IntegerAttr>(),
[](auto attr) -> int64_t { return attr.getInt(); });
// Destination shape: dstShape[i] = srcShape[perm[i]]
SmallVector<int64_t> dstShape(rank);
for (size_t i = 0; i < rank; i++)
dstShape[i] = srcShape[perm[i]];
// Row-major strides for source and destination
SmallVector<size_t> srcStrides(rank, 1);
SmallVector<size_t> dstStrides(rank, 1);
for (int64_t i = rank - 2; i >= 0; i--) {
srcStrides[i] = srcStrides[i + 1] * srcShape[i + 1];
dstStrides[i] = dstStrides[i + 1] * dstShape[i + 1];
}
bool storagePreserving = true;
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
SmallVector<size_t> srcIdx(rank);
size_t remaining = srcFlat;
for (size_t d = 0; d < rank; d++) {
srcIdx[d] = remaining / srcStrides[d];
remaining %= srcStrides[d];
}
size_t dstFlat = 0;
for (size_t d = 0; d < rank; d++)
dstFlat += srcIdx[perm[d]] * dstStrides[d];
if (dstFlat != srcFlat) {
storagePreserving = false;
break;
}
}
if (storagePreserving) {
emitMemCopyOp("lmv", dstAddr, 0, srcAddr, 0, totalElements * elementSize, "len");
return;
}
// Emit element-by-element copy with transposed addressing
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
// Decompose flat source index into multi-dimensional index
SmallVector<size_t> srcIdx(rank);
size_t remaining = srcFlat;
for (size_t d = 0; d < rank; d++) {
srcIdx[d] = remaining / srcStrides[d];
remaining %= srcStrides[d];
}
// Compute flat destination index: dstIdx[d] = srcIdx[perm[d]]
size_t dstFlat = 0;
for (size_t d = 0; d < rank; d++)
dstFlat += srcIdx[perm[d]] * dstStrides[d];
emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len");
}
}
size_t getMatrixSize(ShapedType matrixShape) {
if (matrixShape.getRank() != 2 && matrixShape.getRank() != 4)
assert(false && "Unsupported matrix shape");
return std::max(matrixShape.getDimSize(0), matrixShape.getDimSize(1));
}
std::string getMemorySizeAsString(size_t size) {
if (size > 1024 * 1024 * 1024)
return std::to_string(size / 1024 / 1024 / 1024) + " GB";
if (size > 1024 * 1024)
return std::to_string(size / 1024 / 1024) + " MB";
if (size > 1024)
return std::to_string(size / 1024) + " KB";
return std::to_string(size) + " Bytes";
}
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices;
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
auto addWeight = [&](mlir::Value weight) {
if (!coreOp)
return;
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
if (coreOp.getWeightArgument(weightIndex) != weight)
continue;
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
return;
}
};
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
llvm::sort(indices);
return indices;
}
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
return getUsedWeightIndices(coreOp.getBody().front());
}
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
SmallVector<Operation*> coreLikeOps;
for (Operation& op : funcOp.getBody().front())
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
coreLikeOps.push_back(&op);
return coreLikeOps;
}
static SmallDenseMap<memref::GlobalOp, MemEntry, 16>
collectMaterializedHostGlobals(ModuleOp moduleOp, func::FuncOp funcOp, const PimAcceleratorMemory& memory) {
SmallDenseMap<memref::GlobalOp, MemEntry, 16> materializedHostGlobals;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (hasWeightAlways(getGlobalOp))
return;
auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!targetGlobal || materializedHostGlobals.contains(targetGlobal))
return;
auto it = memory.memEntriesMap.find(getGlobalOp.getResult());
if (it != memory.memEntriesMap.end())
materializedHostGlobals[targetGlobal] = it->second;
});
return materializedHostGlobals;
}
static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
pim::PimCoreOp coreOp,
const SmallDenseMap<memref::GlobalOp, MemEntry, 16>& materializedHostGlobals,
PimAcceleratorMemory& memory) {
coreOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (hasWeightAlways(getGlobalOp) || memory.memEntriesMap.contains(getGlobalOp.getResult()))
return;
auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!targetGlobal)
return;
auto it = materializedHostGlobals.find(targetGlobal);
if (it != materializedHostGlobals.end())
memory.memEntriesMap[getGlobalOp.getResult()] = it->second;
});
}
/// Dispatch all operations in a core region to the appropriate code generator.
/// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is
/// fully resolved before the JSON instructions are emitted.
/// Returns the number of emitted instructions, or -1 on failure.
static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
auto resolveWeightIndex = [&](pim::PimVMMOp vmmOp) -> std::optional<unsigned> {
auto coreOp = vmmOp->getParentOfType<pim::PimCoreOp>();
if (!coreOp)
return std::nullopt;
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
return weightIndex;
return std::nullopt;
};
size_t processedOperations = 0;
auto result =
walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
coreCodeGen.codeGenLoadOp(loadOp, knowledge);
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
coreCodeGen.codeGenStoreOp(storeOp, knowledge);
else if (auto lmvOp = dyn_cast<pim::PimMemCopyOp>(op))
coreCodeGen.codeGenLmvOp(lmvOp, knowledge);
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
else if (auto receiveTensorOp = dyn_cast<pim::PimReceiveTensorOp>(op))
coreCodeGen.codeGenReceiveTensorOp(receiveTensorOp, knowledge);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp, knowledge);
else if (auto sendTensorOp = dyn_cast<pim::PimSendTensorOp>(op))
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) {
auto weightIndex = resolveWeightIndex(vmmOp);
if (!weightIndex)
return failure();
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightIndex, vmmOp, true, knowledge);
}
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
coreCodeGen.codeGenVVAddOp(vvaddOp, knowledge);
else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
coreCodeGen.codeGenVVSubOp(vvsubOp, knowledge);
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
coreCodeGen.codeGenVVMulOp(vvmulOp, knowledge);
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
coreCodeGen.codeGenVVMaxOp(vvmaxOp, knowledge);
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
coreCodeGen.codeGenVVDMulOp(vvdmulOp, knowledge);
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
coreCodeGen.codeGenVAvgOp(vavgOp, knowledge);
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
coreCodeGen.codeGenVReluOp(vreluOp, knowledge);
else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
coreCodeGen.codeGenVTanhOp(vtanhOp, knowledge);
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge);
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
else {
InFlightDiagnostic diag = op.emitError()
<< "unsupported codegen for op '" << op.getName().getStringRef() << "'";
if (auto coreOp = op.getParentOfType<pim::PimCoreOp>())
diag << " inside pim.core " << coreOp.getCoreId();
else if (auto coreBatchOp = op.getParentOfType<pim::PimCoreBatchOp>())
diag << " inside pim.core_batch with laneCount " << coreBatchOp.getLaneCount();
return failure();
}
processedOperations++;
return success();
});
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
}
OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::string& outputDirPath) {
if (!outputDirPath.empty()) {
if (auto error = sys::fs::create_directory(outputDirPath)) {
errs() << "Error creating output directory: " << outputDirPath << ": " << error.message() << '\n';
return InvalidOutputFileAccess;
}
}
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc))
return CompilerFailure;
auto funcOp = *entryFunc;
PimAcceleratorMemory memory;
memory.hostMem.allocateHost(moduleOp, funcOp);
memory.reportHost();
if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath))
return err;
json::Object xbarsPerArrayGroup;
size_t maxCoreId = 0;
uint64_t nextBatchReportId = 0;
// Create Weight Folder
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
SmallDenseMap<memref::GlobalOp, MemEntry, 16> materializedHostGlobals =
collectMaterializedHostGlobals(moduleOp, funcOp, memory);
llvm::DenseMap<size_t, size_t> emittedCoreIds;
size_t nextEmittedCoreId = 0;
for (Operation* op : coreLikeOps) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
if (!emittedCoreIds.contains(originalCoreId))
emittedCoreIds[originalCoreId] = nextEmittedCoreId++;
continue;
}
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
if (!emittedCoreIds.contains(originalCoreId))
emittedCoreIds[originalCoreId] = nextEmittedCoreId++;
}
}
for (Operation* op : coreLikeOps) {
auto emitCore = [&](pim::PimCoreOp coreOp,
bool temporaryCore,
MemoryReportRow* reportRow = nullptr) -> OnnxMlirCompilerErrorCodes {
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
size_t coreId = emittedCoreIds.lookup(originalCoreId);
maxCoreId = std::max(maxCoreId, coreId);
std::error_code errorCode;
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".pim";
raw_fd_ostream coreBinaryStream(outputCorePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
std::unique_ptr<raw_fd_ostream> coreJsonStream;
if (pimEmitJson.getValue()) {
std::string outputCoreJsonPath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
errorCode = std::error_code();
coreJsonStream = std::make_unique<raw_fd_ostream>(outputCoreJsonPath, errorCode);
if (errorCode) {
errs() << "Error while opening core json file `" << outputCoreJsonPath << "`: " << errorCode.message()
<< '\n';
return InvalidOutputFileAccess;
}
*coreJsonStream << '[';
}
pim_binary::writeHeader(coreBinaryStream);
PimCodeGen coreCodeGen(memory, coreBinaryStream, coreJsonStream.get(), emittedCoreIds);
aliasMaterializedHostGlobals(moduleOp, coreOp, materializedHostGlobals, memory);
auto& deviceMemory = memory.getOrCreateDeviceMem(coreId);
deviceMemory.allocateCore(coreOp);
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
if (processedOperations < 0)
return CompilerFailure;
assert(processedOperations > 0);
if (reportRow)
*reportRow = deviceMemory.getReportRow();
pim_binary::patchInstructionCount(coreBinaryStream, coreCodeGen.getEmittedInstructionCount());
coreBinaryStream.close();
if (coreJsonStream) {
coreJsonStream->seek(coreJsonStream->tell() - 1);
*coreJsonStream << ']';
coreJsonStream->close();
}
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
return InvalidOutputFileAccess;
}
auto& mapWeightToFile = mapCoreWeightToFileName[originalCoreId];
json::Array xbarsPerGroup;
for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
xbarsPerGroup.push_back(index);
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
auto& fileName = mapWeightToFile[weight];
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) {
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")
<< "\nError:" << error.message() << '\n';
return InvalidOutputFileAccess;
}
}
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
if (temporaryCore)
coreOp.walk([&memory](Operation* op) { memory.clean(op); });
return CompilerSuccess;
};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
MemoryReportRow coreRow;
if (auto err = emitCore(coreOp, false, &coreRow))
return err;
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())), coreRow);
continue;
}
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
SmallVector<int32_t> reportedCoreIds;
reportedCoreIds.reserve(batchCoreIds.size());
MemoryReportRow batchRow;
std::optional<MemoryReportRow> batchPerCoreRow;
llvm::DenseMap<size_t, SmallVector<unsigned>> lanesByCoreId;
SmallVector<size_t> orderedOriginalCoreIds;
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
auto [it, inserted] = lanesByCoreId.try_emplace(originalCoreId);
if (inserted)
orderedOriginalCoreIds.push_back(originalCoreId);
it->second.push_back(lane);
}
for (size_t originalCoreId : orderedOriginalCoreIds) {
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
if (failed(withScalarCoreFromBatchLanes(coreBatchOp, lanesByCoreId[originalCoreId], [&](pim::PimCoreOp coreOp) {
size_t coreId = emittedCoreIds.lookup(originalCoreId);
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
MemoryReportRow laneRow;
laneResult = emitCore(coreOp, true, &laneRow);
if (laneResult == CompilerSuccess) {
if (!batchPerCoreRow.has_value())
batchPerCoreRow = laneRow;
batchRow = addMemoryReportRows(batchRow, laneRow);
}
return laneResult == CompilerSuccess ? success() : failure();
})))
return laneResult == CompilerSuccess ? CompilerFailure : laneResult;
}
memory.recordBatchReport(nextBatchReportId++,
reportedCoreIds,
batchPerCoreRow.value_or(MemoryReportRow {}),
batchRow.numAlloca,
batchRow.sizeAlloca);
}
memory.flushReport();
return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath);
}