add shared loop creation helpers
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
add shared checked arithmetic helpers refactor pim passes into Pim/Transforms more robust memory coalescing pass
This commit is contained in:
@@ -25,12 +25,14 @@
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "Common/IR/CompactAsmUtils.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Common/Support/CheckedArithmetic.hpp"
|
||||
#include "Common/Support/ReportUtils.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
@@ -71,12 +73,23 @@ static MemoryValueKey getMemoryValueKey(mlir::Value value, std::optional<unsigne
|
||||
return {value, getLaneForMemoryValue(value, lane)};
|
||||
}
|
||||
|
||||
static int32_t getVectorByteSizeOrCrash(ShapedType type) {
|
||||
auto byteSize = pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "vector byte size");
|
||||
if (failed(byteSize))
|
||||
llvm_unreachable("Failed to compute checked vector byte size");
|
||||
return pim::checkedI32OrCrash(*byteSize, "vector byte size");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MemEntry* PimMemory::gatherMemEntry(mlir::Value value, std::optional<unsigned> lane) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
assert("Only static shape is supported" && type.hasStaticShape());
|
||||
size_t allocSize = getShapedTypeSizeInBytes(type);
|
||||
auto checkedAllocSize =
|
||||
pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "memory allocation byte size");
|
||||
if (failed(checkedAllocSize))
|
||||
llvm_unreachable("Failed to compute checked allocation byte size");
|
||||
size_t allocSize = static_cast<size_t>(*checkedAllocSize);
|
||||
MemEntry memEntry = {0, allocSize};
|
||||
return &memEntries.emplace_back(memEntry, getMemoryValueKey(value, lane)).first;
|
||||
}
|
||||
@@ -272,7 +285,8 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value,
|
||||
llvm_unreachable("Missing mem entry");
|
||||
}
|
||||
|
||||
return iter->second.address + resolvedAddress->byteOffset;
|
||||
size_t byteOffset = pim::checkedSizeOrCrash(resolvedAddress->byteOffset, "resolved PIM byte offset");
|
||||
return pim::checkedAddOrCrash(iter->second.address, byteOffset, "resolved PIM address");
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> PimAcceleratorMemory::getIndexValue(mlir::Value value,
|
||||
@@ -291,8 +305,12 @@ llvm::FailureOr<int64_t> PimAcceleratorMemory::getIndexValue(mlir::Value value,
|
||||
void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); }
|
||||
|
||||
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
|
||||
reportEntries.push_back(
|
||||
{MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row, row.numAlloca, row.sizeAlloca});
|
||||
reportEntries.push_back({MemoryReportEntry::Kind::Core,
|
||||
coreId,
|
||||
{pim::checkedI32OrCrash(coreId, "memory report core id")},
|
||||
row,
|
||||
row.numAlloca,
|
||||
row.sizeAlloca});
|
||||
}
|
||||
|
||||
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId,
|
||||
@@ -402,24 +420,24 @@ void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t i
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::sldi;
|
||||
instruction.rd = static_cast<uint8_t>(registerNumber);
|
||||
instruction.r2OrImm = static_cast<int32_t>(immediate);
|
||||
instruction.r2OrImm = pim::checkedI32OrCrash(immediate, "register immediate");
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::setupRd(size_t rdAddress, size_t rdOffset) const {
|
||||
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
|
||||
genSetRegisterImmediateUnsigned(0, pim::checkedAddOrCrash(rdAddress, rdOffset, "rd address"));
|
||||
}
|
||||
|
||||
void PimCodeGen::setupRdRs1(size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset) const {
|
||||
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
|
||||
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset);
|
||||
genSetRegisterImmediateUnsigned(0, pim::checkedAddOrCrash(rdAddress, rdOffset, "rd address"));
|
||||
genSetRegisterImmediateUnsigned(1, pim::checkedAddOrCrash(rs1Address, rs1Offset, "rs1 address"));
|
||||
}
|
||||
|
||||
void PimCodeGen::setupRdRs1Rs2(
|
||||
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const {
|
||||
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
|
||||
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset);
|
||||
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
|
||||
genSetRegisterImmediateUnsigned(0, pim::checkedAddOrCrash(rdAddress, rdOffset, "rd address"));
|
||||
genSetRegisterImmediateUnsigned(1, pim::checkedAddOrCrash(rs1Address, rs1Offset, "rs1 address"));
|
||||
genSetRegisterImmediateUnsigned(2, pim::checkedAddOrCrash(rs2Address, rs2Offset, "rs2 address"));
|
||||
}
|
||||
|
||||
void PimCodeGen::emitMemCopyOp(StringRef opName,
|
||||
@@ -437,8 +455,7 @@ void PimCodeGen::emitMemCopyOp(StringRef opName,
|
||||
instruction.r1 = 1;
|
||||
instruction.generic1 = 0;
|
||||
instruction.generic2 = 0;
|
||||
instruction.generic3 = static_cast<int32_t>(size);
|
||||
(void) sizeFieldName;
|
||||
instruction.generic3 = pim::checkedI32OrCrash(size, sizeFieldName);
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -448,10 +465,10 @@ void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::opcodeFromString(opName);
|
||||
instruction.rd = 0;
|
||||
instruction.r2OrImm = static_cast<int32_t>(remapCoreId(coreId));
|
||||
instruction.r2OrImm = pim::checkedI32OrCrash(remapCoreId(coreId), "communication core id");
|
||||
instruction.generic1 = 0;
|
||||
instruction.generic2 = 0;
|
||||
instruction.generic3 = static_cast<int32_t>(size);
|
||||
instruction.generic3 = pim::checkedI32OrCrash(size, "communication byte size");
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -464,7 +481,7 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 8;
|
||||
instruction.generic1 = 0;
|
||||
instruction.generic2 = static_cast<int32_t>(groupId);
|
||||
instruction.generic2 = pim::checkedI32OrCrash(groupId, "mvm group id");
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -578,7 +595,7 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowle
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvaddOp.getLhs().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvaddOp.getLhs().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -593,7 +610,7 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowle
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvsubOp.getLhs().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvsubOp.getLhs().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -608,7 +625,7 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowle
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmulOp.getLhs().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvmulOp.getLhs().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -623,7 +640,7 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowle
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmaxOp.getLhs().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvmaxOp.getLhs().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -638,7 +655,7 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKno
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvdmulOp.getLhs().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvdmulOp.getLhs().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -653,7 +670,7 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 1;
|
||||
instruction.generic1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vavgOp.getInput().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vavgOp.getInput().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -666,7 +683,7 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowle
|
||||
instruction.opcode = pim_binary::Opcode::vrelu;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vreluOp.getInput().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vreluOp.getInput().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -679,7 +696,7 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowle
|
||||
instruction.opcode = pim_binary::Opcode::vtanh;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vtanhOp.getInput().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vtanhOp.getInput().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -692,7 +709,7 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowle
|
||||
instruction.opcode = pim_binary::Opcode::vsigm;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsigmOp.getInput().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vsigmOp.getInput().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -705,8 +722,7 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
|
||||
instruction.opcode = pim_binary::Opcode::vsoftmax;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 =
|
||||
static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsoftmaxOp.getInput().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vsoftmaxOp.getInput().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -1370,7 +1386,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
if (auto err = linkCoreWeights(job.emittedCoreId, mapCoreWeightToFileName[job.emittedCoreId], xbarsPerGroup))
|
||||
return err;
|
||||
xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup);
|
||||
reportedCoreIds.push_back(static_cast<int32_t>(job.emittedCoreId));
|
||||
reportedCoreIds.push_back(pim::checkedI32OrCrash(job.emittedCoreId, "batch report core id"));
|
||||
if (!batchPerCoreRow)
|
||||
batchPerCoreRow = result.reportRow;
|
||||
batchRow = addMemoryReportRows(batchRow, result.reportRow);
|
||||
|
||||
Reference in New Issue
Block a user