add shared loop creation helpers
Validate Operations / validate-operations (push) Has been cancelled

add shared checked arithmetic helpers
refactor pim passes into Pim/Transforms
more robust memory coalescing pass
This commit is contained in:
NiccoloN
2026-06-01 16:49:06 +02:00
parent 356be6ccc2
commit 636310d0cb
55 changed files with 2007 additions and 1103 deletions
+44 -28
View File
@@ -25,12 +25,14 @@
#include <cassert>
#include <cstdint>
#include <fstream>
#include <limits>
#include <memory>
#include <string>
#include <utility>
#include "Common/IR/CompactAsmUtils.hpp"
#include "Common/PimCommon.hpp"
#include "Common/Support/CheckedArithmetic.hpp"
#include "Common/Support/ReportUtils.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
@@ -71,12 +73,23 @@ static MemoryValueKey getMemoryValueKey(mlir::Value value, std::optional<unsigne
return {value, getLaneForMemoryValue(value, lane)};
}
static int32_t getVectorByteSizeOrCrash(ShapedType type) {
auto byteSize = pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "vector byte size");
if (failed(byteSize))
llvm_unreachable("Failed to compute checked vector byte size");
return pim::checkedI32OrCrash(*byteSize, "vector byte size");
}
} // namespace
MemEntry* PimMemory::gatherMemEntry(mlir::Value value, std::optional<unsigned> lane) {
auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape());
size_t allocSize = getShapedTypeSizeInBytes(type);
auto checkedAllocSize =
pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "memory allocation byte size");
if (failed(checkedAllocSize))
llvm_unreachable("Failed to compute checked allocation byte size");
size_t allocSize = static_cast<size_t>(*checkedAllocSize);
MemEntry memEntry = {0, allocSize};
return &memEntries.emplace_back(memEntry, getMemoryValueKey(value, lane)).first;
}
@@ -272,7 +285,8 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value,
llvm_unreachable("Missing mem entry");
}
return iter->second.address + resolvedAddress->byteOffset;
size_t byteOffset = pim::checkedSizeOrCrash(resolvedAddress->byteOffset, "resolved PIM byte offset");
return pim::checkedAddOrCrash(iter->second.address, byteOffset, "resolved PIM address");
}
llvm::FailureOr<int64_t> PimAcceleratorMemory::getIndexValue(mlir::Value value,
@@ -291,8 +305,12 @@ llvm::FailureOr<int64_t> PimAcceleratorMemory::getIndexValue(mlir::Value value,
void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); }
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
reportEntries.push_back(
{MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row, row.numAlloca, row.sizeAlloca});
reportEntries.push_back({MemoryReportEntry::Kind::Core,
coreId,
{pim::checkedI32OrCrash(coreId, "memory report core id")},
row,
row.numAlloca,
row.sizeAlloca});
}
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId,
@@ -402,24 +420,24 @@ void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t i
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::sldi;
instruction.rd = static_cast<uint8_t>(registerNumber);
instruction.r2OrImm = static_cast<int32_t>(immediate);
instruction.r2OrImm = pim::checkedI32OrCrash(immediate, "register immediate");
emitInstruction(instruction);
}
void PimCodeGen::setupRd(size_t rdAddress, size_t rdOffset) const {
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
genSetRegisterImmediateUnsigned(0, pim::checkedAddOrCrash(rdAddress, rdOffset, "rd address"));
}
void PimCodeGen::setupRdRs1(size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset) const {
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset);
genSetRegisterImmediateUnsigned(0, pim::checkedAddOrCrash(rdAddress, rdOffset, "rd address"));
genSetRegisterImmediateUnsigned(1, pim::checkedAddOrCrash(rs1Address, rs1Offset, "rs1 address"));
}
void PimCodeGen::setupRdRs1Rs2(
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const {
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset);
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
genSetRegisterImmediateUnsigned(0, pim::checkedAddOrCrash(rdAddress, rdOffset, "rd address"));
genSetRegisterImmediateUnsigned(1, pim::checkedAddOrCrash(rs1Address, rs1Offset, "rs1 address"));
genSetRegisterImmediateUnsigned(2, pim::checkedAddOrCrash(rs2Address, rs2Offset, "rs2 address"));
}
void PimCodeGen::emitMemCopyOp(StringRef opName,
@@ -437,8 +455,7 @@ void PimCodeGen::emitMemCopyOp(StringRef opName,
instruction.r1 = 1;
instruction.generic1 = 0;
instruction.generic2 = 0;
instruction.generic3 = static_cast<int32_t>(size);
(void) sizeFieldName;
instruction.generic3 = pim::checkedI32OrCrash(size, sizeFieldName);
emitInstruction(instruction);
}
@@ -448,10 +465,10 @@ void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::opcodeFromString(opName);
instruction.rd = 0;
instruction.r2OrImm = static_cast<int32_t>(remapCoreId(coreId));
instruction.r2OrImm = pim::checkedI32OrCrash(remapCoreId(coreId), "communication core id");
instruction.generic1 = 0;
instruction.generic2 = 0;
instruction.generic3 = static_cast<int32_t>(size);
instruction.generic3 = pim::checkedI32OrCrash(size, "communication byte size");
emitInstruction(instruction);
}
@@ -464,7 +481,7 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
instruction.r1 = 1;
instruction.r2OrImm = 8;
instruction.generic1 = 0;
instruction.generic2 = static_cast<int32_t>(groupId);
instruction.generic2 = pim::checkedI32OrCrash(groupId, "mvm group id");
emitInstruction(instruction);
}
@@ -578,7 +595,7 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowle
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvaddOp.getLhs().getType())));
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvaddOp.getLhs().getType()));
emitInstruction(instruction);
}
@@ -593,7 +610,7 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowle
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvsubOp.getLhs().getType())));
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvsubOp.getLhs().getType()));
emitInstruction(instruction);
}
@@ -608,7 +625,7 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowle
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmulOp.getLhs().getType())));
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvmulOp.getLhs().getType()));
emitInstruction(instruction);
}
@@ -623,7 +640,7 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowle
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmaxOp.getLhs().getType())));
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvmaxOp.getLhs().getType()));
emitInstruction(instruction);
}
@@ -638,7 +655,7 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKno
instruction.rd = 0;
instruction.r1 = 1;
instruction.r2OrImm = 2;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvdmulOp.getLhs().getType())));
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvdmulOp.getLhs().getType()));
emitInstruction(instruction);
}
@@ -653,7 +670,7 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge
instruction.r1 = 1;
instruction.r2OrImm = 1;
instruction.generic1 = 1;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vavgOp.getInput().getType())));
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vavgOp.getInput().getType()));
emitInstruction(instruction);
}
@@ -666,7 +683,7 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowle
instruction.opcode = pim_binary::Opcode::vrelu;
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vreluOp.getInput().getType())));
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vreluOp.getInput().getType()));
emitInstruction(instruction);
}
@@ -679,7 +696,7 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowle
instruction.opcode = pim_binary::Opcode::vtanh;
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vtanhOp.getInput().getType())));
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vtanhOp.getInput().getType()));
emitInstruction(instruction);
}
@@ -692,7 +709,7 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowle
instruction.opcode = pim_binary::Opcode::vsigm;
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsigmOp.getInput().getType())));
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vsigmOp.getInput().getType()));
emitInstruction(instruction);
}
@@ -705,8 +722,7 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
instruction.opcode = pim_binary::Opcode::vsoftmax;
instruction.rd = 0;
instruction.r1 = 1;
instruction.generic3 =
static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsoftmaxOp.getInput().getType())));
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vsoftmaxOp.getInput().getType()));
emitInstruction(instruction);
}
@@ -1370,7 +1386,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
if (auto err = linkCoreWeights(job.emittedCoreId, mapCoreWeightToFileName[job.emittedCoreId], xbarsPerGroup))
return err;
xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup);
reportedCoreIds.push_back(static_cast<int32_t>(job.emittedCoreId));
reportedCoreIds.push_back(pim::checkedI32OrCrash(job.emittedCoreId, "batch report core id"));
if (!batchPerCoreRow)
batchPerCoreRow = result.reportRow;
batchRow = addMemoryReportRows(batchRow, result.reportRow);