generic gemm now works :)
This commit is contained in:
@@ -21,9 +21,11 @@
|
||||
#include "src/Compiler/CompilerPasses.hpp"
|
||||
#include "src/Compiler/CompilerUtils.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
|
||||
MemEntry* PimMemory::gatherMemEntry(Value value) {
|
||||
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
assert("Only static shape is supported" && type.hasStaticShape());
|
||||
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
@@ -31,7 +33,7 @@ MemEntry* PimMemory::gatherMemEntry(Value value) {
|
||||
return &memEntries.emplace_back(memEntry, value).first;
|
||||
}
|
||||
|
||||
void PimMemory::allocateMemoryForValue(Value value, MemEntry& memEntry) {
|
||||
void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
||||
memEntry.address = firstAvailableAddress;
|
||||
firstAvailableAddress += memEntry.size;
|
||||
// Alignment
|
||||
@@ -59,7 +61,7 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
||||
}
|
||||
});
|
||||
|
||||
for (Value arg : funcOp.getArguments())
|
||||
for (mlir::Value arg : funcOp.getArguments())
|
||||
gatherMemEntry(arg);
|
||||
|
||||
allocateCore(funcOp);
|
||||
@@ -73,7 +75,7 @@ void PimMemory::allocateCore(Operation* op) {
|
||||
allocateMemoryForValue(value, memEntry);
|
||||
}
|
||||
|
||||
MemEntry PimMemory::getMemEntry(Value value) const {
|
||||
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
|
||||
auto iter = globalMemEntriesMap.find(value);
|
||||
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
|
||||
return iter->second;
|
||||
@@ -83,7 +85,8 @@ PimMemory PimAcceleratorMemory::getOrCreateDeviceMem(size_t id) {
|
||||
return deviceMem.try_emplace(id, memEntriesMap).first->second;
|
||||
}
|
||||
|
||||
size_t PimAcceleratorMemory::getValueAddress(Value value) const {
|
||||
size_t PimAcceleratorMemory::getValueAddress(mlir::Value value) const {
|
||||
size_t offset = 0;
|
||||
while (true) {
|
||||
auto definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
@@ -101,12 +104,18 @@ size_t PimAcceleratorMemory::getValueAddress(Value value) const {
|
||||
auto subviewSizes = subviewDefiningOp.getStaticSizes();
|
||||
auto subviewStrides = subviewDefiningOp.getStaticStrides();
|
||||
assert(isMemoryContiguous(srcShape, subviewOffsets, subviewSizes, subviewStrides));
|
||||
for (unsigned i = 0; i < subviewOffsets.size(); i++) {
|
||||
size_t localOffset = subviewOffsets[i];
|
||||
for (unsigned j = i + 1; j < subviewSizes.size(); j++)
|
||||
localOffset *= subviewSizes[j];
|
||||
offset += localOffset * subviewDefiningOp.getType().getElementTypeBitWidth() / 8;
|
||||
}
|
||||
value = source;
|
||||
}
|
||||
else
|
||||
break;
|
||||
}
|
||||
return memEntriesMap.at(value).address;
|
||||
return memEntriesMap.at(value).address + offset;
|
||||
}
|
||||
|
||||
json::Object PimCodeGen::createEmptyOffset() {
|
||||
@@ -144,15 +153,20 @@ void PimCodeGen::setupRdRs1Rs2(
|
||||
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
|
||||
}
|
||||
|
||||
void PimCodeGen::emitMemCopyOp(
|
||||
StringRef opName, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset, size_t size) const {
|
||||
void PimCodeGen::emitMemCopyOp(StringRef opName,
|
||||
size_t rdAddr,
|
||||
size_t rdOffset,
|
||||
size_t rs1Addr,
|
||||
size_t rs1Offset,
|
||||
size_t size,
|
||||
StringRef sizeFieldName) const {
|
||||
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = opName;
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["size"] = size;
|
||||
json[sizeFieldName] = size;
|
||||
json["offset"] = createEmptyOffset();
|
||||
emitInstruction(std::move(json));
|
||||
}
|
||||
@@ -206,7 +220,8 @@ void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp) const {
|
||||
lmvOp.getDstOffset(),
|
||||
memory.getValueAddress(lmvOp.getSrc()),
|
||||
lmvOp.getSrcOffset(),
|
||||
lmvOp.getSize());
|
||||
lmvOp.getSize(),
|
||||
"len");
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp) const {
|
||||
@@ -549,7 +564,7 @@ static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
||||
|
||||
json::Array outputsAddresses;
|
||||
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
|
||||
for (Value output : returnOp.getOperands())
|
||||
for (mlir::Value output : returnOp.getOperands())
|
||||
outputsAddresses.push_back(memory.getValueAddress(output));
|
||||
configJson["outputs_addresses"] = std::move(outputsAddresses);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user