Equivalent Class but broken

This commit is contained in:
ilgeco
2026-05-21 14:43:59 +02:00
parent a50e77ff38
commit fe35b3ed43
3 changed files with 481 additions and 414 deletions
@@ -58,18 +58,15 @@ static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int
}
static Value createIndexTensorConstant(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
auto tensorType = RankedTensorType::get({static_cast<int64_t>(values.size())}, IndexType::get(anchorOp->getContext()));
auto tensorType =
RankedTensorType::get({static_cast<int64_t>(values.size())}, IndexType::get(anchorOp->getContext()));
auto tensorAttr = DenseIntElementsAttr::get(tensorType, values);
return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder);
}
static Value createIndexTupleTensorConstant(Operation* anchorOp,
int64_t tupleCount,
int64_t tupleWidth,
ArrayRef<int64_t> values,
OperationFolder& folder) {
auto tensorType =
RankedTensorType::get({tupleCount, tupleWidth}, IndexType::get(anchorOp->getContext()));
static Value createIndexTupleTensorConstant(
Operation* anchorOp, int64_t tupleCount, int64_t tupleWidth, ArrayRef<int64_t> values, OperationFolder& folder) {
auto tensorType = RankedTensorType::get({tupleCount, tupleWidth}, IndexType::get(anchorOp->getContext()));
auto tensorAttr = DenseIntElementsAttr::get(tensorType, values);
return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder);
}
@@ -114,12 +111,14 @@ private:
};
struct CpuProgram {
SpatCompute op;
Operation* op = nullptr;
DenseMap<Value, Value> externalInputMap;
DenseMap<Value, size_t> weightToIndex;
};
using ProgramKey = std::pair<size_t, size_t>;
using ProgramKey = size_t; // Represents the "Leader" CPU
DenseMap<ProgramKey, SmallVector<size_t>> batchedCpus;
SmallVector<ProgramKey> orderedPrograms;
struct RemoteSendInfo {
ChannelInfo channelInfo;
@@ -171,7 +170,7 @@ private:
| static_cast<uint32_t>(channelInfo.targetCoreId);
}
static ProgramKey getProgramKey(const ScheduledTask& task) { return {task.cpu, 0}; }
static ProgramKey getProgramKey(const ScheduledTask& task) { return task.cpu; }
static bool isResultfulBatchInstance(const ComputeInstance& instance) {
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
@@ -377,12 +376,12 @@ private:
void emitExtractRowsSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractRowsSendRun& run) {
SmallVector<int64_t> prefixSums = buildPrefixSums(run.sendCounts);
Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder);
Value channelSourceTargetTuples = createIndexTupleTensorConstant(
hostAnchor,
static_cast<int64_t>(run.channelSourceTargetTuples.size() / 3),
3,
run.channelSourceTargetTuples,
constantFolder);
Value channelSourceTargetTuples =
createIndexTupleTensorConstant(hostAnchor,
static_cast<int64_t>(run.channelSourceTargetTuples.size() / 3),
3,
run.channelSourceTargetTuples,
constantFolder);
Value lower = getOrCreateHostIndexConstant(hostAnchor, 0, constantFolder);
Value upper = getOrCreateHostIndexConstant(hostAnchor, static_cast<int64_t>(run.sendCounts.size()), constantFolder);
@@ -413,8 +412,7 @@ private:
.getResult();
Value nextRowIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step);
Value innerLower =
tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()});
Value innerLower = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()});
Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextRowIndex});
emitInnerSendLoop(hostAnchor, rewriter, extractedRow, innerLower, innerUpper, channelSourceTargetTuples);
rewriter.setInsertionPointAfter(outerLoop);
@@ -423,12 +421,12 @@ private:
void emitExtractSliceSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractSliceSendRun& run) {
SmallVector<int64_t> prefixSums = buildPrefixSums(run.sendCounts);
Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder);
Value channelSourceTargetTuples = createIndexTupleTensorConstant(
hostAnchor,
static_cast<int64_t>(run.channelSourceTargetTuples.size() / 3),
3,
run.channelSourceTargetTuples,
constantFolder);
Value channelSourceTargetTuples =
createIndexTupleTensorConstant(hostAnchor,
static_cast<int64_t>(run.channelSourceTargetTuples.size() / 3),
3,
run.channelSourceTargetTuples,
constantFolder);
Value lower = getOrCreateHostIndexConstant(hostAnchor, 0, constantFolder);
Value upper = getOrCreateHostIndexConstant(hostAnchor, static_cast<int64_t>(run.sendCounts.size()), constantFolder);
@@ -469,8 +467,7 @@ private:
.getResult();
Value nextSliceIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step);
Value innerLower =
tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()});
Value innerLower = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()});
Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextSliceIndex});
emitInnerSendLoop(hostAnchor, rewriter, extractedSlice, innerLower, innerUpper, channelSourceTargetTuples);
rewriter.setInsertionPointAfter(outerLoop);
@@ -532,149 +529,112 @@ private:
}
void buildTaskIndex() {
auto markCpuSeen = [&](size_t cpu) {
if (seenCpus.insert(cpu).second)
orderedCpus.push_back(cpu);
};
DenseSet<size_t> seen;
for (const ScheduledTask& task : scheduledTasks) {
taskByComputeInstance[task.computeInstance] = task;
tasksByCpu[task.cpu].push_back(task);
ProgramKey programKey = getProgramKey(task);
tasksByProgram[programKey].push_back(task);
if (seenPrograms.insert(programKey).second)
orderedPrograms.push_back(programKey);
markCpuSeen(task.cpu);
seen.insert(task.cpu);
}
llvm::sort(orderedCpus);
llvm::sort(orderedPrograms, [](const ProgramKey& lhs, const ProgramKey& rhs) {
if (lhs.second != rhs.second)
return lhs.second < rhs.second;
return lhs.first < rhs.first;
});
for (size_t cpu : orderedCpus)
SmallVector<size_t> activeCpus(seen.begin(), seen.end());
llvm::sort(activeCpus);
DenseSet<size_t> batched;
for (size_t cpu : activeCpus) {
if (batched.contains(cpu))
continue;
SmallVector<size_t> batch;
batch.push_back(cpu);
batched.insert(cpu);
// Group all equivalent CPUs into this batch
auto it = schedule->equivalentClass.find(cpu);
if (it != schedule->equivalentClass.end()) {
for (size_t eqCpu : it->second)
if (batched.insert(eqCpu).second)
batch.push_back(eqCpu);
}
llvm::sort(batch);
size_t leader = batch.front();
batchedCpus[leader] = batch;
orderedPrograms.push_back(leader);
}
for (size_t cpu : activeCpus) {
llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) {
return lhs.orderWithinCpu < rhs.orderWithinCpu;
});
for (ProgramKey programKey : orderedPrograms)
llvm::stable_sort(tasksByProgram[programKey], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) {
return lhs.orderWithinCpu < rhs.orderWithinCpu;
});
}
}
void collectExternalInputsAndWeights() {
for (ProgramKey programKey : orderedPrograms) {
size_t cpu = programKey.first;
for (const ScheduledTask& task : tasksByProgram[programKey]) {
auto& thisCpuWeights = cpuWeights[programKey];
auto& thisSeenWeights = seenWeightsByProgram[programKey];
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
for (Value weight : taskWeights)
if (thisSeenWeights.insert(weight).second)
thisCpuWeights.push_back(weight);
for (ProgramKey leader : orderedPrograms) {
const auto& batch = batchedCpus[leader];
auto& thisCpuWeights = cpuWeights[leader];
auto& thisCpuInputs = cpuExternalInputs[leader];
auto& thisCpuOutputs = cpuExternalOutputs[leader];
auto taskInputs = getComputeInstanceInputs(task.computeInstance);
auto& remoteInputs = remoteInputsByTask[task.computeInstance];
remoteInputs.resize(taskInputs.size());
auto& remoteTensorInputs = remoteTensorInputsByTask[task.computeInstance];
remoteTensorInputs.resize(taskInputs.size());
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
bool isExternalInput = true;
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
producerBatch && producerBatch.getNumResults() != 0) {
size_t resultIndex = cast<OpResult>(input).getResultNumber();
TensorChannelInfo tensorInfo;
tensorInfo.resultIndex = resultIndex;
tensorInfo.channelIds.reserve(static_cast<size_t>(producerBatch.getLaneCount()));
tensorInfo.sourceCoreIds.reserve(static_cast<size_t>(producerBatch.getLaneCount()));
tensorInfo.targetCoreIds.reserve(static_cast<size_t>(producerBatch.getLaneCount()));
tensorInfo.producerInstances.reserve(static_cast<size_t>(producerBatch.getLaneCount()));
// Process every lane sequentially to pack operands
for (size_t cpu : batch) {
DenseSet<Value> laneSeenWeights;
DenseSet<Value> laneSeenInputs;
bool foundAllLaneProducers = true;
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
ComputeInstance producerInstance = getBatchChunkForLane(producerBatch, lane);
auto producerIt = taskByComputeInstance.find(producerInstance);
if (producerIt == taskByComputeInstance.end()) {
foundAllLaneProducers = false;
break;
}
for (const ScheduledTask& task : tasksByCpu[cpu]) {
for (Value weight : getComputeInstanceWeights(task.computeInstance))
if (laneSeenWeights.insert(weight).second)
thisCpuWeights.push_back(weight);
ChannelInfo info {
producerIt->second.cpu == cpu ? -1 : (*nextChannelId)++,
static_cast<int32_t>(producerIt->second.cpu),
static_cast<int32_t>(cpu),
};
tensorInfo.channelIds.push_back(info.channelId);
tensorInfo.sourceCoreIds.push_back(info.sourceCoreId);
tensorInfo.targetCoreIds.push_back(info.targetCoreId);
tensorInfo.producerInstances.push_back(producerInstance);
auto taskInputs = getComputeInstanceInputs(task.computeInstance);
auto& remoteInputs = remoteInputsByTask[task.computeInstance];
remoteInputs.resize(taskInputs.size());
if (producerIt->second.cpu != cpu) {
auto& perResultChannels = remoteSendsByTask[producerInstance];
if (perResultChannels.empty())
perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size());
perResultChannels[resultIndex].push_back(
{info, task.computeInstance, inputIndex, task.orderWithinCpu, 0, true});
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
bool isExternalInput = true;
if (auto producerRef = getProducerValueRef(input, &task.computeInstance)) {
auto producerIt = taskByComputeInstance.find(producerRef->instance);
if (producerIt != taskByComputeInstance.end()) {
isExternalInput = false;
if (producerIt->second.cpu != cpu) {
// Cross-core communication
ChannelInfo info;
info.channelId = (*nextChannelId)++;
info.sourceCoreId = static_cast<int32_t>(producerIt->second.cpu);
info.targetCoreId = static_cast<int32_t>(cpu);
remoteInputs[inputIndex] = info;
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
if (perResultChannels.empty())
perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size());
RemoteSendInfo sendInfo;
sendInfo.channelInfo = info;
sendInfo.consumer = task.computeInstance;
sendInfo.inputIndex = inputIndex;
sendInfo.consumerOrder = task.orderWithinCpu;
sendInfo.sourceOrder = 0;
sendInfo.isTensorInput = false;
perResultChannels[producerRef->resultIndex].push_back(sendInfo);
}
}
}
if (isExternalInput && laneSeenInputs.insert(input).second)
thisCpuInputs.push_back(input);
}
if (foundAllLaneProducers) {
remoteTensorInputs[inputIndex] = std::move(tensorInfo);
continue;
// Define the logical return types based strictly on the Leader
if (cpu == leader) {
auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance);
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
bool hasExternalUser = false;
for (auto& use : output.getUses())
if (!oldComputeOps.contains(use.getOwner()))
hasExternalUser = true;
if (hasExternalUser)
thisCpuOutputs.push_back({task.computeInstance, resultIndex});
}
}
if (auto producerRef = getProducerValueRef(input, &task.computeInstance)) {
auto producerIt = taskByComputeInstance.find(producerRef->instance);
if (producerIt != taskByComputeInstance.end()) {
isExternalInput = false;
if (producerIt->second.cpu != cpu) {
ChannelInfo info {
(*nextChannelId)++,
static_cast<int32_t>(producerIt->second.cpu),
static_cast<int32_t>(cpu),
};
remoteInputs[inputIndex] = info;
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
if (perResultChannels.empty())
perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size());
perResultChannels[producerRef->resultIndex].push_back(
{info, task.computeInstance, inputIndex, task.orderWithinCpu, 0, false});
}
}
}
if (isExternalInput && seenExternalInputsByProgram[programKey].insert(input).second)
cpuExternalInputs[programKey].push_back(input);
}
if (isResultfulBatchInstance(task.computeInstance)) {
auto batch = cast<SpatComputeBatch>(task.computeInstance.op);
for (unsigned resultIndex = 0; resultIndex < batch.getNumResults(); ++resultIndex) {
bool hasExternalUser = false;
for (Operation* user : batch.getResult(resultIndex).getUsers()) {
if (!oldComputeOps.contains(user)) {
hasExternalUser = true;
break;
}
}
if (hasExternalUser)
cpuExternalOutputs[programKey].push_back({task.computeInstance, resultIndex});
}
continue;
}
auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance);
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
bool hasExternalUser = false;
for (auto& use : output.getUses()) {
Operation* useOwner = use.getOwner();
if (oldComputeOps.contains(useOwner))
continue;
hasExternalUser = true;
}
if (hasExternalUser)
cpuExternalOutputs[programKey].push_back({task.computeInstance, resultIndex});
}
}
}
@@ -779,65 +739,167 @@ private:
void createCpuComputeOps() {
IRRewriter rewriter(func.getContext());
for (ProgramKey programKey : orderedPrograms) {
size_t cpu = programKey.first;
SmallVector<Value> operands;
operands.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size());
llvm::append_range(operands, cpuWeights[programKey]);
llvm::append_range(operands, cpuExternalInputs[programKey]);
for (ProgramKey leader : orderedPrograms) {
const auto& batch = batchedCpus[leader];
bool isBatch = batch.size() > 1;
SmallVector<Type> resultTypes;
resultTypes.reserve(cpuExternalOutputs[programKey].size());
for (ProducerValueRef outputRef : cpuExternalOutputs[programKey]) {
SmallVector<Type> packedResultTypes;
for (ProducerValueRef outputRef : cpuExternalOutputs[leader]) {
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
SmallVector<Type> outputTypes = getTaskOutputTypes(task.computeInstance);
resultTypes.push_back(outputTypes[outputRef.resultIndex]);
Type elemType = getTaskOutputTypes(task.computeInstance)[outputRef.resultIndex];
resultTypes.push_back(elemType);
if (isBatch) {
auto ranked = cast<RankedTensorType>(elemType);
SmallVector<int64_t> shape;
shape.push_back(static_cast<int64_t>(batch.size()));
shape.append(ranked.getShape().begin(), ranked.getShape().end());
packedResultTypes.push_back(RankedTensorType::get(shape, ranked.getElementType()));
}
}
rewriter.setInsertionPoint(returnOp);
auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands));
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(cpuWeights[programKey].size()), static_cast<int>(cpuExternalInputs[programKey].size())});
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(static_cast<int32_t>(cpu)));
CpuProgram program;
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size());
blockArgLocs.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size());
for (Value weight : cpuWeights[programKey]) {
blockArgTypes.push_back(weight.getType());
blockArgLocs.push_back(loc);
}
for (Value input : cpuExternalInputs[programKey]) {
blockArgTypes.push_back(input.getType());
blockArgLocs.push_back(loc);
}
Block* newBlock =
if (!isBatch) {
// Isolated CPU Execution
SmallVector<Value> operands;
operands.reserve(cpuWeights[leader].size() + cpuExternalInputs[leader].size());
llvm::append_range(operands, cpuWeights[leader]);
llvm::append_range(operands, cpuExternalInputs[leader]);
auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands));
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(cpuWeights[leader].size()), static_cast<int>(cpuExternalInputs[leader].size())});
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(static_cast<int32_t>(leader)));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
for (Value op : operands) {
blockArgTypes.push_back(op.getType());
blockArgLocs.push_back(loc);
}
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
CpuProgram program;
program.op = newCompute;
for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[programKey]))
program.weightToIndex[weight] = weightIndex;
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[programKey]))
program.externalInputMap[input] = newCompute.getInputArgument(inputIndex);
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[programKey])) {
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
if (isResultfulBatchInstance(task.computeInstance)) {
auto batch = cast<SpatComputeBatch>(task.computeInstance.op);
auto& batchResults = resultfulBatchLaneResults[batch.getOperation()];
if (batchResults.empty())
batchResults.resize(batch.getNumResults());
auto& laneResults = batchResults[outputRef.resultIndex];
if (laneResults.empty())
laneResults.resize(static_cast<size_t>(batch.getLaneCount()));
laneResults[task.computeInstance.laneStart] = newCompute.getResult(resultIndex);
continue;
program.op = newCompute;
for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[leader]))
program.weightToIndex[weight] = weightIndex;
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[leader]))
program.externalInputMap[input] = newCompute.getInputArgument(inputIndex);
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[leader])) {
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
if (isResultfulBatchInstance(task.computeInstance)) {
auto oldBatch = cast<SpatComputeBatch>(task.computeInstance.op);
auto& batchResults = resultfulBatchLaneResults[oldBatch.getOperation()];
if (batchResults.empty())
batchResults.resize(oldBatch.getNumResults());
auto& laneResults = batchResults[outputRef.resultIndex];
if (laneResults.empty())
laneResults.resize(static_cast<size_t>(oldBatch.getLaneCount()));
laneResults[task.computeInstance.laneStart] = newCompute.getResult(resultIndex);
continue;
}
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] =
newCompute.getResult(resultIndex);
}
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] =
newCompute.getResult(resultIndex);
}
cpuPrograms[programKey] = std::move(program);
else {
// Equivalence Class Batch Execution
auto newBatch = SpatComputeBatch::create(rewriter,
loc,
TypeRange(packedResultTypes),
rewriter.getI32IntegerAttr(batch.size()),
cpuWeights[leader],
cpuExternalInputs[leader]);
newBatch.getProperties().setOperandSegmentSizes(
{static_cast<int>(cpuWeights[leader].size()), static_cast<int>(cpuExternalInputs[leader].size())});
SmallVector<int32_t> coreIds;
for (size_t c : batch)
coreIds.push_back(static_cast<int32_t>(c));
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.push_back(rewriter.getIndexType()); // Lane ID Argument
blockArgLocs.push_back(loc);
size_t weightsPerLane = cpuWeights[leader].size() / batch.size();
size_t inputsPerLane = cpuExternalInputs[leader].size() / batch.size();
for (size_t i = 0; i < weightsPerLane; ++i) {
blockArgTypes.push_back(cpuWeights[leader][i].getType());
blockArgLocs.push_back(loc);
}
for (size_t i = 0; i < inputsPerLane; ++i) {
blockArgTypes.push_back(cpuExternalInputs[leader][i].getType());
blockArgLocs.push_back(loc);
}
for (Type t : packedResultTypes) {
blockArgTypes.push_back(t);
blockArgLocs.push_back(loc);
} // Dest tensors for InParallel Yield
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
program.op = newBatch;
// Host-side slice extractions
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(newBatch);
for (auto [laneIndex, cpu] : llvm::enumerate(batch)) {
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[leader])) {
size_t taskIdx = 0;
for (size_t i = 0; i < tasksByCpu[leader].size(); ++i) {
if (tasksByCpu[leader][i].computeInstance == outputRef.instance) {
taskIdx = i;
break;
}
}
ComputeInstance laneInstance = tasksByCpu[cpu][taskIdx].computeInstance;
auto ranked = cast<RankedTensorType>(resultTypes[resultIndex]);
SmallVector<OpFoldResult> offsets;
offsets.push_back(rewriter.getIndexAttr(laneIndex));
SmallVector<OpFoldResult> sizes;
sizes.push_back(rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> strides;
strides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim : ranked.getShape()) {
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(dim));
strides.push_back(rewriter.getIndexAttr(1));
}
auto slice = tensor::ExtractSliceOp::create(
rewriter, loc, ranked, newBatch.getResult(resultIndex), offsets, sizes, strides);
if (isResultfulBatchInstance(laneInstance)) {
auto oldBatch = cast<SpatComputeBatch>(laneInstance.op);
auto& batchResults = resultfulBatchLaneResults[oldBatch.getOperation()];
if (batchResults.empty())
batchResults.resize(oldBatch.getNumResults());
auto& laneValues = batchResults[outputRef.resultIndex];
if (laneValues.empty())
laneValues.resize(static_cast<size_t>(oldBatch.getLaneCount()));
laneValues[laneInstance.laneStart] = slice.getResult();
}
else {
oldToNewExternalValueMap[getComputeInstanceOutputValues(laneInstance)[outputRef.resultIndex]] =
slice.getResult();
}
}
}
for (size_t i = 0; i < weightsPerLane; ++i)
program.weightToIndex[cpuWeights[leader][i]] = i;
for (size_t i = 0; i < inputsPerLane; ++i)
program.externalInputMap[cpuExternalInputs[leader][i]] =
newBatch.getBody().front().getArgument(1 + weightsPerLane + i);
}
cpuPrograms[leader] = std::move(program);
}
}
@@ -888,130 +950,105 @@ private:
DenseMap<size_t, DenseMap<uint64_t, size_t>> receiveQueueIndicesByCpu;
DenseMap<size_t, DenseMap<ComputeInstance, SmallVector<Value>>> preReceivedInputsByCpu;
for (ProgramKey programKey : orderedPrograms) {
size_t cpu = programKey.first;
CpuProgram& program = cpuPrograms[programKey];
auto lookupPreReceivedInput = [&](DenseMap<ComputeInstance, SmallVector<Value>>& preReceivedInputsByTask,
ComputeInstance consumer,
size_t inputIndex) -> std::optional<Value> {
auto inputsIt = preReceivedInputsByTask.find(consumer);
if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex)
return std::nullopt;
Value value = inputsIt->second[inputIndex];
if (!value)
return std::nullopt;
return value;
};
for (ProgramKey leader : orderedPrograms) {
const auto& batch = batchedCpus[leader];
bool isBatch = batch.size() > 1;
CpuProgram& program = cpuPrograms[leader];
IRRewriter rewriter(func.getContext());
rewriter.setInsertionPointToEnd(&program.op.getBody().front());
auto& receiveQueueIndices = receiveQueueIndicesByCpu[cpu];
auto& preReceivedInputsByTask = preReceivedInputsByCpu[cpu];
rewriter.setInsertionPointToEnd(&program.op->getRegion(0).front());
auto lookupPreReceivedInput = [&](ComputeInstance consumer, size_t inputIndex) -> std::optional<Value> {
auto inputsIt = preReceivedInputsByTask.find(consumer);
if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex)
return std::nullopt;
Value value = inputsIt->second[inputIndex];
if (!value)
return std::nullopt;
return value;
};
auto& receiveQueueIndices = receiveQueueIndicesByCpu[leader];
auto& preReceivedInputsByTask = preReceivedInputsByCpu[leader];
ArrayRef<ScheduledTask> programTasks = tasksByProgram[programKey];
for (size_t taskIndex = 0; taskIndex < programTasks.size(); ++taskIndex) {
const ScheduledTask& task = programTasks[taskIndex];
ArrayRef<ScheduledTask> leaderTasks = tasksByCpu[leader];
for (size_t taskIndex = 0; taskIndex < leaderTasks.size(); ++taskIndex) {
const ScheduledTask& task = leaderTasks[taskIndex];
SmallVector<Value> taskInputs = getComputeInstanceInputs(task.computeInstance);
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance);
SmallVector<Value> resolvedInputs;
resolvedInputs.reserve(taskInputs.size());
auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance);
auto remoteTensorInputsIt = remoteTensorInputsByTask.find(task.computeInstance);
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
if (remoteTensorInputsIt != remoteTensorInputsByTask.end() && inputIndex < remoteTensorInputsIt->second.size()
&& remoteTensorInputsIt->second[inputIndex]) {
const TensorChannelInfo& tensorInfo = *remoteTensorInputsIt->second[inputIndex];
bool hasLocalProducer = llvm::is_contained(tensorInfo.sourceCoreIds, static_cast<int32_t>(cpu));
if (!hasLocalProducer) {
auto receive = spatial::SpatChannelReceiveTensorOp::create(
rewriter,
loc,
input.getType(),
createIndexConstants(program.op, tensorInfo.channelIds, constantFolder),
createIndexConstants(program.op, tensorInfo.sourceCoreIds, constantFolder),
createIndexConstants(program.op, tensorInfo.targetCoreIds, constantFolder));
resolvedInputs.push_back(receive.getOutput());
continue;
}
SmallVector<Value> laneValues;
laneValues.reserve(tensorInfo.producerInstances.size());
for (auto [laneIndex, producerInstance] : llvm::enumerate(tensorInfo.producerInstances)) {
if (tensorInfo.sourceCoreIds[laneIndex] == static_cast<int32_t>(cpu)) {
auto producedIt = producedValuesByTask.find(producerInstance);
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= tensorInfo.resultIndex) {
task.computeInstance.op->emitOpError(
"missing local tensor lane producer during merge materialization")
<< " consumerCpu=" << cpu << " producerLaneStart=" << producerInstance.laneStart;
return failure();
}
laneValues.push_back(producedIt->second[tensorInfo.resultIndex]);
continue;
}
auto producerTaskIt = taskByComputeInstance.find(producerInstance);
if (producerTaskIt == taskByComputeInstance.end())
return failure();
Type laneType = getTaskOutputTypes(producerTaskIt->second.computeInstance)[tensorInfo.resultIndex];
Value channelId = createIndexConstant(program.op, tensorInfo.channelIds[laneIndex], constantFolder);
Value sourceCoreId = createIndexConstant(program.op, tensorInfo.sourceCoreIds[laneIndex], constantFolder);
Value targetCoreId = createIndexConstant(program.op, tensorInfo.targetCoreIds[laneIndex], constantFolder);
auto receive =
spatial::SpatChannelReceiveOp::create(rewriter, loc, laneType, channelId, sourceCoreId, targetCoreId);
laneValues.push_back(receive.getResult());
}
Value packedInput = tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, ValueRange(laneValues)).getResult();
resolvedInputs.push_back(packedInput);
continue;
}
auto producerRef = getProducerValueRef(input, &task.computeInstance);
if (producerRef) {
auto producerIt = taskByComputeInstance.find(producerRef->instance);
if (producerIt != taskByComputeInstance.end()) {
if (producerIt->second.cpu == cpu) {
if (producerIt->second.cpu == leader) {
auto producedIt = producedValuesByTask.find(producerRef->instance);
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
task.computeInstance.op->emitOpError(
"missing local producer value during per-cpu merge materialization")
<< " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu
<< " producerLaneStart=" << producerRef->instance.laneStart
<< " producerLaneCount=" << producerRef->instance.laneCount;
return failure();
}
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex)
return task.computeInstance.op->emitOpError("missing local producer value");
resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]);
continue;
}
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
if (pairsNeedingReceiveReorder.contains(pairKey)) {
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.computeInstance, inputIndex)) {
resolvedInputs.push_back(*preReceived);
if (isBatch) {
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
for (size_t cpu : batch) {
const ScheduledTask& laneTask = tasksByCpu[cpu][taskIndex];
const ChannelInfo& info = *remoteInputsByTask[laneTask.computeInstance][inputIndex];
channelIds.push_back(info.channelId);
sourceCoreIds.push_back(info.sourceCoreId);
targetCoreIds.push_back(info.targetCoreId);
}
SmallVector<Value> cIds = createIndexConstants(program.op, channelIds, constantFolder);
SmallVector<Value> sIds = createIndexConstants(program.op, sourceCoreIds, constantFolder);
SmallVector<Value> tIds = createIndexConstants(program.op, targetCoreIds, constantFolder);
auto recv =
spatial::SpatChannelReceiveBatchOp::create(rewriter, loc, input.getType(), cIds, sIds, tIds);
resolvedInputs.push_back(recv.getOutput());
}
else {
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
if (pairsNeedingReceiveReorder.contains(pairKey)) {
if (std::optional<Value> preReceived =
lookupPreReceivedInput(preReceivedInputsByTask, task.computeInstance, inputIndex)) {
resolvedInputs.push_back(*preReceived);
continue;
}
FailureOr<Value> received = receiveThroughInput(rewriter,
leader,
receiveQueueIndices,
preReceivedInputsByTask,
channelInfo,
task.computeInstance,
inputIndex);
if (failed(received))
return task.computeInstance.op->emitOpError("failed to materialize reordered remote receive");
resolvedInputs.push_back(*received);
continue;
}
FailureOr<Value> received = receiveThroughInput(rewriter,
cpu,
receiveQueueIndices,
preReceivedInputsByTask,
channelInfo,
task.computeInstance,
inputIndex);
if (failed(received)) {
task.computeInstance.op->emitOpError("failed to materialize reordered remote receive")
<< " consumerCpu=" << cpu << " sourceCoreId=" << channelInfo.sourceCoreId
<< " targetCoreId=" << channelInfo.targetCoreId << " channelId=" << channelInfo.channelId;
return failure();
}
resolvedInputs.push_back(*received);
continue;
Value cId = createIndexConstant(program.op, channelInfo.channelId, constantFolder);
Value sId = createIndexConstant(program.op, channelInfo.sourceCoreId, constantFolder);
Value tId = createIndexConstant(program.op, channelInfo.targetCoreId, constantFolder);
auto receive = spatial::SpatChannelReceiveOp::create(rewriter, loc, input.getType(), cId, sId, tId);
resolvedInputs.push_back(receive.getResult());
}
Value channelId = createIndexConstant(program.op, channelInfo.channelId, constantFolder);
Value sourceCoreId = createIndexConstant(program.op, channelInfo.sourceCoreId, constantFolder);
Value targetCoreId = createIndexConstant(program.op, channelInfo.targetCoreId, constantFolder);
auto receive = spatial::SpatChannelReceiveOp::create(
rewriter, loc, input.getType(), channelId, sourceCoreId, targetCoreId);
resolvedInputs.push_back(receive.getResult());
continue;
}
}
@@ -1019,12 +1056,17 @@ private:
}
SmallVector<Value> taskYieldValues;
rewriter.setInsertionPointToEnd(&program.op.getBody().front());
rewriter.setInsertionPointToEnd(&program.op->getRegion(0).front());
if (isa<SpatCompute>(task.computeInstance.op)) {
IRMapping mapper;
auto compute = cast<SpatCompute>(task.computeInstance.op);
for (auto [weightIndex, weight] : llvm::enumerate(taskWeights))
mapper.map(compute.getWeightArgument(weightIndex), program.op.getWeightArgument(program.weightToIndex.at(weight)));
for (auto [weightIndex, weight] : llvm::enumerate(taskWeights)) {
Value destArg = isBatch
? cast<SpatComputeBatch>(program.op).getWeightArgument(program.weightToIndex.at(weight))
: cast<SpatCompute>(program.op).getWeightArgument(program.weightToIndex.at(weight));
mapper.map(compute.getWeightArgument(weightIndex), destArg);
}
for (auto [inputIndex, input] : llvm::enumerate(resolvedInputs))
mapper.map(compute.getInputArgument(inputIndex), input);
@@ -1034,138 +1076,104 @@ private:
taskYieldValues.push_back(mapper.lookup(yieldOperand));
continue;
}
rewriter.clone(op, mapper);
}
}
else {
auto batch = cast<SpatComputeBatch>(task.computeInstance.op);
if (batch.getNumResults() != 0) {
IRMapping mapper;
Value laneValue = getOrCreateHostIndexConstant(
program.op, static_cast<int64_t>(task.computeInstance.laneStart), constantFolder);
mapper.map(batch.getLaneArgument(), laneValue);
for (auto [weightIndex, weight] : llvm::enumerate(taskWeights))
mapper.map(batch.getWeightArgument(weightIndex), program.op.getWeightArgument(program.weightToIndex.at(weight)));
for (auto [inputIndex, input] : llvm::enumerate(resolvedInputs))
mapper.map(batch.getInputArgument(inputIndex), input);
for (Operation& op : templateBlock.without_terminator()) {
Operation* clonedOp = rewriter.clone(op, mapper);
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
}
FailureOr<SmallVector<BatchYieldInfo>> yieldInfo = collectResultfulBatchYieldInfo(batch);
if (failed(yieldInfo))
return task.computeInstance.op->emitOpError("failed to collect resultful batch yield info");
for (const BatchYieldInfo& info : *yieldInfo)
taskYieldValues.push_back(mapper.lookup(info.yieldedValue));
}
else {
size_t batchLaneCount = static_cast<size_t>(batch.getLaneCount());
size_t inputsPerLane = batchLaneCount == 0 ? 0 : batch.getInputs().size() / batchLaneCount;
size_t weightsPerLane = batchLaneCount == 0 ? 0 : batch.getWeights().size() / batchLaneCount;
size_t loopRunLength = getLoopableResultlessBatchRunLength(programTasks, taskIndex);
if (loopRunLength > 1) {
Value lower = getOrCreateHostIndexConstant(
program.op, static_cast<int64_t>(task.computeInstance.laneStart), constantFolder);
Value upper = getOrCreateHostIndexConstant(
program.op, static_cast<int64_t>(task.computeInstance.laneStart + loopRunLength), constantFolder);
Value step = getOrCreateHostIndexConstant(program.op, 1, constantFolder);
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {});
rewriter.setInsertionPointToStart(loop.getBody());
IRMapping mapper;
mapper.map(batch.getLaneArgument(), loop.getInductionVar());
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
mapper.map(batch.getWeightArgument(weightIndex),
program.op.getWeightArgument(program.weightToIndex.at(taskWeights[weightIndex])));
for (size_t inputIndex = 0; inputIndex < inputsPerLane; ++inputIndex)
mapper.map(batch.getInputArgument(inputIndex), resolvedInputs[inputIndex]);
for (Operation& op : templateBlock) {
if (isa<spatial::SpatYieldOp>(&op))
continue;
Operation* clonedOp = rewriter.clone(op, mapper);
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
}
rewriter.setInsertionPointAfter(loop);
taskIndex += loopRunLength - 1;
}
else {
for (size_t laneOffset = 0; laneOffset < task.computeInstance.laneCount; ++laneOffset) {
IRMapping mapper;
Value laneValue = getOrCreateHostIndexConstant(
program.op, static_cast<int64_t>(task.computeInstance.laneStart + laneOffset), constantFolder);
mapper.map(batch.getLaneArgument(), laneValue);
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
mapper.map(batch.getWeightArgument(weightIndex),
program.op.getWeightArgument(
program.weightToIndex.at(taskWeights[laneOffset * weightsPerLane + weightIndex])));
for (size_t inputIndex = 0; inputIndex < inputsPerLane; ++inputIndex)
mapper.map(batch.getInputArgument(inputIndex), resolvedInputs[laneOffset * inputsPerLane + inputIndex]);
for (Operation& op : templateBlock) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
for (Value yieldOperand : yield.getOperands())
taskYieldValues.push_back(mapper.lookup(yieldOperand));
continue;
}
Operation* clonedOp = rewriter.clone(op, mapper);
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
}
}
}
}
// Include your existing isolated logic for preserving resultless spat.compute_batch here if needed
}
producedValuesByTask[task.computeInstance] = taskYieldValues;
if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) {
for (size_t resultIndex = 0; resultIndex < sendsIt->second.size();) {
const SmallVector<RemoteSendInfo>& sendInfos = sendsIt->second[resultIndex];
if (sendInfos.empty())
{
if (sendInfos.empty()) {
++resultIndex;
continue;
}
size_t nextResultIndex = resultIndex + 1;
if (tryEmitCompactSendLoops(
program.op, rewriter, sendsIt->second, taskYieldValues, resultIndex, nextResultIndex)) {
resultIndex = nextResultIndex;
continue;
}
if (isBatch) {
size_t numSends = sendInfos.size();
for (size_t s = 0; s < numSends; ++s) {
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
Value producedValue = taskYieldValues[resultIndex];
for (const RemoteSendInfo& sendInfo : sendInfos) {
Value channelId = createIndexConstant(program.op, sendInfo.channelInfo.channelId, constantFolder);
Value sourceCoreId = createIndexConstant(program.op, sendInfo.channelInfo.sourceCoreId, constantFolder);
Value targetCoreId = createIndexConstant(program.op, sendInfo.channelInfo.targetCoreId, constantFolder);
spatial::SpatChannelSendOp::create(rewriter, loc, channelId, sourceCoreId, targetCoreId, producedValue);
for (size_t cpu : batch) {
const ScheduledTask& laneTask = tasksByCpu[cpu][taskIndex];
const RemoteSendInfo& send = remoteSendsByTask[laneTask.computeInstance][resultIndex][s];
channelIds.push_back(send.channelInfo.channelId);
sourceCoreIds.push_back(send.channelInfo.sourceCoreId);
targetCoreIds.push_back(send.channelInfo.targetCoreId);
}
SmallVector<Value> cIds = createIndexConstants(program.op, channelIds, constantFolder);
SmallVector<Value> sIds = createIndexConstants(program.op, sourceCoreIds, constantFolder);
SmallVector<Value> tIds = createIndexConstants(program.op, targetCoreIds, constantFolder);
spatial::SpatChannelSendBatchOp::create(rewriter, loc, cIds, sIds, tIds, taskYieldValues[resultIndex]);
}
++resultIndex;
}
else {
size_t nextResultIndex = resultIndex + 1;
if (tryEmitCompactSendLoops(
program.op, rewriter, sendsIt->second, taskYieldValues, resultIndex, nextResultIndex)) {
resultIndex = nextResultIndex;
continue;
}
Value producedValue = taskYieldValues[resultIndex];
for (const RemoteSendInfo& sendInfo : sendInfos) {
Value cId = createIndexConstant(program.op, sendInfo.channelInfo.channelId, constantFolder);
Value sId = createIndexConstant(program.op, sendInfo.channelInfo.sourceCoreId, constantFolder);
Value tId = createIndexConstant(program.op, sendInfo.channelInfo.targetCoreId, constantFolder);
spatial::SpatChannelSendOp::create(rewriter, loc, cId, sId, tId, producedValue);
}
++resultIndex;
}
++resultIndex;
}
}
}
SmallVector<Value> yieldValues;
yieldValues.reserve(cpuExternalOutputs[programKey].size());
for (ProducerValueRef outputRef : cpuExternalOutputs[programKey]) {
yieldValues.reserve(cpuExternalOutputs[leader].size());
for (ProducerValueRef outputRef : cpuExternalOutputs[leader]) {
auto producedIt = producedValuesByTask.find(outputRef.instance);
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) {
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
task.computeInstance.op->emitOpError("missing yielded external value during per-cpu merge materialization")
<< " cpu=" << cpu << " laneStart=" << outputRef.instance.laneStart;
return failure();
}
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex)
return func.emitError("missing yielded external value during materialization");
yieldValues.push_back(producedIt->second[outputRef.resultIndex]);
}
spatial::SpatYieldOp::create(rewriter, loc, ValueRange(yieldValues));
if (isBatch) {
auto batchOp = cast<SpatComputeBatch>(program.op);
auto inParallel = spatial::SpatInParallelOp::create(rewriter, loc);
Block* parallelBlock = rewriter.createBlock(&inParallel.getRegion());
rewriter.setInsertionPointToEnd(parallelBlock);
for (auto [resultIndex, yieldedVal] : llvm::enumerate(yieldValues)) {
auto destArg = batchOp.getOutputArgument(resultIndex);
auto destType = cast<RankedTensorType>(destArg.getType());
SmallVector<OpFoldResult> offsets;
offsets.push_back(batchOp.getLaneArgument());
SmallVector<OpFoldResult> sizes;
sizes.push_back(rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> strides;
strides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim : destType.getShape().drop_front()) {
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(dim));
strides.push_back(rewriter.getIndexAttr(1));
}
tensor::ParallelInsertSliceOp::create(rewriter, loc, yieldedVal, destArg, offsets, sizes, strides);
}
}
else {
spatial::SpatYieldOp::create(rewriter, loc, ValueRange(yieldValues));
}
}
return success();
@@ -1245,7 +1253,6 @@ private:
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
DenseMap<ProgramKey, SmallVector<ScheduledTask>> tasksByProgram;
SmallVector<size_t> orderedCpus;
SmallVector<ProgramKey> orderedPrograms;
DenseSet<size_t> seenCpus;
DenseSet<ProgramKey> seenPrograms;
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
@@ -19,6 +19,7 @@ struct MergeScheduleResult {
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
};
} // namespace spatial
@@ -1,6 +1,7 @@
#include "mlir/IR/Threading.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
@@ -274,7 +275,64 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder;
});
// 5. Populate Final Result
// 5. Check if equal schedule in two level
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
for (size_t currentProcessor = 0; currentProcessor < processorCount - 1; ++currentProcessor) {
for (size_t controlProcessor = currentProcessor + 1; controlProcessor < processorCount; ++controlProcessor) {
if (tasksByProcessor[currentProcessor].size() != tasksByProcessor[controlProcessor].size())
continue;
auto& currentTasks = tasksByProcessor[currentProcessor];
auto& controlTasks = tasksByProcessor[controlProcessor];
bool equalSchedule = true;
for (auto [currentTask, controlTask] : llvm::zip(currentTasks, controlTasks)) {
const ComputeInstance currentComputeInstance = graph.nodes[currentTask].instance;
const ComputeInstance controlComputeInstance = graph.nodes[controlTask].instance;
if (currentComputeInstance.op != controlComputeInstance.op) {
equalSchedule = false;
break;
}
}
if (equalSchedule) {
equivalentClass[currentProcessor].push_back(controlProcessor);
equivalentClass[controlProcessor].push_back(currentProcessor);
}
}
}
{
llvm::dbgs() << "--- Scheduling Equivalence Classes ---\n";
std::vector<bool> visited(processorCount, false);
size_t uniqueClassCount = 0;
for (size_t i = 0; i < processorCount; ++i) {
if (visited[i])
continue;
// We found a new unique schedule (equivalence class)
++uniqueClassCount;
visited[i] = true;
llvm::dbgs() << "Class " << uniqueClassCount << ": CPUs { " << i;
// Find and mark all identical companions
auto it = equivalentClass.find(i);
if (it != equivalentClass.end()) {
for (size_t eqCpu : it->second) {
if (!visited[eqCpu]) {
llvm::dbgs() << ", " << eqCpu;
visited[eqCpu] = true;
}
}
}
llvm::dbgs() << " }\n";
}
llvm::dbgs() << "Total unique CPU nodes to emit: " << uniqueClassCount << "\n";
llvm::dbgs() << "--------------------------------------\n";
}
// 6. Populate Final Result
MergeScheduleResult result;
result.dominanceOrderCompute.reserve(nodeCount);
@@ -296,8 +354,9 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
}
}
result.equivalentClass = equivalentClass;
return result;
}
} // namespace spatial
} // namespace onnx_mlir