generic gemm now works :)

This commit is contained in:
NiccoloN
2026-03-06 18:23:27 +01:00
parent 825188cc89
commit 1348bb1c97
5 changed files with 66 additions and 45 deletions

View File

@@ -21,9 +21,11 @@
#include "src/Compiler/CompilerPasses.hpp"
#include "src/Compiler/CompilerUtils.hpp"
using namespace llvm;
using namespace mlir;
using namespace onnx_mlir;
MemEntry* PimMemory::gatherMemEntry(Value value) {
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape());
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
@@ -31,7 +33,7 @@ MemEntry* PimMemory::gatherMemEntry(Value value) {
return &memEntries.emplace_back(memEntry, value).first;
}
void PimMemory::allocateMemoryForValue(Value value, MemEntry& memEntry) {
void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
memEntry.address = firstAvailableAddress;
firstAvailableAddress += memEntry.size;
// Alignment
@@ -59,7 +61,7 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
}
});
for (Value arg : funcOp.getArguments())
for (mlir::Value arg : funcOp.getArguments())
gatherMemEntry(arg);
allocateCore(funcOp);
@@ -73,7 +75,7 @@ void PimMemory::allocateCore(Operation* op) {
allocateMemoryForValue(value, memEntry);
}
MemEntry PimMemory::getMemEntry(Value value) const {
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
auto iter = globalMemEntriesMap.find(value);
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
return iter->second;
@@ -83,7 +85,8 @@ PimMemory PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
return deviceMem.try_emplace(id, memEntriesMap).first->second;
}
size_t PimAcceleratorMemory::getValueAddress(Value value) const {
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
size_t offset = 0;
while (true) {
auto definingOp = value.getDefiningOp();
if (!definingOp)
@@ -101,12 +104,18 @@ size_t PimAcceleratorMemory::getValueAddress(Value value) const {
auto subviewSizes = subviewDefiningOp.getStaticSizes();
auto subviewStrides = subviewDefiningOp.getStaticStrides();
assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides));
for (unsigned i = 0; i < subviewOffsets.size(); i++) {
size_t localOffset = subviewOffsets[i];
for (unsigned j = i + 1; j < subviewSizes.size(); j++)
localOffset *= subviewSizes[j];
offset += localOffset * subviewDefiningOp.getType().getElementTypeBitWidth() / 8;
}
value = source;
}
else
break;
}
return memEntriesMap.at(value).address;
return memEntriesMap.at(value).address + offset;
}
json::Object PimCodeGen::createEmptyOffset() {
@@ -144,15 +153,20 @@ void PimCodeGen::setupRdRs1Rs2(
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
}
void PimCodeGen::emitMemCopyOp(
StringRef opName, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset, size_t size) const {
void PimCodeGen::emitMemCopyOp(StringRef opName,
size_t rdAddr,
size_t rdOffset,
size_t rs1Addr,
size_t rs1Offset,
size_t size,
StringRef sizeFieldName) const {
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
json::Object json;
json["op"] = opName;
json["rd"] = 0;
json["rs1"] = 1;
json["size"] = size;
json[sizeFieldName] = size;
json["offset"] = createEmptyOffset();
emitInstruction(std::move(json));
}
@@ -206,7 +220,8 @@ void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const {
lmvOp.getDstOffset(),
memory.getValueAddress(lmvOp.getSrc()),
lmvOp.getSrcOffset(),
lmvOp.getSize());
lmvOp.getSize(),
"len");
}
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const {
@@ -549,7 +564,7 @@ static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
json::Array outputsAddresses;
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
for (Value output : returnOp.getOperands())
for (mlir::Value output : returnOp.getOperands())
outputsAddresses.push_back(memory.getValueAddress(output));
configJson["outputs_addresses"] = std::move(outputsAddresses);