Equivalent Class but broken

This commit is contained in:
ilgeco
2026-05-21 14:43:59 +02:00
parent a50e77ff38
commit fe35b3ed43
3 changed files with 481 additions and 414 deletions
@@ -58,18 +58,15 @@ static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int
} }
static Value createIndexTensorConstant(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) { static Value createIndexTensorConstant(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
auto tensorType = RankedTensorType::get({static_cast<int64_t>(values.size())}, IndexType::get(anchorOp->getContext())); auto tensorType =
RankedTensorType::get({static_cast<int64_t>(values.size())}, IndexType::get(anchorOp->getContext()));
auto tensorAttr = DenseIntElementsAttr::get(tensorType, values); auto tensorAttr = DenseIntElementsAttr::get(tensorType, values);
return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder); return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder);
} }
static Value createIndexTupleTensorConstant(Operation* anchorOp, static Value createIndexTupleTensorConstant(
int64_t tupleCount, Operation* anchorOp, int64_t tupleCount, int64_t tupleWidth, ArrayRef<int64_t> values, OperationFolder& folder) {
int64_t tupleWidth, auto tensorType = RankedTensorType::get({tupleCount, tupleWidth}, IndexType::get(anchorOp->getContext()));
ArrayRef<int64_t> values,
OperationFolder& folder) {
auto tensorType =
RankedTensorType::get({tupleCount, tupleWidth}, IndexType::get(anchorOp->getContext()));
auto tensorAttr = DenseIntElementsAttr::get(tensorType, values); auto tensorAttr = DenseIntElementsAttr::get(tensorType, values);
return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder); return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder);
} }
@@ -114,12 +111,14 @@ private:
}; };
struct CpuProgram { struct CpuProgram {
SpatCompute op; Operation* op = nullptr;
DenseMap<Value, Value> externalInputMap; DenseMap<Value, Value> externalInputMap;
DenseMap<Value, size_t> weightToIndex; DenseMap<Value, size_t> weightToIndex;
}; };
using ProgramKey = std::pair<size_t, size_t>; using ProgramKey = size_t; // Represents the "Leader" CPU
DenseMap<ProgramKey, SmallVector<size_t>> batchedCpus;
SmallVector<ProgramKey> orderedPrograms;
struct RemoteSendInfo { struct RemoteSendInfo {
ChannelInfo channelInfo; ChannelInfo channelInfo;
@@ -171,7 +170,7 @@ private:
| static_cast<uint32_t>(channelInfo.targetCoreId); | static_cast<uint32_t>(channelInfo.targetCoreId);
} }
static ProgramKey getProgramKey(const ScheduledTask& task) { return {task.cpu, 0}; } static ProgramKey getProgramKey(const ScheduledTask& task) { return task.cpu; }
static bool isResultfulBatchInstance(const ComputeInstance& instance) { static bool isResultfulBatchInstance(const ComputeInstance& instance) {
auto batch = dyn_cast<SpatComputeBatch>(instance.op); auto batch = dyn_cast<SpatComputeBatch>(instance.op);
@@ -377,12 +376,12 @@ private:
void emitExtractRowsSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractRowsSendRun& run) { void emitExtractRowsSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractRowsSendRun& run) {
SmallVector<int64_t> prefixSums = buildPrefixSums(run.sendCounts); SmallVector<int64_t> prefixSums = buildPrefixSums(run.sendCounts);
Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder); Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder);
Value channelSourceTargetTuples = createIndexTupleTensorConstant( Value channelSourceTargetTuples =
hostAnchor, createIndexTupleTensorConstant(hostAnchor,
static_cast<int64_t>(run.channelSourceTargetTuples.size() / 3), static_cast<int64_t>(run.channelSourceTargetTuples.size() / 3),
3, 3,
run.channelSourceTargetTuples, run.channelSourceTargetTuples,
constantFolder); constantFolder);
Value lower = getOrCreateHostIndexConstant(hostAnchor, 0, constantFolder); Value lower = getOrCreateHostIndexConstant(hostAnchor, 0, constantFolder);
Value upper = getOrCreateHostIndexConstant(hostAnchor, static_cast<int64_t>(run.sendCounts.size()), constantFolder); Value upper = getOrCreateHostIndexConstant(hostAnchor, static_cast<int64_t>(run.sendCounts.size()), constantFolder);
@@ -413,8 +412,7 @@ private:
.getResult(); .getResult();
Value nextRowIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step); Value nextRowIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step);
Value innerLower = Value innerLower = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()});
tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()});
Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextRowIndex}); Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextRowIndex});
emitInnerSendLoop(hostAnchor, rewriter, extractedRow, innerLower, innerUpper, channelSourceTargetTuples); emitInnerSendLoop(hostAnchor, rewriter, extractedRow, innerLower, innerUpper, channelSourceTargetTuples);
rewriter.setInsertionPointAfter(outerLoop); rewriter.setInsertionPointAfter(outerLoop);
@@ -423,12 +421,12 @@ private:
void emitExtractSliceSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractSliceSendRun& run) { void emitExtractSliceSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractSliceSendRun& run) {
SmallVector<int64_t> prefixSums = buildPrefixSums(run.sendCounts); SmallVector<int64_t> prefixSums = buildPrefixSums(run.sendCounts);
Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder); Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder);
Value channelSourceTargetTuples = createIndexTupleTensorConstant( Value channelSourceTargetTuples =
hostAnchor, createIndexTupleTensorConstant(hostAnchor,
static_cast<int64_t>(run.channelSourceTargetTuples.size() / 3), static_cast<int64_t>(run.channelSourceTargetTuples.size() / 3),
3, 3,
run.channelSourceTargetTuples, run.channelSourceTargetTuples,
constantFolder); constantFolder);
Value lower = getOrCreateHostIndexConstant(hostAnchor, 0, constantFolder); Value lower = getOrCreateHostIndexConstant(hostAnchor, 0, constantFolder);
Value upper = getOrCreateHostIndexConstant(hostAnchor, static_cast<int64_t>(run.sendCounts.size()), constantFolder); Value upper = getOrCreateHostIndexConstant(hostAnchor, static_cast<int64_t>(run.sendCounts.size()), constantFolder);
@@ -469,8 +467,7 @@ private:
.getResult(); .getResult();
Value nextSliceIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step); Value nextSliceIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step);
Value innerLower = Value innerLower = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()});
tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()});
Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextSliceIndex}); Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextSliceIndex});
emitInnerSendLoop(hostAnchor, rewriter, extractedSlice, innerLower, innerUpper, channelSourceTargetTuples); emitInnerSendLoop(hostAnchor, rewriter, extractedSlice, innerLower, innerUpper, channelSourceTargetTuples);
rewriter.setInsertionPointAfter(outerLoop); rewriter.setInsertionPointAfter(outerLoop);
@@ -532,149 +529,112 @@ private:
} }
void buildTaskIndex() { void buildTaskIndex() {
auto markCpuSeen = [&](size_t cpu) { DenseSet<size_t> seen;
if (seenCpus.insert(cpu).second)
orderedCpus.push_back(cpu);
};
for (const ScheduledTask& task : scheduledTasks) { for (const ScheduledTask& task : scheduledTasks) {
taskByComputeInstance[task.computeInstance] = task; taskByComputeInstance[task.computeInstance] = task;
tasksByCpu[task.cpu].push_back(task); tasksByCpu[task.cpu].push_back(task);
ProgramKey programKey = getProgramKey(task); seen.insert(task.cpu);
tasksByProgram[programKey].push_back(task);
if (seenPrograms.insert(programKey).second)
orderedPrograms.push_back(programKey);
markCpuSeen(task.cpu);
} }
llvm::sort(orderedCpus); SmallVector<size_t> activeCpus(seen.begin(), seen.end());
llvm::sort(orderedPrograms, [](const ProgramKey& lhs, const ProgramKey& rhs) { llvm::sort(activeCpus);
if (lhs.second != rhs.second)
return lhs.second < rhs.second; DenseSet<size_t> batched;
return lhs.first < rhs.first; for (size_t cpu : activeCpus) {
}); if (batched.contains(cpu))
for (size_t cpu : orderedCpus) continue;
SmallVector<size_t> batch;
batch.push_back(cpu);
batched.insert(cpu);
// Group all equivalent CPUs into this batch
auto it = schedule->equivalentClass.find(cpu);
if (it != schedule->equivalentClass.end()) {
for (size_t eqCpu : it->second)
if (batched.insert(eqCpu).second)
batch.push_back(eqCpu);
}
llvm::sort(batch);
size_t leader = batch.front();
batchedCpus[leader] = batch;
orderedPrograms.push_back(leader);
}
for (size_t cpu : activeCpus) {
llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) { llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) {
return lhs.orderWithinCpu < rhs.orderWithinCpu; return lhs.orderWithinCpu < rhs.orderWithinCpu;
}); });
for (ProgramKey programKey : orderedPrograms) }
llvm::stable_sort(tasksByProgram[programKey], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) {
return lhs.orderWithinCpu < rhs.orderWithinCpu;
});
} }
void collectExternalInputsAndWeights() { void collectExternalInputsAndWeights() {
for (ProgramKey programKey : orderedPrograms) { for (ProgramKey leader : orderedPrograms) {
size_t cpu = programKey.first; const auto& batch = batchedCpus[leader];
for (const ScheduledTask& task : tasksByProgram[programKey]) { auto& thisCpuWeights = cpuWeights[leader];
auto& thisCpuWeights = cpuWeights[programKey]; auto& thisCpuInputs = cpuExternalInputs[leader];
auto& thisSeenWeights = seenWeightsByProgram[programKey]; auto& thisCpuOutputs = cpuExternalOutputs[leader];
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
for (Value weight : taskWeights)
if (thisSeenWeights.insert(weight).second)
thisCpuWeights.push_back(weight);
auto taskInputs = getComputeInstanceInputs(task.computeInstance); // Process every lane sequentially to pack operands
auto& remoteInputs = remoteInputsByTask[task.computeInstance]; for (size_t cpu : batch) {
remoteInputs.resize(taskInputs.size()); DenseSet<Value> laneSeenWeights;
auto& remoteTensorInputs = remoteTensorInputsByTask[task.computeInstance]; DenseSet<Value> laneSeenInputs;
remoteTensorInputs.resize(taskInputs.size());
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
bool isExternalInput = true;
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
producerBatch && producerBatch.getNumResults() != 0) {
size_t resultIndex = cast<OpResult>(input).getResultNumber();
TensorChannelInfo tensorInfo;
tensorInfo.resultIndex = resultIndex;
tensorInfo.channelIds.reserve(static_cast<size_t>(producerBatch.getLaneCount()));
tensorInfo.sourceCoreIds.reserve(static_cast<size_t>(producerBatch.getLaneCount()));
tensorInfo.targetCoreIds.reserve(static_cast<size_t>(producerBatch.getLaneCount()));
tensorInfo.producerInstances.reserve(static_cast<size_t>(producerBatch.getLaneCount()));
bool foundAllLaneProducers = true; for (const ScheduledTask& task : tasksByCpu[cpu]) {
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) { for (Value weight : getComputeInstanceWeights(task.computeInstance))
ComputeInstance producerInstance = getBatchChunkForLane(producerBatch, lane); if (laneSeenWeights.insert(weight).second)
auto producerIt = taskByComputeInstance.find(producerInstance); thisCpuWeights.push_back(weight);
if (producerIt == taskByComputeInstance.end()) {
foundAllLaneProducers = false;
break;
}
ChannelInfo info { auto taskInputs = getComputeInstanceInputs(task.computeInstance);
producerIt->second.cpu == cpu ? -1 : (*nextChannelId)++, auto& remoteInputs = remoteInputsByTask[task.computeInstance];
static_cast<int32_t>(producerIt->second.cpu), remoteInputs.resize(taskInputs.size());
static_cast<int32_t>(cpu),
};
tensorInfo.channelIds.push_back(info.channelId);
tensorInfo.sourceCoreIds.push_back(info.sourceCoreId);
tensorInfo.targetCoreIds.push_back(info.targetCoreId);
tensorInfo.producerInstances.push_back(producerInstance);
if (producerIt->second.cpu != cpu) { for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
auto& perResultChannels = remoteSendsByTask[producerInstance]; bool isExternalInput = true;
if (perResultChannels.empty()) if (auto producerRef = getProducerValueRef(input, &task.computeInstance)) {
perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size()); auto producerIt = taskByComputeInstance.find(producerRef->instance);
perResultChannels[resultIndex].push_back( if (producerIt != taskByComputeInstance.end()) {
{info, task.computeInstance, inputIndex, task.orderWithinCpu, 0, true}); isExternalInput = false;
if (producerIt->second.cpu != cpu) {
// Cross-core communication
ChannelInfo info;
info.channelId = (*nextChannelId)++;
info.sourceCoreId = static_cast<int32_t>(producerIt->second.cpu);
info.targetCoreId = static_cast<int32_t>(cpu);
remoteInputs[inputIndex] = info;
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
if (perResultChannels.empty())
perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size());
RemoteSendInfo sendInfo;
sendInfo.channelInfo = info;
sendInfo.consumer = task.computeInstance;
sendInfo.inputIndex = inputIndex;
sendInfo.consumerOrder = task.orderWithinCpu;
sendInfo.sourceOrder = 0;
sendInfo.isTensorInput = false;
perResultChannels[producerRef->resultIndex].push_back(sendInfo);
}
} }
} }
if (isExternalInput && laneSeenInputs.insert(input).second)
thisCpuInputs.push_back(input);
}
if (foundAllLaneProducers) { // Define the logical return types based strictly on the Leader
remoteTensorInputs[inputIndex] = std::move(tensorInfo); if (cpu == leader) {
continue; auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance);
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
bool hasExternalUser = false;
for (auto& use : output.getUses())
if (!oldComputeOps.contains(use.getOwner()))
hasExternalUser = true;
if (hasExternalUser)
thisCpuOutputs.push_back({task.computeInstance, resultIndex});
} }
} }
if (auto producerRef = getProducerValueRef(input, &task.computeInstance)) {
auto producerIt = taskByComputeInstance.find(producerRef->instance);
if (producerIt != taskByComputeInstance.end()) {
isExternalInput = false;
if (producerIt->second.cpu != cpu) {
ChannelInfo info {
(*nextChannelId)++,
static_cast<int32_t>(producerIt->second.cpu),
static_cast<int32_t>(cpu),
};
remoteInputs[inputIndex] = info;
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
if (perResultChannels.empty())
perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size());
perResultChannels[producerRef->resultIndex].push_back(
{info, task.computeInstance, inputIndex, task.orderWithinCpu, 0, false});
}
}
}
if (isExternalInput && seenExternalInputsByProgram[programKey].insert(input).second)
cpuExternalInputs[programKey].push_back(input);
}
if (isResultfulBatchInstance(task.computeInstance)) {
auto batch = cast<SpatComputeBatch>(task.computeInstance.op);
for (unsigned resultIndex = 0; resultIndex < batch.getNumResults(); ++resultIndex) {
bool hasExternalUser = false;
for (Operation* user : batch.getResult(resultIndex).getUsers()) {
if (!oldComputeOps.contains(user)) {
hasExternalUser = true;
break;
}
}
if (hasExternalUser)
cpuExternalOutputs[programKey].push_back({task.computeInstance, resultIndex});
}
continue;
}
auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance);
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
bool hasExternalUser = false;
for (auto& use : output.getUses()) {
Operation* useOwner = use.getOwner();
if (oldComputeOps.contains(useOwner))
continue;
hasExternalUser = true;
}
if (hasExternalUser)
cpuExternalOutputs[programKey].push_back({task.computeInstance, resultIndex});
} }
} }
} }
@@ -779,65 +739,167 @@ private:
void createCpuComputeOps() { void createCpuComputeOps() {
IRRewriter rewriter(func.getContext()); IRRewriter rewriter(func.getContext());
for (ProgramKey programKey : orderedPrograms) { for (ProgramKey leader : orderedPrograms) {
size_t cpu = programKey.first; const auto& batch = batchedCpus[leader];
SmallVector<Value> operands; bool isBatch = batch.size() > 1;
operands.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size());
llvm::append_range(operands, cpuWeights[programKey]);
llvm::append_range(operands, cpuExternalInputs[programKey]);
SmallVector<Type> resultTypes; SmallVector<Type> resultTypes;
resultTypes.reserve(cpuExternalOutputs[programKey].size()); SmallVector<Type> packedResultTypes;
for (ProducerValueRef outputRef : cpuExternalOutputs[programKey]) {
for (ProducerValueRef outputRef : cpuExternalOutputs[leader]) {
ScheduledTask task = taskByComputeInstance.at(outputRef.instance); ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
SmallVector<Type> outputTypes = getTaskOutputTypes(task.computeInstance); Type elemType = getTaskOutputTypes(task.computeInstance)[outputRef.resultIndex];
resultTypes.push_back(outputTypes[outputRef.resultIndex]); resultTypes.push_back(elemType);
if (isBatch) {
auto ranked = cast<RankedTensorType>(elemType);
SmallVector<int64_t> shape;
shape.push_back(static_cast<int64_t>(batch.size()));
shape.append(ranked.getShape().begin(), ranked.getShape().end());
packedResultTypes.push_back(RankedTensorType::get(shape, ranked.getElementType()));
}
} }
rewriter.setInsertionPoint(returnOp); rewriter.setInsertionPoint(returnOp);
auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands)); CpuProgram program;
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(cpuWeights[programKey].size()), static_cast<int>(cpuExternalInputs[programKey].size())});
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(static_cast<int32_t>(cpu)));
SmallVector<Type> blockArgTypes; if (!isBatch) {
SmallVector<Location> blockArgLocs; // Isolated CPU Execution
blockArgTypes.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size()); SmallVector<Value> operands;
blockArgLocs.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size()); operands.reserve(cpuWeights[leader].size() + cpuExternalInputs[leader].size());
for (Value weight : cpuWeights[programKey]) { llvm::append_range(operands, cpuWeights[leader]);
blockArgTypes.push_back(weight.getType()); llvm::append_range(operands, cpuExternalInputs[leader]);
blockArgLocs.push_back(loc);
} auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands));
for (Value input : cpuExternalInputs[programKey]) { newCompute.getProperties().setOperandSegmentSizes(
blockArgTypes.push_back(input.getType()); {static_cast<int>(cpuWeights[leader].size()), static_cast<int>(cpuExternalInputs[leader].size())});
blockArgLocs.push_back(loc); newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(static_cast<int32_t>(leader)));
}
Block* newBlock = SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
for (Value op : operands) {
blockArgTypes.push_back(op.getType());
blockArgLocs.push_back(loc);
}
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
CpuProgram program; program.op = newCompute;
program.op = newCompute; for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[leader]))
for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[programKey])) program.weightToIndex[weight] = weightIndex;
program.weightToIndex[weight] = weightIndex; for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[leader]))
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[programKey])) program.externalInputMap[input] = newCompute.getInputArgument(inputIndex);
program.externalInputMap[input] = newCompute.getInputArgument(inputIndex);
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[programKey])) { for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[leader])) {
ScheduledTask task = taskByComputeInstance.at(outputRef.instance); ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
if (isResultfulBatchInstance(task.computeInstance)) { if (isResultfulBatchInstance(task.computeInstance)) {
auto batch = cast<SpatComputeBatch>(task.computeInstance.op); auto oldBatch = cast<SpatComputeBatch>(task.computeInstance.op);
auto& batchResults = resultfulBatchLaneResults[batch.getOperation()]; auto& batchResults = resultfulBatchLaneResults[oldBatch.getOperation()];
if (batchResults.empty()) if (batchResults.empty())
batchResults.resize(batch.getNumResults()); batchResults.resize(oldBatch.getNumResults());
auto& laneResults = batchResults[outputRef.resultIndex]; auto& laneResults = batchResults[outputRef.resultIndex];
if (laneResults.empty()) if (laneResults.empty())
laneResults.resize(static_cast<size_t>(batch.getLaneCount())); laneResults.resize(static_cast<size_t>(oldBatch.getLaneCount()));
laneResults[task.computeInstance.laneStart] = newCompute.getResult(resultIndex); laneResults[task.computeInstance.laneStart] = newCompute.getResult(resultIndex);
continue; continue;
}
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] =
newCompute.getResult(resultIndex);
} }
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] =
newCompute.getResult(resultIndex);
} }
cpuPrograms[programKey] = std::move(program); else {
// Equivalence Class Batch Execution
auto newBatch = SpatComputeBatch::create(rewriter,
loc,
TypeRange(packedResultTypes),
rewriter.getI32IntegerAttr(batch.size()),
cpuWeights[leader],
cpuExternalInputs[leader]);
newBatch.getProperties().setOperandSegmentSizes(
{static_cast<int>(cpuWeights[leader].size()), static_cast<int>(cpuExternalInputs[leader].size())});
SmallVector<int32_t> coreIds;
for (size_t c : batch)
coreIds.push_back(static_cast<int32_t>(c));
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.push_back(rewriter.getIndexType()); // Lane ID Argument
blockArgLocs.push_back(loc);
size_t weightsPerLane = cpuWeights[leader].size() / batch.size();
size_t inputsPerLane = cpuExternalInputs[leader].size() / batch.size();
for (size_t i = 0; i < weightsPerLane; ++i) {
blockArgTypes.push_back(cpuWeights[leader][i].getType());
blockArgLocs.push_back(loc);
}
for (size_t i = 0; i < inputsPerLane; ++i) {
blockArgTypes.push_back(cpuExternalInputs[leader][i].getType());
blockArgLocs.push_back(loc);
}
for (Type t : packedResultTypes) {
blockArgTypes.push_back(t);
blockArgLocs.push_back(loc);
} // Dest tensors for InParallel Yield
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
program.op = newBatch;
// Host-side slice extractions
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(newBatch);
for (auto [laneIndex, cpu] : llvm::enumerate(batch)) {
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[leader])) {
size_t taskIdx = 0;
for (size_t i = 0; i < tasksByCpu[leader].size(); ++i) {
if (tasksByCpu[leader][i].computeInstance == outputRef.instance) {
taskIdx = i;
break;
}
}
ComputeInstance laneInstance = tasksByCpu[cpu][taskIdx].computeInstance;
auto ranked = cast<RankedTensorType>(resultTypes[resultIndex]);
SmallVector<OpFoldResult> offsets;
offsets.push_back(rewriter.getIndexAttr(laneIndex));
SmallVector<OpFoldResult> sizes;
sizes.push_back(rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> strides;
strides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim : ranked.getShape()) {
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(dim));
strides.push_back(rewriter.getIndexAttr(1));
}
auto slice = tensor::ExtractSliceOp::create(
rewriter, loc, ranked, newBatch.getResult(resultIndex), offsets, sizes, strides);
if (isResultfulBatchInstance(laneInstance)) {
auto oldBatch = cast<SpatComputeBatch>(laneInstance.op);
auto& batchResults = resultfulBatchLaneResults[oldBatch.getOperation()];
if (batchResults.empty())
batchResults.resize(oldBatch.getNumResults());
auto& laneValues = batchResults[outputRef.resultIndex];
if (laneValues.empty())
laneValues.resize(static_cast<size_t>(oldBatch.getLaneCount()));
laneValues[laneInstance.laneStart] = slice.getResult();
}
else {
oldToNewExternalValueMap[getComputeInstanceOutputValues(laneInstance)[outputRef.resultIndex]] =
slice.getResult();
}
}
}
for (size_t i = 0; i < weightsPerLane; ++i)
program.weightToIndex[cpuWeights[leader][i]] = i;
for (size_t i = 0; i < inputsPerLane; ++i)
program.externalInputMap[cpuExternalInputs[leader][i]] =
newBatch.getBody().front().getArgument(1 + weightsPerLane + i);
}
cpuPrograms[leader] = std::move(program);
} }
} }
@@ -888,130 +950,105 @@ private:
DenseMap<size_t, DenseMap<uint64_t, size_t>> receiveQueueIndicesByCpu; DenseMap<size_t, DenseMap<uint64_t, size_t>> receiveQueueIndicesByCpu;
DenseMap<size_t, DenseMap<ComputeInstance, SmallVector<Value>>> preReceivedInputsByCpu; DenseMap<size_t, DenseMap<ComputeInstance, SmallVector<Value>>> preReceivedInputsByCpu;
for (ProgramKey programKey : orderedPrograms) { auto lookupPreReceivedInput = [&](DenseMap<ComputeInstance, SmallVector<Value>>& preReceivedInputsByTask,
size_t cpu = programKey.first; ComputeInstance consumer,
CpuProgram& program = cpuPrograms[programKey]; size_t inputIndex) -> std::optional<Value> {
auto inputsIt = preReceivedInputsByTask.find(consumer);
if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex)
return std::nullopt;
Value value = inputsIt->second[inputIndex];
if (!value)
return std::nullopt;
return value;
};
for (ProgramKey leader : orderedPrograms) {
const auto& batch = batchedCpus[leader];
bool isBatch = batch.size() > 1;
CpuProgram& program = cpuPrograms[leader];
IRRewriter rewriter(func.getContext()); IRRewriter rewriter(func.getContext());
rewriter.setInsertionPointToEnd(&program.op.getBody().front()); rewriter.setInsertionPointToEnd(&program.op->getRegion(0).front());
auto& receiveQueueIndices = receiveQueueIndicesByCpu[cpu];
auto& preReceivedInputsByTask = preReceivedInputsByCpu[cpu];
auto lookupPreReceivedInput = [&](ComputeInstance consumer, size_t inputIndex) -> std::optional<Value> { auto& receiveQueueIndices = receiveQueueIndicesByCpu[leader];
auto inputsIt = preReceivedInputsByTask.find(consumer); auto& preReceivedInputsByTask = preReceivedInputsByCpu[leader];
if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex)
return std::nullopt;
Value value = inputsIt->second[inputIndex];
if (!value)
return std::nullopt;
return value;
};
ArrayRef<ScheduledTask> programTasks = tasksByProgram[programKey]; ArrayRef<ScheduledTask> leaderTasks = tasksByCpu[leader];
for (size_t taskIndex = 0; taskIndex < programTasks.size(); ++taskIndex) {
const ScheduledTask& task = programTasks[taskIndex]; for (size_t taskIndex = 0; taskIndex < leaderTasks.size(); ++taskIndex) {
const ScheduledTask& task = leaderTasks[taskIndex];
SmallVector<Value> taskInputs = getComputeInstanceInputs(task.computeInstance); SmallVector<Value> taskInputs = getComputeInstanceInputs(task.computeInstance);
auto taskWeights = getComputeInstanceWeights(task.computeInstance); auto taskWeights = getComputeInstanceWeights(task.computeInstance);
Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance); Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance);
SmallVector<Value> resolvedInputs; SmallVector<Value> resolvedInputs;
resolvedInputs.reserve(taskInputs.size()); resolvedInputs.reserve(taskInputs.size());
auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance); auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance);
auto remoteTensorInputsIt = remoteTensorInputsByTask.find(task.computeInstance);
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
if (remoteTensorInputsIt != remoteTensorInputsByTask.end() && inputIndex < remoteTensorInputsIt->second.size()
&& remoteTensorInputsIt->second[inputIndex]) {
const TensorChannelInfo& tensorInfo = *remoteTensorInputsIt->second[inputIndex];
bool hasLocalProducer = llvm::is_contained(tensorInfo.sourceCoreIds, static_cast<int32_t>(cpu));
if (!hasLocalProducer) {
auto receive = spatial::SpatChannelReceiveTensorOp::create(
rewriter,
loc,
input.getType(),
createIndexConstants(program.op, tensorInfo.channelIds, constantFolder),
createIndexConstants(program.op, tensorInfo.sourceCoreIds, constantFolder),
createIndexConstants(program.op, tensorInfo.targetCoreIds, constantFolder));
resolvedInputs.push_back(receive.getOutput());
continue;
}
SmallVector<Value> laneValues;
laneValues.reserve(tensorInfo.producerInstances.size());
for (auto [laneIndex, producerInstance] : llvm::enumerate(tensorInfo.producerInstances)) {
if (tensorInfo.sourceCoreIds[laneIndex] == static_cast<int32_t>(cpu)) {
auto producedIt = producedValuesByTask.find(producerInstance);
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= tensorInfo.resultIndex) {
task.computeInstance.op->emitOpError(
"missing local tensor lane producer during merge materialization")
<< " consumerCpu=" << cpu << " producerLaneStart=" << producerInstance.laneStart;
return failure();
}
laneValues.push_back(producedIt->second[tensorInfo.resultIndex]);
continue;
}
auto producerTaskIt = taskByComputeInstance.find(producerInstance);
if (producerTaskIt == taskByComputeInstance.end())
return failure();
Type laneType = getTaskOutputTypes(producerTaskIt->second.computeInstance)[tensorInfo.resultIndex];
Value channelId = createIndexConstant(program.op, tensorInfo.channelIds[laneIndex], constantFolder);
Value sourceCoreId = createIndexConstant(program.op, tensorInfo.sourceCoreIds[laneIndex], constantFolder);
Value targetCoreId = createIndexConstant(program.op, tensorInfo.targetCoreIds[laneIndex], constantFolder);
auto receive =
spatial::SpatChannelReceiveOp::create(rewriter, loc, laneType, channelId, sourceCoreId, targetCoreId);
laneValues.push_back(receive.getResult());
}
Value packedInput = tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, ValueRange(laneValues)).getResult();
resolvedInputs.push_back(packedInput);
continue;
}
auto producerRef = getProducerValueRef(input, &task.computeInstance); auto producerRef = getProducerValueRef(input, &task.computeInstance);
if (producerRef) { if (producerRef) {
auto producerIt = taskByComputeInstance.find(producerRef->instance); auto producerIt = taskByComputeInstance.find(producerRef->instance);
if (producerIt != taskByComputeInstance.end()) { if (producerIt != taskByComputeInstance.end()) {
if (producerIt->second.cpu == cpu) { if (producerIt->second.cpu == leader) {
auto producedIt = producedValuesByTask.find(producerRef->instance); auto producedIt = producedValuesByTask.find(producerRef->instance);
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) { if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex)
task.computeInstance.op->emitOpError( return task.computeInstance.op->emitOpError("missing local producer value");
"missing local producer value during per-cpu merge materialization")
<< " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu
<< " producerLaneStart=" << producerRef->instance.laneStart
<< " producerLaneCount=" << producerRef->instance.laneCount;
return failure();
}
resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]); resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]);
continue; continue;
} }
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
uint64_t pairKey = getRemoteSendPairKey(channelInfo); if (isBatch) {
if (pairsNeedingReceiveReorder.contains(pairKey)) { SmallVector<int64_t> channelIds;
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.computeInstance, inputIndex)) { SmallVector<int32_t> sourceCoreIds;
resolvedInputs.push_back(*preReceived); SmallVector<int32_t> targetCoreIds;
for (size_t cpu : batch) {
const ScheduledTask& laneTask = tasksByCpu[cpu][taskIndex];
const ChannelInfo& info = *remoteInputsByTask[laneTask.computeInstance][inputIndex];
channelIds.push_back(info.channelId);
sourceCoreIds.push_back(info.sourceCoreId);
targetCoreIds.push_back(info.targetCoreId);
}
SmallVector<Value> cIds = createIndexConstants(program.op, channelIds, constantFolder);
SmallVector<Value> sIds = createIndexConstants(program.op, sourceCoreIds, constantFolder);
SmallVector<Value> tIds = createIndexConstants(program.op, targetCoreIds, constantFolder);
auto recv =
spatial::SpatChannelReceiveBatchOp::create(rewriter, loc, input.getType(), cIds, sIds, tIds);
resolvedInputs.push_back(recv.getOutput());
}
else {
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
if (pairsNeedingReceiveReorder.contains(pairKey)) {
if (std::optional<Value> preReceived =
lookupPreReceivedInput(preReceivedInputsByTask, task.computeInstance, inputIndex)) {
resolvedInputs.push_back(*preReceived);
continue;
}
FailureOr<Value> received = receiveThroughInput(rewriter,
leader,
receiveQueueIndices,
preReceivedInputsByTask,
channelInfo,
task.computeInstance,
inputIndex);
if (failed(received))
return task.computeInstance.op->emitOpError("failed to materialize reordered remote receive");
resolvedInputs.push_back(*received);
continue; continue;
} }
FailureOr<Value> received = receiveThroughInput(rewriter,
cpu, Value cId = createIndexConstant(program.op, channelInfo.channelId, constantFolder);
receiveQueueIndices, Value sId = createIndexConstant(program.op, channelInfo.sourceCoreId, constantFolder);
preReceivedInputsByTask, Value tId = createIndexConstant(program.op, channelInfo.targetCoreId, constantFolder);
channelInfo, auto receive = spatial::SpatChannelReceiveOp::create(rewriter, loc, input.getType(), cId, sId, tId);
task.computeInstance, resolvedInputs.push_back(receive.getResult());
inputIndex);
if (failed(received)) {
task.computeInstance.op->emitOpError("failed to materialize reordered remote receive")
<< " consumerCpu=" << cpu << " sourceCoreId=" << channelInfo.sourceCoreId
<< " targetCoreId=" << channelInfo.targetCoreId << " channelId=" << channelInfo.channelId;
return failure();
}
resolvedInputs.push_back(*received);
continue;
} }
Value channelId = createIndexConstant(program.op, channelInfo.channelId, constantFolder);
Value sourceCoreId = createIndexConstant(program.op, channelInfo.sourceCoreId, constantFolder);
Value targetCoreId = createIndexConstant(program.op, channelInfo.targetCoreId, constantFolder);
auto receive = spatial::SpatChannelReceiveOp::create(
rewriter, loc, input.getType(), channelId, sourceCoreId, targetCoreId);
resolvedInputs.push_back(receive.getResult());
continue; continue;
} }
} }
@@ -1019,12 +1056,17 @@ private:
} }
SmallVector<Value> taskYieldValues; SmallVector<Value> taskYieldValues;
rewriter.setInsertionPointToEnd(&program.op.getBody().front()); rewriter.setInsertionPointToEnd(&program.op->getRegion(0).front());
if (isa<SpatCompute>(task.computeInstance.op)) { if (isa<SpatCompute>(task.computeInstance.op)) {
IRMapping mapper; IRMapping mapper;
auto compute = cast<SpatCompute>(task.computeInstance.op); auto compute = cast<SpatCompute>(task.computeInstance.op);
for (auto [weightIndex, weight] : llvm::enumerate(taskWeights)) for (auto [weightIndex, weight] : llvm::enumerate(taskWeights)) {
mapper.map(compute.getWeightArgument(weightIndex), program.op.getWeightArgument(program.weightToIndex.at(weight))); Value destArg = isBatch
? cast<SpatComputeBatch>(program.op).getWeightArgument(program.weightToIndex.at(weight))
: cast<SpatCompute>(program.op).getWeightArgument(program.weightToIndex.at(weight));
mapper.map(compute.getWeightArgument(weightIndex), destArg);
}
for (auto [inputIndex, input] : llvm::enumerate(resolvedInputs)) for (auto [inputIndex, input] : llvm::enumerate(resolvedInputs))
mapper.map(compute.getInputArgument(inputIndex), input); mapper.map(compute.getInputArgument(inputIndex), input);
@@ -1034,138 +1076,104 @@ private:
taskYieldValues.push_back(mapper.lookup(yieldOperand)); taskYieldValues.push_back(mapper.lookup(yieldOperand));
continue; continue;
} }
rewriter.clone(op, mapper); rewriter.clone(op, mapper);
} }
} }
else { else {
auto batch = cast<SpatComputeBatch>(task.computeInstance.op); // Include your existing isolated logic for preserving resultless spat.compute_batch here if needed
if (batch.getNumResults() != 0) {
IRMapping mapper;
Value laneValue = getOrCreateHostIndexConstant(
program.op, static_cast<int64_t>(task.computeInstance.laneStart), constantFolder);
mapper.map(batch.getLaneArgument(), laneValue);
for (auto [weightIndex, weight] : llvm::enumerate(taskWeights))
mapper.map(batch.getWeightArgument(weightIndex), program.op.getWeightArgument(program.weightToIndex.at(weight)));
for (auto [inputIndex, input] : llvm::enumerate(resolvedInputs))
mapper.map(batch.getInputArgument(inputIndex), input);
for (Operation& op : templateBlock.without_terminator()) {
Operation* clonedOp = rewriter.clone(op, mapper);
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
}
FailureOr<SmallVector<BatchYieldInfo>> yieldInfo = collectResultfulBatchYieldInfo(batch);
if (failed(yieldInfo))
return task.computeInstance.op->emitOpError("failed to collect resultful batch yield info");
for (const BatchYieldInfo& info : *yieldInfo)
taskYieldValues.push_back(mapper.lookup(info.yieldedValue));
}
else {
size_t batchLaneCount = static_cast<size_t>(batch.getLaneCount());
size_t inputsPerLane = batchLaneCount == 0 ? 0 : batch.getInputs().size() / batchLaneCount;
size_t weightsPerLane = batchLaneCount == 0 ? 0 : batch.getWeights().size() / batchLaneCount;
size_t loopRunLength = getLoopableResultlessBatchRunLength(programTasks, taskIndex);
if (loopRunLength > 1) {
Value lower = getOrCreateHostIndexConstant(
program.op, static_cast<int64_t>(task.computeInstance.laneStart), constantFolder);
Value upper = getOrCreateHostIndexConstant(
program.op, static_cast<int64_t>(task.computeInstance.laneStart + loopRunLength), constantFolder);
Value step = getOrCreateHostIndexConstant(program.op, 1, constantFolder);
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {});
rewriter.setInsertionPointToStart(loop.getBody());
IRMapping mapper;
mapper.map(batch.getLaneArgument(), loop.getInductionVar());
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
mapper.map(batch.getWeightArgument(weightIndex),
program.op.getWeightArgument(program.weightToIndex.at(taskWeights[weightIndex])));
for (size_t inputIndex = 0; inputIndex < inputsPerLane; ++inputIndex)
mapper.map(batch.getInputArgument(inputIndex), resolvedInputs[inputIndex]);
for (Operation& op : templateBlock) {
if (isa<spatial::SpatYieldOp>(&op))
continue;
Operation* clonedOp = rewriter.clone(op, mapper);
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
}
rewriter.setInsertionPointAfter(loop);
taskIndex += loopRunLength - 1;
}
else {
for (size_t laneOffset = 0; laneOffset < task.computeInstance.laneCount; ++laneOffset) {
IRMapping mapper;
Value laneValue = getOrCreateHostIndexConstant(
program.op, static_cast<int64_t>(task.computeInstance.laneStart + laneOffset), constantFolder);
mapper.map(batch.getLaneArgument(), laneValue);
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
mapper.map(batch.getWeightArgument(weightIndex),
program.op.getWeightArgument(
program.weightToIndex.at(taskWeights[laneOffset * weightsPerLane + weightIndex])));
for (size_t inputIndex = 0; inputIndex < inputsPerLane; ++inputIndex)
mapper.map(batch.getInputArgument(inputIndex), resolvedInputs[laneOffset * inputsPerLane + inputIndex]);
for (Operation& op : templateBlock) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
for (Value yieldOperand : yield.getOperands())
taskYieldValues.push_back(mapper.lookup(yieldOperand));
continue;
}
Operation* clonedOp = rewriter.clone(op, mapper);
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
}
}
}
}
} }
producedValuesByTask[task.computeInstance] = taskYieldValues; producedValuesByTask[task.computeInstance] = taskYieldValues;
if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) { if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) {
for (size_t resultIndex = 0; resultIndex < sendsIt->second.size();) { for (size_t resultIndex = 0; resultIndex < sendsIt->second.size();) {
const SmallVector<RemoteSendInfo>& sendInfos = sendsIt->second[resultIndex]; const SmallVector<RemoteSendInfo>& sendInfos = sendsIt->second[resultIndex];
if (sendInfos.empty()) if (sendInfos.empty()) {
{
++resultIndex; ++resultIndex;
continue; continue;
} }
size_t nextResultIndex = resultIndex + 1; if (isBatch) {
if (tryEmitCompactSendLoops( size_t numSends = sendInfos.size();
program.op, rewriter, sendsIt->second, taskYieldValues, resultIndex, nextResultIndex)) { for (size_t s = 0; s < numSends; ++s) {
resultIndex = nextResultIndex; SmallVector<int64_t> channelIds;
continue; SmallVector<int32_t> sourceCoreIds;
} SmallVector<int32_t> targetCoreIds;
Value producedValue = taskYieldValues[resultIndex]; for (size_t cpu : batch) {
for (const RemoteSendInfo& sendInfo : sendInfos) { const ScheduledTask& laneTask = tasksByCpu[cpu][taskIndex];
Value channelId = createIndexConstant(program.op, sendInfo.channelInfo.channelId, constantFolder); const RemoteSendInfo& send = remoteSendsByTask[laneTask.computeInstance][resultIndex][s];
Value sourceCoreId = createIndexConstant(program.op, sendInfo.channelInfo.sourceCoreId, constantFolder); channelIds.push_back(send.channelInfo.channelId);
Value targetCoreId = createIndexConstant(program.op, sendInfo.channelInfo.targetCoreId, constantFolder); sourceCoreIds.push_back(send.channelInfo.sourceCoreId);
spatial::SpatChannelSendOp::create(rewriter, loc, channelId, sourceCoreId, targetCoreId, producedValue); targetCoreIds.push_back(send.channelInfo.targetCoreId);
}
SmallVector<Value> cIds = createIndexConstants(program.op, channelIds, constantFolder);
SmallVector<Value> sIds = createIndexConstants(program.op, sourceCoreIds, constantFolder);
SmallVector<Value> tIds = createIndexConstants(program.op, targetCoreIds, constantFolder);
spatial::SpatChannelSendBatchOp::create(rewriter, loc, cIds, sIds, tIds, taskYieldValues[resultIndex]);
}
++resultIndex;
}
else {
size_t nextResultIndex = resultIndex + 1;
if (tryEmitCompactSendLoops(
program.op, rewriter, sendsIt->second, taskYieldValues, resultIndex, nextResultIndex)) {
resultIndex = nextResultIndex;
continue;
}
Value producedValue = taskYieldValues[resultIndex];
for (const RemoteSendInfo& sendInfo : sendInfos) {
Value cId = createIndexConstant(program.op, sendInfo.channelInfo.channelId, constantFolder);
Value sId = createIndexConstant(program.op, sendInfo.channelInfo.sourceCoreId, constantFolder);
Value tId = createIndexConstant(program.op, sendInfo.channelInfo.targetCoreId, constantFolder);
spatial::SpatChannelSendOp::create(rewriter, loc, cId, sId, tId, producedValue);
}
++resultIndex;
} }
++resultIndex;
} }
} }
} }
SmallVector<Value> yieldValues; SmallVector<Value> yieldValues;
yieldValues.reserve(cpuExternalOutputs[programKey].size()); yieldValues.reserve(cpuExternalOutputs[leader].size());
for (ProducerValueRef outputRef : cpuExternalOutputs[programKey]) { for (ProducerValueRef outputRef : cpuExternalOutputs[leader]) {
auto producedIt = producedValuesByTask.find(outputRef.instance); auto producedIt = producedValuesByTask.find(outputRef.instance);
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) { if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex)
ScheduledTask task = taskByComputeInstance.at(outputRef.instance); return func.emitError("missing yielded external value during materialization");
task.computeInstance.op->emitOpError("missing yielded external value during per-cpu merge materialization")
<< " cpu=" << cpu << " laneStart=" << outputRef.instance.laneStart;
return failure();
}
yieldValues.push_back(producedIt->second[outputRef.resultIndex]); yieldValues.push_back(producedIt->second[outputRef.resultIndex]);
} }
spatial::SpatYieldOp::create(rewriter, loc, ValueRange(yieldValues));
if (isBatch) {
auto batchOp = cast<SpatComputeBatch>(program.op);
auto inParallel = spatial::SpatInParallelOp::create(rewriter, loc);
Block* parallelBlock = rewriter.createBlock(&inParallel.getRegion());
rewriter.setInsertionPointToEnd(parallelBlock);
for (auto [resultIndex, yieldedVal] : llvm::enumerate(yieldValues)) {
auto destArg = batchOp.getOutputArgument(resultIndex);
auto destType = cast<RankedTensorType>(destArg.getType());
SmallVector<OpFoldResult> offsets;
offsets.push_back(batchOp.getLaneArgument());
SmallVector<OpFoldResult> sizes;
sizes.push_back(rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> strides;
strides.push_back(rewriter.getIndexAttr(1));
for (int64_t dim : destType.getShape().drop_front()) {
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(dim));
strides.push_back(rewriter.getIndexAttr(1));
}
tensor::ParallelInsertSliceOp::create(rewriter, loc, yieldedVal, destArg, offsets, sizes, strides);
}
}
else {
spatial::SpatYieldOp::create(rewriter, loc, ValueRange(yieldValues));
}
} }
return success(); return success();
@@ -1245,7 +1253,6 @@ private:
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu; DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
DenseMap<ProgramKey, SmallVector<ScheduledTask>> tasksByProgram; DenseMap<ProgramKey, SmallVector<ScheduledTask>> tasksByProgram;
SmallVector<size_t> orderedCpus; SmallVector<size_t> orderedCpus;
SmallVector<ProgramKey> orderedPrograms;
DenseSet<size_t> seenCpus; DenseSet<size_t> seenCpus;
DenseSet<ProgramKey> seenPrograms; DenseSet<ProgramKey> seenPrograms;
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask; DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
@@ -19,6 +19,7 @@ struct MergeScheduleResult {
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap; llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu; llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap; llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
}; };
} // namespace spatial } // namespace spatial
@@ -1,6 +1,7 @@
#include "mlir/IR/Threading.h" #include "mlir/IR/Threading.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
@@ -274,7 +275,64 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder; return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder;
}); });
// 5. Populate Final Result // 5. Check if equal schedule in two level
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
for (size_t currentProcessor = 0; currentProcessor < processorCount - 1; ++currentProcessor) {
for (size_t controlProcessor = currentProcessor + 1; controlProcessor < processorCount; ++controlProcessor) {
if (tasksByProcessor[currentProcessor].size() != tasksByProcessor[controlProcessor].size())
continue;
auto& currentTasks = tasksByProcessor[currentProcessor];
auto& controlTasks = tasksByProcessor[controlProcessor];
bool equalSchedule = true;
for (auto [currentTask, controlTask] : llvm::zip(currentTasks, controlTasks)) {
const ComputeInstance currentComputeInstance = graph.nodes[currentTask].instance;
const ComputeInstance controlComputeInstance = graph.nodes[controlTask].instance;
if (currentComputeInstance.op != controlComputeInstance.op) {
equalSchedule = false;
break;
}
}
if (equalSchedule) {
equivalentClass[currentProcessor].push_back(controlProcessor);
equivalentClass[controlProcessor].push_back(currentProcessor);
}
}
}
{
llvm::dbgs() << "--- Scheduling Equivalence Classes ---\n";
std::vector<bool> visited(processorCount, false);
size_t uniqueClassCount = 0;
for (size_t i = 0; i < processorCount; ++i) {
if (visited[i])
continue;
// We found a new unique schedule (equivalence class)
++uniqueClassCount;
visited[i] = true;
llvm::dbgs() << "Class " << uniqueClassCount << ": CPUs { " << i;
// Find and mark all identical companions
auto it = equivalentClass.find(i);
if (it != equivalentClass.end()) {
for (size_t eqCpu : it->second) {
if (!visited[eqCpu]) {
llvm::dbgs() << ", " << eqCpu;
visited[eqCpu] = true;
}
}
}
llvm::dbgs() << " }\n";
}
llvm::dbgs() << "Total unique CPU nodes to emit: " << uniqueClassCount << "\n";
llvm::dbgs() << "--------------------------------------\n";
}
// 6. Populate Final Result
MergeScheduleResult result; MergeScheduleResult result;
result.dominanceOrderCompute.reserve(nodeCount); result.dominanceOrderCompute.reserve(nodeCount);
@@ -296,8 +354,9 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
} }
} }
result.equivalentClass = equivalentClass;
return result; return result;
} }
} // namespace spatial } // namespace spatial
} // namespace onnx_mlir } // namespace onnx_mlir