generic gemm now works :)

This commit is contained in:
NiccoloN
2026-03-06 18:23:27 +01:00
parent 825188cc89
commit 1348bb1c97
5 changed files with 66 additions and 45 deletions

View File

@@ -1123,7 +1123,7 @@ impl Trace {
if prefix == "Pre" { if prefix == "Pre" {
writeln!( writeln!(
file, file,
"Inst: lvm {} {} {} {{ {} {} }}", "Inst: lmv {} {} {} {{ {} {} }}",
rd, r1, imm_len, offset_select, offset_value rd, r1, imm_len, offset_select, offset_value
); );
} else { } else {
@@ -1141,13 +1141,15 @@ impl Trace {
); );
let r1_val = add_offset_r1(r1_val, offset_select, offset_value); let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
let rd_val = add_offset_rd(rd_val, offset_select, offset_value); let rd_val = add_offset_rd(rd_val, offset_select, offset_value);
let core_memory = core.load::<u8>(r1_val, imm_len).unwrap(); let core_memory = core
let global_memory = host.load::<u8>(rd_val, imm_len).unwrap(); .reserve_load(r1_val, imm_len).unwrap()
.reserve_load(rd_val, imm_len).unwrap()
.execute_load::<u8>().unwrap();
writeln!(file, "{} Memory:", prefix); writeln!(file, "{} Memory:", prefix);
writeln!(file, "\tLocal[{}:{}]: ", r1_val, r1_val + imm_len as usize,); writeln!(file, "\tLocal[{}:{}]: ", r1_val, r1_val + imm_len as usize,);
pretty_print::print_slice::<_,f32>(file, core_memory[0], 30); pretty_print::print_slice::<_,f32>(file, core_memory[0], 30);
writeln!(file, "\tLocal[{}:{}]: ", rd_val, rd_val + imm_len as usize,); writeln!(file, "\tLocal[{}:{}]: ", rd_val, rd_val + imm_len as usize,);
pretty_print::print_slice::<_,f32>(file, global_memory[0], 30); pretty_print::print_slice::<_,f32>(file, core_memory[1], 30);
if prefix == "Post" { if prefix == "Post" {
writeln!(file, "\n###############################################\n"); writeln!(file, "\n###############################################\n");

View File

@@ -21,9 +21,11 @@
#include "src/Compiler/CompilerPasses.hpp" #include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp" #include "src/Compiler/CompilerUtils.hpp"
using namespace llvm;
using namespace mlir;
using namespace onnx_mlir; using namespace onnx_mlir;
MemEntry* PimMemory::gatherMemEntry(Value value) { MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType()); auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape()); assert("Only static shape is supported" && type.hasStaticShape());
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8; size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
@@ -31,7 +33,7 @@ MemEntry* PimMemory::gatherMemEntry(Value value) {
return &memEntries.emplace_back(memEntry, value).first; return &memEntries.emplace_back(memEntry, value).first;
} }
void PimMemory::allocateMemoryForValue(Value value, MemEntry& memEntry) { void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
memEntry.address = firstAvailableAddress; memEntry.address = firstAvailableAddress;
firstAvailableAddress += memEntry.size; firstAvailableAddress += memEntry.size;
// Alignment // Alignment
@@ -59,7 +61,7 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
} }
}); });
for (Value arg : funcOp.getArguments()) for (mlir::Value arg : funcOp.getArguments())
gatherMemEntry(arg); gatherMemEntry(arg);
allocateCore(funcOp); allocateCore(funcOp);
@@ -73,7 +75,7 @@ void PimMemory::allocateCore(Operation* op) {
allocateMemoryForValue(value, memEntry); allocateMemoryForValue(value, memEntry);
} }
MemEntry PimMemory::getMemEntry(Value value) const { MemEntry PimMemory::getMemEntry(mlir::Value value) const {
auto iter = globalMemEntriesMap.find(value); auto iter = globalMemEntriesMap.find(value);
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end()); assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
return iter->second; return iter->second;
@@ -83,7 +85,8 @@ PimMemory PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
return deviceMem.try_emplace(id, memEntriesMap).first->second; return deviceMem.try_emplace(id, memEntriesMap).first->second;
} }
size_t PimAcceleratorMemory::getValueAddress(Value value) const { size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
size_t offset = 0;
while (true) { while (true) {
auto definingOp = value.getDefiningOp(); auto definingOp = value.getDefiningOp();
if (!definingOp) if (!definingOp)
@@ -101,12 +104,18 @@ size_t PimAcceleratorMemory::getValueAddress(Value value) const {
auto subviewSizes = subviewDefiningOp.getStaticSizes(); auto subviewSizes = subviewDefiningOp.getStaticSizes();
auto subviewStrides = subviewDefiningOp.getStaticStrides(); auto subviewStrides = subviewDefiningOp.getStaticStrides();
assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides)); assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides));
for (unsigned i = 0; i < subviewOffsets.size(); i++) {
size_t localOffset = subviewOffsets[i];
for (unsigned j = i + 1; j < subviewSizes.size(); j++)
localOffset *= subviewSizes[j];
offset += localOffset * subviewDefiningOp.getType().getElementTypeBitWidth() / 8;
}
value = source; value = source;
} }
else else
break; break;
} }
return memEntriesMap.at(value).address; return memEntriesMap.at(value).address + offset;
} }
json::Object PimCodeGen::createEmptyOffset() { json::Object PimCodeGen::createEmptyOffset() {
@@ -144,15 +153,20 @@ void PimCodeGen::setupRdRs1Rs2(
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset); genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
} }
void PimCodeGen::emitMemCopyOp( void PimCodeGen::emitMemCopyOp(StringRef opName,
StringRef opName, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset, size_t size) const { size_t rdAddr,
size_t rdOffset,
size_t rs1Addr,
size_t rs1Offset,
size_t size,
StringRef sizeFieldName) const {
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset); setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
json::Object json; json::Object json;
json["op"] = opName; json["op"] = opName;
json["rd"] = 0; json["rd"] = 0;
json["rs1"] = 1; json["rs1"] = 1;
json["size"] = size; json[sizeFieldName] = size;
json["offset"] = createEmptyOffset(); json["offset"] = createEmptyOffset();
emitInstruction(std::move(json)); emitInstruction(std::move(json));
} }
@@ -206,7 +220,8 @@ void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const {
lmvOp.getDstOffset(), lmvOp.getDstOffset(),
memory.getValueAddress(lmvOp.getSrc()), memory.getValueAddress(lmvOp.getSrc()),
lmvOp.getSrcOffset(), lmvOp.getSrcOffset(),
lmvOp.getSize()); lmvOp.getSize(),
"len");
} }
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const { void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const {
@@ -549,7 +564,7 @@ static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
json::Array outputsAddresses; json::Array outputsAddresses;
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>()) for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
for (Value output : returnOp.getOperands()) for (mlir::Value output : returnOp.getOperands())
outputsAddresses.push_back(memory.getValueAddress(output)); outputsAddresses.push_back(memory.getValueAddress(output));
configJson["outputs_addresses"] = std::move(outputsAddresses); configJson["outputs_addresses"] = std::move(outputsAddresses);

View File

@@ -9,47 +9,41 @@
namespace onnx_mlir { namespace onnx_mlir {
using namespace llvm;
using namespace mlir;
using Value = mlir::Value;
using Type = mlir::Type;
using FunctionType = mlir::FunctionType;
struct MemEntry { struct MemEntry {
size_t address; size_t address;
size_t size; size_t size;
}; };
class PimMemory { class PimMemory {
SmallVector<std::pair<MemEntry, Value>, 32> memEntries; llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap; llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
size_t maxSize = 0; // 0 for unbounded memory size_t maxSize = 0; // 0 for unbounded memory
size_t startAddress = 0; size_t startAddress = 0;
size_t minAlignment = 4; size_t minAlignment = 4;
size_t firstAvailableAddress = 0; size_t firstAvailableAddress = 0;
MemEntry* gatherMemEntry(Value value); MemEntry* gatherMemEntry(mlir::Value value);
void allocateMemoryForValue(Value value, MemEntry& memEntry); void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
public: public:
PimMemory(SmallDenseMap<Value, MemEntry, 32>& globalMemEntriesMap) PimMemory(llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap)
: globalMemEntriesMap(globalMemEntriesMap) {} : globalMemEntriesMap(globalMemEntriesMap) {}
void allocateHost(ModuleOp moduleOp, func::FuncOp funcOp); void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
void allocateCore(Operation* op); void allocateCore(mlir::Operation* op);
size_t getFirstAvailableAddress() const { return firstAvailableAddress; } size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
MemEntry getMemEntry(Value value) const; MemEntry getMemEntry(mlir::Value value) const;
}; };
class PimAcceleratorMemory { class PimAcceleratorMemory {
public: public:
SmallDenseMap<Value, MemEntry, 32> memEntriesMap; llvm::SmallDenseMap<mlir::Value, MemEntry, 32> memEntriesMap;
PimMemory hostMem; PimMemory hostMem;
private: private:
SmallDenseMap<size_t, PimMemory> deviceMem; llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
public: public:
PimAcceleratorMemory() PimAcceleratorMemory()
@@ -57,15 +51,15 @@ public:
PimMemory getOrCreateDeviceMem(size_t id); PimMemory getOrCreateDeviceMem(size_t id);
size_t getValueAddress(Value value) const; size_t getValueAddress(mlir::Value value) const;
}; };
class PimCodeGen { class PimCodeGen {
PimAcceleratorMemory& memory; PimAcceleratorMemory& memory;
raw_fd_ostream& coreFileStream; llvm::raw_fd_ostream& coreFileStream;
static json::Object createEmptyOffset(); static llvm::json::Object createEmptyOffset();
void emitInstruction(json::Object instruction) const; void emitInstruction(llvm::json::Object instruction) const;
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const; void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
void setupRd(size_t rdAddress, size_t rdOffset) const; void setupRd(size_t rdAddress, size_t rdOffset) const;
@@ -73,13 +67,18 @@ class PimCodeGen {
void setupRdRs1Rs2( void setupRdRs1Rs2(
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const; size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const;
void void emitMemCopyOp(mlir::StringRef opName,
emitMemCopyOp(StringRef opName, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset, size_t size) const; size_t rdAddr,
void emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const; size_t rdOffset,
size_t rs1Addr,
size_t rs1Offset,
size_t size,
mlir::StringRef sizeFieldName = "size") const;
void emitCommunicationOp(mlir::StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const;
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const; void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
public: public:
PimCodeGen(PimAcceleratorMemory& memory, raw_fd_ostream& coreJson) PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
: memory(memory), coreFileStream(coreJson) {} : memory(memory), coreFileStream(coreJson) {}
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const; void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp) const;
@@ -98,6 +97,6 @@ public:
void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const; void codeGenApplyFiltersOp(pim::PimApplyFiltersOp applyFiltersOp) const;
}; };
OnnxMlirCompilerErrorCodes compileToPimJson(ModuleOp& moduleOpRef, std::string& outputDirName); OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -24,8 +24,6 @@ namespace onnx_mlir {
namespace spatial { namespace spatial {
void ONNXToSpatialPass::runOnOperation() { void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Running ONNXToSpatialLoweringPass\n";
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext(); MLIRContext* ctx = &getContext();

View File

@@ -257,9 +257,16 @@ void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
outputTensors.reserve(returnOp->getNumOperands()); outputTensors.reserve(returnOp->getNumOperands());
rewriter.setInsertionPointToStart(returnOp->getBlock()); rewriter.setInsertionPointToStart(returnOp->getBlock());
for (auto returnValue : returnOp->getOperands()) { for (auto returnValue : returnOp->getOperands()) {
auto newOutputTensor = Operation* returnValueDefiningOp = returnValue.getDefiningOp();
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType())); if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
outputTensors.push_back(newOutputTensor); assert(!returnValueDefiningOp->hasAttr("weightAlways"));
outputTensors.push_back(returnValue);
}
else {
auto newOutputTensor =
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
outputTensors.push_back(newOutputTensor);
}
} }
} }