375 lines
11 KiB
C++
375 lines
11 KiB
C++
#pragma once
|
|
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include "llvm/Support/Endian.h"
|
|
#include "llvm/Support/JSON.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
#include <array>
|
|
#include <cassert>
|
|
#include <limits>
|
|
|
|
namespace onnx_mlir::pim_binary {
|
|
|
|
inline constexpr char kMagic[4] = {'P', 'I', 'M', 'B'};
|
|
inline constexpr uint32_t kVersion = 1;
|
|
inline constexpr uint64_t kCountOffset = 8;
|
|
inline constexpr size_t kHeaderSize = 12;
|
|
inline constexpr size_t kRecordSize = 20;
|
|
|
|
enum class Opcode : uint32_t {
|
|
nop = 0,
|
|
sldi = 1,
|
|
sld = 2,
|
|
sadd = 3,
|
|
ssub = 4,
|
|
smul = 5,
|
|
saddi = 6,
|
|
smuli = 7,
|
|
setbw = 8,
|
|
mvmul = 9,
|
|
vvadd = 10,
|
|
vvsub = 11,
|
|
vvmul = 12,
|
|
vvdmul = 13,
|
|
vvmax = 14,
|
|
vvsll = 15,
|
|
vvsra = 16,
|
|
vavg = 17,
|
|
vrelu = 18,
|
|
vtanh = 19,
|
|
vsigm = 20,
|
|
vsoftmax = 21,
|
|
vmv = 22,
|
|
vrsu = 23,
|
|
vrsl = 24,
|
|
ld = 25,
|
|
st = 26,
|
|
lldi = 27,
|
|
lmv = 28,
|
|
send = 29,
|
|
recv = 30,
|
|
wait = 31,
|
|
sync = 32,
|
|
};
|
|
|
|
struct InstructionRecord {
|
|
Opcode opcode = Opcode::nop;
|
|
uint8_t rd = 0;
|
|
uint8_t r1 = 0;
|
|
int32_t r2OrImm = 0;
|
|
int32_t generic1 = 0;
|
|
int32_t generic2 = 0;
|
|
int32_t generic3 = 0;
|
|
uint8_t flags = 0;
|
|
};
|
|
|
|
inline void writeUint32LE(llvm::raw_ostream& os, uint32_t value) {
|
|
std::array<char, sizeof(uint32_t)> bytes;
|
|
llvm::support::endian::write32le(bytes.data(), value);
|
|
os.write(bytes.data(), bytes.size());
|
|
}
|
|
|
|
inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { writeUint32LE(os, static_cast<uint32_t>(value)); }
|
|
|
|
inline void writeHeader(llvm::raw_ostream& os) {
|
|
os.write(kMagic, sizeof(kMagic));
|
|
writeUint32LE(os, kVersion);
|
|
writeUint32LE(os, 0);
|
|
}
|
|
|
|
inline void patchInstructionCount(llvm::raw_pwrite_stream& os, uint32_t instructionCount) {
|
|
std::array<char, sizeof(uint32_t)> bytes;
|
|
llvm::support::endian::write32le(bytes.data(), instructionCount);
|
|
os.pwrite(bytes.data(), bytes.size(), kCountOffset);
|
|
}
|
|
|
|
inline void writeInstructionRecord(llvm::raw_ostream& os, const InstructionRecord& record) {
|
|
os << static_cast<char>(static_cast<uint8_t>(record.opcode));
|
|
os << static_cast<char>(record.rd);
|
|
os << static_cast<char>(record.r1);
|
|
os << static_cast<char>(record.flags);
|
|
writeInt32LE(os, record.r2OrImm);
|
|
writeInt32LE(os, record.generic1);
|
|
writeInt32LE(os, record.generic2);
|
|
writeInt32LE(os, record.generic3);
|
|
}
|
|
|
|
inline int32_t toI32(int64_t value) {
|
|
assert(value >= std::numeric_limits<int32_t>::min() && value <= std::numeric_limits<int32_t>::max()
|
|
&& "PIM binary field out of int32 range");
|
|
return static_cast<int32_t>(value);
|
|
}
|
|
|
|
inline uint8_t toU8(int64_t value) {
|
|
assert(value >= 0 && value <= std::numeric_limits<uint8_t>::max() && "PIM binary field out of uint8 range");
|
|
return static_cast<uint8_t>(value);
|
|
}
|
|
|
|
inline int32_t getOptionalInt(const llvm::json::Object& object, llvm::StringRef key, int32_t defaultValue = 0) {
|
|
if (std::optional<int64_t> value = object.getInteger(key))
|
|
return toI32(*value);
|
|
return defaultValue;
|
|
}
|
|
|
|
inline Opcode opcodeFromString(llvm::StringRef opName) {
|
|
if (opName == "nop")
|
|
return Opcode::nop;
|
|
if (opName == "sldi")
|
|
return Opcode::sldi;
|
|
if (opName == "sld")
|
|
return Opcode::sld;
|
|
if (opName == "sadd")
|
|
return Opcode::sadd;
|
|
if (opName == "ssub")
|
|
return Opcode::ssub;
|
|
if (opName == "smul")
|
|
return Opcode::smul;
|
|
if (opName == "saddi")
|
|
return Opcode::saddi;
|
|
if (opName == "smuli")
|
|
return Opcode::smuli;
|
|
if (opName == "setbw")
|
|
return Opcode::setbw;
|
|
if (opName == "mvmul")
|
|
return Opcode::mvmul;
|
|
if (opName == "vvadd")
|
|
return Opcode::vvadd;
|
|
if (opName == "vvsub")
|
|
return Opcode::vvsub;
|
|
if (opName == "vvmul")
|
|
return Opcode::vvmul;
|
|
if (opName == "vvdmul")
|
|
return Opcode::vvdmul;
|
|
if (opName == "vvmax")
|
|
return Opcode::vvmax;
|
|
if (opName == "vvsll")
|
|
return Opcode::vvsll;
|
|
if (opName == "vvsra")
|
|
return Opcode::vvsra;
|
|
if (opName == "vavg")
|
|
return Opcode::vavg;
|
|
if (opName == "vrelu")
|
|
return Opcode::vrelu;
|
|
if (opName == "vtanh")
|
|
return Opcode::vtanh;
|
|
if (opName == "vsigm")
|
|
return Opcode::vsigm;
|
|
if (opName == "vsoftmax")
|
|
return Opcode::vsoftmax;
|
|
if (opName == "vmv")
|
|
return Opcode::vmv;
|
|
if (opName == "vrsu")
|
|
return Opcode::vrsu;
|
|
if (opName == "vrsl")
|
|
return Opcode::vrsl;
|
|
if (opName == "ld")
|
|
return Opcode::ld;
|
|
if (opName == "st")
|
|
return Opcode::st;
|
|
if (opName == "lldi")
|
|
return Opcode::lldi;
|
|
if (opName == "lmv")
|
|
return Opcode::lmv;
|
|
if (opName == "send")
|
|
return Opcode::send;
|
|
if (opName == "recv")
|
|
return Opcode::recv;
|
|
if (opName == "wait")
|
|
return Opcode::wait;
|
|
if (opName == "sync")
|
|
return Opcode::sync;
|
|
llvm_unreachable("Unsupported PIM binary opcode");
|
|
}
|
|
|
|
inline llvm::StringRef opcodeToString(Opcode opcode) {
|
|
switch (opcode) {
|
|
case Opcode::nop: return "nop";
|
|
case Opcode::sldi: return "sldi";
|
|
case Opcode::sld: return "sld";
|
|
case Opcode::sadd: return "sadd";
|
|
case Opcode::ssub: return "ssub";
|
|
case Opcode::smul: return "smul";
|
|
case Opcode::saddi: return "saddi";
|
|
case Opcode::smuli: return "smuli";
|
|
case Opcode::setbw: return "setbw";
|
|
case Opcode::mvmul: return "mvmul";
|
|
case Opcode::vvadd: return "vvadd";
|
|
case Opcode::vvsub: return "vvsub";
|
|
case Opcode::vvmul: return "vvmul";
|
|
case Opcode::vvdmul: return "vvdmul";
|
|
case Opcode::vvmax: return "vvmax";
|
|
case Opcode::vvsll: return "vvsll";
|
|
case Opcode::vvsra: return "vvsra";
|
|
case Opcode::vavg: return "vavg";
|
|
case Opcode::vrelu: return "vrelu";
|
|
case Opcode::vtanh: return "vtanh";
|
|
case Opcode::vsigm: return "vsigm";
|
|
case Opcode::vsoftmax: return "vsoftmax";
|
|
case Opcode::vmv: return "vmv";
|
|
case Opcode::vrsu: return "vrsu";
|
|
case Opcode::vrsl: return "vrsl";
|
|
case Opcode::ld: return "ld";
|
|
case Opcode::st: return "st";
|
|
case Opcode::lldi: return "lldi";
|
|
case Opcode::lmv: return "lmv";
|
|
case Opcode::send: return "send";
|
|
case Opcode::recv: return "recv";
|
|
case Opcode::wait: return "wait";
|
|
case Opcode::sync: return "sync";
|
|
}
|
|
llvm_unreachable("Unsupported PIM binary opcode");
|
|
}
|
|
|
|
inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruction) {
|
|
InstructionRecord record;
|
|
std::optional<llvm::StringRef> opName = instruction.getString("op");
|
|
assert(opName && "Missing op field in PIM instruction");
|
|
record.opcode = opcodeFromString(*opName);
|
|
record.rd = toU8(getOptionalInt(instruction, "rd"));
|
|
record.r1 = toU8(getOptionalInt(instruction, "rs1"));
|
|
|
|
switch (record.opcode) {
|
|
case Opcode::sldi:
|
|
case Opcode::saddi:
|
|
case Opcode::smuli:
|
|
case Opcode::lldi: record.r2OrImm = getOptionalInt(instruction, "imm"); break;
|
|
case Opcode::mvmul:
|
|
record.r2OrImm = getOptionalInt(instruction, "mbiw");
|
|
record.generic1 = getOptionalInt(instruction, "relu");
|
|
record.generic2 = getOptionalInt(instruction, "group");
|
|
break;
|
|
case Opcode::setbw:
|
|
record.generic1 = getOptionalInt(instruction, "ibiw");
|
|
record.generic2 = getOptionalInt(instruction, "obiw");
|
|
break;
|
|
case Opcode::send:
|
|
case Opcode::recv:
|
|
record.r2OrImm = getOptionalInt(instruction, "core");
|
|
record.generic3 = getOptionalInt(instruction, "size");
|
|
break;
|
|
default: record.r2OrImm = getOptionalInt(instruction, "rs2"); break;
|
|
}
|
|
|
|
if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) {
|
|
if (auto* offsetValue = instruction.getObject("offset")) {
|
|
record.generic1 = getOptionalInt(*offsetValue, "offset_select");
|
|
record.generic2 = getOptionalInt(*offsetValue, "offset_value");
|
|
}
|
|
}
|
|
|
|
if (instruction.get("len"))
|
|
record.generic3 = getOptionalInt(instruction, "len");
|
|
else if (instruction.get("size") && record.opcode != Opcode::send && record.opcode != Opcode::recv)
|
|
record.generic3 = getOptionalInt(instruction, "size");
|
|
|
|
return record;
|
|
}
|
|
|
|
inline llvm::json::Object makeInstructionJson(const InstructionRecord& record) {
|
|
llvm::json::Object instruction;
|
|
instruction["op"] = opcodeToString(record.opcode).str();
|
|
|
|
auto addOffset = [&](int32_t offsetSelect, int32_t offsetValue) {
|
|
llvm::json::Object offset;
|
|
offset["offset_select"] = offsetSelect;
|
|
offset["offset_value"] = offsetValue;
|
|
instruction["offset"] = std::move(offset);
|
|
};
|
|
|
|
switch (record.opcode) {
|
|
case Opcode::sldi:
|
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
|
instruction["imm"] = record.r2OrImm;
|
|
break;
|
|
case Opcode::sld:
|
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
|
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
|
addOffset(record.generic1, record.generic2);
|
|
break;
|
|
case Opcode::sadd:
|
|
case Opcode::ssub:
|
|
case Opcode::smul:
|
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
|
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
|
instruction["rs2"] = record.r2OrImm;
|
|
break;
|
|
case Opcode::saddi:
|
|
case Opcode::smuli:
|
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
|
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
|
instruction["imm"] = record.r2OrImm;
|
|
break;
|
|
case Opcode::setbw:
|
|
instruction["ibiw"] = record.generic1;
|
|
instruction["obiw"] = record.generic2;
|
|
break;
|
|
case Opcode::mvmul:
|
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
|
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
|
instruction["mbiw"] = record.r2OrImm;
|
|
instruction["relu"] = record.generic1;
|
|
instruction["group"] = record.generic2;
|
|
break;
|
|
case Opcode::vvadd:
|
|
case Opcode::vvsub:
|
|
case Opcode::vvmul:
|
|
case Opcode::vvdmul:
|
|
case Opcode::vvmax:
|
|
case Opcode::vvsll:
|
|
case Opcode::vvsra:
|
|
case Opcode::vavg:
|
|
case Opcode::vmv:
|
|
case Opcode::vrsu:
|
|
case Opcode::vrsl:
|
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
|
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
|
instruction["rs2"] = record.r2OrImm;
|
|
addOffset(record.generic1, record.generic2);
|
|
instruction["len"] = record.generic3;
|
|
break;
|
|
case Opcode::vrelu:
|
|
case Opcode::vtanh:
|
|
case Opcode::vsigm:
|
|
case Opcode::vsoftmax:
|
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
|
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
|
addOffset(record.generic1, record.generic2);
|
|
instruction["len"] = record.generic3;
|
|
break;
|
|
case Opcode::ld:
|
|
case Opcode::st:
|
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
|
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
|
addOffset(record.generic1, record.generic2);
|
|
instruction["size"] = record.generic3;
|
|
break;
|
|
case Opcode::lldi:
|
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
|
instruction["imm"] = record.r2OrImm;
|
|
addOffset(record.generic1, record.generic2);
|
|
instruction["len"] = record.generic3;
|
|
break;
|
|
case Opcode::lmv:
|
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
|
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
|
addOffset(record.generic1, record.generic2);
|
|
instruction["len"] = record.generic3;
|
|
break;
|
|
case Opcode::send:
|
|
case Opcode::recv:
|
|
instruction["rd"] = static_cast<int64_t>(record.rd);
|
|
instruction["core"] = record.r2OrImm;
|
|
addOffset(record.generic1, record.generic2);
|
|
instruction["size"] = record.generic3;
|
|
break;
|
|
case Opcode::wait:
|
|
case Opcode::sync:
|
|
case Opcode::nop: break;
|
|
}
|
|
|
|
return instruction;
|
|
}
|
|
|
|
} // namespace onnx_mlir::pim_binary
|