Equivalent Class but broken
This commit is contained in:
+347
-340
@@ -58,18 +58,15 @@ static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int
|
|||||||
}
|
}
|
||||||
|
|
||||||
static Value createIndexTensorConstant(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
|
static Value createIndexTensorConstant(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
|
||||||
auto tensorType = RankedTensorType::get({static_cast<int64_t>(values.size())}, IndexType::get(anchorOp->getContext()));
|
auto tensorType =
|
||||||
|
RankedTensorType::get({static_cast<int64_t>(values.size())}, IndexType::get(anchorOp->getContext()));
|
||||||
auto tensorAttr = DenseIntElementsAttr::get(tensorType, values);
|
auto tensorAttr = DenseIntElementsAttr::get(tensorType, values);
|
||||||
return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder);
|
return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createIndexTupleTensorConstant(Operation* anchorOp,
|
static Value createIndexTupleTensorConstant(
|
||||||
int64_t tupleCount,
|
Operation* anchorOp, int64_t tupleCount, int64_t tupleWidth, ArrayRef<int64_t> values, OperationFolder& folder) {
|
||||||
int64_t tupleWidth,
|
auto tensorType = RankedTensorType::get({tupleCount, tupleWidth}, IndexType::get(anchorOp->getContext()));
|
||||||
ArrayRef<int64_t> values,
|
|
||||||
OperationFolder& folder) {
|
|
||||||
auto tensorType =
|
|
||||||
RankedTensorType::get({tupleCount, tupleWidth}, IndexType::get(anchorOp->getContext()));
|
|
||||||
auto tensorAttr = DenseIntElementsAttr::get(tensorType, values);
|
auto tensorAttr = DenseIntElementsAttr::get(tensorType, values);
|
||||||
return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder);
|
return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder);
|
||||||
}
|
}
|
||||||
@@ -114,12 +111,14 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct CpuProgram {
|
struct CpuProgram {
|
||||||
SpatCompute op;
|
Operation* op = nullptr;
|
||||||
DenseMap<Value, Value> externalInputMap;
|
DenseMap<Value, Value> externalInputMap;
|
||||||
DenseMap<Value, size_t> weightToIndex;
|
DenseMap<Value, size_t> weightToIndex;
|
||||||
};
|
};
|
||||||
|
|
||||||
using ProgramKey = std::pair<size_t, size_t>;
|
using ProgramKey = size_t; // Represents the "Leader" CPU
|
||||||
|
DenseMap<ProgramKey, SmallVector<size_t>> batchedCpus;
|
||||||
|
SmallVector<ProgramKey> orderedPrograms;
|
||||||
|
|
||||||
struct RemoteSendInfo {
|
struct RemoteSendInfo {
|
||||||
ChannelInfo channelInfo;
|
ChannelInfo channelInfo;
|
||||||
@@ -171,7 +170,7 @@ private:
|
|||||||
| static_cast<uint32_t>(channelInfo.targetCoreId);
|
| static_cast<uint32_t>(channelInfo.targetCoreId);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ProgramKey getProgramKey(const ScheduledTask& task) { return {task.cpu, 0}; }
|
static ProgramKey getProgramKey(const ScheduledTask& task) { return task.cpu; }
|
||||||
|
|
||||||
static bool isResultfulBatchInstance(const ComputeInstance& instance) {
|
static bool isResultfulBatchInstance(const ComputeInstance& instance) {
|
||||||
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
|
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
|
||||||
@@ -377,8 +376,8 @@ private:
|
|||||||
void emitExtractRowsSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractRowsSendRun& run) {
|
void emitExtractRowsSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractRowsSendRun& run) {
|
||||||
SmallVector<int64_t> prefixSums = buildPrefixSums(run.sendCounts);
|
SmallVector<int64_t> prefixSums = buildPrefixSums(run.sendCounts);
|
||||||
Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder);
|
Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder);
|
||||||
Value channelSourceTargetTuples = createIndexTupleTensorConstant(
|
Value channelSourceTargetTuples =
|
||||||
hostAnchor,
|
createIndexTupleTensorConstant(hostAnchor,
|
||||||
static_cast<int64_t>(run.channelSourceTargetTuples.size() / 3),
|
static_cast<int64_t>(run.channelSourceTargetTuples.size() / 3),
|
||||||
3,
|
3,
|
||||||
run.channelSourceTargetTuples,
|
run.channelSourceTargetTuples,
|
||||||
@@ -413,8 +412,7 @@ private:
|
|||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
Value nextRowIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step);
|
Value nextRowIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step);
|
||||||
Value innerLower =
|
Value innerLower = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()});
|
||||||
tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()});
|
|
||||||
Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextRowIndex});
|
Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextRowIndex});
|
||||||
emitInnerSendLoop(hostAnchor, rewriter, extractedRow, innerLower, innerUpper, channelSourceTargetTuples);
|
emitInnerSendLoop(hostAnchor, rewriter, extractedRow, innerLower, innerUpper, channelSourceTargetTuples);
|
||||||
rewriter.setInsertionPointAfter(outerLoop);
|
rewriter.setInsertionPointAfter(outerLoop);
|
||||||
@@ -423,8 +421,8 @@ private:
|
|||||||
void emitExtractSliceSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractSliceSendRun& run) {
|
void emitExtractSliceSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractSliceSendRun& run) {
|
||||||
SmallVector<int64_t> prefixSums = buildPrefixSums(run.sendCounts);
|
SmallVector<int64_t> prefixSums = buildPrefixSums(run.sendCounts);
|
||||||
Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder);
|
Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder);
|
||||||
Value channelSourceTargetTuples = createIndexTupleTensorConstant(
|
Value channelSourceTargetTuples =
|
||||||
hostAnchor,
|
createIndexTupleTensorConstant(hostAnchor,
|
||||||
static_cast<int64_t>(run.channelSourceTargetTuples.size() / 3),
|
static_cast<int64_t>(run.channelSourceTargetTuples.size() / 3),
|
||||||
3,
|
3,
|
||||||
run.channelSourceTargetTuples,
|
run.channelSourceTargetTuples,
|
||||||
@@ -469,8 +467,7 @@ private:
|
|||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
Value nextSliceIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step);
|
Value nextSliceIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step);
|
||||||
Value innerLower =
|
Value innerLower = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()});
|
||||||
tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()});
|
|
||||||
Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextSliceIndex});
|
Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextSliceIndex});
|
||||||
emitInnerSendLoop(hostAnchor, rewriter, extractedSlice, innerLower, innerUpper, channelSourceTargetTuples);
|
emitInnerSendLoop(hostAnchor, rewriter, extractedSlice, innerLower, innerUpper, channelSourceTargetTuples);
|
||||||
rewriter.setInsertionPointAfter(outerLoop);
|
rewriter.setInsertionPointAfter(outerLoop);
|
||||||
@@ -532,149 +529,112 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void buildTaskIndex() {
|
void buildTaskIndex() {
|
||||||
auto markCpuSeen = [&](size_t cpu) {
|
DenseSet<size_t> seen;
|
||||||
if (seenCpus.insert(cpu).second)
|
|
||||||
orderedCpus.push_back(cpu);
|
|
||||||
};
|
|
||||||
|
|
||||||
for (const ScheduledTask& task : scheduledTasks) {
|
for (const ScheduledTask& task : scheduledTasks) {
|
||||||
taskByComputeInstance[task.computeInstance] = task;
|
taskByComputeInstance[task.computeInstance] = task;
|
||||||
tasksByCpu[task.cpu].push_back(task);
|
tasksByCpu[task.cpu].push_back(task);
|
||||||
ProgramKey programKey = getProgramKey(task);
|
seen.insert(task.cpu);
|
||||||
tasksByProgram[programKey].push_back(task);
|
|
||||||
if (seenPrograms.insert(programKey).second)
|
|
||||||
orderedPrograms.push_back(programKey);
|
|
||||||
markCpuSeen(task.cpu);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::sort(orderedCpus);
|
SmallVector<size_t> activeCpus(seen.begin(), seen.end());
|
||||||
llvm::sort(orderedPrograms, [](const ProgramKey& lhs, const ProgramKey& rhs) {
|
llvm::sort(activeCpus);
|
||||||
if (lhs.second != rhs.second)
|
|
||||||
return lhs.second < rhs.second;
|
DenseSet<size_t> batched;
|
||||||
return lhs.first < rhs.first;
|
for (size_t cpu : activeCpus) {
|
||||||
});
|
if (batched.contains(cpu))
|
||||||
for (size_t cpu : orderedCpus)
|
continue;
|
||||||
|
|
||||||
|
SmallVector<size_t> batch;
|
||||||
|
batch.push_back(cpu);
|
||||||
|
batched.insert(cpu);
|
||||||
|
|
||||||
|
// Group all equivalent CPUs into this batch
|
||||||
|
auto it = schedule->equivalentClass.find(cpu);
|
||||||
|
if (it != schedule->equivalentClass.end()) {
|
||||||
|
for (size_t eqCpu : it->second)
|
||||||
|
if (batched.insert(eqCpu).second)
|
||||||
|
batch.push_back(eqCpu);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::sort(batch);
|
||||||
|
size_t leader = batch.front();
|
||||||
|
batchedCpus[leader] = batch;
|
||||||
|
orderedPrograms.push_back(leader);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t cpu : activeCpus) {
|
||||||
llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) {
|
llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) {
|
||||||
return lhs.orderWithinCpu < rhs.orderWithinCpu;
|
return lhs.orderWithinCpu < rhs.orderWithinCpu;
|
||||||
});
|
});
|
||||||
for (ProgramKey programKey : orderedPrograms)
|
}
|
||||||
llvm::stable_sort(tasksByProgram[programKey], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) {
|
|
||||||
return lhs.orderWithinCpu < rhs.orderWithinCpu;
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void collectExternalInputsAndWeights() {
|
void collectExternalInputsAndWeights() {
|
||||||
for (ProgramKey programKey : orderedPrograms) {
|
for (ProgramKey leader : orderedPrograms) {
|
||||||
size_t cpu = programKey.first;
|
const auto& batch = batchedCpus[leader];
|
||||||
for (const ScheduledTask& task : tasksByProgram[programKey]) {
|
auto& thisCpuWeights = cpuWeights[leader];
|
||||||
auto& thisCpuWeights = cpuWeights[programKey];
|
auto& thisCpuInputs = cpuExternalInputs[leader];
|
||||||
auto& thisSeenWeights = seenWeightsByProgram[programKey];
|
auto& thisCpuOutputs = cpuExternalOutputs[leader];
|
||||||
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
|
|
||||||
for (Value weight : taskWeights)
|
// Process every lane sequentially to pack operands
|
||||||
if (thisSeenWeights.insert(weight).second)
|
for (size_t cpu : batch) {
|
||||||
|
DenseSet<Value> laneSeenWeights;
|
||||||
|
DenseSet<Value> laneSeenInputs;
|
||||||
|
|
||||||
|
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||||
|
for (Value weight : getComputeInstanceWeights(task.computeInstance))
|
||||||
|
if (laneSeenWeights.insert(weight).second)
|
||||||
thisCpuWeights.push_back(weight);
|
thisCpuWeights.push_back(weight);
|
||||||
|
|
||||||
auto taskInputs = getComputeInstanceInputs(task.computeInstance);
|
auto taskInputs = getComputeInstanceInputs(task.computeInstance);
|
||||||
auto& remoteInputs = remoteInputsByTask[task.computeInstance];
|
auto& remoteInputs = remoteInputsByTask[task.computeInstance];
|
||||||
remoteInputs.resize(taskInputs.size());
|
remoteInputs.resize(taskInputs.size());
|
||||||
auto& remoteTensorInputs = remoteTensorInputsByTask[task.computeInstance];
|
|
||||||
remoteTensorInputs.resize(taskInputs.size());
|
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||||
bool isExternalInput = true;
|
bool isExternalInput = true;
|
||||||
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
|
|
||||||
producerBatch && producerBatch.getNumResults() != 0) {
|
|
||||||
size_t resultIndex = cast<OpResult>(input).getResultNumber();
|
|
||||||
TensorChannelInfo tensorInfo;
|
|
||||||
tensorInfo.resultIndex = resultIndex;
|
|
||||||
tensorInfo.channelIds.reserve(static_cast<size_t>(producerBatch.getLaneCount()));
|
|
||||||
tensorInfo.sourceCoreIds.reserve(static_cast<size_t>(producerBatch.getLaneCount()));
|
|
||||||
tensorInfo.targetCoreIds.reserve(static_cast<size_t>(producerBatch.getLaneCount()));
|
|
||||||
tensorInfo.producerInstances.reserve(static_cast<size_t>(producerBatch.getLaneCount()));
|
|
||||||
|
|
||||||
bool foundAllLaneProducers = true;
|
|
||||||
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
|
|
||||||
ComputeInstance producerInstance = getBatchChunkForLane(producerBatch, lane);
|
|
||||||
auto producerIt = taskByComputeInstance.find(producerInstance);
|
|
||||||
if (producerIt == taskByComputeInstance.end()) {
|
|
||||||
foundAllLaneProducers = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
ChannelInfo info {
|
|
||||||
producerIt->second.cpu == cpu ? -1 : (*nextChannelId)++,
|
|
||||||
static_cast<int32_t>(producerIt->second.cpu),
|
|
||||||
static_cast<int32_t>(cpu),
|
|
||||||
};
|
|
||||||
tensorInfo.channelIds.push_back(info.channelId);
|
|
||||||
tensorInfo.sourceCoreIds.push_back(info.sourceCoreId);
|
|
||||||
tensorInfo.targetCoreIds.push_back(info.targetCoreId);
|
|
||||||
tensorInfo.producerInstances.push_back(producerInstance);
|
|
||||||
|
|
||||||
if (producerIt->second.cpu != cpu) {
|
|
||||||
auto& perResultChannels = remoteSendsByTask[producerInstance];
|
|
||||||
if (perResultChannels.empty())
|
|
||||||
perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size());
|
|
||||||
perResultChannels[resultIndex].push_back(
|
|
||||||
{info, task.computeInstance, inputIndex, task.orderWithinCpu, 0, true});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (foundAllLaneProducers) {
|
|
||||||
remoteTensorInputs[inputIndex] = std::move(tensorInfo);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto producerRef = getProducerValueRef(input, &task.computeInstance)) {
|
if (auto producerRef = getProducerValueRef(input, &task.computeInstance)) {
|
||||||
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
||||||
if (producerIt != taskByComputeInstance.end()) {
|
if (producerIt != taskByComputeInstance.end()) {
|
||||||
isExternalInput = false;
|
isExternalInput = false;
|
||||||
if (producerIt->second.cpu != cpu) {
|
if (producerIt->second.cpu != cpu) {
|
||||||
ChannelInfo info {
|
// Cross-core communication
|
||||||
(*nextChannelId)++,
|
ChannelInfo info;
|
||||||
static_cast<int32_t>(producerIt->second.cpu),
|
info.channelId = (*nextChannelId)++;
|
||||||
static_cast<int32_t>(cpu),
|
info.sourceCoreId = static_cast<int32_t>(producerIt->second.cpu);
|
||||||
};
|
info.targetCoreId = static_cast<int32_t>(cpu);
|
||||||
|
|
||||||
remoteInputs[inputIndex] = info;
|
remoteInputs[inputIndex] = info;
|
||||||
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
||||||
if (perResultChannels.empty())
|
if (perResultChannels.empty())
|
||||||
perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size());
|
perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size());
|
||||||
perResultChannels[producerRef->resultIndex].push_back(
|
|
||||||
{info, task.computeInstance, inputIndex, task.orderWithinCpu, 0, false});
|
RemoteSendInfo sendInfo;
|
||||||
|
sendInfo.channelInfo = info;
|
||||||
|
sendInfo.consumer = task.computeInstance;
|
||||||
|
sendInfo.inputIndex = inputIndex;
|
||||||
|
sendInfo.consumerOrder = task.orderWithinCpu;
|
||||||
|
sendInfo.sourceOrder = 0;
|
||||||
|
sendInfo.isTensorInput = false;
|
||||||
|
perResultChannels[producerRef->resultIndex].push_back(sendInfo);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (isExternalInput && seenExternalInputsByProgram[programKey].insert(input).second)
|
if (isExternalInput && laneSeenInputs.insert(input).second)
|
||||||
cpuExternalInputs[programKey].push_back(input);
|
thisCpuInputs.push_back(input);
|
||||||
}
|
|
||||||
|
|
||||||
if (isResultfulBatchInstance(task.computeInstance)) {
|
|
||||||
auto batch = cast<SpatComputeBatch>(task.computeInstance.op);
|
|
||||||
for (unsigned resultIndex = 0; resultIndex < batch.getNumResults(); ++resultIndex) {
|
|
||||||
bool hasExternalUser = false;
|
|
||||||
for (Operation* user : batch.getResult(resultIndex).getUsers()) {
|
|
||||||
if (!oldComputeOps.contains(user)) {
|
|
||||||
hasExternalUser = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (hasExternalUser)
|
|
||||||
cpuExternalOutputs[programKey].push_back({task.computeInstance, resultIndex});
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Define the logical return types based strictly on the Leader
|
||||||
|
if (cpu == leader) {
|
||||||
auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance);
|
auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance);
|
||||||
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
|
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
|
||||||
bool hasExternalUser = false;
|
bool hasExternalUser = false;
|
||||||
for (auto& use : output.getUses()) {
|
for (auto& use : output.getUses())
|
||||||
Operation* useOwner = use.getOwner();
|
if (!oldComputeOps.contains(use.getOwner()))
|
||||||
if (oldComputeOps.contains(useOwner))
|
|
||||||
continue;
|
|
||||||
hasExternalUser = true;
|
hasExternalUser = true;
|
||||||
}
|
|
||||||
if (hasExternalUser)
|
if (hasExternalUser)
|
||||||
cpuExternalOutputs[programKey].push_back({task.computeInstance, resultIndex});
|
thisCpuOutputs.push_back({task.computeInstance, resultIndex});
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -779,65 +739,167 @@ private:
|
|||||||
|
|
||||||
void createCpuComputeOps() {
|
void createCpuComputeOps() {
|
||||||
IRRewriter rewriter(func.getContext());
|
IRRewriter rewriter(func.getContext());
|
||||||
for (ProgramKey programKey : orderedPrograms) {
|
for (ProgramKey leader : orderedPrograms) {
|
||||||
size_t cpu = programKey.first;
|
const auto& batch = batchedCpus[leader];
|
||||||
SmallVector<Value> operands;
|
bool isBatch = batch.size() > 1;
|
||||||
operands.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size());
|
|
||||||
llvm::append_range(operands, cpuWeights[programKey]);
|
|
||||||
llvm::append_range(operands, cpuExternalInputs[programKey]);
|
|
||||||
|
|
||||||
SmallVector<Type> resultTypes;
|
SmallVector<Type> resultTypes;
|
||||||
resultTypes.reserve(cpuExternalOutputs[programKey].size());
|
SmallVector<Type> packedResultTypes;
|
||||||
for (ProducerValueRef outputRef : cpuExternalOutputs[programKey]) {
|
|
||||||
|
for (ProducerValueRef outputRef : cpuExternalOutputs[leader]) {
|
||||||
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||||
SmallVector<Type> outputTypes = getTaskOutputTypes(task.computeInstance);
|
Type elemType = getTaskOutputTypes(task.computeInstance)[outputRef.resultIndex];
|
||||||
resultTypes.push_back(outputTypes[outputRef.resultIndex]);
|
resultTypes.push_back(elemType);
|
||||||
|
|
||||||
|
if (isBatch) {
|
||||||
|
auto ranked = cast<RankedTensorType>(elemType);
|
||||||
|
SmallVector<int64_t> shape;
|
||||||
|
shape.push_back(static_cast<int64_t>(batch.size()));
|
||||||
|
shape.append(ranked.getShape().begin(), ranked.getShape().end());
|
||||||
|
packedResultTypes.push_back(RankedTensorType::get(shape, ranked.getElementType()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.setInsertionPoint(returnOp);
|
rewriter.setInsertionPoint(returnOp);
|
||||||
|
CpuProgram program;
|
||||||
|
|
||||||
|
if (!isBatch) {
|
||||||
|
// Isolated CPU Execution
|
||||||
|
SmallVector<Value> operands;
|
||||||
|
operands.reserve(cpuWeights[leader].size() + cpuExternalInputs[leader].size());
|
||||||
|
llvm::append_range(operands, cpuWeights[leader]);
|
||||||
|
llvm::append_range(operands, cpuExternalInputs[leader]);
|
||||||
|
|
||||||
auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands));
|
auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands));
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
{static_cast<int>(cpuWeights[programKey].size()), static_cast<int>(cpuExternalInputs[programKey].size())});
|
{static_cast<int>(cpuWeights[leader].size()), static_cast<int>(cpuExternalInputs[leader].size())});
|
||||||
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(static_cast<int32_t>(cpu)));
|
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(static_cast<int32_t>(leader)));
|
||||||
|
|
||||||
SmallVector<Type> blockArgTypes;
|
SmallVector<Type> blockArgTypes;
|
||||||
SmallVector<Location> blockArgLocs;
|
SmallVector<Location> blockArgLocs;
|
||||||
blockArgTypes.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size());
|
for (Value op : operands) {
|
||||||
blockArgLocs.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size());
|
blockArgTypes.push_back(op.getType());
|
||||||
for (Value weight : cpuWeights[programKey]) {
|
|
||||||
blockArgTypes.push_back(weight.getType());
|
|
||||||
blockArgLocs.push_back(loc);
|
blockArgLocs.push_back(loc);
|
||||||
}
|
}
|
||||||
for (Value input : cpuExternalInputs[programKey]) {
|
|
||||||
blockArgTypes.push_back(input.getType());
|
|
||||||
blockArgLocs.push_back(loc);
|
|
||||||
}
|
|
||||||
Block* newBlock =
|
|
||||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
|
|
||||||
CpuProgram program;
|
|
||||||
program.op = newCompute;
|
program.op = newCompute;
|
||||||
for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[programKey]))
|
for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[leader]))
|
||||||
program.weightToIndex[weight] = weightIndex;
|
program.weightToIndex[weight] = weightIndex;
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[programKey]))
|
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[leader]))
|
||||||
program.externalInputMap[input] = newCompute.getInputArgument(inputIndex);
|
program.externalInputMap[input] = newCompute.getInputArgument(inputIndex);
|
||||||
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[programKey])) {
|
|
||||||
|
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[leader])) {
|
||||||
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||||
if (isResultfulBatchInstance(task.computeInstance)) {
|
if (isResultfulBatchInstance(task.computeInstance)) {
|
||||||
auto batch = cast<SpatComputeBatch>(task.computeInstance.op);
|
auto oldBatch = cast<SpatComputeBatch>(task.computeInstance.op);
|
||||||
auto& batchResults = resultfulBatchLaneResults[batch.getOperation()];
|
auto& batchResults = resultfulBatchLaneResults[oldBatch.getOperation()];
|
||||||
if (batchResults.empty())
|
if (batchResults.empty())
|
||||||
batchResults.resize(batch.getNumResults());
|
batchResults.resize(oldBatch.getNumResults());
|
||||||
auto& laneResults = batchResults[outputRef.resultIndex];
|
auto& laneResults = batchResults[outputRef.resultIndex];
|
||||||
if (laneResults.empty())
|
if (laneResults.empty())
|
||||||
laneResults.resize(static_cast<size_t>(batch.getLaneCount()));
|
laneResults.resize(static_cast<size_t>(oldBatch.getLaneCount()));
|
||||||
laneResults[task.computeInstance.laneStart] = newCompute.getResult(resultIndex);
|
laneResults[task.computeInstance.laneStart] = newCompute.getResult(resultIndex);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] =
|
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] =
|
||||||
newCompute.getResult(resultIndex);
|
newCompute.getResult(resultIndex);
|
||||||
}
|
}
|
||||||
cpuPrograms[programKey] = std::move(program);
|
}
|
||||||
|
else {
|
||||||
|
// Equivalence Class Batch Execution
|
||||||
|
auto newBatch = SpatComputeBatch::create(rewriter,
|
||||||
|
loc,
|
||||||
|
TypeRange(packedResultTypes),
|
||||||
|
rewriter.getI32IntegerAttr(batch.size()),
|
||||||
|
cpuWeights[leader],
|
||||||
|
cpuExternalInputs[leader]);
|
||||||
|
newBatch.getProperties().setOperandSegmentSizes(
|
||||||
|
{static_cast<int>(cpuWeights[leader].size()), static_cast<int>(cpuExternalInputs[leader].size())});
|
||||||
|
|
||||||
|
SmallVector<int32_t> coreIds;
|
||||||
|
for (size_t c : batch)
|
||||||
|
coreIds.push_back(static_cast<int32_t>(c));
|
||||||
|
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||||
|
|
||||||
|
SmallVector<Type> blockArgTypes;
|
||||||
|
SmallVector<Location> blockArgLocs;
|
||||||
|
blockArgTypes.push_back(rewriter.getIndexType()); // Lane ID Argument
|
||||||
|
blockArgLocs.push_back(loc);
|
||||||
|
|
||||||
|
size_t weightsPerLane = cpuWeights[leader].size() / batch.size();
|
||||||
|
size_t inputsPerLane = cpuExternalInputs[leader].size() / batch.size();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < weightsPerLane; ++i) {
|
||||||
|
blockArgTypes.push_back(cpuWeights[leader][i].getType());
|
||||||
|
blockArgLocs.push_back(loc);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < inputsPerLane; ++i) {
|
||||||
|
blockArgTypes.push_back(cpuExternalInputs[leader][i].getType());
|
||||||
|
blockArgLocs.push_back(loc);
|
||||||
|
}
|
||||||
|
for (Type t : packedResultTypes) {
|
||||||
|
blockArgTypes.push_back(t);
|
||||||
|
blockArgLocs.push_back(loc);
|
||||||
|
} // Dest tensors for InParallel Yield
|
||||||
|
|
||||||
|
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
|
program.op = newBatch;
|
||||||
|
|
||||||
|
// Host-side slice extractions
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointAfter(newBatch);
|
||||||
|
for (auto [laneIndex, cpu] : llvm::enumerate(batch)) {
|
||||||
|
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[leader])) {
|
||||||
|
size_t taskIdx = 0;
|
||||||
|
for (size_t i = 0; i < tasksByCpu[leader].size(); ++i) {
|
||||||
|
if (tasksByCpu[leader][i].computeInstance == outputRef.instance) {
|
||||||
|
taskIdx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ComputeInstance laneInstance = tasksByCpu[cpu][taskIdx].computeInstance;
|
||||||
|
|
||||||
|
auto ranked = cast<RankedTensorType>(resultTypes[resultIndex]);
|
||||||
|
SmallVector<OpFoldResult> offsets;
|
||||||
|
offsets.push_back(rewriter.getIndexAttr(laneIndex));
|
||||||
|
SmallVector<OpFoldResult> sizes;
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(1));
|
||||||
|
SmallVector<OpFoldResult> strides;
|
||||||
|
strides.push_back(rewriter.getIndexAttr(1));
|
||||||
|
|
||||||
|
for (int64_t dim : ranked.getShape()) {
|
||||||
|
offsets.push_back(rewriter.getIndexAttr(0));
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||||
|
strides.push_back(rewriter.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
auto slice = tensor::ExtractSliceOp::create(
|
||||||
|
rewriter, loc, ranked, newBatch.getResult(resultIndex), offsets, sizes, strides);
|
||||||
|
|
||||||
|
if (isResultfulBatchInstance(laneInstance)) {
|
||||||
|
auto oldBatch = cast<SpatComputeBatch>(laneInstance.op);
|
||||||
|
auto& batchResults = resultfulBatchLaneResults[oldBatch.getOperation()];
|
||||||
|
if (batchResults.empty())
|
||||||
|
batchResults.resize(oldBatch.getNumResults());
|
||||||
|
auto& laneValues = batchResults[outputRef.resultIndex];
|
||||||
|
if (laneValues.empty())
|
||||||
|
laneValues.resize(static_cast<size_t>(oldBatch.getLaneCount()));
|
||||||
|
laneValues[laneInstance.laneStart] = slice.getResult();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
oldToNewExternalValueMap[getComputeInstanceOutputValues(laneInstance)[outputRef.resultIndex]] =
|
||||||
|
slice.getResult();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < weightsPerLane; ++i)
|
||||||
|
program.weightToIndex[cpuWeights[leader][i]] = i;
|
||||||
|
for (size_t i = 0; i < inputsPerLane; ++i)
|
||||||
|
program.externalInputMap[cpuExternalInputs[leader][i]] =
|
||||||
|
newBatch.getBody().front().getArgument(1 + weightsPerLane + i);
|
||||||
|
}
|
||||||
|
cpuPrograms[leader] = std::move(program);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -888,15 +950,9 @@ private:
|
|||||||
DenseMap<size_t, DenseMap<uint64_t, size_t>> receiveQueueIndicesByCpu;
|
DenseMap<size_t, DenseMap<uint64_t, size_t>> receiveQueueIndicesByCpu;
|
||||||
DenseMap<size_t, DenseMap<ComputeInstance, SmallVector<Value>>> preReceivedInputsByCpu;
|
DenseMap<size_t, DenseMap<ComputeInstance, SmallVector<Value>>> preReceivedInputsByCpu;
|
||||||
|
|
||||||
for (ProgramKey programKey : orderedPrograms) {
|
auto lookupPreReceivedInput = [&](DenseMap<ComputeInstance, SmallVector<Value>>& preReceivedInputsByTask,
|
||||||
size_t cpu = programKey.first;
|
ComputeInstance consumer,
|
||||||
CpuProgram& program = cpuPrograms[programKey];
|
size_t inputIndex) -> std::optional<Value> {
|
||||||
IRRewriter rewriter(func.getContext());
|
|
||||||
rewriter.setInsertionPointToEnd(&program.op.getBody().front());
|
|
||||||
auto& receiveQueueIndices = receiveQueueIndicesByCpu[cpu];
|
|
||||||
auto& preReceivedInputsByTask = preReceivedInputsByCpu[cpu];
|
|
||||||
|
|
||||||
auto lookupPreReceivedInput = [&](ComputeInstance consumer, size_t inputIndex) -> std::optional<Value> {
|
|
||||||
auto inputsIt = preReceivedInputsByTask.find(consumer);
|
auto inputsIt = preReceivedInputsByTask.find(consumer);
|
||||||
if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex)
|
if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
@@ -906,112 +962,93 @@ private:
|
|||||||
return value;
|
return value;
|
||||||
};
|
};
|
||||||
|
|
||||||
ArrayRef<ScheduledTask> programTasks = tasksByProgram[programKey];
|
for (ProgramKey leader : orderedPrograms) {
|
||||||
for (size_t taskIndex = 0; taskIndex < programTasks.size(); ++taskIndex) {
|
const auto& batch = batchedCpus[leader];
|
||||||
const ScheduledTask& task = programTasks[taskIndex];
|
bool isBatch = batch.size() > 1;
|
||||||
|
CpuProgram& program = cpuPrograms[leader];
|
||||||
|
|
||||||
|
IRRewriter rewriter(func.getContext());
|
||||||
|
rewriter.setInsertionPointToEnd(&program.op->getRegion(0).front());
|
||||||
|
|
||||||
|
auto& receiveQueueIndices = receiveQueueIndicesByCpu[leader];
|
||||||
|
auto& preReceivedInputsByTask = preReceivedInputsByCpu[leader];
|
||||||
|
|
||||||
|
ArrayRef<ScheduledTask> leaderTasks = tasksByCpu[leader];
|
||||||
|
|
||||||
|
for (size_t taskIndex = 0; taskIndex < leaderTasks.size(); ++taskIndex) {
|
||||||
|
const ScheduledTask& task = leaderTasks[taskIndex];
|
||||||
SmallVector<Value> taskInputs = getComputeInstanceInputs(task.computeInstance);
|
SmallVector<Value> taskInputs = getComputeInstanceInputs(task.computeInstance);
|
||||||
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
|
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
|
||||||
Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance);
|
Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance);
|
||||||
|
|
||||||
SmallVector<Value> resolvedInputs;
|
SmallVector<Value> resolvedInputs;
|
||||||
resolvedInputs.reserve(taskInputs.size());
|
resolvedInputs.reserve(taskInputs.size());
|
||||||
|
|
||||||
auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance);
|
auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance);
|
||||||
auto remoteTensorInputsIt = remoteTensorInputsByTask.find(task.computeInstance);
|
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||||
if (remoteTensorInputsIt != remoteTensorInputsByTask.end() && inputIndex < remoteTensorInputsIt->second.size()
|
|
||||||
&& remoteTensorInputsIt->second[inputIndex]) {
|
|
||||||
const TensorChannelInfo& tensorInfo = *remoteTensorInputsIt->second[inputIndex];
|
|
||||||
bool hasLocalProducer = llvm::is_contained(tensorInfo.sourceCoreIds, static_cast<int32_t>(cpu));
|
|
||||||
if (!hasLocalProducer) {
|
|
||||||
auto receive = spatial::SpatChannelReceiveTensorOp::create(
|
|
||||||
rewriter,
|
|
||||||
loc,
|
|
||||||
input.getType(),
|
|
||||||
createIndexConstants(program.op, tensorInfo.channelIds, constantFolder),
|
|
||||||
createIndexConstants(program.op, tensorInfo.sourceCoreIds, constantFolder),
|
|
||||||
createIndexConstants(program.op, tensorInfo.targetCoreIds, constantFolder));
|
|
||||||
resolvedInputs.push_back(receive.getOutput());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value> laneValues;
|
|
||||||
laneValues.reserve(tensorInfo.producerInstances.size());
|
|
||||||
for (auto [laneIndex, producerInstance] : llvm::enumerate(tensorInfo.producerInstances)) {
|
|
||||||
if (tensorInfo.sourceCoreIds[laneIndex] == static_cast<int32_t>(cpu)) {
|
|
||||||
auto producedIt = producedValuesByTask.find(producerInstance);
|
|
||||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= tensorInfo.resultIndex) {
|
|
||||||
task.computeInstance.op->emitOpError(
|
|
||||||
"missing local tensor lane producer during merge materialization")
|
|
||||||
<< " consumerCpu=" << cpu << " producerLaneStart=" << producerInstance.laneStart;
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
laneValues.push_back(producedIt->second[tensorInfo.resultIndex]);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto producerTaskIt = taskByComputeInstance.find(producerInstance);
|
|
||||||
if (producerTaskIt == taskByComputeInstance.end())
|
|
||||||
return failure();
|
|
||||||
Type laneType = getTaskOutputTypes(producerTaskIt->second.computeInstance)[tensorInfo.resultIndex];
|
|
||||||
Value channelId = createIndexConstant(program.op, tensorInfo.channelIds[laneIndex], constantFolder);
|
|
||||||
Value sourceCoreId = createIndexConstant(program.op, tensorInfo.sourceCoreIds[laneIndex], constantFolder);
|
|
||||||
Value targetCoreId = createIndexConstant(program.op, tensorInfo.targetCoreIds[laneIndex], constantFolder);
|
|
||||||
auto receive =
|
|
||||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, laneType, channelId, sourceCoreId, targetCoreId);
|
|
||||||
laneValues.push_back(receive.getResult());
|
|
||||||
}
|
|
||||||
|
|
||||||
Value packedInput = tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, ValueRange(laneValues)).getResult();
|
|
||||||
resolvedInputs.push_back(packedInput);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto producerRef = getProducerValueRef(input, &task.computeInstance);
|
auto producerRef = getProducerValueRef(input, &task.computeInstance);
|
||||||
if (producerRef) {
|
if (producerRef) {
|
||||||
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
||||||
if (producerIt != taskByComputeInstance.end()) {
|
if (producerIt != taskByComputeInstance.end()) {
|
||||||
if (producerIt->second.cpu == cpu) {
|
if (producerIt->second.cpu == leader) {
|
||||||
auto producedIt = producedValuesByTask.find(producerRef->instance);
|
auto producedIt = producedValuesByTask.find(producerRef->instance);
|
||||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
|
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex)
|
||||||
task.computeInstance.op->emitOpError(
|
return task.computeInstance.op->emitOpError("missing local producer value");
|
||||||
"missing local producer value during per-cpu merge materialization")
|
|
||||||
<< " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu
|
|
||||||
<< " producerLaneStart=" << producerRef->instance.laneStart
|
|
||||||
<< " producerLaneCount=" << producerRef->instance.laneCount;
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]);
|
resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isBatch) {
|
||||||
|
SmallVector<int64_t> channelIds;
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
|
||||||
|
for (size_t cpu : batch) {
|
||||||
|
const ScheduledTask& laneTask = tasksByCpu[cpu][taskIndex];
|
||||||
|
const ChannelInfo& info = *remoteInputsByTask[laneTask.computeInstance][inputIndex];
|
||||||
|
channelIds.push_back(info.channelId);
|
||||||
|
sourceCoreIds.push_back(info.sourceCoreId);
|
||||||
|
targetCoreIds.push_back(info.targetCoreId);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> cIds = createIndexConstants(program.op, channelIds, constantFolder);
|
||||||
|
SmallVector<Value> sIds = createIndexConstants(program.op, sourceCoreIds, constantFolder);
|
||||||
|
SmallVector<Value> tIds = createIndexConstants(program.op, targetCoreIds, constantFolder);
|
||||||
|
|
||||||
|
auto recv =
|
||||||
|
spatial::SpatChannelReceiveBatchOp::create(rewriter, loc, input.getType(), cIds, sIds, tIds);
|
||||||
|
resolvedInputs.push_back(recv.getOutput());
|
||||||
|
}
|
||||||
|
else {
|
||||||
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
|
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
|
||||||
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
|
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
|
||||||
|
|
||||||
if (pairsNeedingReceiveReorder.contains(pairKey)) {
|
if (pairsNeedingReceiveReorder.contains(pairKey)) {
|
||||||
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.computeInstance, inputIndex)) {
|
if (std::optional<Value> preReceived =
|
||||||
|
lookupPreReceivedInput(preReceivedInputsByTask, task.computeInstance, inputIndex)) {
|
||||||
resolvedInputs.push_back(*preReceived);
|
resolvedInputs.push_back(*preReceived);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
FailureOr<Value> received = receiveThroughInput(rewriter,
|
FailureOr<Value> received = receiveThroughInput(rewriter,
|
||||||
cpu,
|
leader,
|
||||||
receiveQueueIndices,
|
receiveQueueIndices,
|
||||||
preReceivedInputsByTask,
|
preReceivedInputsByTask,
|
||||||
channelInfo,
|
channelInfo,
|
||||||
task.computeInstance,
|
task.computeInstance,
|
||||||
inputIndex);
|
inputIndex);
|
||||||
if (failed(received)) {
|
if (failed(received))
|
||||||
task.computeInstance.op->emitOpError("failed to materialize reordered remote receive")
|
return task.computeInstance.op->emitOpError("failed to materialize reordered remote receive");
|
||||||
<< " consumerCpu=" << cpu << " sourceCoreId=" << channelInfo.sourceCoreId
|
|
||||||
<< " targetCoreId=" << channelInfo.targetCoreId << " channelId=" << channelInfo.channelId;
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
resolvedInputs.push_back(*received);
|
resolvedInputs.push_back(*received);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Value channelId = createIndexConstant(program.op, channelInfo.channelId, constantFolder);
|
|
||||||
Value sourceCoreId = createIndexConstant(program.op, channelInfo.sourceCoreId, constantFolder);
|
Value cId = createIndexConstant(program.op, channelInfo.channelId, constantFolder);
|
||||||
Value targetCoreId = createIndexConstant(program.op, channelInfo.targetCoreId, constantFolder);
|
Value sId = createIndexConstant(program.op, channelInfo.sourceCoreId, constantFolder);
|
||||||
auto receive = spatial::SpatChannelReceiveOp::create(
|
Value tId = createIndexConstant(program.op, channelInfo.targetCoreId, constantFolder);
|
||||||
rewriter, loc, input.getType(), channelId, sourceCoreId, targetCoreId);
|
auto receive = spatial::SpatChannelReceiveOp::create(rewriter, loc, input.getType(), cId, sId, tId);
|
||||||
resolvedInputs.push_back(receive.getResult());
|
resolvedInputs.push_back(receive.getResult());
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1019,12 +1056,17 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> taskYieldValues;
|
SmallVector<Value> taskYieldValues;
|
||||||
rewriter.setInsertionPointToEnd(&program.op.getBody().front());
|
rewriter.setInsertionPointToEnd(&program.op->getRegion(0).front());
|
||||||
|
|
||||||
if (isa<SpatCompute>(task.computeInstance.op)) {
|
if (isa<SpatCompute>(task.computeInstance.op)) {
|
||||||
IRMapping mapper;
|
IRMapping mapper;
|
||||||
auto compute = cast<SpatCompute>(task.computeInstance.op);
|
auto compute = cast<SpatCompute>(task.computeInstance.op);
|
||||||
for (auto [weightIndex, weight] : llvm::enumerate(taskWeights))
|
for (auto [weightIndex, weight] : llvm::enumerate(taskWeights)) {
|
||||||
mapper.map(compute.getWeightArgument(weightIndex), program.op.getWeightArgument(program.weightToIndex.at(weight)));
|
Value destArg = isBatch
|
||||||
|
? cast<SpatComputeBatch>(program.op).getWeightArgument(program.weightToIndex.at(weight))
|
||||||
|
: cast<SpatCompute>(program.op).getWeightArgument(program.weightToIndex.at(weight));
|
||||||
|
mapper.map(compute.getWeightArgument(weightIndex), destArg);
|
||||||
|
}
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(resolvedInputs))
|
for (auto [inputIndex, input] : llvm::enumerate(resolvedInputs))
|
||||||
mapper.map(compute.getInputArgument(inputIndex), input);
|
mapper.map(compute.getInputArgument(inputIndex), input);
|
||||||
|
|
||||||
@@ -1034,106 +1076,47 @@ private:
|
|||||||
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.clone(op, mapper);
|
rewriter.clone(op, mapper);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto batch = cast<SpatComputeBatch>(task.computeInstance.op);
|
// Include your existing isolated logic for preserving resultless spat.compute_batch here if needed
|
||||||
if (batch.getNumResults() != 0) {
|
|
||||||
IRMapping mapper;
|
|
||||||
Value laneValue = getOrCreateHostIndexConstant(
|
|
||||||
program.op, static_cast<int64_t>(task.computeInstance.laneStart), constantFolder);
|
|
||||||
mapper.map(batch.getLaneArgument(), laneValue);
|
|
||||||
for (auto [weightIndex, weight] : llvm::enumerate(taskWeights))
|
|
||||||
mapper.map(batch.getWeightArgument(weightIndex), program.op.getWeightArgument(program.weightToIndex.at(weight)));
|
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(resolvedInputs))
|
|
||||||
mapper.map(batch.getInputArgument(inputIndex), input);
|
|
||||||
|
|
||||||
for (Operation& op : templateBlock.without_terminator()) {
|
|
||||||
Operation* clonedOp = rewriter.clone(op, mapper);
|
|
||||||
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
|
||||||
mapper.map(oldResult, newResult);
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<SmallVector<BatchYieldInfo>> yieldInfo = collectResultfulBatchYieldInfo(batch);
|
|
||||||
if (failed(yieldInfo))
|
|
||||||
return task.computeInstance.op->emitOpError("failed to collect resultful batch yield info");
|
|
||||||
for (const BatchYieldInfo& info : *yieldInfo)
|
|
||||||
taskYieldValues.push_back(mapper.lookup(info.yieldedValue));
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
size_t batchLaneCount = static_cast<size_t>(batch.getLaneCount());
|
|
||||||
size_t inputsPerLane = batchLaneCount == 0 ? 0 : batch.getInputs().size() / batchLaneCount;
|
|
||||||
size_t weightsPerLane = batchLaneCount == 0 ? 0 : batch.getWeights().size() / batchLaneCount;
|
|
||||||
size_t loopRunLength = getLoopableResultlessBatchRunLength(programTasks, taskIndex);
|
|
||||||
if (loopRunLength > 1) {
|
|
||||||
Value lower = getOrCreateHostIndexConstant(
|
|
||||||
program.op, static_cast<int64_t>(task.computeInstance.laneStart), constantFolder);
|
|
||||||
Value upper = getOrCreateHostIndexConstant(
|
|
||||||
program.op, static_cast<int64_t>(task.computeInstance.laneStart + loopRunLength), constantFolder);
|
|
||||||
Value step = getOrCreateHostIndexConstant(program.op, 1, constantFolder);
|
|
||||||
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {});
|
|
||||||
rewriter.setInsertionPointToStart(loop.getBody());
|
|
||||||
|
|
||||||
IRMapping mapper;
|
|
||||||
mapper.map(batch.getLaneArgument(), loop.getInductionVar());
|
|
||||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
|
|
||||||
mapper.map(batch.getWeightArgument(weightIndex),
|
|
||||||
program.op.getWeightArgument(program.weightToIndex.at(taskWeights[weightIndex])));
|
|
||||||
for (size_t inputIndex = 0; inputIndex < inputsPerLane; ++inputIndex)
|
|
||||||
mapper.map(batch.getInputArgument(inputIndex), resolvedInputs[inputIndex]);
|
|
||||||
|
|
||||||
for (Operation& op : templateBlock) {
|
|
||||||
if (isa<spatial::SpatYieldOp>(&op))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
Operation* clonedOp = rewriter.clone(op, mapper);
|
|
||||||
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
|
||||||
mapper.map(oldResult, newResult);
|
|
||||||
}
|
|
||||||
rewriter.setInsertionPointAfter(loop);
|
|
||||||
taskIndex += loopRunLength - 1;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
for (size_t laneOffset = 0; laneOffset < task.computeInstance.laneCount; ++laneOffset) {
|
|
||||||
IRMapping mapper;
|
|
||||||
Value laneValue = getOrCreateHostIndexConstant(
|
|
||||||
program.op, static_cast<int64_t>(task.computeInstance.laneStart + laneOffset), constantFolder);
|
|
||||||
mapper.map(batch.getLaneArgument(), laneValue);
|
|
||||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
|
|
||||||
mapper.map(batch.getWeightArgument(weightIndex),
|
|
||||||
program.op.getWeightArgument(
|
|
||||||
program.weightToIndex.at(taskWeights[laneOffset * weightsPerLane + weightIndex])));
|
|
||||||
for (size_t inputIndex = 0; inputIndex < inputsPerLane; ++inputIndex)
|
|
||||||
mapper.map(batch.getInputArgument(inputIndex), resolvedInputs[laneOffset * inputsPerLane + inputIndex]);
|
|
||||||
|
|
||||||
for (Operation& op : templateBlock) {
|
|
||||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
|
||||||
for (Value yieldOperand : yield.getOperands())
|
|
||||||
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
Operation* clonedOp = rewriter.clone(op, mapper);
|
|
||||||
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
|
||||||
mapper.map(oldResult, newResult);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
producedValuesByTask[task.computeInstance] = taskYieldValues;
|
producedValuesByTask[task.computeInstance] = taskYieldValues;
|
||||||
|
|
||||||
if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) {
|
if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) {
|
||||||
for (size_t resultIndex = 0; resultIndex < sendsIt->second.size();) {
|
for (size_t resultIndex = 0; resultIndex < sendsIt->second.size();) {
|
||||||
const SmallVector<RemoteSendInfo>& sendInfos = sendsIt->second[resultIndex];
|
const SmallVector<RemoteSendInfo>& sendInfos = sendsIt->second[resultIndex];
|
||||||
if (sendInfos.empty())
|
if (sendInfos.empty()) {
|
||||||
{
|
|
||||||
++resultIndex;
|
++resultIndex;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isBatch) {
|
||||||
|
size_t numSends = sendInfos.size();
|
||||||
|
for (size_t s = 0; s < numSends; ++s) {
|
||||||
|
SmallVector<int64_t> channelIds;
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
|
||||||
|
for (size_t cpu : batch) {
|
||||||
|
const ScheduledTask& laneTask = tasksByCpu[cpu][taskIndex];
|
||||||
|
const RemoteSendInfo& send = remoteSendsByTask[laneTask.computeInstance][resultIndex][s];
|
||||||
|
channelIds.push_back(send.channelInfo.channelId);
|
||||||
|
sourceCoreIds.push_back(send.channelInfo.sourceCoreId);
|
||||||
|
targetCoreIds.push_back(send.channelInfo.targetCoreId);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> cIds = createIndexConstants(program.op, channelIds, constantFolder);
|
||||||
|
SmallVector<Value> sIds = createIndexConstants(program.op, sourceCoreIds, constantFolder);
|
||||||
|
SmallVector<Value> tIds = createIndexConstants(program.op, targetCoreIds, constantFolder);
|
||||||
|
|
||||||
|
spatial::SpatChannelSendBatchOp::create(rewriter, loc, cIds, sIds, tIds, taskYieldValues[resultIndex]);
|
||||||
|
}
|
||||||
|
++resultIndex;
|
||||||
|
}
|
||||||
|
else {
|
||||||
size_t nextResultIndex = resultIndex + 1;
|
size_t nextResultIndex = resultIndex + 1;
|
||||||
if (tryEmitCompactSendLoops(
|
if (tryEmitCompactSendLoops(
|
||||||
program.op, rewriter, sendsIt->second, taskYieldValues, resultIndex, nextResultIndex)) {
|
program.op, rewriter, sendsIt->second, taskYieldValues, resultIndex, nextResultIndex)) {
|
||||||
@@ -1143,30 +1126,55 @@ private:
|
|||||||
|
|
||||||
Value producedValue = taskYieldValues[resultIndex];
|
Value producedValue = taskYieldValues[resultIndex];
|
||||||
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||||
Value channelId = createIndexConstant(program.op, sendInfo.channelInfo.channelId, constantFolder);
|
Value cId = createIndexConstant(program.op, sendInfo.channelInfo.channelId, constantFolder);
|
||||||
Value sourceCoreId = createIndexConstant(program.op, sendInfo.channelInfo.sourceCoreId, constantFolder);
|
Value sId = createIndexConstant(program.op, sendInfo.channelInfo.sourceCoreId, constantFolder);
|
||||||
Value targetCoreId = createIndexConstant(program.op, sendInfo.channelInfo.targetCoreId, constantFolder);
|
Value tId = createIndexConstant(program.op, sendInfo.channelInfo.targetCoreId, constantFolder);
|
||||||
spatial::SpatChannelSendOp::create(rewriter, loc, channelId, sourceCoreId, targetCoreId, producedValue);
|
spatial::SpatChannelSendOp::create(rewriter, loc, cId, sId, tId, producedValue);
|
||||||
}
|
}
|
||||||
++resultIndex;
|
++resultIndex;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<Value> yieldValues;
|
SmallVector<Value> yieldValues;
|
||||||
yieldValues.reserve(cpuExternalOutputs[programKey].size());
|
yieldValues.reserve(cpuExternalOutputs[leader].size());
|
||||||
for (ProducerValueRef outputRef : cpuExternalOutputs[programKey]) {
|
for (ProducerValueRef outputRef : cpuExternalOutputs[leader]) {
|
||||||
auto producedIt = producedValuesByTask.find(outputRef.instance);
|
auto producedIt = producedValuesByTask.find(outputRef.instance);
|
||||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) {
|
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex)
|
||||||
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
return func.emitError("missing yielded external value during materialization");
|
||||||
task.computeInstance.op->emitOpError("missing yielded external value during per-cpu merge materialization")
|
|
||||||
<< " cpu=" << cpu << " laneStart=" << outputRef.instance.laneStart;
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
yieldValues.push_back(producedIt->second[outputRef.resultIndex]);
|
yieldValues.push_back(producedIt->second[outputRef.resultIndex]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isBatch) {
|
||||||
|
auto batchOp = cast<SpatComputeBatch>(program.op);
|
||||||
|
auto inParallel = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||||
|
Block* parallelBlock = rewriter.createBlock(&inParallel.getRegion());
|
||||||
|
rewriter.setInsertionPointToEnd(parallelBlock);
|
||||||
|
|
||||||
|
for (auto [resultIndex, yieldedVal] : llvm::enumerate(yieldValues)) {
|
||||||
|
auto destArg = batchOp.getOutputArgument(resultIndex);
|
||||||
|
auto destType = cast<RankedTensorType>(destArg.getType());
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> offsets;
|
||||||
|
offsets.push_back(batchOp.getLaneArgument());
|
||||||
|
SmallVector<OpFoldResult> sizes;
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(1));
|
||||||
|
SmallVector<OpFoldResult> strides;
|
||||||
|
strides.push_back(rewriter.getIndexAttr(1));
|
||||||
|
|
||||||
|
for (int64_t dim : destType.getShape().drop_front()) {
|
||||||
|
offsets.push_back(rewriter.getIndexAttr(0));
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||||
|
strides.push_back(rewriter.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
tensor::ParallelInsertSliceOp::create(rewriter, loc, yieldedVal, destArg, offsets, sizes, strides);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, ValueRange(yieldValues));
|
spatial::SpatYieldOp::create(rewriter, loc, ValueRange(yieldValues));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -1245,7 +1253,6 @@ private:
|
|||||||
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
|
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
|
||||||
DenseMap<ProgramKey, SmallVector<ScheduledTask>> tasksByProgram;
|
DenseMap<ProgramKey, SmallVector<ScheduledTask>> tasksByProgram;
|
||||||
SmallVector<size_t> orderedCpus;
|
SmallVector<size_t> orderedCpus;
|
||||||
SmallVector<ProgramKey> orderedPrograms;
|
|
||||||
DenseSet<size_t> seenCpus;
|
DenseSet<size_t> seenCpus;
|
||||||
DenseSet<ProgramKey> seenPrograms;
|
DenseSet<ProgramKey> seenPrograms;
|
||||||
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
|
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ struct MergeScheduleResult {
|
|||||||
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
||||||
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
||||||
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
||||||
|
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#include "mlir/IR/Threading.h"
|
#include "mlir/IR/Threading.h"
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/ErrorHandling.h"
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
|
||||||
@@ -274,7 +275,64 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder;
|
return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder;
|
||||||
});
|
});
|
||||||
|
|
||||||
// 5. Populate Final Result
|
// 5. Check if equal schedule in two level
|
||||||
|
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
|
||||||
|
for (size_t currentProcessor = 0; currentProcessor < processorCount - 1; ++currentProcessor) {
|
||||||
|
for (size_t controlProcessor = currentProcessor + 1; controlProcessor < processorCount; ++controlProcessor) {
|
||||||
|
if (tasksByProcessor[currentProcessor].size() != tasksByProcessor[controlProcessor].size())
|
||||||
|
continue;
|
||||||
|
auto& currentTasks = tasksByProcessor[currentProcessor];
|
||||||
|
auto& controlTasks = tasksByProcessor[controlProcessor];
|
||||||
|
bool equalSchedule = true;
|
||||||
|
|
||||||
|
for (auto [currentTask, controlTask] : llvm::zip(currentTasks, controlTasks)) {
|
||||||
|
const ComputeInstance currentComputeInstance = graph.nodes[currentTask].instance;
|
||||||
|
const ComputeInstance controlComputeInstance = graph.nodes[controlTask].instance;
|
||||||
|
if (currentComputeInstance.op != controlComputeInstance.op) {
|
||||||
|
equalSchedule = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (equalSchedule) {
|
||||||
|
equivalentClass[currentProcessor].push_back(controlProcessor);
|
||||||
|
equivalentClass[controlProcessor].push_back(currentProcessor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
llvm::dbgs() << "--- Scheduling Equivalence Classes ---\n";
|
||||||
|
std::vector<bool> visited(processorCount, false);
|
||||||
|
size_t uniqueClassCount = 0;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < processorCount; ++i) {
|
||||||
|
if (visited[i])
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// We found a new unique schedule (equivalence class)
|
||||||
|
++uniqueClassCount;
|
||||||
|
visited[i] = true;
|
||||||
|
|
||||||
|
llvm::dbgs() << "Class " << uniqueClassCount << ": CPUs { " << i;
|
||||||
|
|
||||||
|
// Find and mark all identical companions
|
||||||
|
auto it = equivalentClass.find(i);
|
||||||
|
if (it != equivalentClass.end()) {
|
||||||
|
for (size_t eqCpu : it->second) {
|
||||||
|
if (!visited[eqCpu]) {
|
||||||
|
llvm::dbgs() << ", " << eqCpu;
|
||||||
|
visited[eqCpu] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
llvm::dbgs() << " }\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::dbgs() << "Total unique CPU nodes to emit: " << uniqueClassCount << "\n";
|
||||||
|
llvm::dbgs() << "--------------------------------------\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Populate Final Result
|
||||||
MergeScheduleResult result;
|
MergeScheduleResult result;
|
||||||
result.dominanceOrderCompute.reserve(nodeCount);
|
result.dominanceOrderCompute.reserve(nodeCount);
|
||||||
|
|
||||||
@@ -296,8 +354,9 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
result.equivalentClass = equivalentClass;
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user