c77ffa9c56
support for tensors of index values
1072 lines
44 KiB
C++
1072 lines
44 KiB
C++
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/IR/AsmState.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/IR/Verifier.h"
|
|
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/FileSystem.h"
|
|
#include "llvm/Support/Format.h"
|
|
#include "llvm/Support/JSON.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
#include <absl/types/compare.h>
|
|
#include <algorithm>
|
|
#include <cassert>
|
|
#include <cstdint>
|
|
#include <fstream>
|
|
#include <string>
|
|
#include <utility>
|
|
|
|
#include "Common/IR/CompactAsmUtils.hpp"
|
|
#include "Common/PimCommon.hpp"
|
|
#include "Common/Support/ReportUtils.hpp"
|
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
|
|
using namespace llvm;
|
|
using namespace mlir;
|
|
using namespace onnx_mlir;
|
|
using namespace onnx_mlir::compact_asm;
|
|
|
|
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
|
auto type = cast<ShapedType>(value.getType());
|
|
assert("Only static shape is supported" && type.hasStaticShape());
|
|
size_t allocSize = getShapedTypeSizeInBytes(type);
|
|
MemEntry memEntry = {0, allocSize};
|
|
return &memEntries.emplace_back(memEntry, value).first;
|
|
}
|
|
|
|
void PimMemory::allocateGatheredMemory() {
|
|
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; });
|
|
for (auto& [memEntry, value] : memEntries)
|
|
allocateMemoryForValue(value, memEntry);
|
|
}
|
|
|
|
void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
|
memEntry.address = firstAvailableAddress;
|
|
firstAvailableAddress += memEntry.size;
|
|
// Alignment
|
|
if (size_t remainder = firstAvailableAddress % minAlignment)
|
|
firstAvailableAddress += minAlignment - remainder;
|
|
|
|
ownedMemEntriesMap[value] = memEntry;
|
|
globalMemEntriesMap[value] = memEntry;
|
|
}
|
|
|
|
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
|
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
|
|
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
|
|
SmallVector<mlir::Value> args;
|
|
|
|
for (mlir::Value arg : funcOp.getArguments()) {
|
|
gatherMemEntry(arg);
|
|
args.push_back(arg);
|
|
}
|
|
|
|
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
|
if (!hasWeightAlways(getGlobalOp)) {
|
|
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
|
if (globalMemrefOp.getName().starts_with("arg")) {
|
|
StringRef indexStr = globalMemrefOp.getName().substr(4);
|
|
int index = 0;
|
|
llvm::to_integer(indexStr, index, 10);
|
|
globalAliases.push_back({getGlobalOp.getResult(), args[index]});
|
|
}
|
|
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
|
|
if (inserted)
|
|
gatherMemEntry(getGlobalOp.getResult());
|
|
else
|
|
globalAliases.push_back({getGlobalOp.getResult(), iter->second});
|
|
}
|
|
});
|
|
|
|
funcOp.walk([&](memref::AllocOp allocOp) {
|
|
if (!allocOp->getParentOfType<pim::PimCoreOp>())
|
|
gatherMemEntry(allocOp.getResult());
|
|
});
|
|
|
|
allocateGatheredMemory();
|
|
|
|
for (auto [alias, original] : globalAliases)
|
|
globalMemEntriesMap[alias] = getMemEntry(original);
|
|
}
|
|
|
|
void PimMemory::allocateCore(Operation* op) {
|
|
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
|
|
|
|
allocateGatheredMemory();
|
|
}
|
|
|
|
static void printHostMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
|
|
llvm::SmallVector<ReportField, 2> fields = {
|
|
{"Number of globals", std::to_string(row.numGlobal) },
|
|
{"Global memory", formatReportMemory(row.sizeGlobal)}
|
|
};
|
|
printReportFlatFields(os, fields);
|
|
}
|
|
|
|
static void printCoreMemoryReportRow(raw_ostream& os, const MemoryReportEntry& entry) {
|
|
llvm::SmallVector<ReportField, 2> fields = {
|
|
{"Number of allocas", std::to_string(entry.row.numAlloca) },
|
|
{"Allocated memory", formatReportMemory(entry.row.sizeAlloca)}
|
|
};
|
|
printReportFlatFields(os, fields);
|
|
}
|
|
|
|
static void printBatchMemoryReportRow(raw_ostream& os, const MemoryReportEntry& entry) {
|
|
llvm::SmallVector<ReportField, 2> perCoreFields = {
|
|
{"Number of allocas", std::to_string(entry.row.numAlloca) },
|
|
{"Allocated memory", formatReportMemory(entry.row.sizeAlloca)}
|
|
};
|
|
llvm::SmallVector<ReportField, 2> totalFields = {
|
|
{"Number of allocas", std::to_string(entry.totalAllocaCount) },
|
|
{"Batch memory", formatReportMemory(entry.totalAllocaBytes)}
|
|
};
|
|
printReportPerCoreAndTotalFields(os, perCoreFields, totalFields);
|
|
}
|
|
|
|
static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const MemoryReportRow& rhs) {
|
|
MemoryReportRow result = lhs;
|
|
result.numAlloca += rhs.numAlloca;
|
|
result.sizeAlloca += rhs.sizeAlloca;
|
|
result.numGlobal += rhs.numGlobal;
|
|
result.sizeGlobal += rhs.sizeGlobal;
|
|
return result;
|
|
}
|
|
|
|
MemoryReportRow PimMemory::getReportRow() const {
|
|
MemoryReportRow row;
|
|
for (auto& [val, memEntry] : ownedMemEntriesMap) {
|
|
if (auto op = val.getDefiningOp()) {
|
|
if (isa<memref::AllocOp>(op)) {
|
|
row.numAlloca++;
|
|
row.sizeAlloca += memEntry.size;
|
|
}
|
|
|
|
if (isa<memref::GetGlobalOp>(op)) {
|
|
row.numGlobal++;
|
|
row.sizeGlobal += memEntry.size;
|
|
}
|
|
}
|
|
}
|
|
return row;
|
|
}
|
|
|
|
void PimMemory::remove(mlir::Value val) {
|
|
if (auto removeIter = ownedMemEntriesMap.find(val); removeIter != ownedMemEntriesMap.end())
|
|
ownedMemEntriesMap.erase(removeIter);
|
|
if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end())
|
|
globalMemEntriesMap.erase(removeIter);
|
|
}
|
|
|
|
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
|
|
auto iter = globalMemEntriesMap.find(value);
|
|
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
|
|
return iter->second;
|
|
}
|
|
|
|
PimMemory& PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
|
|
return deviceMem.try_emplace(id, memEntriesMap).first->second;
|
|
}
|
|
|
|
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
|
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
|
|
if (failed(resolvedAddress)) {
|
|
errs() << "Failed to resolve contiguous address for value: ";
|
|
value.print(errs());
|
|
errs() << "\n";
|
|
if (auto* definingOp = value.getDefiningOp()) {
|
|
errs() << "Defining op:\n";
|
|
definingOp->print(errs());
|
|
errs() << "\n";
|
|
}
|
|
llvm_unreachable("Failed to resolve contiguous address");
|
|
}
|
|
|
|
auto iter = memEntriesMap.find(resolvedAddress->base);
|
|
if (iter == memEntriesMap.end()) {
|
|
errs() << "Missing mem entry for value: ";
|
|
resolvedAddress->base.print(errs());
|
|
errs() << "\n";
|
|
if (auto* definingOp = resolvedAddress->base.getDefiningOp()) {
|
|
errs() << "Defining op:\n";
|
|
definingOp->print(errs());
|
|
errs() << "\n";
|
|
}
|
|
llvm_unreachable("Missing mem entry");
|
|
}
|
|
|
|
return iter->second.address + resolvedAddress->byteOffset;
|
|
}
|
|
|
|
void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); }
|
|
|
|
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
|
|
reportEntries.push_back(
|
|
{MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row, row.numAlloca, row.sizeAlloca});
|
|
}
|
|
|
|
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId,
|
|
ArrayRef<int32_t> coreIds,
|
|
const MemoryReportRow& perCoreRow,
|
|
uint64_t totalAllocaCount,
|
|
uint64_t totalAllocaBytes) {
|
|
MemoryReportEntry entry;
|
|
entry.kind = MemoryReportEntry::Kind::Batch;
|
|
entry.id = batchId;
|
|
llvm::append_range(entry.coreIds, coreIds);
|
|
entry.row = perCoreRow;
|
|
entry.totalAllocaCount = totalAllocaCount;
|
|
entry.totalAllocaBytes = totalAllocaBytes;
|
|
reportEntries.push_back(std::move(entry));
|
|
}
|
|
|
|
void PimAcceleratorMemory::flushReport() {
|
|
if (!fileReport.is_open())
|
|
return;
|
|
|
|
llvm::raw_os_ostream os(fileReport);
|
|
uint64_t totalGlobalMemory = hostReportRow.has_value() ? hostReportRow->sizeGlobal : 0;
|
|
uint64_t totalCoresMemory = 0;
|
|
for (const MemoryReportEntry& entry : reportEntries)
|
|
totalCoresMemory += entry.totalAllocaBytes;
|
|
|
|
llvm::SmallVector<ReportField, 2> totalFields = {
|
|
{"Global memory", formatReportMemory(totalGlobalMemory)},
|
|
{"Cores memory", formatReportMemory(totalCoresMemory) }
|
|
};
|
|
printReportTotalsBlock(os, totalFields);
|
|
|
|
if (hostReportRow.has_value()) {
|
|
os << "\nHost:\n";
|
|
printHostMemoryReportRow(os, *hostReportRow);
|
|
}
|
|
|
|
if (!reportEntries.empty()) {
|
|
if (hostReportRow.has_value())
|
|
os << "\n";
|
|
sortReportEntriesByFirstCore(reportEntries);
|
|
|
|
for (size_t index = 0; index < reportEntries.size();) {
|
|
size_t runEnd = index + 1;
|
|
while (runEnd < reportEntries.size() && reportEntries[runEnd].kind == reportEntries[index].kind
|
|
&& reportEntries[runEnd].row == reportEntries[index].row
|
|
&& reportEntries[runEnd].totalAllocaCount == reportEntries[index].totalAllocaCount
|
|
&& reportEntries[runEnd].totalAllocaBytes == reportEntries[index].totalAllocaBytes) {
|
|
++runEnd;
|
|
}
|
|
|
|
if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch) {
|
|
os << "Batch ";
|
|
for (size_t batchIndex = index; batchIndex < runEnd; ++batchIndex) {
|
|
if (batchIndex != index)
|
|
os << ",\n ";
|
|
os << reportEntries[batchIndex].id << " (cores ";
|
|
printCompressedIntegerEntries(os, ArrayRef<int32_t>(reportEntries[batchIndex].coreIds));
|
|
os << ")";
|
|
}
|
|
}
|
|
else {
|
|
llvm::SmallVector<int32_t, 8> coreIds;
|
|
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
|
|
coreIds.push_back(reportEntries[coreIndex].coreIds.front());
|
|
os << "Core ";
|
|
printCompressedIntegerEntries(os, ArrayRef<int32_t>(coreIds));
|
|
}
|
|
os << ":\n";
|
|
if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch)
|
|
printBatchMemoryReportRow(os, reportEntries[index]);
|
|
else
|
|
printCoreMemoryReportRow(os, reportEntries[index]);
|
|
printReportEntrySeparator(os, runEnd < reportEntries.size());
|
|
|
|
index = runEnd;
|
|
}
|
|
}
|
|
|
|
os.flush();
|
|
fileReport.close();
|
|
}
|
|
|
|
void PimAcceleratorMemory::clean(mlir::Operation* op) {
|
|
for (auto value : op->getResults()) {
|
|
hostMem.remove(value);
|
|
for (auto& device : deviceMem)
|
|
device.second.remove(value);
|
|
}
|
|
}
|
|
|
|
size_t PimCodeGen::remapCoreId(size_t coreId) const {
|
|
auto it = emittedCoreIds.find(coreId);
|
|
assert(it != emittedCoreIds.end() && "Missing emitted core id remapping");
|
|
return it->second;
|
|
}
|
|
|
|
void PimCodeGen::emitInstruction(const pim_binary::InstructionRecord& instruction) const {
|
|
pim_binary::writeInstructionRecord(coreBinaryStream, instruction);
|
|
++emittedInstructionCount;
|
|
if (coreJsonStream)
|
|
*coreJsonStream << json::Value(pim_binary::makeInstructionJson(instruction)) << ',';
|
|
}
|
|
|
|
void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const {
|
|
pim_binary::InstructionRecord instruction;
|
|
instruction.opcode = pim_binary::Opcode::sldi;
|
|
instruction.rd = static_cast<uint8_t>(registerNumber);
|
|
instruction.r2OrImm = static_cast<int32_t>(immediate);
|
|
emitInstruction(instruction);
|
|
}
|
|
|
|
void PimCodeGen::setupRd(size_t rdAddress, size_t rdOffset) const {
|
|
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
|
|
}
|
|
|
|
void PimCodeGen::setupRdRs1(size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset) const {
|
|
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
|
|
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset);
|
|
}
|
|
|
|
void PimCodeGen::setupRdRs1Rs2(
|
|
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const {
|
|
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
|
|
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset);
|
|
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
|
|
}
|
|
|
|
void PimCodeGen::emitMemCopyOp(StringRef opName,
|
|
size_t rdAddr,
|
|
size_t rdOffset,
|
|
size_t rs1Addr,
|
|
size_t rs1Offset,
|
|
size_t size,
|
|
StringRef sizeFieldName) const {
|
|
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
|
|
|
|
pim_binary::InstructionRecord instruction;
|
|
instruction.opcode = pim_binary::opcodeFromString(opName);
|
|
instruction.rd = 0;
|
|
instruction.r1 = 1;
|
|
instruction.generic1 = 0;
|
|
instruction.generic2 = 0;
|
|
instruction.generic3 = static_cast<int32_t>(size);
|
|
(void) sizeFieldName;
|
|
emitInstruction(instruction);
|
|
}
|
|
|
|
void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const {
|
|
setupRd(bufferAddr, 0);
|
|
|
|
pim_binary::InstructionRecord instruction;
|
|
instruction.opcode = pim_binary::opcodeFromString(opName);
|
|
instruction.rd = 0;
|
|
instruction.r2OrImm = static_cast<int32_t>(remapCoreId(coreId));
|
|
instruction.generic1 = 0;
|
|
instruction.generic2 = 0;
|
|
instruction.generic3 = static_cast<int32_t>(size);
|
|
emitInstruction(instruction);
|
|
}
|
|
|
|
void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const {
|
|
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
|
|
|
|
pim_binary::InstructionRecord instruction;
|
|
instruction.opcode = pim_binary::Opcode::mvmul;
|
|
instruction.rd = 0;
|
|
instruction.r1 = 1;
|
|
instruction.r2OrImm = 8;
|
|
instruction.generic1 = 0;
|
|
instruction.generic2 = static_cast<int32_t>(groupId);
|
|
emitInstruction(instruction);
|
|
}
|
|
|
|
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const {
|
|
auto deviceTargetOffset = resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge);
|
|
auto hostSourceOffset = resolveIndexValue(loadOp.getHostSourceOffset(), knowledge);
|
|
assert(succeeded(deviceTargetOffset) && succeeded(hostSourceOffset)
|
|
&& "pim.memcp_hd offsets must be statically resolvable during codegen");
|
|
emitMemCopyOp("ld",
|
|
addressOf(loadOp.getDeviceTarget(), knowledge),
|
|
*deviceTargetOffset,
|
|
addressOf(loadOp.getHostSource(), knowledge),
|
|
*hostSourceOffset,
|
|
loadOp.getSize());
|
|
}
|
|
|
|
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
|
|
auto hostTargetOffset = resolveIndexValue(storeOp.getHostTargetOffset(), knowledge);
|
|
auto deviceSourceOffset = resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge);
|
|
assert(succeeded(hostTargetOffset) && succeeded(deviceSourceOffset)
|
|
&& "pim.memcp_dh offsets must be statically resolvable during codegen");
|
|
emitMemCopyOp("st",
|
|
addressOf(storeOp.getHostTarget(), knowledge),
|
|
*hostTargetOffset,
|
|
addressOf(storeOp.getDeviceSource(), knowledge),
|
|
*deviceSourceOffset,
|
|
storeOp.getSize());
|
|
}
|
|
|
|
void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const {
|
|
emitMemCopyOp("lmv",
|
|
addressOf(lmvOp.getTarget(), knowledge),
|
|
lmvOp.getTargetOffset(),
|
|
addressOf(lmvOp.getSource(), knowledge),
|
|
lmvOp.getSourceOffset(),
|
|
lmvOp.getSize(),
|
|
"len");
|
|
}
|
|
|
|
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const {
|
|
auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge);
|
|
assert(succeeded(sourceCoreId) && "pim.receive source core id must be statically resolvable during codegen");
|
|
emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize());
|
|
}
|
|
|
|
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
|
const StaticValueKnowledge& knowledge) const {
|
|
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
|
|
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(receiveTensorOp.getOutputBuffer().getType()))
|
|
/ receiveTensorOp.getSourceCoreIds().size();
|
|
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
|
|
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
|
|
}
|
|
|
|
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
|
auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge);
|
|
assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen");
|
|
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize());
|
|
}
|
|
|
|
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
|
|
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
|
|
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(sendTensorOp.getInput().getType()))
|
|
/ sendTensorOp.getTargetCoreIds().size();
|
|
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
|
|
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
|
|
}
|
|
|
|
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
|
|
auto outputType = cast<ShapedType>(concatOp.getOutputBuffer().getType());
|
|
assert(outputType.hasStaticShape() && "concat codegen requires static output shape");
|
|
|
|
int64_t axis = concatOp.getAxis();
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
size_t elementSize = getElementTypeSizeInBytes(outputType.getElementType());
|
|
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
|
|
|
|
size_t outerCount = 1;
|
|
for (int64_t dim = 0; dim < axis; ++dim)
|
|
outerCount *= static_cast<size_t>(outputShape[dim]);
|
|
|
|
size_t innerCount = 1;
|
|
for (size_t dim = static_cast<size_t>(axis) + 1; dim < outputShape.size(); ++dim)
|
|
innerCount *= static_cast<size_t>(outputShape[dim]);
|
|
|
|
size_t outputConcatDim = static_cast<size_t>(outputShape[axis]);
|
|
size_t concatOffset = 0;
|
|
for (mlir::Value input : concatOp.getInputs()) {
|
|
auto inputType = cast<ShapedType>(input.getType());
|
|
assert(inputType.hasStaticShape() && "concat codegen requires static input shapes");
|
|
|
|
size_t inputConcatDim = static_cast<size_t>(inputType.getDimSize(axis));
|
|
size_t blockSizeInBytes = inputConcatDim * innerCount * elementSize;
|
|
size_t inputAddr = addressOf(input, knowledge);
|
|
|
|
for (size_t outerIndex = 0; outerIndex < outerCount; ++outerIndex) {
|
|
size_t dstOffset = (outerIndex * outputConcatDim + concatOffset) * innerCount * elementSize;
|
|
size_t srcOffset = outerIndex * inputConcatDim * innerCount * elementSize;
|
|
emitMemCopyOp("lmv", outputAddr, dstOffset, inputAddr, srcOffset, blockSizeInBytes, "len");
|
|
}
|
|
|
|
concatOffset += inputConcatDim;
|
|
}
|
|
}
|
|
|
|
template <typename MVMTy>
|
|
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
|
|
MVMTy mvmLikeOp,
|
|
bool transposeMatrix,
|
|
const StaticValueKnowledge& knowledge) {
|
|
emitMvmOp(mvmId, addressOf(mvmLikeOp.getOutputBuffer(), knowledge), 0, addressOf(mvmLikeOp.getInput(), knowledge), 0);
|
|
|
|
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
|
|
}
|
|
|
|
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const {
|
|
auto outputBufferAddr = addressOf(vvaddOp.getOutputBuffer(), knowledge);
|
|
auto lhsAddr = addressOf(vvaddOp.getLhs(), knowledge);
|
|
auto rhsAddr = addressOf(vvaddOp.getRhs(), knowledge);
|
|
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
|
|
|
pim_binary::InstructionRecord instruction;
|
|
instruction.opcode = pim_binary::Opcode::vvadd;
|
|
instruction.rd = 0;
|
|
instruction.r1 = 1;
|
|
instruction.r2OrImm = 2;
|
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvaddOp.getLhs().getType())));
|
|
emitInstruction(instruction);
|
|
}
|
|
|
|
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowledge& knowledge) const {
|
|
auto outputBufferAddr = addressOf(vvsubOp.getOutputBuffer(), knowledge);
|
|
auto lhsAddr = addressOf(vvsubOp.getLhs(), knowledge);
|
|
auto rhsAddr = addressOf(vvsubOp.getRhs(), knowledge);
|
|
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
|
|
|
pim_binary::InstructionRecord instruction;
|
|
instruction.opcode = pim_binary::Opcode::vvsub;
|
|
instruction.rd = 0;
|
|
instruction.r1 = 1;
|
|
instruction.r2OrImm = 2;
|
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvsubOp.getLhs().getType())));
|
|
emitInstruction(instruction);
|
|
}
|
|
|
|
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowledge& knowledge) const {
|
|
auto outputBufferAddr = addressOf(vvmulOp.getOutputBuffer(), knowledge);
|
|
auto lhsAddr = addressOf(vvmulOp.getLhs(), knowledge);
|
|
auto rhsAddr = addressOf(vvmulOp.getRhs(), knowledge);
|
|
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
|
|
|
pim_binary::InstructionRecord instruction;
|
|
instruction.opcode = pim_binary::Opcode::vvmul;
|
|
instruction.rd = 0;
|
|
instruction.r1 = 1;
|
|
instruction.r2OrImm = 2;
|
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmulOp.getLhs().getType())));
|
|
emitInstruction(instruction);
|
|
}
|
|
|
|
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowledge& knowledge) const {
|
|
auto outputBufferAddr = addressOf(vvmaxOp.getOutputBuffer(), knowledge);
|
|
auto lhsAddr = addressOf(vvmaxOp.getLhs(), knowledge);
|
|
auto rhsAddr = addressOf(vvmaxOp.getRhs(), knowledge);
|
|
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
|
|
|
pim_binary::InstructionRecord instruction;
|
|
instruction.opcode = pim_binary::Opcode::vvmax;
|
|
instruction.rd = 0;
|
|
instruction.r1 = 1;
|
|
instruction.r2OrImm = 2;
|
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmaxOp.getLhs().getType())));
|
|
emitInstruction(instruction);
|
|
}
|
|
|
|
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKnowledge& knowledge) const {
|
|
auto outputBufferAddr = addressOf(vvdmulOp.getOutputBuffer(), knowledge);
|
|
auto lhsAddr = addressOf(vvdmulOp.getLhs(), knowledge);
|
|
auto rhsAddr = addressOf(vvdmulOp.getRhs(), knowledge);
|
|
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
|
|
|
pim_binary::InstructionRecord instruction;
|
|
instruction.opcode = pim_binary::Opcode::vvdmul;
|
|
instruction.rd = 0;
|
|
instruction.r1 = 1;
|
|
instruction.r2OrImm = 2;
|
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvdmulOp.getLhs().getType())));
|
|
emitInstruction(instruction);
|
|
}
|
|
|
|
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge& knowledge) const {
|
|
auto outputBufferAddr = addressOf(vavgOp.getOutputBuffer(), knowledge);
|
|
auto inputAddr = addressOf(vavgOp.getInput(), knowledge);
|
|
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
|
|
|
pim_binary::InstructionRecord instruction;
|
|
instruction.opcode = pim_binary::Opcode::vavg;
|
|
instruction.rd = 0;
|
|
instruction.r1 = 1;
|
|
instruction.r2OrImm = 1;
|
|
instruction.generic1 = 1;
|
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vavgOp.getInput().getType())));
|
|
emitInstruction(instruction);
|
|
}
|
|
|
|
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowledge& knowledge) const {
|
|
auto outputBufferAddr = addressOf(vreluOp.getOutputBuffer(), knowledge);
|
|
auto inputAddr = addressOf(vreluOp.getInput(), knowledge);
|
|
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
|
|
|
pim_binary::InstructionRecord instruction;
|
|
instruction.opcode = pim_binary::Opcode::vrelu;
|
|
instruction.rd = 0;
|
|
instruction.r1 = 1;
|
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vreluOp.getInput().getType())));
|
|
emitInstruction(instruction);
|
|
}
|
|
|
|
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const {
|
|
auto outputBufferAddr = addressOf(vtanhOp.getOutputBuffer(), knowledge);
|
|
auto inputAddr = addressOf(vtanhOp.getInput(), knowledge);
|
|
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
|
|
|
pim_binary::InstructionRecord instruction;
|
|
instruction.opcode = pim_binary::Opcode::vtanh;
|
|
instruction.rd = 0;
|
|
instruction.r1 = 1;
|
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vtanhOp.getInput().getType())));
|
|
emitInstruction(instruction);
|
|
}
|
|
|
|
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const {
|
|
auto outputBufferAddr = addressOf(vsigmOp.getOutputBuffer(), knowledge);
|
|
auto inputAddr = addressOf(vsigmOp.getInput(), knowledge);
|
|
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
|
|
|
pim_binary::InstructionRecord instruction;
|
|
instruction.opcode = pim_binary::Opcode::vsigm;
|
|
instruction.rd = 0;
|
|
instruction.r1 = 1;
|
|
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsigmOp.getInput().getType())));
|
|
emitInstruction(instruction);
|
|
}
|
|
|
|
void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const {
|
|
auto outputBufferAddr = addressOf(vsoftmaxOp.getOutputBuffer(), knowledge);
|
|
auto inputAddr = addressOf(vsoftmaxOp.getInput(), knowledge);
|
|
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
|
|
|
pim_binary::InstructionRecord instruction;
|
|
instruction.opcode = pim_binary::Opcode::vsoftmax;
|
|
instruction.rd = 0;
|
|
instruction.r1 = 1;
|
|
instruction.generic3 =
|
|
static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsoftmaxOp.getInput().getType())));
|
|
emitInstruction(instruction);
|
|
}
|
|
|
|
void PimCodeGen::codeGetGlobalOp(memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const {}
|
|
|
|
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
|
|
auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
|
|
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
|
|
|
|
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
|
|
auto srcShape = srcType.getShape();
|
|
size_t rank = srcShape.size();
|
|
size_t elementSize = getElementTypeSizeInBytes(srcType.getElementType());
|
|
size_t totalElements = srcType.getNumElements();
|
|
|
|
// Read permutation. Destination dim i corresponds to source dim perm[i].
|
|
SmallVector<int64_t> perm = map_to_vector(transposeOp.getPermutation().getAsRange<IntegerAttr>(),
|
|
[](auto attr) -> int64_t { return attr.getInt(); });
|
|
|
|
// Destination shape: dstShape[i] = srcShape[perm[i]]
|
|
SmallVector<int64_t> dstShape(rank);
|
|
for (size_t i = 0; i < rank; i++)
|
|
dstShape[i] = srcShape[perm[i]];
|
|
|
|
// Row-major strides for source and destination
|
|
SmallVector<size_t> srcStrides(rank, 1);
|
|
SmallVector<size_t> dstStrides(rank, 1);
|
|
for (int64_t i = rank - 2; i >= 0; i--) {
|
|
srcStrides[i] = srcStrides[i + 1] * srcShape[i + 1];
|
|
dstStrides[i] = dstStrides[i + 1] * dstShape[i + 1];
|
|
}
|
|
|
|
bool storagePreserving = true;
|
|
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
|
|
SmallVector<size_t> srcIdx(rank);
|
|
size_t remaining = srcFlat;
|
|
for (size_t d = 0; d < rank; d++) {
|
|
srcIdx[d] = remaining / srcStrides[d];
|
|
remaining %= srcStrides[d];
|
|
}
|
|
|
|
size_t dstFlat = 0;
|
|
for (size_t d = 0; d < rank; d++)
|
|
dstFlat += srcIdx[perm[d]] * dstStrides[d];
|
|
|
|
if (dstFlat != srcFlat) {
|
|
storagePreserving = false;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (storagePreserving) {
|
|
emitMemCopyOp("lmv", dstAddr, 0, srcAddr, 0, totalElements * elementSize, "len");
|
|
return;
|
|
}
|
|
|
|
// Emit element-by-element copy with transposed addressing
|
|
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
|
|
// Decompose flat source index into multi-dimensional index
|
|
SmallVector<size_t> srcIdx(rank);
|
|
size_t remaining = srcFlat;
|
|
for (size_t d = 0; d < rank; d++) {
|
|
srcIdx[d] = remaining / srcStrides[d];
|
|
remaining %= srcStrides[d];
|
|
}
|
|
|
|
// Compute flat destination index: dstIdx[d] = srcIdx[perm[d]]
|
|
size_t dstFlat = 0;
|
|
for (size_t d = 0; d < rank; d++)
|
|
dstFlat += srcIdx[perm[d]] * dstStrides[d];
|
|
|
|
emitMemCopyOp("lmv", dstAddr, dstFlat * elementSize, srcAddr, srcFlat * elementSize, elementSize, "len");
|
|
}
|
|
}
|
|
|
|
size_t getMatrixSize(ShapedType matrixShape) {
|
|
if (matrixShape.getRank() != 2 && matrixShape.getRank() != 4)
|
|
assert(false && "Unsupported matrix shape");
|
|
return std::max(matrixShape.getDimSize(0), matrixShape.getDimSize(1));
|
|
}
|
|
|
|
std::string getMemorySizeAsString(size_t size) {
|
|
if (size > 1024 * 1024 * 1024)
|
|
return std::to_string(size / 1024 / 1024 / 1024) + " GB";
|
|
if (size > 1024 * 1024)
|
|
return std::to_string(size / 1024 / 1024) + " MB";
|
|
if (size > 1024)
|
|
return std::to_string(size / 1024) + " KB";
|
|
return std::to_string(size) + " Bytes";
|
|
}
|
|
|
|
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
|
SmallVector<unsigned, 8> indices;
|
|
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
|
|
auto addWeight = [&](mlir::Value weight) {
|
|
if (!coreOp)
|
|
return;
|
|
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
|
|
if (coreOp.getWeightArgument(weightIndex) != weight)
|
|
continue;
|
|
if (!llvm::is_contained(indices, weightIndex))
|
|
indices.push_back(weightIndex);
|
|
return;
|
|
}
|
|
};
|
|
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
|
|
llvm::sort(indices);
|
|
return indices;
|
|
}
|
|
|
|
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
|
|
return getUsedWeightIndices(coreOp.getBody().front());
|
|
}
|
|
|
|
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
|
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
|
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
|
|
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
|
}
|
|
|
|
static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
|
|
SmallVector<Operation*> coreLikeOps;
|
|
for (Operation& op : funcOp.getBody().front())
|
|
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
|
|
coreLikeOps.push_back(&op);
|
|
return coreLikeOps;
|
|
}
|
|
|
|
static SmallDenseMap<memref::GlobalOp, MemEntry, 16>
|
|
collectMaterializedHostGlobals(ModuleOp moduleOp, func::FuncOp funcOp, const PimAcceleratorMemory& memory) {
|
|
SmallDenseMap<memref::GlobalOp, MemEntry, 16> materializedHostGlobals;
|
|
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
|
if (hasWeightAlways(getGlobalOp))
|
|
return;
|
|
auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
|
if (!targetGlobal || materializedHostGlobals.contains(targetGlobal))
|
|
return;
|
|
auto it = memory.memEntriesMap.find(getGlobalOp.getResult());
|
|
if (it != memory.memEntriesMap.end())
|
|
materializedHostGlobals[targetGlobal] = it->second;
|
|
});
|
|
return materializedHostGlobals;
|
|
}
|
|
|
|
static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
|
|
pim::PimCoreOp coreOp,
|
|
const SmallDenseMap<memref::GlobalOp, MemEntry, 16>& materializedHostGlobals,
|
|
PimAcceleratorMemory& memory) {
|
|
coreOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
|
if (hasWeightAlways(getGlobalOp) || memory.memEntriesMap.contains(getGlobalOp.getResult()))
|
|
return;
|
|
|
|
auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
|
if (!targetGlobal)
|
|
return;
|
|
|
|
auto it = materializedHostGlobals.find(targetGlobal);
|
|
if (it != materializedHostGlobals.end())
|
|
memory.memEntriesMap[getGlobalOp.getResult()] = it->second;
|
|
});
|
|
}
|
|
|
|
/// Dispatch all operations in a core region to the appropriate code generator.
|
|
/// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is
|
|
/// fully resolved before the JSON instructions are emitted.
|
|
/// Returns the number of emitted instructions, or -1 on failure.
|
|
static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
|
auto resolveWeightIndex = [&](pim::PimVMMOp vmmOp) -> std::optional<unsigned> {
|
|
auto coreOp = vmmOp->getParentOfType<pim::PimCoreOp>();
|
|
if (!coreOp)
|
|
return std::nullopt;
|
|
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
|
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
|
|
return weightIndex;
|
|
return std::nullopt;
|
|
};
|
|
size_t processedOperations = 0;
|
|
auto result =
|
|
walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
|
if (auto loadOp = dyn_cast<pim::PimMemCopyHostToDevOp>(op))
|
|
coreCodeGen.codeGenLoadOp(loadOp, knowledge);
|
|
else if (auto storeOp = dyn_cast<pim::PimMemCopyDevToHostOp>(op))
|
|
coreCodeGen.codeGenStoreOp(storeOp, knowledge);
|
|
else if (auto lmvOp = dyn_cast<pim::PimMemCopyOp>(op))
|
|
coreCodeGen.codeGenLmvOp(lmvOp, knowledge);
|
|
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
|
|
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
|
|
else if (auto receiveTensorOp = dyn_cast<pim::PimReceiveTensorOp>(op))
|
|
coreCodeGen.codeGenReceiveTensorOp(receiveTensorOp, knowledge);
|
|
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
|
|
coreCodeGen.codeGenSendOp(sendOp, knowledge);
|
|
else if (auto sendTensorOp = dyn_cast<pim::PimSendTensorOp>(op))
|
|
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
|
|
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
|
|
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
|
|
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) {
|
|
auto weightIndex = resolveWeightIndex(vmmOp);
|
|
if (!weightIndex)
|
|
return failure();
|
|
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightIndex, vmmOp, true, knowledge);
|
|
}
|
|
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
|
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
|
|
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
|
coreCodeGen.codeGenVVAddOp(vvaddOp, knowledge);
|
|
else if (auto vvsubOp = dyn_cast<pim::PimVVSubOp>(op))
|
|
coreCodeGen.codeGenVVSubOp(vvsubOp, knowledge);
|
|
else if (auto vvmulOp = dyn_cast<pim::PimVVMulOp>(op))
|
|
coreCodeGen.codeGenVVMulOp(vvmulOp, knowledge);
|
|
else if (auto vvmaxOp = dyn_cast<pim::PimVVMaxOp>(op))
|
|
coreCodeGen.codeGenVVMaxOp(vvmaxOp, knowledge);
|
|
else if (auto vvdmulOp = dyn_cast<pim::PimVVDMulOp>(op))
|
|
coreCodeGen.codeGenVVDMulOp(vvdmulOp, knowledge);
|
|
else if (auto vavgOp = dyn_cast<pim::PimVAvgOp>(op))
|
|
coreCodeGen.codeGenVAvgOp(vavgOp, knowledge);
|
|
else if (auto vreluOp = dyn_cast<pim::PimVReluOp>(op))
|
|
coreCodeGen.codeGenVReluOp(vreluOp, knowledge);
|
|
else if (auto vtanhOp = dyn_cast<pim::PimVTanhOp>(op))
|
|
coreCodeGen.codeGenVTanhOp(vtanhOp, knowledge);
|
|
else if (auto vsigmOp = dyn_cast<pim::PimVSigmOp>(op))
|
|
coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge);
|
|
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
|
|
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
|
|
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
|
|
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
|
|
else {
|
|
InFlightDiagnostic diag = op.emitError()
|
|
<< "unsupported codegen for op '" << op.getName().getStringRef() << "'";
|
|
if (auto coreOp = op.getParentOfType<pim::PimCoreOp>())
|
|
diag << " inside pim.core " << coreOp.getCoreId();
|
|
else if (auto coreBatchOp = op.getParentOfType<pim::PimCoreBatchOp>())
|
|
diag << " inside pim.core_batch with laneCount " << coreBatchOp.getLaneCount();
|
|
return failure();
|
|
}
|
|
processedOperations++;
|
|
return success();
|
|
});
|
|
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
|
|
}
|
|
|
|
OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::string& outputDirPath) {
|
|
if (!outputDirPath.empty()) {
|
|
if (auto error = sys::fs::create_directory(outputDirPath)) {
|
|
errs() << "Error creating output directory: " << outputDirPath << ": " << error.message() << '\n';
|
|
return InvalidOutputFileAccess;
|
|
}
|
|
}
|
|
|
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
|
if (failed(entryFunc))
|
|
return CompilerFailure;
|
|
auto funcOp = *entryFunc;
|
|
|
|
PimAcceleratorMemory memory;
|
|
memory.hostMem.allocateHost(moduleOp, funcOp);
|
|
memory.reportHost();
|
|
|
|
if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath))
|
|
return err;
|
|
|
|
json::Object xbarsPerArrayGroup;
|
|
size_t maxCoreId = 0;
|
|
uint64_t nextBatchReportId = 0;
|
|
|
|
// Create Weight Folder
|
|
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
|
|
|
|
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
|
SmallDenseMap<memref::GlobalOp, MemEntry, 16> materializedHostGlobals =
|
|
collectMaterializedHostGlobals(moduleOp, funcOp, memory);
|
|
llvm::DenseMap<size_t, size_t> emittedCoreIds;
|
|
size_t nextEmittedCoreId = 0;
|
|
|
|
for (Operation* op : coreLikeOps) {
|
|
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
|
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
|
|
if (!emittedCoreIds.contains(originalCoreId))
|
|
emittedCoreIds[originalCoreId] = nextEmittedCoreId++;
|
|
continue;
|
|
}
|
|
|
|
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
|
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
|
|
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
|
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
|
if (!emittedCoreIds.contains(originalCoreId))
|
|
emittedCoreIds[originalCoreId] = nextEmittedCoreId++;
|
|
}
|
|
}
|
|
|
|
for (Operation* op : coreLikeOps) {
|
|
auto emitCore = [&](pim::PimCoreOp coreOp,
|
|
bool temporaryCore,
|
|
MemoryReportRow* reportRow = nullptr) -> OnnxMlirCompilerErrorCodes {
|
|
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
|
|
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
|
maxCoreId = std::max(maxCoreId, coreId);
|
|
|
|
std::error_code errorCode;
|
|
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".pim";
|
|
raw_fd_ostream coreBinaryStream(outputCorePath, errorCode, sys::fs::OF_None);
|
|
if (errorCode) {
|
|
errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n';
|
|
return InvalidOutputFileAccess;
|
|
}
|
|
|
|
std::unique_ptr<raw_fd_ostream> coreJsonStream;
|
|
if (pimEmitJson.getValue()) {
|
|
std::string outputCoreJsonPath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
|
|
errorCode = std::error_code();
|
|
coreJsonStream = std::make_unique<raw_fd_ostream>(outputCoreJsonPath, errorCode);
|
|
if (errorCode) {
|
|
errs() << "Error while opening core json file `" << outputCoreJsonPath << "`: " << errorCode.message()
|
|
<< '\n';
|
|
return InvalidOutputFileAccess;
|
|
}
|
|
*coreJsonStream << '[';
|
|
}
|
|
|
|
pim_binary::writeHeader(coreBinaryStream);
|
|
|
|
PimCodeGen coreCodeGen(memory, coreBinaryStream, coreJsonStream.get(), emittedCoreIds);
|
|
aliasMaterializedHostGlobals(moduleOp, coreOp, materializedHostGlobals, memory);
|
|
auto& deviceMemory = memory.getOrCreateDeviceMem(coreId);
|
|
deviceMemory.allocateCore(coreOp);
|
|
|
|
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
|
|
if (processedOperations < 0)
|
|
return CompilerFailure;
|
|
assert(processedOperations > 0);
|
|
|
|
if (reportRow)
|
|
*reportRow = deviceMemory.getReportRow();
|
|
|
|
pim_binary::patchInstructionCount(coreBinaryStream, coreCodeGen.getEmittedInstructionCount());
|
|
coreBinaryStream.close();
|
|
|
|
if (coreJsonStream) {
|
|
coreJsonStream->seek(coreJsonStream->tell() - 1);
|
|
*coreJsonStream << ']';
|
|
coreJsonStream->close();
|
|
}
|
|
|
|
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
|
|
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
|
|
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
|
|
return InvalidOutputFileAccess;
|
|
}
|
|
|
|
auto& mapWeightToFile = mapCoreWeightToFileName[originalCoreId];
|
|
json::Array xbarsPerGroup;
|
|
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
|
if (index >= coreOp.getWeights().size()) {
|
|
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
|
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
|
}
|
|
mlir::Value weight = coreOp.getWeights()[index];
|
|
xbarsPerGroup.push_back(index);
|
|
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
|
|
auto& fileName = mapWeightToFile[weight];
|
|
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
|
|
coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) {
|
|
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
|
|
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")
|
|
<< "\nError:" << error.message() << '\n';
|
|
return InvalidOutputFileAccess;
|
|
}
|
|
}
|
|
|
|
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
|
|
if (temporaryCore)
|
|
coreOp.walk([&memory](Operation* op) { memory.clean(op); });
|
|
return CompilerSuccess;
|
|
};
|
|
|
|
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
|
MemoryReportRow coreRow;
|
|
if (auto err = emitCore(coreOp, false, &coreRow))
|
|
return err;
|
|
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())), coreRow);
|
|
continue;
|
|
}
|
|
|
|
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
|
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
|
|
SmallVector<int32_t> reportedCoreIds;
|
|
reportedCoreIds.reserve(batchCoreIds.size());
|
|
MemoryReportRow batchRow;
|
|
std::optional<MemoryReportRow> batchPerCoreRow;
|
|
llvm::DenseMap<size_t, SmallVector<unsigned>> lanesByCoreId;
|
|
SmallVector<size_t> orderedOriginalCoreIds;
|
|
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
|
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
|
auto [it, inserted] = lanesByCoreId.try_emplace(originalCoreId);
|
|
if (inserted)
|
|
orderedOriginalCoreIds.push_back(originalCoreId);
|
|
it->second.push_back(lane);
|
|
}
|
|
|
|
for (size_t originalCoreId : orderedOriginalCoreIds) {
|
|
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
|
|
if (failed(withScalarCoreFromBatchLanes(coreBatchOp, lanesByCoreId[originalCoreId], [&](pim::PimCoreOp coreOp) {
|
|
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
|
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
|
|
MemoryReportRow laneRow;
|
|
laneResult = emitCore(coreOp, true, &laneRow);
|
|
if (laneResult == CompilerSuccess) {
|
|
if (!batchPerCoreRow.has_value())
|
|
batchPerCoreRow = laneRow;
|
|
batchRow = addMemoryReportRows(batchRow, laneRow);
|
|
}
|
|
return laneResult == CompilerSuccess ? success() : failure();
|
|
})))
|
|
return laneResult == CompilerSuccess ? CompilerFailure : laneResult;
|
|
}
|
|
memory.recordBatchReport(nextBatchReportId++,
|
|
reportedCoreIds,
|
|
batchPerCoreRow.value_or(MemoryReportRow {}),
|
|
batchRow.numAlloca,
|
|
batchRow.sizeAlloca);
|
|
}
|
|
|
|
memory.flushReport();
|
|
return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath);
|
|
}
|