This commit is contained in:
@@ -41,15 +41,23 @@ using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
using namespace onnx_mlir::compact_asm;
|
||||
|
||||
static size_t getElementTypeSizeInBytes(mlir::Type elementType) {
|
||||
if (elementType.isIndex())
|
||||
return sizeof(int64_t);
|
||||
if (elementType.isIntOrFloat())
|
||||
return elementType.getIntOrFloatBitWidth() / 8;
|
||||
llvm_unreachable("unsupported shaped element type");
|
||||
}
|
||||
|
||||
static size_t getValueSizeInBytes(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
return type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
|
||||
}
|
||||
|
||||
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
assert("Only static shape is supported" && type.hasStaticShape());
|
||||
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
size_t allocSize = type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
|
||||
MemEntry memEntry = {0, allocSize};
|
||||
return &memEntries.emplace_back(memEntry, value).first;
|
||||
}
|
||||
@@ -398,20 +406,28 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto deviceTargetOffset = resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge);
|
||||
auto hostSourceOffset = resolveIndexValue(loadOp.getHostSourceOffset(), knowledge);
|
||||
assert(succeeded(deviceTargetOffset) && succeeded(hostSourceOffset)
|
||||
&& "pim.memcp_hd offsets must be statically resolvable during codegen");
|
||||
emitMemCopyOp("ld",
|
||||
addressOf(loadOp.getDeviceTarget(), knowledge),
|
||||
loadOp.getDeviceTargetOffset(),
|
||||
*deviceTargetOffset,
|
||||
addressOf(loadOp.getHostSource(), knowledge),
|
||||
loadOp.getHostSourceOffset(),
|
||||
*hostSourceOffset,
|
||||
loadOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto hostTargetOffset = resolveIndexValue(storeOp.getHostTargetOffset(), knowledge);
|
||||
auto deviceSourceOffset = resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge);
|
||||
assert(succeeded(hostTargetOffset) && succeeded(deviceSourceOffset)
|
||||
&& "pim.memcp_dh offsets must be statically resolvable during codegen");
|
||||
emitMemCopyOp("st",
|
||||
addressOf(storeOp.getHostTarget(), knowledge),
|
||||
storeOp.getHostTargetOffset(),
|
||||
*hostTargetOffset,
|
||||
addressOf(storeOp.getDeviceSource(), knowledge),
|
||||
storeOp.getDeviceSourceOffset(),
|
||||
*deviceSourceOffset,
|
||||
storeOp.getSize());
|
||||
}
|
||||
|
||||
@@ -426,8 +442,9 @@ void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledg
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const {
|
||||
emitCommunicationOp(
|
||||
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
|
||||
auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge);
|
||||
assert(succeeded(sourceCoreId) && "pim.receive source core id must be statically resolvable during codegen");
|
||||
emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
||||
@@ -439,7 +456,9 @@ void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
|
||||
auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge);
|
||||
assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen");
|
||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -728,12 +747,19 @@ std::string getMemorySizeAsString(size_t size) {
|
||||
|
||||
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||
SmallVector<unsigned, 8> indices;
|
||||
auto addIndex = [&](unsigned weightIndex) {
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
|
||||
auto addWeight = [&](mlir::Value weight) {
|
||||
if (!coreOp)
|
||||
return;
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
|
||||
if (coreOp.getWeightArgument(weightIndex) != weight)
|
||||
continue;
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
@@ -795,6 +821,15 @@ static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
|
||||
/// fully resolved before the JSON instructions are emitted.
|
||||
/// Returns the number of emitted instructions, or -1 on failure.
|
||||
static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
auto resolveWeightIndex = [&](pim::PimVMMOp vmmOp) -> std::optional<unsigned> {
|
||||
auto coreOp = vmmOp->getParentOfType<pim::PimCoreOp>();
|
||||
if (!coreOp)
|
||||
return std::nullopt;
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
|
||||
return weightIndex;
|
||||
return std::nullopt;
|
||||
};
|
||||
size_t processedOperations = 0;
|
||||
auto result =
|
||||
walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
@@ -814,8 +849,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
|
||||
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
|
||||
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
|
||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
|
||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) {
|
||||
auto weightIndex = resolveWeightIndex(vmmOp);
|
||||
if (!weightIndex)
|
||||
return failure();
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightIndex, vmmOp, true, knowledge);
|
||||
}
|
||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
|
||||
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
||||
@@ -1004,10 +1043,19 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
||||
reportedCoreIds.reserve(batchCoreIds.size());
|
||||
MemoryReportRow batchRow;
|
||||
std::optional<MemoryReportRow> batchPerCoreRow;
|
||||
llvm::DenseMap<size_t, SmallVector<unsigned>> lanesByCoreId;
|
||||
SmallVector<size_t> orderedOriginalCoreIds;
|
||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
||||
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
||||
auto [it, inserted] = lanesByCoreId.try_emplace(originalCoreId);
|
||||
if (inserted)
|
||||
orderedOriginalCoreIds.push_back(originalCoreId);
|
||||
it->second.push_back(lane);
|
||||
}
|
||||
|
||||
for (size_t originalCoreId : orderedOriginalCoreIds) {
|
||||
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
|
||||
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
|
||||
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
||||
if (failed(withScalarCoreFromBatchLanes(coreBatchOp, lanesByCoreId[originalCoreId], [&](pim::PimCoreOp coreOp) {
|
||||
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
||||
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
|
||||
MemoryReportRow laneRow;
|
||||
|
||||
Reference in New Issue
Block a user