generic gemm now works :)
This commit is contained in:
@@ -9,47 +9,41 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
using Value = mlir::Value;
|
||||
using Type = mlir::Type;
|
||||
using FunctionType = mlir::FunctionType;
|
||||
|
||||
struct MemEntry {
|
||||
size_t address;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
class PimMemory {
|
||||
SmallVector<std::pair<MemEntry, Value>, 32> memEntries;
|
||||
SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap;
|
||||
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
|
||||
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
|
||||
|
||||
size_t maxSize = 0; // 0 for unbounded memory
|
||||
size_t startAddress = 0;
|
||||
size_t minAlignment = 4;
|
||||
size_t firstAvailableAddress = 0;
|
||||
|
||||
MemEntry* gatherMemEntry(Value value);
|
||||
void allocateMemoryForValue(Value value, MemEntry& memEntry);
|
||||
MemEntry* gatherMemEntry(mlir::Value value);
|
||||
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
|
||||
|
||||
public:
|
||||
PimMemory(SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap)
|
||||
PimMemory(llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap)
|
||||
: globalMemEntriesMap(globalMemEntriesMap) {}
|
||||
|
||||
void allocateHost(ModuleOp moduleOp, func::FuncOp funcOp);
|
||||
void allocateCore(Operation* op);
|
||||
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
|
||||
void allocateCore(mlir::Operation* op);
|
||||
|
||||
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
|
||||
MemEntry getMemEntry(Value value) const;
|
||||
MemEntry getMemEntry(mlir::Value value) const;
|
||||
};
|
||||
|
||||
class PimAcceleratorMemory {
|
||||
public:
|
||||
SmallDenseMap<Value, MemEntry, 32> memEntriesMap;
|
||||
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> memEntriesMap;
|
||||
PimMemory hostMem;
|
||||
|
||||
private:
|
||||
SmallDenseMap<size_t, PimMemory> deviceMem;
|
||||
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
|
||||
|
||||
public:
|
||||
PimAcceleratorMemory()
|
||||
@@ -57,15 +51,15 @@ public:
|
||||
|
||||
PimMemory getOrCreateDeviceMem(size_t id);
|
||||
|
||||
size_t getValueAddress(Value value) const;
|
||||
size_t getValueAddress(mlir::Value value) const;
|
||||
};
|
||||
|
||||
class PimCodeGen {
|
||||
PimAcceleratorMemory& memory;
|
||||
raw_fd_ostream& coreFileStream;
|
||||
llvm::raw_fd_ostream& coreFileStream;
|
||||
|
||||
static json::Object createEmptyOffset();
|
||||
void emitInstruction(json::Object instruction) const;
|
||||
static llvm::json::Object createEmptyOffset();
|
||||
void emitInstruction(llvm::json::Object instruction) const;
|
||||
|
||||
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
|
||||
void setupRd(size_t rdAddress, size_t rdOffset) const;
|
||||
@@ -73,13 +67,18 @@ class PimCodeGen {
|
||||
void setupRdRs1Rs2(
|
||||
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const;
|
||||
|
||||
void
|
||||
emitMemCopyOp(StringRef opName, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset, size_t size) const;
|
||||
void emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const;
|
||||
void emitMemCopyOp(mlir::StringRef opName,
|
||||
size_t rdAddr,
|
||||
size_t rdOffset,
|
||||
size_t rs1Addr,
|
||||
size_t rs1Offset,
|
||||
size_t size,
|
||||
mlir::StringRef sizeFieldName = "size") const;
|
||||
void emitCommunicationOp(mlir::StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const;
|
||||
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
|
||||
|
||||
public:
|
||||
PimCodeGen(PimAcceleratorMemory& memory, raw_fd_ostream& coreJson)
|
||||
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
|
||||
: memory(memory), coreFileStream(coreJson) {}
|
||||
|
||||
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const;
|
||||
@@ -98,6 +97,6 @@ public:
|
||||
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
|
||||
};
|
||||
|
||||
OnnxMlirCompilerErrorCodes compileToPimJson(ModuleOp& moduleOpRef, std::string& outputDirName);
|
||||
OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user