add shared loop creation helpers
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
add shared checked arithmetic helpers refactor pim passes into Pim/Transforms more robust memory coalescing pass
This commit is contained in:
@@ -28,7 +28,10 @@ add_pim_library(OMPimCompilerUtils
|
||||
OMPimCompilerOptions
|
||||
OMPimCommon
|
||||
OMPimBufferization
|
||||
OMPimStaticMemoryCoalescing
|
||||
OMPimMemoryCoalescing
|
||||
OMPimHostConstantFolding
|
||||
OMPimHostConstantMaterialization
|
||||
OMPimVerification
|
||||
OMPimPasses
|
||||
OMONNXToSpatial
|
||||
OMSpatialToPim
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
|
||||
namespace onnx_mlir::pim_binary {
|
||||
|
||||
@@ -95,15 +95,10 @@ inline void writeInstructionRecord(llvm::raw_ostream& os, const InstructionRecor
|
||||
writeInt32LE(os, record.generic3);
|
||||
}
|
||||
|
||||
inline int32_t toI32(int64_t value) {
|
||||
assert(value >= std::numeric_limits<int32_t>::min() && value <= std::numeric_limits<int32_t>::max()
|
||||
&& "PIM binary field out of int32 range");
|
||||
return static_cast<int32_t>(value);
|
||||
}
|
||||
inline int32_t toI32(int64_t value) { return onnx_mlir::pim::checkedI32OrCrash(value, "binary field"); }
|
||||
|
||||
inline uint8_t toU8(int64_t value) {
|
||||
assert(value >= 0 && value <= std::numeric_limits<uint8_t>::max() && "PIM binary field out of uint8 range");
|
||||
return static_cast<uint8_t>(value);
|
||||
return onnx_mlir::pim::checkedU8OrCrash(static_cast<uint64_t>(value), "binary field");
|
||||
}
|
||||
|
||||
inline int32_t getOptionalInt(const llvm::json::Object& object, llvm::StringRef key, int32_t defaultValue = 0) {
|
||||
|
||||
@@ -25,12 +25,14 @@
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "Common/IR/CompactAsmUtils.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Common/Support/CheckedArithmetic.hpp"
|
||||
#include "Common/Support/ReportUtils.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
@@ -71,12 +73,23 @@ static MemoryValueKey getMemoryValueKey(mlir::Value value, std::optional<unsigne
|
||||
return {value, getLaneForMemoryValue(value, lane)};
|
||||
}
|
||||
|
||||
static int32_t getVectorByteSizeOrCrash(ShapedType type) {
|
||||
auto byteSize = pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "vector byte size");
|
||||
if (failed(byteSize))
|
||||
llvm_unreachable("Failed to compute checked vector byte size");
|
||||
return pim::checkedI32OrCrash(*byteSize, "vector byte size");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MemEntry* PimMemory::gatherMemEntry(mlir::Value value, std::optional<unsigned> lane) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
assert("Only static shape is supported" && type.hasStaticShape());
|
||||
size_t allocSize = getShapedTypeSizeInBytes(type);
|
||||
auto checkedAllocSize =
|
||||
pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "memory allocation byte size");
|
||||
if (failed(checkedAllocSize))
|
||||
llvm_unreachable("Failed to compute checked allocation byte size");
|
||||
size_t allocSize = static_cast<size_t>(*checkedAllocSize);
|
||||
MemEntry memEntry = {0, allocSize};
|
||||
return &memEntries.emplace_back(memEntry, getMemoryValueKey(value, lane)).first;
|
||||
}
|
||||
@@ -272,7 +285,8 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value,
|
||||
llvm_unreachable("Missing mem entry");
|
||||
}
|
||||
|
||||
return iter->second.address + resolvedAddress->byteOffset;
|
||||
size_t byteOffset = pim::checkedSizeOrCrash(resolvedAddress->byteOffset, "resolved PIM byte offset");
|
||||
return pim::checkedAddOrCrash(iter->second.address, byteOffset, "resolved PIM address");
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> PimAcceleratorMemory::getIndexValue(mlir::Value value,
|
||||
@@ -291,8 +305,12 @@ llvm::FailureOr<int64_t> PimAcceleratorMemory::getIndexValue(mlir::Value value,
|
||||
void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); }
|
||||
|
||||
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
|
||||
reportEntries.push_back(
|
||||
{MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row, row.numAlloca, row.sizeAlloca});
|
||||
reportEntries.push_back({MemoryReportEntry::Kind::Core,
|
||||
coreId,
|
||||
{pim::checkedI32OrCrash(coreId, "memory report core id")},
|
||||
row,
|
||||
row.numAlloca,
|
||||
row.sizeAlloca});
|
||||
}
|
||||
|
||||
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId,
|
||||
@@ -402,24 +420,24 @@ void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t i
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::sldi;
|
||||
instruction.rd = static_cast<uint8_t>(registerNumber);
|
||||
instruction.r2OrImm = static_cast<int32_t>(immediate);
|
||||
instruction.r2OrImm = pim::checkedI32OrCrash(immediate, "register immediate");
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::setupRd(size_t rdAddress, size_t rdOffset) const {
|
||||
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
|
||||
genSetRegisterImmediateUnsigned(0, pim::checkedAddOrCrash(rdAddress, rdOffset, "rd address"));
|
||||
}
|
||||
|
||||
void PimCodeGen::setupRdRs1(size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset) const {
|
||||
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
|
||||
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset);
|
||||
genSetRegisterImmediateUnsigned(0, pim::checkedAddOrCrash(rdAddress, rdOffset, "rd address"));
|
||||
genSetRegisterImmediateUnsigned(1, pim::checkedAddOrCrash(rs1Address, rs1Offset, "rs1 address"));
|
||||
}
|
||||
|
||||
void PimCodeGen::setupRdRs1Rs2(
|
||||
size_t rdAddress, size_t rdOffset, size_t rs1Address, size_t rs1Offset, size_t rs2Address, size_t rs2Offset) const {
|
||||
genSetRegisterImmediateUnsigned(0, rdAddress + rdOffset);
|
||||
genSetRegisterImmediateUnsigned(1, rs1Address + rs1Offset);
|
||||
genSetRegisterImmediateUnsigned(2, rs2Address + rs2Offset);
|
||||
genSetRegisterImmediateUnsigned(0, pim::checkedAddOrCrash(rdAddress, rdOffset, "rd address"));
|
||||
genSetRegisterImmediateUnsigned(1, pim::checkedAddOrCrash(rs1Address, rs1Offset, "rs1 address"));
|
||||
genSetRegisterImmediateUnsigned(2, pim::checkedAddOrCrash(rs2Address, rs2Offset, "rs2 address"));
|
||||
}
|
||||
|
||||
void PimCodeGen::emitMemCopyOp(StringRef opName,
|
||||
@@ -437,8 +455,7 @@ void PimCodeGen::emitMemCopyOp(StringRef opName,
|
||||
instruction.r1 = 1;
|
||||
instruction.generic1 = 0;
|
||||
instruction.generic2 = 0;
|
||||
instruction.generic3 = static_cast<int32_t>(size);
|
||||
(void) sizeFieldName;
|
||||
instruction.generic3 = pim::checkedI32OrCrash(size, sizeFieldName);
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -448,10 +465,10 @@ void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::opcodeFromString(opName);
|
||||
instruction.rd = 0;
|
||||
instruction.r2OrImm = static_cast<int32_t>(remapCoreId(coreId));
|
||||
instruction.r2OrImm = pim::checkedI32OrCrash(remapCoreId(coreId), "communication core id");
|
||||
instruction.generic1 = 0;
|
||||
instruction.generic2 = 0;
|
||||
instruction.generic3 = static_cast<int32_t>(size);
|
||||
instruction.generic3 = pim::checkedI32OrCrash(size, "communication byte size");
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -464,7 +481,7 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 8;
|
||||
instruction.generic1 = 0;
|
||||
instruction.generic2 = static_cast<int32_t>(groupId);
|
||||
instruction.generic2 = pim::checkedI32OrCrash(groupId, "mvm group id");
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -578,7 +595,7 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowle
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvaddOp.getLhs().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvaddOp.getLhs().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -593,7 +610,7 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowle
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvsubOp.getLhs().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvsubOp.getLhs().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -608,7 +625,7 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowle
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmulOp.getLhs().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvmulOp.getLhs().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -623,7 +640,7 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowle
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmaxOp.getLhs().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvmaxOp.getLhs().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -638,7 +655,7 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKno
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvdmulOp.getLhs().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vvdmulOp.getLhs().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -653,7 +670,7 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 1;
|
||||
instruction.generic1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vavgOp.getInput().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vavgOp.getInput().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -666,7 +683,7 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowle
|
||||
instruction.opcode = pim_binary::Opcode::vrelu;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vreluOp.getInput().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vreluOp.getInput().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -679,7 +696,7 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowle
|
||||
instruction.opcode = pim_binary::Opcode::vtanh;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vtanhOp.getInput().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vtanhOp.getInput().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -692,7 +709,7 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowle
|
||||
instruction.opcode = pim_binary::Opcode::vsigm;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsigmOp.getInput().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vsigmOp.getInput().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -705,8 +722,7 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
|
||||
instruction.opcode = pim_binary::Opcode::vsoftmax;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 =
|
||||
static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsoftmaxOp.getInput().getType())));
|
||||
instruction.generic3 = getVectorByteSizeOrCrash(cast<ShapedType>(vsoftmaxOp.getInput().getType()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
@@ -1370,7 +1386,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
if (auto err = linkCoreWeights(job.emittedCoreId, mapCoreWeightToFileName[job.emittedCoreId], xbarsPerGroup))
|
||||
return err;
|
||||
xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup);
|
||||
reportedCoreIds.push_back(static_cast<int32_t>(job.emittedCoreId));
|
||||
reportedCoreIds.push_back(pim::checkedI32OrCrash(job.emittedCoreId, "batch report core id"));
|
||||
if (!batchPerCoreRow)
|
||||
batchPerCoreRow = result.reportRow;
|
||||
batchRow = addMemoryReportRows(batchRow, result.reportRow);
|
||||
|
||||
@@ -40,7 +40,6 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
|
||||
if (pimEmissionTarget >= EmitPimBufferized) {
|
||||
pm.addPass(createPimBufferizationPass());
|
||||
pm.addPass(createPimStaticMemoryCoalescingPass());
|
||||
pm.addPass(createMessagePass("Pim bufferized"));
|
||||
}
|
||||
|
||||
@@ -48,6 +47,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
pm.addPass(createPimHostConstantFoldingPass());
|
||||
pm.addPass(createMessagePass("Pim host constants folded"));
|
||||
pm.addPass(createPimMaterializeHostConstantsPass());
|
||||
pm.addPass(createPimMemoryCoalescingPass());
|
||||
pm.addPass(createPimVerificationPass());
|
||||
pm.addPass(createMessagePass("Pim verified"));
|
||||
pm.addPass(createEmitPimCodePass());
|
||||
|
||||
Reference in New Issue
Block a user