refactor PimCodeGen

This commit is contained in:
NiccoloN
2026-02-26 19:13:54 +01:00
parent b26b5754d5
commit a2c31836ae
4 changed files with 416 additions and 496 deletions

View File

@@ -1,16 +1,20 @@
#pragma once
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
#include "llvm/Support/JSON.h"
#include "Common/ValueMap.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerPasses.hpp"
namespace onnx_mlir {
using namespace llvm;
using namespace mlir;
using Value = mlir::Value;
using Type = mlir::Type;
using FunctionType = mlir::FunctionType;
struct MemEntry {
size_t address;
size_t size;
@@ -18,7 +22,7 @@ struct MemEntry {
class PimMemory {
SmallVector<std::pair<MemEntry, Value>, 32> memEntries;
llvm::SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap;
SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap;
size_t maxSize = 0; // 0 for unbounded memory
size_t startAddress = 0;
@@ -29,7 +33,7 @@ class PimMemory {
void allocateMemoryForValue(Value value, MemEntry& memEntry);
public:
PimMemory(llvm::SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap)
PimMemory(SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap)
: globalMemEntriesMap(globalMemEntriesMap) {}
void allocateHost(ModuleOp moduleOp, func::FuncOp funcOp);
@@ -41,11 +45,11 @@ public:
class PimAcceleratorMemory {
public:
llvm::SmallDenseMap<Value, MemEntry, 32> memEntriesMap;
SmallDenseMap<Value, MemEntry, 32> memEntriesMap;
PimMemory hostMem;
private:
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
SmallDenseMap<size_t, PimMemory> deviceMem;
public:
PimAcceleratorMemory()
@@ -58,40 +62,41 @@ public:
class PimCodeGen {
PimAcceleratorMemory& memory;
llvm::raw_fd_ostream& coreFileStream;
raw_fd_ostream& coreFileStream;
static json::Object createEmptyOffset();
void emitInstruction(json::Object instruction) const;
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
void setupRd(size_t rdAddress, size_t rdOffset) const;
void setupRdRs1(size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset) const;
void setupRdRs1Rs2(
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const;
void
emitMemCopyOp(StringRef opName, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset, size_t size) const;
void emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const;
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
public:
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
PimCodeGen(PimAcceleratorMemory& memory, raw_fd_ostream& coreJson)
: memory(memory), coreFileStream(coreJson) {}
llvm::json::Object createSetImmediate(size_t targetRegister, size_t immediate);
llvm::json::Object createEmptyOffset();
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const;
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp) const;
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate);
void createRd(size_t rdAddress, size_t rdOffset);
void createRdRs1(size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset);
void createRdRs1Rs2(
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset);
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp);
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp);
void codeGenReceiveOp(pim::PimReceiveOp receiveOp) const;
void codeGenSendOp(pim::PimSendOp sendOp) const;
template <typename MVMTy>
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix);
void codeGenReceiveOp(pim::PimReceiveOp receiveOp);
void codeGenSendOp(pim::PimSendOp sendOp);
void codeGenVAddOp(pim::PimVAddOp vaddOp);
void codeGenVMaxOp(pim::PimVMaxOp vmaxOp);
void codeGenVReluOp(pim::PimVReluOp vreluOp);
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp);
void codeGenVAddOp(pim::PimVAddOp vaddOp) const;
void codeGenVMaxOp(pim::PimVMaxOp vmaxOp) const;
void codeGenVReluOp(pim::PimVReluOp vreluOp) const;
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
};
OnnxMlirCompilerErrorCodes compileToPimJson(ModuleOp& moduleOpRef, std::string& outputDirName);
} // namespace onnx_mlir