This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
#include "MaterializeMergeSchedule.hpp"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
@@ -16,6 +14,7 @@
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "MaterializeMergeSchedule.hpp"
|
||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -41,9 +40,7 @@ static int32_t getPhysicalCoreId(size_t schedulerCpu) { return static_cast<int32
|
||||
class MergeScheduleMaterializerImpl {
|
||||
public:
|
||||
explicit MergeScheduleMaterializerImpl(func::FuncOp funcOp)
|
||||
: func(funcOp),
|
||||
loc(funcOp.getLoc()),
|
||||
returnOp(cast<func::ReturnOp>(funcOp.getBody().front().getTerminator())) {}
|
||||
: func(funcOp), loc(funcOp.getLoc()), returnOp(cast<func::ReturnOp>(funcOp.getBody().front().getTerminator())) {}
|
||||
|
||||
LogicalResult run(const MergeScheduleResult& scheduleResult, int64_t& nextChannelIdRef) {
|
||||
schedule = &scheduleResult;
|
||||
@@ -66,10 +63,9 @@ public:
|
||||
|
||||
private:
|
||||
struct ScheduledTask {
|
||||
ComputeInstance key;
|
||||
ComputeInstance computeInstance;
|
||||
Operation* sourceOp = nullptr;
|
||||
size_t cpu = 0;
|
||||
size_t slot = 0;
|
||||
size_t order = 0;
|
||||
size_t executionOrder = 0;
|
||||
};
|
||||
@@ -112,7 +108,7 @@ private:
|
||||
values.push_back(value);
|
||||
}
|
||||
|
||||
bool isInternalInputOp(Operation *op) {
|
||||
bool isOldComputeResult(Operation* op) {
|
||||
auto it = isInternalInputOpCache.find(op);
|
||||
if (it != isInternalInputOpCache.end())
|
||||
return it->second;
|
||||
@@ -123,9 +119,9 @@ private:
|
||||
|
||||
for (Value result : extract->getResults()) {
|
||||
for (Operation* user : result.getUsers()) {
|
||||
if (toEraseSet.contains(user))
|
||||
if (oldComputeOps.contains(user))
|
||||
continue;
|
||||
if (isInternalInputOp(user))
|
||||
if (isOldComputeResult(user))
|
||||
continue;
|
||||
return isInternalInputOpCache[op] = false;
|
||||
}
|
||||
@@ -135,8 +131,9 @@ private:
|
||||
|
||||
void collectInternalInputOps(Value value) {
|
||||
Operation* op = value.getDefiningOp();
|
||||
//TODO ExtractSliceOp is not the only legal host op to traverse! dio
|
||||
while (auto extract = dyn_cast_if_present<tensor::ExtractSliceOp>(op)) {
|
||||
if (isInternalInputOp(extract.getOperation()))
|
||||
if (isOldComputeResult(extract.getOperation()))
|
||||
internalInputOpsToErase.insert(extract.getOperation());
|
||||
value = extract.getSource();
|
||||
op = value.getDefiningOp();
|
||||
@@ -148,7 +145,7 @@ private:
|
||||
return;
|
||||
for (Value result : op->getResults()) {
|
||||
for (Operation* user : result.getUsers()) {
|
||||
if (toEraseSet.contains(user) || isa<func::ReturnOp>(user))
|
||||
if (oldComputeOps.contains(user) || isa<func::ReturnOp>(user))
|
||||
continue;
|
||||
collectExternalUsers(user);
|
||||
}
|
||||
@@ -158,10 +155,12 @@ private:
|
||||
void collectScheduledTasks() {
|
||||
size_t nextOrder = 0;
|
||||
for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) {
|
||||
toEraseSet.insert(scheduledInstance.op);
|
||||
scheduledTasks.push_back(
|
||||
{scheduledInstance, scheduledInstance.op, schedule->computeToCpuMap.lookup(scheduledInstance),
|
||||
schedule->computeToCpuSlotMap.lookup(scheduledInstance), nextOrder++});
|
||||
oldComputeOps.insert(scheduledInstance.op);
|
||||
scheduledTasks.push_back({scheduledInstance,
|
||||
scheduledInstance.op,
|
||||
schedule->computeToCpuMap.lookup(scheduledInstance),
|
||||
schedule->computeToCpuSlotMap.lookup(scheduledInstance),
|
||||
nextOrder++});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,21 +171,18 @@ private:
|
||||
};
|
||||
|
||||
for (const ScheduledTask& task : scheduledTasks) {
|
||||
taskByKey[task.key] = task;
|
||||
taskByComputeInstance[task.computeInstance] = task;
|
||||
tasksByCpu[task.cpu].push_back(task);
|
||||
markCpuSeen(task.cpu);
|
||||
}
|
||||
|
||||
llvm::sort(orderedCpus);
|
||||
for (size_t cpu : orderedCpus) {
|
||||
llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask &lhs, const ScheduledTask &rhs) {
|
||||
if (lhs.slot != rhs.slot)
|
||||
return lhs.slot < rhs.slot;
|
||||
return lhs.order < rhs.order;
|
||||
});
|
||||
llvm::stable_sort(tasksByCpu[cpu],
|
||||
[&](const ScheduledTask& lhs, const ScheduledTask& rhs) { return lhs.order < rhs.order; });
|
||||
for (auto [executionOrder, task] : llvm::enumerate(tasksByCpu[cpu])) {
|
||||
task.executionOrder = executionOrder;
|
||||
taskByKey[task.key].executionOrder = executionOrder;
|
||||
taskByComputeInstance[task.computeInstance].executionOrder = executionOrder;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -194,19 +190,22 @@ private:
|
||||
void collectExternalInputsAndWeights() {
|
||||
for (size_t cpu : orderedCpus) {
|
||||
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||
auto taskWeights = getComputeInstanceWeights(task.key);
|
||||
auto& thisCpuWeights = cpuWeights[cpu];
|
||||
auto& thisSeenWeights = seenWeightsByCpu[cpu];
|
||||
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
|
||||
for (Value weight : taskWeights)
|
||||
appendUniqueValue(cpuWeights[cpu], seenWeightsByCpu[cpu], weight);
|
||||
if (thisSeenWeights.insert(weight).second)
|
||||
thisCpuWeights.push_back(weight);
|
||||
|
||||
auto taskInputs = getComputeInstanceInputs(task.key);
|
||||
auto &remoteInputs = remoteInputsByTask[task.key];
|
||||
auto taskInputs = getComputeInstanceInputs(task.computeInstance);
|
||||
auto& remoteInputs = remoteInputsByTask[task.computeInstance];
|
||||
remoteInputs.resize(taskInputs.size());
|
||||
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||
auto producerRef = getProducerValueRef(input);
|
||||
if (producerRef) {
|
||||
collectInternalInputOps(input);
|
||||
auto producerIt = taskByKey.find(producerRef->instance);
|
||||
if (producerIt != taskByKey.end()) {
|
||||
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
||||
if (producerIt != taskByComputeInstance.end()) {
|
||||
if (producerIt->second.cpu != cpu) {
|
||||
ChannelInfo info {
|
||||
(*nextChannelId)++,
|
||||
@@ -216,9 +215,9 @@ private:
|
||||
remoteInputs[inputIndex] = info;
|
||||
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
||||
if (perResultChannels.empty())
|
||||
perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.key).size());
|
||||
perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.computeInstance).size());
|
||||
perResultChannels[producerRef->resultIndex].push_back(
|
||||
{info, task.key, inputIndex, task.executionOrder, 0});
|
||||
{info, task.computeInstance, inputIndex, task.executionOrder, 0});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@@ -226,19 +225,19 @@ private:
|
||||
appendUniqueValue(cpuExternalInputs[cpu], seenExternalInputsByCpu[cpu], input);
|
||||
}
|
||||
|
||||
auto taskOutputs = getComputeInstanceOutputValues(task.key);
|
||||
auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance);
|
||||
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
|
||||
bool hasExternalUser = false;
|
||||
for (auto& use : output.getUses()) {
|
||||
Operation* useOwner = use.getOwner();
|
||||
if (toEraseSet.contains(useOwner))
|
||||
if (oldComputeOps.contains(useOwner))
|
||||
continue;
|
||||
hasExternalUser = true;
|
||||
if (!isa<func::ReturnOp>(useOwner))
|
||||
collectExternalUsers(useOwner);
|
||||
}
|
||||
if (hasExternalUser)
|
||||
cpuExternalOutputs[cpu].push_back({task.key, resultIndex});
|
||||
cpuExternalOutputs[cpu].push_back({task.computeInstance, resultIndex});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -249,7 +248,7 @@ private:
|
||||
DenseMap<uint64_t, size_t> nextSourceOrderByPair;
|
||||
DenseMap<uint64_t, size_t> lastConsumerOrderByPair;
|
||||
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||
auto sendsIt = remoteSendsByTask.find(task.key);
|
||||
auto sendsIt = remoteSendsByTask.find(task.computeInstance);
|
||||
if (sendsIt == remoteSendsByTask.end())
|
||||
continue;
|
||||
for (auto& sendInfos : sendsIt->second) {
|
||||
@@ -344,8 +343,8 @@ private:
|
||||
SmallVector<Type> resultTypes;
|
||||
resultTypes.reserve(cpuExternalOutputs[cpu].size());
|
||||
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
||||
ScheduledTask task = taskByKey.at(outputRef.instance);
|
||||
resultTypes.push_back(getComputeInstanceOutputTypes(task.key)[outputRef.resultIndex]);
|
||||
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||
resultTypes.push_back(getComputeInstanceOutputTypes(task.computeInstance)[outputRef.resultIndex]);
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
@@ -373,8 +372,8 @@ private:
|
||||
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[cpu]))
|
||||
program.externalInputMap[input] = newBlock->getArgument(inputIndex);
|
||||
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[cpu])) {
|
||||
ScheduledTask task = taskByKey.at(outputRef.instance);
|
||||
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.key)[outputRef.resultIndex]] =
|
||||
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] =
|
||||
newCompute.getResult(resultIndex);
|
||||
}
|
||||
cpuPrograms[cpu] = std::move(program);
|
||||
@@ -400,15 +399,14 @@ private:
|
||||
size_t& queueIndex = receiveQueueIndices[pairKey];
|
||||
while (queueIndex < queue.size()) {
|
||||
const RemoteReceiveEntry& entry = queue[queueIndex++];
|
||||
auto consumerTaskIt = taskByKey.find(entry.consumer);
|
||||
if (consumerTaskIt == taskByKey.end())
|
||||
auto consumerTaskIt = taskByComputeInstance.find(entry.consumer);
|
||||
if (consumerTaskIt == taskByComputeInstance.end())
|
||||
return failure();
|
||||
SmallVector<Value> consumerInputs = getComputeInstanceInputs(consumerTaskIt->second.key);
|
||||
SmallVector<Value> consumerInputs = getComputeInstanceInputs(consumerTaskIt->second.computeInstance);
|
||||
if (consumerInputs.size() <= entry.inputIndex)
|
||||
return failure();
|
||||
Type inputType = consumerInputs[entry.inputIndex].getType();
|
||||
auto receive =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter,
|
||||
auto receive = spatial::SpatChannelReceiveOp::create(rewriter,
|
||||
loc,
|
||||
inputType,
|
||||
rewriter.getI64IntegerAttr(entry.channelInfo.channelId),
|
||||
@@ -445,24 +443,23 @@ private:
|
||||
};
|
||||
|
||||
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||
SmallVector<Value> taskInputs = getComputeInstanceInputs(task.key);
|
||||
auto taskWeights = getComputeInstanceWeights(task.key);
|
||||
Block &templateBlock = getComputeInstanceTemplateBlock(task.key);
|
||||
SmallVector<Value> taskInputs = getComputeInstanceInputs(task.computeInstance);
|
||||
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
|
||||
Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance);
|
||||
|
||||
SmallVector<Value> resolvedInputs;
|
||||
resolvedInputs.reserve(taskInputs.size());
|
||||
auto remoteInputsIt = remoteInputsByTask.find(task.key);
|
||||
auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance);
|
||||
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||
auto producerRef = getProducerValueRef(input);
|
||||
if (producerRef) {
|
||||
auto producerIt = taskByKey.find(producerRef->instance);
|
||||
if (producerIt != taskByKey.end()) {
|
||||
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
||||
if (producerIt != taskByComputeInstance.end()) {
|
||||
if (producerIt->second.cpu == cpu) {
|
||||
auto producedIt = producedValuesByTask.find(producerRef->instance);
|
||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
|
||||
task.sourceOp->emitOpError("missing local producer value during per-cpu merge materialization")
|
||||
<< " consumerCpu=" << cpu << " consumerSlot=" << task.slot
|
||||
<< " producerCpu=" << producerIt->second.cpu << " producerSlot=" << producerIt->second.slot
|
||||
<< " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu
|
||||
<< " producerLaneStart=" << producerRef->instance.laneStart
|
||||
<< " producerLaneCount=" << producerRef->instance.laneCount;
|
||||
return failure();
|
||||
@@ -473,17 +470,21 @@ private:
|
||||
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
|
||||
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
|
||||
if (pairsNeedingReceiveReorder.contains(pairKey)) {
|
||||
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.key, inputIndex)) {
|
||||
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.computeInstance, inputIndex)) {
|
||||
resolvedInputs.push_back(*preReceived);
|
||||
continue;
|
||||
}
|
||||
FailureOr<Value> received = receiveThroughInput(
|
||||
rewriter, cpu, receiveQueueIndices, preReceivedInputsByTask, channelInfo, task.key, inputIndex);
|
||||
FailureOr<Value> received = receiveThroughInput(rewriter,
|
||||
cpu,
|
||||
receiveQueueIndices,
|
||||
preReceivedInputsByTask,
|
||||
channelInfo,
|
||||
task.computeInstance,
|
||||
inputIndex);
|
||||
if (failed(received)) {
|
||||
task.sourceOp->emitOpError("failed to materialize reordered remote receive")
|
||||
<< " consumerCpu=" << cpu << " consumerSlot=" << task.slot
|
||||
<< " sourceCoreId=" << channelInfo.sourceCoreId << " targetCoreId=" << channelInfo.targetCoreId
|
||||
<< " channelId=" << channelInfo.channelId;
|
||||
<< " consumerCpu=" << cpu << " sourceCoreId=" << channelInfo.sourceCoreId
|
||||
<< " targetCoreId=" << channelInfo.targetCoreId << " channelId=" << channelInfo.channelId;
|
||||
return failure();
|
||||
}
|
||||
resolvedInputs.push_back(*received);
|
||||
@@ -529,8 +530,9 @@ private:
|
||||
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (size_t laneOffset = 0; laneOffset < task.key.laneCount; ++laneOffset) {
|
||||
}
|
||||
else {
|
||||
for (size_t laneOffset = 0; laneOffset < task.computeInstance.laneCount; ++laneOffset) {
|
||||
IRMapping mapper;
|
||||
if (templateBlock.getNumArguments() == 1)
|
||||
mapper.map(templateBlock.getArgument(0), resolvedInputs[laneOffset]);
|
||||
@@ -545,8 +547,7 @@ private:
|
||||
Operation* clonedOp = rewriter.clone(op, mapper);
|
||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||
if (oldWeightedMvmOp.getWeightIndex() != 0) {
|
||||
task.sourceOp->emitOpError(
|
||||
"batched per-cpu merge materialization expects lane-local weight index 0");
|
||||
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
|
||||
return failure();
|
||||
}
|
||||
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||
@@ -554,8 +555,7 @@ private:
|
||||
}
|
||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
||||
if (oldWeightedVmmOp.getWeightIndex() != 0) {
|
||||
task.sourceOp->emitOpError(
|
||||
"batched per-cpu merge materialization expects lane-local weight index 0");
|
||||
task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0");
|
||||
return failure();
|
||||
}
|
||||
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
||||
@@ -565,8 +565,8 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
producedValuesByTask[task.key] = taskYieldValues;
|
||||
if (auto sendsIt = remoteSendsByTask.find(task.key); sendsIt != remoteSendsByTask.end()) {
|
||||
producedValuesByTask[task.computeInstance] = taskYieldValues;
|
||||
if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) {
|
||||
for (auto [resultIndex, sendInfos] : llvm::enumerate(sendsIt->second)) {
|
||||
if (sendInfos.empty())
|
||||
continue;
|
||||
@@ -588,9 +588,9 @@ private:
|
||||
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
||||
auto producedIt = producedValuesByTask.find(outputRef.instance);
|
||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) {
|
||||
ScheduledTask task = taskByKey.at(outputRef.instance);
|
||||
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||
task.sourceOp->emitOpError("missing yielded external value during per-cpu merge materialization")
|
||||
<< " cpu=" << cpu << " slot=" << task.slot << " laneStart=" << outputRef.instance.laneStart;
|
||||
<< " cpu=" << cpu << " laneStart=" << outputRef.instance.laneStart;
|
||||
return failure();
|
||||
}
|
||||
yieldValues.push_back(producedIt->second[outputRef.resultIndex]);
|
||||
@@ -604,13 +604,13 @@ private:
|
||||
void replaceExternalUses() {
|
||||
for (auto [oldValue, newValue] : oldToNewExternalValueMap) {
|
||||
for (auto& use : llvm::make_early_inc_range(oldValue.getUses()))
|
||||
if (!toEraseSet.contains(use.getOwner()))
|
||||
if (!oldComputeOps.contains(use.getOwner()))
|
||||
use.assign(newValue);
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult eraseOldScheduledOps() {
|
||||
DenseSet<Operation *> allOpsToErase = toEraseSet;
|
||||
DenseSet<Operation*> allOpsToErase = oldComputeOps;
|
||||
for (Operation* op : internalInputOpsToErase)
|
||||
allOpsToErase.insert(op);
|
||||
|
||||
@@ -658,8 +658,8 @@ private:
|
||||
func::ReturnOp returnOp;
|
||||
|
||||
SmallVector<ScheduledTask> scheduledTasks;
|
||||
DenseSet<Operation *> toEraseSet;
|
||||
DenseMap<ComputeInstance, ScheduledTask> taskByKey;
|
||||
DenseSet<Operation*> oldComputeOps;
|
||||
DenseMap<ComputeInstance, ScheduledTask> taskByComputeInstance;
|
||||
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
|
||||
SmallVector<size_t> orderedCpus;
|
||||
DenseSet<size_t> seenCpus;
|
||||
|
||||
Reference in New Issue
Block a user