From 1348bb1c97253f22e8e90164ffb606ffa5e29eb5 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Fri, 6 Mar 2026 18:23:27 +0100 Subject: [PATCH] generic gemm now works :) --- .../src/lib/tracing/tracing_isa.rs | 10 ++-- src/PIM/Compiler/PimCodeGen.cpp | 37 +++++++++----- src/PIM/Compiler/PimCodeGen.hpp | 49 +++++++++---------- .../ONNXToSpatial/ONNXToSpatialPass.cpp | 2 - .../SpatialToPIM/SpatialToPIMPass.cpp | 13 +++-- 5 files changed, 66 insertions(+), 45 deletions(-) diff --git a/backend-simulators/pim/pim-simulator/src/lib/tracing/tracing_isa.rs b/backend-simulators/pim/pim-simulator/src/lib/tracing/tracing_isa.rs index 8b54420..55026ca 100644 --- a/backend-simulators/pim/pim-simulator/src/lib/tracing/tracing_isa.rs +++ b/backend-simulators/pim/pim-simulator/src/lib/tracing/tracing_isa.rs @@ -1123,7 +1123,7 @@ impl Trace { if prefix == "Pre" { writeln!( file, - "Inst: lvm {} {} {} {{ {} {} }}", + "Inst: lmv {} {} {} {{ {} {} }}", rd, r1, imm_len, offset_select, offset_value ); } else { @@ -1141,13 +1141,15 @@ impl Trace { ); let r1_val = add_offset_r1(r1_val, offset_select, offset_value); let rd_val = add_offset_rd(rd_val, offset_select, offset_value); - let core_memory = core.load::(r1_val, imm_len).unwrap(); - let global_memory = host.load::(rd_val, imm_len).unwrap(); + let core_memory = core + .reserve_load(r1_val, imm_len).unwrap() + .reserve_load(rd_val, imm_len).unwrap() + .execute_load::().unwrap(); writeln!(file, "{} Memory:", prefix); writeln!(file, "\tLocal[{}:{}]: ", r1_val, r1_val + imm_len as usize,); pretty_print::print_slice::<_,f32>(file, core_memory[0], 30); writeln!(file, "\tLocal[{}:{}]: ", rd_val, rd_val + imm_len as usize,); - pretty_print::print_slice::<_,f32>(file, global_memory[0], 30); + pretty_print::print_slice::<_,f32>(file, core_memory[1], 30); if prefix == "Post" { writeln!(file, "\n###############################################\n"); diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 1deb83f..362403d 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -21,9 +21,11 @@ #include "src/Compiler/CompilerPasses.hpp" #include "src/Compiler/CompilerUtils.hpp" +using namespace llvm; +using namespace mlir; using namespace onnx_mlir; -MemEntry* PimMemory::gatherMemEntry(Value value) { +MemEntry* PimMemory::gatherMemEntry(mlir::Value value) { auto type = cast(value.getType()); assert("Only static shape is supported" && type.hasStaticShape()); size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8; @@ -31,7 +33,7 @@ MemEntry* PimMemory::gatherMemEntry(Value value) { return &memEntries.emplace_back(memEntry, value).first; } -void PimMemory::allocateMemoryForValue(Value value, MemEntry& memEntry) { +void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) { memEntry.address = firstAvailableAddress; firstAvailableAddress += memEntry.size; // Alignment @@ -59,7 +61,7 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { } }); - for (Value arg : funcOp.getArguments()) + for (mlir::Value arg : funcOp.getArguments()) gatherMemEntry(arg); allocateCore(funcOp); @@ -73,7 +75,7 @@ void PimMemory::allocateCore(Operation* op) { allocateMemoryForValue(value, memEntry); } -MemEntry PimMemory::getMemEntry(Value value) const { +MemEntry PimMemory::getMemEntry(mlir::Value value) const { auto iter = globalMemEntriesMap.find(value); assert("Missing memEntry for value" && iter != globalMemEntriesMap.end()); return iter->second; @@ -83,7 +85,8 @@ PimMemory PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) { return deviceMem.try_emplace(id, memEntriesMap).first->second; } -size_t PimAcceleratorMemory::getValueAddress(Value value) const { +size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const { + size_t offset = 0; while (true) { auto definingOp = value.getDefiningOp(); if (!definingOp) @@ -101,12 +104,18 @@ size_t PimAcceleratorMemory::getValueAddress(Value value) const { auto subviewSizes = subviewDefiningOp.getStaticSizes(); auto subviewStrides = subviewDefiningOp.getStaticStrides(); assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides)); + for (unsigned i = 0; i < subviewOffsets.size(); i++) { + size_t localOffset = subviewOffsets[i]; + for (unsigned j = i + 1; j < subviewSizes.size(); j++) + localOffset *= subviewSizes[j]; + offset += localOffset * subviewDefiningOp.getType().getElementTypeBitWidth() / 8; + } value = source; } else break; } - return memEntriesMap.at(value).address; + return memEntriesMap.at(value).address + offset; } json::Object PimCodeGen::createEmptyOffset() { @@ -144,15 +153,20 @@ void PimCodeGen::setupRdRs1Rs2( genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset); } -void PimCodeGen::emitMemCopyOp( - StringRef opName, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset, size_t size) const { +void PimCodeGen::emitMemCopyOp(StringRef opName, + size_t rdAddr, + size_t rdOffset, + size_t rs1Addr, + size_t rs1Offset, + size_t size, + StringRef sizeFieldName) const { setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset); json::Object json; json["op"] = opName; json["rd"] = 0; json["rs1"] = 1; - json["size"] = size; + json[sizeFieldName] = size; json["offset"] = createEmptyOffset(); emitInstruction(std::move(json)); } @@ -206,7 +220,8 @@ void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const { lmvOp.getDstOffset(), memory.getValueAddress(lmvOp.getSrc()), lmvOp.getSrcOffset(), - lmvOp.getSize()); + lmvOp.getSize(), + "len"); } void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const { @@ -549,7 +564,7 @@ static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp, json::Array outputsAddresses; for (func::ReturnOp returnOp : funcOp.getOps()) - for (Value output : returnOp.getOperands()) + for (mlir::Value output : returnOp.getOperands()) outputsAddresses.push_back(memory.getValueAddress(output)); configJson["outputs_addresses"] = std::move(outputsAddresses); diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index 6a2effb..24de8c7 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -9,47 +9,41 @@ namespace onnx_mlir { -using namespace llvm; -using namespace mlir; -using Value = mlir::Value; -using Type = mlir::Type; -using FunctionType = mlir::FunctionType; - struct MemEntry { size_t address; size_t size; }; class PimMemory { - SmallVector, 32> memEntries; - SmallDenseMap& globalMemEntriesMap; + llvm::SmallVector, 32> memEntries; + llvm::SmallDenseMap& globalMemEntriesMap; size_t maxSize = 0; // 0 for unbounded memory size_t startAddress = 0; size_t minAlignment = 4; size_t firstAvailableAddress = 0; - MemEntry* gatherMemEntry(Value value); - void allocateMemoryForValue(Value value, MemEntry& memEntry); + MemEntry* gatherMemEntry(mlir::Value value); + void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry); public: - PimMemory(SmallDenseMap& globalMemEntriesMap) + PimMemory(llvm::SmallDenseMap& globalMemEntriesMap) : globalMemEntriesMap(globalMemEntriesMap) {} - void allocateHost(ModuleOp moduleOp, func::FuncOp funcOp); - void allocateCore(Operation* op); + void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp); + void allocateCore(mlir::Operation* op); size_t getFirstAvailableAddress() const { return firstAvailableAddress; } - MemEntry getMemEntry(Value value) const; + MemEntry getMemEntry(mlir::Value value) const; }; class PimAcceleratorMemory { public: - SmallDenseMap memEntriesMap; + llvm::SmallDenseMap memEntriesMap; PimMemory hostMem; private: - SmallDenseMap deviceMem; + llvm::SmallDenseMap deviceMem; public: PimAcceleratorMemory() @@ -57,15 +51,15 @@ public: PimMemory getOrCreateDeviceMem(size_t id); - size_t getValueAddress(Value value) const; + size_t getValueAddress(mlir::Value value) const; }; class PimCodeGen { PimAcceleratorMemory& memory; - raw_fd_ostream& coreFileStream; + llvm::raw_fd_ostream& coreFileStream; - static json::Object createEmptyOffset(); - void emitInstruction(json::Object instruction) const; + static llvm::json::Object createEmptyOffset(); + void emitInstruction(llvm::json::Object instruction) const; void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const; void setupRd(size_t rdAddress, size_t rdOffset) const; @@ -73,13 +67,18 @@ class PimCodeGen { void setupRdRs1Rs2( size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const; - void - emitMemCopyOp(StringRef opName, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset, size_t size) const; - void emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const; + void emitMemCopyOp(mlir::StringRef opName, + size_t rdAddr, + size_t rdOffset, + size_t rs1Addr, + size_t rs1Offset, + size_t size, + mlir::StringRef sizeFieldName = "size") const; + void emitCommunicationOp(mlir::StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const; void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const; public: - PimCodeGen(PimAcceleratorMemory& memory, raw_fd_ostream& coreJson) + PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson) : memory(memory), coreFileStream(coreJson) {} void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const; @@ -98,6 +97,6 @@ public: void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const; }; -OnnxMlirCompilerErrorCodes compileToPimJson(ModuleOp& moduleOpRef, std::string& outputDirName); +OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName); } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 70a8fbe..182c715 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -24,8 +24,6 @@ namespace onnx_mlir { namespace spatial { void ONNXToSpatialPass::runOnOperation() { - llvm::dbgs() << "Running ONNXToSpatialLoweringPass\n"; - ModuleOp moduleOp = getOperation(); MLIRContext* ctx = &getContext(); diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp index 5b590b5..9a8ed76 100644 --- a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp @@ -257,9 +257,16 @@ void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew outputTensors.reserve(returnOp->getNumOperands()); rewriter.setInsertionPointToStart(returnOp->getBlock()); for (auto returnValue : returnOp->getOperands()) { - auto newOutputTensor = - createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast(returnValue.getType())); - outputTensors.push_back(newOutputTensor); + Operation* returnValueDefiningOp = returnValue.getDefiningOp(); + if (returnValueDefiningOp->hasTrait()) { + assert(!returnValueDefiningOp->hasAttr("weightAlways")); + outputTensors.push_back(returnValue); + } + else { + auto newOutputTensor = + createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast(returnValue.getType())); + outputTensors.push_back(newOutputTensor); + } } }