#pragma once #include "llvm/ADT/StringRef.h" #include "llvm/Support/Endian.h" #include "llvm/Support/JSON.h" #include "llvm/Support/raw_ostream.h" #include #include #include namespace onnx_mlir::pim_binary { inline constexpr char kMagic[4] = {'P', 'I', 'M', 'B'}; inline constexpr uint32_t kVersion = 1; inline constexpr uint64_t kCountOffset = 8; inline constexpr size_t kHeaderSize = 12; inline constexpr size_t kRecordSize = 20; enum class Opcode : uint32_t { nop = 0, sldi = 1, sld = 2, sadd = 3, ssub = 4, smul = 5, saddi = 6, smuli = 7, setbw = 8, mvmul = 9, vvadd = 10, vvsub = 11, vvmul = 12, vvdmul = 13, vvmax = 14, vvsll = 15, vvsra = 16, vavg = 17, vrelu = 18, vtanh = 19, vsigm = 20, vsoftmax = 21, vmv = 22, vrsu = 23, vrsl = 24, ld = 25, st = 26, lldi = 27, lmv = 28, send = 29, recv = 30, wait = 31, sync = 32, }; struct InstructionRecord { Opcode opcode = Opcode::nop; uint8_t rd = 0; uint8_t r1 = 0; int32_t r2OrImm = 0; int32_t generic1 = 0; int32_t generic2 = 0; int32_t generic3 = 0; uint8_t flags = 0; }; inline void writeUint32LE(llvm::raw_ostream& os, uint32_t value) { std::array bytes; llvm::support::endian::write32le(bytes.data(), value); os.write(bytes.data(), bytes.size()); } inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { writeUint32LE(os, static_cast(value)); } inline void writeHeader(llvm::raw_ostream& os) { os.write(kMagic, sizeof(kMagic)); writeUint32LE(os, kVersion); writeUint32LE(os, 0); } inline void patchInstructionCount(llvm::raw_pwrite_stream& os, uint32_t instructionCount) { std::array bytes; llvm::support::endian::write32le(bytes.data(), instructionCount); os.pwrite(bytes.data(), bytes.size(), kCountOffset); } inline void writeInstructionRecord(llvm::raw_ostream& os, const InstructionRecord& record) { os << static_cast(static_cast(record.opcode)); os << static_cast(record.rd); os << static_cast(record.r1); os << static_cast(record.flags); writeInt32LE(os, record.r2OrImm); writeInt32LE(os, record.generic1); writeInt32LE(os, record.generic2); writeInt32LE(os, record.generic3); } inline int32_t toI32(int64_t value) { assert(value >= std::numeric_limits::min() && value <= std::numeric_limits::max() && "PIM binary field out of int32 range"); return static_cast(value); } inline uint8_t toU8(int64_t value) { assert(value >= 0 && value <= std::numeric_limits::max() && "PIM binary field out of uint8 range"); return static_cast(value); } inline int32_t getOptionalInt(const llvm::json::Object& object, llvm::StringRef key, int32_t defaultValue = 0) { if (std::optional value = object.getInteger(key)) return toI32(*value); return defaultValue; } inline Opcode opcodeFromString(llvm::StringRef opName) { if (opName == "nop") return Opcode::nop; if (opName == "sldi") return Opcode::sldi; if (opName == "sld") return Opcode::sld; if (opName == "sadd") return Opcode::sadd; if (opName == "ssub") return Opcode::ssub; if (opName == "smul") return Opcode::smul; if (opName == "saddi") return Opcode::saddi; if (opName == "smuli") return Opcode::smuli; if (opName == "setbw") return Opcode::setbw; if (opName == "mvmul") return Opcode::mvmul; if (opName == "vvadd") return Opcode::vvadd; if (opName == "vvsub") return Opcode::vvsub; if (opName == "vvmul") return Opcode::vvmul; if (opName == "vvdmul") return Opcode::vvdmul; if (opName == "vvmax") return Opcode::vvmax; if (opName == "vvsll") return Opcode::vvsll; if (opName == "vvsra") return Opcode::vvsra; if (opName == "vavg") return Opcode::vavg; if (opName == "vrelu") return Opcode::vrelu; if (opName == "vtanh") return Opcode::vtanh; if (opName == "vsigm") return Opcode::vsigm; if (opName == "vsoftmax") return Opcode::vsoftmax; if (opName == "vmv") return Opcode::vmv; if (opName == "vrsu") return Opcode::vrsu; if (opName == "vrsl") return Opcode::vrsl; if (opName == "ld") return Opcode::ld; if (opName == "st") return Opcode::st; if (opName == "lldi") return Opcode::lldi; if (opName == "lmv") return Opcode::lmv; if (opName == "send") return Opcode::send; if (opName == "recv") return Opcode::recv; if (opName == "wait") return Opcode::wait; if (opName == "sync") return Opcode::sync; llvm_unreachable("Unsupported PIM binary opcode"); } inline llvm::StringRef opcodeToString(Opcode opcode) { switch (opcode) { case Opcode::nop: return "nop"; case Opcode::sldi: return "sldi"; case Opcode::sld: return "sld"; case Opcode::sadd: return "sadd"; case Opcode::ssub: return "ssub"; case Opcode::smul: return "smul"; case Opcode::saddi: return "saddi"; case Opcode::smuli: return "smuli"; case Opcode::setbw: return "setbw"; case Opcode::mvmul: return "mvmul"; case Opcode::vvadd: return "vvadd"; case Opcode::vvsub: return "vvsub"; case Opcode::vvmul: return "vvmul"; case Opcode::vvdmul: return "vvdmul"; case Opcode::vvmax: return "vvmax"; case Opcode::vvsll: return "vvsll"; case Opcode::vvsra: return "vvsra"; case Opcode::vavg: return "vavg"; case Opcode::vrelu: return "vrelu"; case Opcode::vtanh: return "vtanh"; case Opcode::vsigm: return "vsigm"; case Opcode::vsoftmax: return "vsoftmax"; case Opcode::vmv: return "vmv"; case Opcode::vrsu: return "vrsu"; case Opcode::vrsl: return "vrsl"; case Opcode::ld: return "ld"; case Opcode::st: return "st"; case Opcode::lldi: return "lldi"; case Opcode::lmv: return "lmv"; case Opcode::send: return "send"; case Opcode::recv: return "recv"; case Opcode::wait: return "wait"; case Opcode::sync: return "sync"; } llvm_unreachable("Unsupported PIM binary opcode"); } inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruction) { InstructionRecord record; std::optional opName = instruction.getString("op"); assert(opName && "Missing op field in PIM instruction"); record.opcode = opcodeFromString(*opName); record.rd = toU8(getOptionalInt(instruction, "rd")); record.r1 = toU8(getOptionalInt(instruction, "rs1")); switch (record.opcode) { case Opcode::sldi: case Opcode::saddi: case Opcode::smuli: case Opcode::lldi: record.r2OrImm = getOptionalInt(instruction, "imm"); break; case Opcode::mvmul: record.r2OrImm = getOptionalInt(instruction, "mbiw"); record.generic1 = getOptionalInt(instruction, "relu"); record.generic2 = getOptionalInt(instruction, "group"); break; case Opcode::setbw: record.generic1 = getOptionalInt(instruction, "ibiw"); record.generic2 = getOptionalInt(instruction, "obiw"); break; case Opcode::send: case Opcode::recv: record.r2OrImm = getOptionalInt(instruction, "core"); record.generic3 = getOptionalInt(instruction, "size"); break; default: record.r2OrImm = getOptionalInt(instruction, "rs2"); break; } if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) { if (auto* offsetValue = instruction.getObject("offset")) { record.generic1 = getOptionalInt(*offsetValue, "offset_select"); record.generic2 = getOptionalInt(*offsetValue, "offset_value"); } } if (instruction.get("len")) record.generic3 = getOptionalInt(instruction, "len"); else if (instruction.get("size") && record.opcode != Opcode::send && record.opcode != Opcode::recv) record.generic3 = getOptionalInt(instruction, "size"); return record; } inline llvm::json::Object makeInstructionJson(const InstructionRecord& record) { llvm::json::Object instruction; instruction["op"] = opcodeToString(record.opcode).str(); auto addOffset = [&](int32_t offsetSelect, int32_t offsetValue) { llvm::json::Object offset; offset["offset_select"] = offsetSelect; offset["offset_value"] = offsetValue; instruction["offset"] = std::move(offset); }; switch (record.opcode) { case Opcode::sldi: instruction["rd"] = static_cast(record.rd); instruction["imm"] = record.r2OrImm; break; case Opcode::sld: instruction["rd"] = static_cast(record.rd); instruction["rs1"] = static_cast(record.r1); addOffset(record.generic1, record.generic2); break; case Opcode::sadd: case Opcode::ssub: case Opcode::smul: instruction["rd"] = static_cast(record.rd); instruction["rs1"] = static_cast(record.r1); instruction["rs2"] = record.r2OrImm; break; case Opcode::saddi: case Opcode::smuli: instruction["rd"] = static_cast(record.rd); instruction["rs1"] = static_cast(record.r1); instruction["imm"] = record.r2OrImm; break; case Opcode::setbw: instruction["ibiw"] = record.generic1; instruction["obiw"] = record.generic2; break; case Opcode::mvmul: instruction["rd"] = static_cast(record.rd); instruction["rs1"] = static_cast(record.r1); instruction["mbiw"] = record.r2OrImm; instruction["relu"] = record.generic1; instruction["group"] = record.generic2; break; case Opcode::vvadd: case Opcode::vvsub: case Opcode::vvmul: case Opcode::vvdmul: case Opcode::vvmax: case Opcode::vvsll: case Opcode::vvsra: case Opcode::vavg: case Opcode::vmv: case Opcode::vrsu: case Opcode::vrsl: instruction["rd"] = static_cast(record.rd); instruction["rs1"] = static_cast(record.r1); instruction["rs2"] = record.r2OrImm; addOffset(record.generic1, record.generic2); instruction["len"] = record.generic3; break; case Opcode::vrelu: case Opcode::vtanh: case Opcode::vsigm: case Opcode::vsoftmax: instruction["rd"] = static_cast(record.rd); instruction["rs1"] = static_cast(record.r1); addOffset(record.generic1, record.generic2); instruction["len"] = record.generic3; break; case Opcode::ld: case Opcode::st: instruction["rd"] = static_cast(record.rd); instruction["rs1"] = static_cast(record.r1); addOffset(record.generic1, record.generic2); instruction["size"] = record.generic3; break; case Opcode::lldi: instruction["rd"] = static_cast(record.rd); instruction["imm"] = record.r2OrImm; addOffset(record.generic1, record.generic2); instruction["len"] = record.generic3; break; case Opcode::lmv: instruction["rd"] = static_cast(record.rd); instruction["rs1"] = static_cast(record.r1); addOffset(record.generic1, record.generic2); instruction["len"] = record.generic3; break; case Opcode::send: case Opcode::recv: instruction["rd"] = static_cast(record.rd); instruction["core"] = record.r2OrImm; addOffset(record.generic1, record.generic2); instruction["size"] = record.generic3; break; case Opcode::wait: case Opcode::sync: case Opcode::nop: break; } return instruction; } } // namespace onnx_mlir::pim_binary