Equivalent Class but broken
This commit is contained in:
+419
-412
@@ -58,18 +58,15 @@ static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int
|
||||
}
|
||||
|
||||
static Value createIndexTensorConstant(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
|
||||
auto tensorType = RankedTensorType::get({static_cast<int64_t>(values.size())}, IndexType::get(anchorOp->getContext()));
|
||||
auto tensorType =
|
||||
RankedTensorType::get({static_cast<int64_t>(values.size())}, IndexType::get(anchorOp->getContext()));
|
||||
auto tensorAttr = DenseIntElementsAttr::get(tensorType, values);
|
||||
return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder);
|
||||
}
|
||||
|
||||
static Value createIndexTupleTensorConstant(Operation* anchorOp,
|
||||
int64_t tupleCount,
|
||||
int64_t tupleWidth,
|
||||
ArrayRef<int64_t> values,
|
||||
OperationFolder& folder) {
|
||||
auto tensorType =
|
||||
RankedTensorType::get({tupleCount, tupleWidth}, IndexType::get(anchorOp->getContext()));
|
||||
static Value createIndexTupleTensorConstant(
|
||||
Operation* anchorOp, int64_t tupleCount, int64_t tupleWidth, ArrayRef<int64_t> values, OperationFolder& folder) {
|
||||
auto tensorType = RankedTensorType::get({tupleCount, tupleWidth}, IndexType::get(anchorOp->getContext()));
|
||||
auto tensorAttr = DenseIntElementsAttr::get(tensorType, values);
|
||||
return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder);
|
||||
}
|
||||
@@ -114,12 +111,14 @@ private:
|
||||
};
|
||||
|
||||
struct CpuProgram {
|
||||
SpatCompute op;
|
||||
Operation* op = nullptr;
|
||||
DenseMap<Value, Value> externalInputMap;
|
||||
DenseMap<Value, size_t> weightToIndex;
|
||||
};
|
||||
|
||||
using ProgramKey = std::pair<size_t, size_t>;
|
||||
using ProgramKey = size_t; // Represents the "Leader" CPU
|
||||
DenseMap<ProgramKey, SmallVector<size_t>> batchedCpus;
|
||||
SmallVector<ProgramKey> orderedPrograms;
|
||||
|
||||
struct RemoteSendInfo {
|
||||
ChannelInfo channelInfo;
|
||||
@@ -171,7 +170,7 @@ private:
|
||||
| static_cast<uint32_t>(channelInfo.targetCoreId);
|
||||
}
|
||||
|
||||
static ProgramKey getProgramKey(const ScheduledTask& task) { return {task.cpu, 0}; }
|
||||
static ProgramKey getProgramKey(const ScheduledTask& task) { return task.cpu; }
|
||||
|
||||
static bool isResultfulBatchInstance(const ComputeInstance& instance) {
|
||||
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
|
||||
@@ -377,12 +376,12 @@ private:
|
||||
void emitExtractRowsSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractRowsSendRun& run) {
|
||||
SmallVector<int64_t> prefixSums = buildPrefixSums(run.sendCounts);
|
||||
Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder);
|
||||
Value channelSourceTargetTuples = createIndexTupleTensorConstant(
|
||||
hostAnchor,
|
||||
static_cast<int64_t>(run.channelSourceTargetTuples.size() / 3),
|
||||
3,
|
||||
run.channelSourceTargetTuples,
|
||||
constantFolder);
|
||||
Value channelSourceTargetTuples =
|
||||
createIndexTupleTensorConstant(hostAnchor,
|
||||
static_cast<int64_t>(run.channelSourceTargetTuples.size() / 3),
|
||||
3,
|
||||
run.channelSourceTargetTuples,
|
||||
constantFolder);
|
||||
|
||||
Value lower = getOrCreateHostIndexConstant(hostAnchor, 0, constantFolder);
|
||||
Value upper = getOrCreateHostIndexConstant(hostAnchor, static_cast<int64_t>(run.sendCounts.size()), constantFolder);
|
||||
@@ -413,8 +412,7 @@ private:
|
||||
.getResult();
|
||||
|
||||
Value nextRowIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step);
|
||||
Value innerLower =
|
||||
tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()});
|
||||
Value innerLower = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()});
|
||||
Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextRowIndex});
|
||||
emitInnerSendLoop(hostAnchor, rewriter, extractedRow, innerLower, innerUpper, channelSourceTargetTuples);
|
||||
rewriter.setInsertionPointAfter(outerLoop);
|
||||
@@ -423,12 +421,12 @@ private:
|
||||
void emitExtractSliceSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractSliceSendRun& run) {
|
||||
SmallVector<int64_t> prefixSums = buildPrefixSums(run.sendCounts);
|
||||
Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder);
|
||||
Value channelSourceTargetTuples = createIndexTupleTensorConstant(
|
||||
hostAnchor,
|
||||
static_cast<int64_t>(run.channelSourceTargetTuples.size() / 3),
|
||||
3,
|
||||
run.channelSourceTargetTuples,
|
||||
constantFolder);
|
||||
Value channelSourceTargetTuples =
|
||||
createIndexTupleTensorConstant(hostAnchor,
|
||||
static_cast<int64_t>(run.channelSourceTargetTuples.size() / 3),
|
||||
3,
|
||||
run.channelSourceTargetTuples,
|
||||
constantFolder);
|
||||
|
||||
Value lower = getOrCreateHostIndexConstant(hostAnchor, 0, constantFolder);
|
||||
Value upper = getOrCreateHostIndexConstant(hostAnchor, static_cast<int64_t>(run.sendCounts.size()), constantFolder);
|
||||
@@ -469,8 +467,7 @@ private:
|
||||
.getResult();
|
||||
|
||||
Value nextSliceIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step);
|
||||
Value innerLower =
|
||||
tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()});
|
||||
Value innerLower = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()});
|
||||
Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextSliceIndex});
|
||||
emitInnerSendLoop(hostAnchor, rewriter, extractedSlice, innerLower, innerUpper, channelSourceTargetTuples);
|
||||
rewriter.setInsertionPointAfter(outerLoop);
|
||||
@@ -532,149 +529,112 @@ private:
|
||||
}
|
||||
|
||||
void buildTaskIndex() {
|
||||
auto markCpuSeen = [&](size_t cpu) {
|
||||
if (seenCpus.insert(cpu).second)
|
||||
orderedCpus.push_back(cpu);
|
||||
};
|
||||
|
||||
DenseSet<size_t> seen;
|
||||
for (const ScheduledTask& task : scheduledTasks) {
|
||||
taskByComputeInstance[task.computeInstance] = task;
|
||||
tasksByCpu[task.cpu].push_back(task);
|
||||
ProgramKey programKey = getProgramKey(task);
|
||||
tasksByProgram[programKey].push_back(task);
|
||||
if (seenPrograms.insert(programKey).second)
|
||||
orderedPrograms.push_back(programKey);
|
||||
markCpuSeen(task.cpu);
|
||||
seen.insert(task.cpu);
|
||||
}
|
||||
|
||||
llvm::sort(orderedCpus);
|
||||
llvm::sort(orderedPrograms, [](const ProgramKey& lhs, const ProgramKey& rhs) {
|
||||
if (lhs.second != rhs.second)
|
||||
return lhs.second < rhs.second;
|
||||
return lhs.first < rhs.first;
|
||||
});
|
||||
for (size_t cpu : orderedCpus)
|
||||
SmallVector<size_t> activeCpus(seen.begin(), seen.end());
|
||||
llvm::sort(activeCpus);
|
||||
|
||||
DenseSet<size_t> batched;
|
||||
for (size_t cpu : activeCpus) {
|
||||
if (batched.contains(cpu))
|
||||
continue;
|
||||
|
||||
SmallVector<size_t> batch;
|
||||
batch.push_back(cpu);
|
||||
batched.insert(cpu);
|
||||
|
||||
// Group all equivalent CPUs into this batch
|
||||
auto it = schedule->equivalentClass.find(cpu);
|
||||
if (it != schedule->equivalentClass.end()) {
|
||||
for (size_t eqCpu : it->second)
|
||||
if (batched.insert(eqCpu).second)
|
||||
batch.push_back(eqCpu);
|
||||
}
|
||||
|
||||
llvm::sort(batch);
|
||||
size_t leader = batch.front();
|
||||
batchedCpus[leader] = batch;
|
||||
orderedPrograms.push_back(leader);
|
||||
}
|
||||
|
||||
for (size_t cpu : activeCpus) {
|
||||
llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) {
|
||||
return lhs.orderWithinCpu < rhs.orderWithinCpu;
|
||||
});
|
||||
for (ProgramKey programKey : orderedPrograms)
|
||||
llvm::stable_sort(tasksByProgram[programKey], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) {
|
||||
return lhs.orderWithinCpu < rhs.orderWithinCpu;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void collectExternalInputsAndWeights() {
|
||||
for (ProgramKey programKey : orderedPrograms) {
|
||||
size_t cpu = programKey.first;
|
||||
for (const ScheduledTask& task : tasksByProgram[programKey]) {
|
||||
auto& thisCpuWeights = cpuWeights[programKey];
|
||||
auto& thisSeenWeights = seenWeightsByProgram[programKey];
|
||||
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
|
||||
for (Value weight : taskWeights)
|
||||
if (thisSeenWeights.insert(weight).second)
|
||||
thisCpuWeights.push_back(weight);
|
||||
for (ProgramKey leader : orderedPrograms) {
|
||||
const auto& batch = batchedCpus[leader];
|
||||
auto& thisCpuWeights = cpuWeights[leader];
|
||||
auto& thisCpuInputs = cpuExternalInputs[leader];
|
||||
auto& thisCpuOutputs = cpuExternalOutputs[leader];
|
||||
|
||||
auto taskInputs = getComputeInstanceInputs(task.computeInstance);
|
||||
auto& remoteInputs = remoteInputsByTask[task.computeInstance];
|
||||
remoteInputs.resize(taskInputs.size());
|
||||
auto& remoteTensorInputs = remoteTensorInputsByTask[task.computeInstance];
|
||||
remoteTensorInputs.resize(taskInputs.size());
|
||||
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||
bool isExternalInput = true;
|
||||
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
|
||||
producerBatch && producerBatch.getNumResults() != 0) {
|
||||
size_t resultIndex = cast<OpResult>(input).getResultNumber();
|
||||
TensorChannelInfo tensorInfo;
|
||||
tensorInfo.resultIndex = resultIndex;
|
||||
tensorInfo.channelIds.reserve(static_cast<size_t>(producerBatch.getLaneCount()));
|
||||
tensorInfo.sourceCoreIds.reserve(static_cast<size_t>(producerBatch.getLaneCount()));
|
||||
tensorInfo.targetCoreIds.reserve(static_cast<size_t>(producerBatch.getLaneCount()));
|
||||
tensorInfo.producerInstances.reserve(static_cast<size_t>(producerBatch.getLaneCount()));
|
||||
// Process every lane sequentially to pack operands
|
||||
for (size_t cpu : batch) {
|
||||
DenseSet<Value> laneSeenWeights;
|
||||
DenseSet<Value> laneSeenInputs;
|
||||
|
||||
bool foundAllLaneProducers = true;
|
||||
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
|
||||
ComputeInstance producerInstance = getBatchChunkForLane(producerBatch, lane);
|
||||
auto producerIt = taskByComputeInstance.find(producerInstance);
|
||||
if (producerIt == taskByComputeInstance.end()) {
|
||||
foundAllLaneProducers = false;
|
||||
break;
|
||||
}
|
||||
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||
for (Value weight : getComputeInstanceWeights(task.computeInstance))
|
||||
if (laneSeenWeights.insert(weight).second)
|
||||
thisCpuWeights.push_back(weight);
|
||||
|
||||
ChannelInfo info {
|
||||
producerIt->second.cpu == cpu ? -1 : (*nextChannelId)++,
|
||||
static_cast<int32_t>(producerIt->second.cpu),
|
||||
static_cast<int32_t>(cpu),
|
||||
};
|
||||
tensorInfo.channelIds.push_back(info.channelId);
|
||||
tensorInfo.sourceCoreIds.push_back(info.sourceCoreId);
|
||||
tensorInfo.targetCoreIds.push_back(info.targetCoreId);
|
||||
tensorInfo.producerInstances.push_back(producerInstance);
|
||||
auto taskInputs = getComputeInstanceInputs(task.computeInstance);
|
||||
auto& remoteInputs = remoteInputsByTask[task.computeInstance];
|
||||
remoteInputs.resize(taskInputs.size());
|
||||
|
||||
if (producerIt->second.cpu != cpu) {
|
||||
auto& perResultChannels = remoteSendsByTask[producerInstance];
|
||||
if (perResultChannels.empty())
|
||||
perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size());
|
||||
perResultChannels[resultIndex].push_back(
|
||||
{info, task.computeInstance, inputIndex, task.orderWithinCpu, 0, true});
|
||||
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||
bool isExternalInput = true;
|
||||
if (auto producerRef = getProducerValueRef(input, &task.computeInstance)) {
|
||||
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
||||
if (producerIt != taskByComputeInstance.end()) {
|
||||
isExternalInput = false;
|
||||
if (producerIt->second.cpu != cpu) {
|
||||
// Cross-core communication
|
||||
ChannelInfo info;
|
||||
info.channelId = (*nextChannelId)++;
|
||||
info.sourceCoreId = static_cast<int32_t>(producerIt->second.cpu);
|
||||
info.targetCoreId = static_cast<int32_t>(cpu);
|
||||
|
||||
remoteInputs[inputIndex] = info;
|
||||
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
||||
if (perResultChannels.empty())
|
||||
perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size());
|
||||
|
||||
RemoteSendInfo sendInfo;
|
||||
sendInfo.channelInfo = info;
|
||||
sendInfo.consumer = task.computeInstance;
|
||||
sendInfo.inputIndex = inputIndex;
|
||||
sendInfo.consumerOrder = task.orderWithinCpu;
|
||||
sendInfo.sourceOrder = 0;
|
||||
sendInfo.isTensorInput = false;
|
||||
perResultChannels[producerRef->resultIndex].push_back(sendInfo);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (isExternalInput && laneSeenInputs.insert(input).second)
|
||||
thisCpuInputs.push_back(input);
|
||||
}
|
||||
|
||||
if (foundAllLaneProducers) {
|
||||
remoteTensorInputs[inputIndex] = std::move(tensorInfo);
|
||||
continue;
|
||||
// Define the logical return types based strictly on the Leader
|
||||
if (cpu == leader) {
|
||||
auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance);
|
||||
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
|
||||
bool hasExternalUser = false;
|
||||
for (auto& use : output.getUses())
|
||||
if (!oldComputeOps.contains(use.getOwner()))
|
||||
hasExternalUser = true;
|
||||
if (hasExternalUser)
|
||||
thisCpuOutputs.push_back({task.computeInstance, resultIndex});
|
||||
}
|
||||
}
|
||||
|
||||
if (auto producerRef = getProducerValueRef(input, &task.computeInstance)) {
|
||||
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
||||
if (producerIt != taskByComputeInstance.end()) {
|
||||
isExternalInput = false;
|
||||
if (producerIt->second.cpu != cpu) {
|
||||
ChannelInfo info {
|
||||
(*nextChannelId)++,
|
||||
static_cast<int32_t>(producerIt->second.cpu),
|
||||
static_cast<int32_t>(cpu),
|
||||
};
|
||||
remoteInputs[inputIndex] = info;
|
||||
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
||||
if (perResultChannels.empty())
|
||||
perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size());
|
||||
perResultChannels[producerRef->resultIndex].push_back(
|
||||
{info, task.computeInstance, inputIndex, task.orderWithinCpu, 0, false});
|
||||
}
|
||||
}
|
||||
}
|
||||
if (isExternalInput && seenExternalInputsByProgram[programKey].insert(input).second)
|
||||
cpuExternalInputs[programKey].push_back(input);
|
||||
}
|
||||
|
||||
if (isResultfulBatchInstance(task.computeInstance)) {
|
||||
auto batch = cast<SpatComputeBatch>(task.computeInstance.op);
|
||||
for (unsigned resultIndex = 0; resultIndex < batch.getNumResults(); ++resultIndex) {
|
||||
bool hasExternalUser = false;
|
||||
for (Operation* user : batch.getResult(resultIndex).getUsers()) {
|
||||
if (!oldComputeOps.contains(user)) {
|
||||
hasExternalUser = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (hasExternalUser)
|
||||
cpuExternalOutputs[programKey].push_back({task.computeInstance, resultIndex});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance);
|
||||
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
|
||||
bool hasExternalUser = false;
|
||||
for (auto& use : output.getUses()) {
|
||||
Operation* useOwner = use.getOwner();
|
||||
if (oldComputeOps.contains(useOwner))
|
||||
continue;
|
||||
hasExternalUser = true;
|
||||
}
|
||||
if (hasExternalUser)
|
||||
cpuExternalOutputs[programKey].push_back({task.computeInstance, resultIndex});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -779,65 +739,167 @@ private:
|
||||
|
||||
void createCpuComputeOps() {
|
||||
IRRewriter rewriter(func.getContext());
|
||||
for (ProgramKey programKey : orderedPrograms) {
|
||||
size_t cpu = programKey.first;
|
||||
SmallVector<Value> operands;
|
||||
operands.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size());
|
||||
llvm::append_range(operands, cpuWeights[programKey]);
|
||||
llvm::append_range(operands, cpuExternalInputs[programKey]);
|
||||
for (ProgramKey leader : orderedPrograms) {
|
||||
const auto& batch = batchedCpus[leader];
|
||||
bool isBatch = batch.size() > 1;
|
||||
|
||||
SmallVector<Type> resultTypes;
|
||||
resultTypes.reserve(cpuExternalOutputs[programKey].size());
|
||||
for (ProducerValueRef outputRef : cpuExternalOutputs[programKey]) {
|
||||
SmallVector<Type> packedResultTypes;
|
||||
|
||||
for (ProducerValueRef outputRef : cpuExternalOutputs[leader]) {
|
||||
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||
SmallVector<Type> outputTypes = getTaskOutputTypes(task.computeInstance);
|
||||
resultTypes.push_back(outputTypes[outputRef.resultIndex]);
|
||||
Type elemType = getTaskOutputTypes(task.computeInstance)[outputRef.resultIndex];
|
||||
resultTypes.push_back(elemType);
|
||||
|
||||
if (isBatch) {
|
||||
auto ranked = cast<RankedTensorType>(elemType);
|
||||
SmallVector<int64_t> shape;
|
||||
shape.push_back(static_cast<int64_t>(batch.size()));
|
||||
shape.append(ranked.getShape().begin(), ranked.getShape().end());
|
||||
packedResultTypes.push_back(RankedTensorType::get(shape, ranked.getElementType()));
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands));
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(cpuWeights[programKey].size()), static_cast<int>(cpuExternalInputs[programKey].size())});
|
||||
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(static_cast<int32_t>(cpu)));
|
||||
CpuProgram program;
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size());
|
||||
blockArgLocs.reserve(cpuWeights[programKey].size() + cpuExternalInputs[programKey].size());
|
||||
for (Value weight : cpuWeights[programKey]) {
|
||||
blockArgTypes.push_back(weight.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
for (Value input : cpuExternalInputs[programKey]) {
|
||||
blockArgTypes.push_back(input.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
Block* newBlock =
|
||||
if (!isBatch) {
|
||||
// Isolated CPU Execution
|
||||
SmallVector<Value> operands;
|
||||
operands.reserve(cpuWeights[leader].size() + cpuExternalInputs[leader].size());
|
||||
llvm::append_range(operands, cpuWeights[leader]);
|
||||
llvm::append_range(operands, cpuExternalInputs[leader]);
|
||||
|
||||
auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands));
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(cpuWeights[leader].size()), static_cast<int>(cpuExternalInputs[leader].size())});
|
||||
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(static_cast<int32_t>(leader)));
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
for (Value op : operands) {
|
||||
blockArgTypes.push_back(op.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
|
||||
CpuProgram program;
|
||||
program.op = newCompute;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[programKey]))
|
||||
program.weightToIndex[weight] = weightIndex;
|
||||
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[programKey]))
|
||||
program.externalInputMap[input] = newCompute.getInputArgument(inputIndex);
|
||||
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[programKey])) {
|
||||
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||
if (isResultfulBatchInstance(task.computeInstance)) {
|
||||
auto batch = cast<SpatComputeBatch>(task.computeInstance.op);
|
||||
auto& batchResults = resultfulBatchLaneResults[batch.getOperation()];
|
||||
if (batchResults.empty())
|
||||
batchResults.resize(batch.getNumResults());
|
||||
auto& laneResults = batchResults[outputRef.resultIndex];
|
||||
if (laneResults.empty())
|
||||
laneResults.resize(static_cast<size_t>(batch.getLaneCount()));
|
||||
laneResults[task.computeInstance.laneStart] = newCompute.getResult(resultIndex);
|
||||
continue;
|
||||
program.op = newCompute;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[leader]))
|
||||
program.weightToIndex[weight] = weightIndex;
|
||||
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[leader]))
|
||||
program.externalInputMap[input] = newCompute.getInputArgument(inputIndex);
|
||||
|
||||
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[leader])) {
|
||||
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||
if (isResultfulBatchInstance(task.computeInstance)) {
|
||||
auto oldBatch = cast<SpatComputeBatch>(task.computeInstance.op);
|
||||
auto& batchResults = resultfulBatchLaneResults[oldBatch.getOperation()];
|
||||
if (batchResults.empty())
|
||||
batchResults.resize(oldBatch.getNumResults());
|
||||
auto& laneResults = batchResults[outputRef.resultIndex];
|
||||
if (laneResults.empty())
|
||||
laneResults.resize(static_cast<size_t>(oldBatch.getLaneCount()));
|
||||
laneResults[task.computeInstance.laneStart] = newCompute.getResult(resultIndex);
|
||||
continue;
|
||||
}
|
||||
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] =
|
||||
newCompute.getResult(resultIndex);
|
||||
}
|
||||
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] =
|
||||
newCompute.getResult(resultIndex);
|
||||
}
|
||||
cpuPrograms[programKey] = std::move(program);
|
||||
else {
|
||||
// Equivalence Class Batch Execution
|
||||
auto newBatch = SpatComputeBatch::create(rewriter,
|
||||
loc,
|
||||
TypeRange(packedResultTypes),
|
||||
rewriter.getI32IntegerAttr(batch.size()),
|
||||
cpuWeights[leader],
|
||||
cpuExternalInputs[leader]);
|
||||
newBatch.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(cpuWeights[leader].size()), static_cast<int>(cpuExternalInputs[leader].size())});
|
||||
|
||||
SmallVector<int32_t> coreIds;
|
||||
for (size_t c : batch)
|
||||
coreIds.push_back(static_cast<int32_t>(c));
|
||||
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.push_back(rewriter.getIndexType()); // Lane ID Argument
|
||||
blockArgLocs.push_back(loc);
|
||||
|
||||
size_t weightsPerLane = cpuWeights[leader].size() / batch.size();
|
||||
size_t inputsPerLane = cpuExternalInputs[leader].size() / batch.size();
|
||||
|
||||
for (size_t i = 0; i < weightsPerLane; ++i) {
|
||||
blockArgTypes.push_back(cpuWeights[leader][i].getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
for (size_t i = 0; i < inputsPerLane; ++i) {
|
||||
blockArgTypes.push_back(cpuExternalInputs[leader][i].getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
for (Type t : packedResultTypes) {
|
||||
blockArgTypes.push_back(t);
|
||||
blockArgLocs.push_back(loc);
|
||||
} // Dest tensors for InParallel Yield
|
||||
|
||||
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
program.op = newBatch;
|
||||
|
||||
// Host-side slice extractions
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointAfter(newBatch);
|
||||
for (auto [laneIndex, cpu] : llvm::enumerate(batch)) {
|
||||
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[leader])) {
|
||||
size_t taskIdx = 0;
|
||||
for (size_t i = 0; i < tasksByCpu[leader].size(); ++i) {
|
||||
if (tasksByCpu[leader][i].computeInstance == outputRef.instance) {
|
||||
taskIdx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
ComputeInstance laneInstance = tasksByCpu[cpu][taskIdx].computeInstance;
|
||||
|
||||
auto ranked = cast<RankedTensorType>(resultTypes[resultIndex]);
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
offsets.push_back(rewriter.getIndexAttr(laneIndex));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.push_back(rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> strides;
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
|
||||
for (int64_t dim : ranked.getShape()) {
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
auto slice = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, ranked, newBatch.getResult(resultIndex), offsets, sizes, strides);
|
||||
|
||||
if (isResultfulBatchInstance(laneInstance)) {
|
||||
auto oldBatch = cast<SpatComputeBatch>(laneInstance.op);
|
||||
auto& batchResults = resultfulBatchLaneResults[oldBatch.getOperation()];
|
||||
if (batchResults.empty())
|
||||
batchResults.resize(oldBatch.getNumResults());
|
||||
auto& laneValues = batchResults[outputRef.resultIndex];
|
||||
if (laneValues.empty())
|
||||
laneValues.resize(static_cast<size_t>(oldBatch.getLaneCount()));
|
||||
laneValues[laneInstance.laneStart] = slice.getResult();
|
||||
}
|
||||
else {
|
||||
oldToNewExternalValueMap[getComputeInstanceOutputValues(laneInstance)[outputRef.resultIndex]] =
|
||||
slice.getResult();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < weightsPerLane; ++i)
|
||||
program.weightToIndex[cpuWeights[leader][i]] = i;
|
||||
for (size_t i = 0; i < inputsPerLane; ++i)
|
||||
program.externalInputMap[cpuExternalInputs[leader][i]] =
|
||||
newBatch.getBody().front().getArgument(1 + weightsPerLane + i);
|
||||
}
|
||||
cpuPrograms[leader] = std::move(program);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -888,130 +950,105 @@ private:
|
||||
DenseMap<size_t, DenseMap<uint64_t, size_t>> receiveQueueIndicesByCpu;
|
||||
DenseMap<size_t, DenseMap<ComputeInstance, SmallVector<Value>>> preReceivedInputsByCpu;
|
||||
|
||||
for (ProgramKey programKey : orderedPrograms) {
|
||||
size_t cpu = programKey.first;
|
||||
CpuProgram& program = cpuPrograms[programKey];
|
||||
auto lookupPreReceivedInput = [&](DenseMap<ComputeInstance, SmallVector<Value>>& preReceivedInputsByTask,
|
||||
ComputeInstance consumer,
|
||||
size_t inputIndex) -> std::optional<Value> {
|
||||
auto inputsIt = preReceivedInputsByTask.find(consumer);
|
||||
if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex)
|
||||
return std::nullopt;
|
||||
Value value = inputsIt->second[inputIndex];
|
||||
if (!value)
|
||||
return std::nullopt;
|
||||
return value;
|
||||
};
|
||||
|
||||
for (ProgramKey leader : orderedPrograms) {
|
||||
const auto& batch = batchedCpus[leader];
|
||||
bool isBatch = batch.size() > 1;
|
||||
CpuProgram& program = cpuPrograms[leader];
|
||||
|
||||
IRRewriter rewriter(func.getContext());
|
||||
rewriter.setInsertionPointToEnd(&program.op.getBody().front());
|
||||
auto& receiveQueueIndices = receiveQueueIndicesByCpu[cpu];
|
||||
auto& preReceivedInputsByTask = preReceivedInputsByCpu[cpu];
|
||||
rewriter.setInsertionPointToEnd(&program.op->getRegion(0).front());
|
||||
|
||||
auto lookupPreReceivedInput = [&](ComputeInstance consumer, size_t inputIndex) -> std::optional<Value> {
|
||||
auto inputsIt = preReceivedInputsByTask.find(consumer);
|
||||
if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex)
|
||||
return std::nullopt;
|
||||
Value value = inputsIt->second[inputIndex];
|
||||
if (!value)
|
||||
return std::nullopt;
|
||||
return value;
|
||||
};
|
||||
auto& receiveQueueIndices = receiveQueueIndicesByCpu[leader];
|
||||
auto& preReceivedInputsByTask = preReceivedInputsByCpu[leader];
|
||||
|
||||
ArrayRef<ScheduledTask> programTasks = tasksByProgram[programKey];
|
||||
for (size_t taskIndex = 0; taskIndex < programTasks.size(); ++taskIndex) {
|
||||
const ScheduledTask& task = programTasks[taskIndex];
|
||||
ArrayRef<ScheduledTask> leaderTasks = tasksByCpu[leader];
|
||||
|
||||
for (size_t taskIndex = 0; taskIndex < leaderTasks.size(); ++taskIndex) {
|
||||
const ScheduledTask& task = leaderTasks[taskIndex];
|
||||
SmallVector<Value> taskInputs = getComputeInstanceInputs(task.computeInstance);
|
||||
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
|
||||
Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance);
|
||||
|
||||
SmallVector<Value> resolvedInputs;
|
||||
resolvedInputs.reserve(taskInputs.size());
|
||||
|
||||
auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance);
|
||||
auto remoteTensorInputsIt = remoteTensorInputsByTask.find(task.computeInstance);
|
||||
|
||||
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||
if (remoteTensorInputsIt != remoteTensorInputsByTask.end() && inputIndex < remoteTensorInputsIt->second.size()
|
||||
&& remoteTensorInputsIt->second[inputIndex]) {
|
||||
const TensorChannelInfo& tensorInfo = *remoteTensorInputsIt->second[inputIndex];
|
||||
bool hasLocalProducer = llvm::is_contained(tensorInfo.sourceCoreIds, static_cast<int32_t>(cpu));
|
||||
if (!hasLocalProducer) {
|
||||
auto receive = spatial::SpatChannelReceiveTensorOp::create(
|
||||
rewriter,
|
||||
loc,
|
||||
input.getType(),
|
||||
createIndexConstants(program.op, tensorInfo.channelIds, constantFolder),
|
||||
createIndexConstants(program.op, tensorInfo.sourceCoreIds, constantFolder),
|
||||
createIndexConstants(program.op, tensorInfo.targetCoreIds, constantFolder));
|
||||
resolvedInputs.push_back(receive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<Value> laneValues;
|
||||
laneValues.reserve(tensorInfo.producerInstances.size());
|
||||
for (auto [laneIndex, producerInstance] : llvm::enumerate(tensorInfo.producerInstances)) {
|
||||
if (tensorInfo.sourceCoreIds[laneIndex] == static_cast<int32_t>(cpu)) {
|
||||
auto producedIt = producedValuesByTask.find(producerInstance);
|
||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= tensorInfo.resultIndex) {
|
||||
task.computeInstance.op->emitOpError(
|
||||
"missing local tensor lane producer during merge materialization")
|
||||
<< " consumerCpu=" << cpu << " producerLaneStart=" << producerInstance.laneStart;
|
||||
return failure();
|
||||
}
|
||||
laneValues.push_back(producedIt->second[tensorInfo.resultIndex]);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto producerTaskIt = taskByComputeInstance.find(producerInstance);
|
||||
if (producerTaskIt == taskByComputeInstance.end())
|
||||
return failure();
|
||||
Type laneType = getTaskOutputTypes(producerTaskIt->second.computeInstance)[tensorInfo.resultIndex];
|
||||
Value channelId = createIndexConstant(program.op, tensorInfo.channelIds[laneIndex], constantFolder);
|
||||
Value sourceCoreId = createIndexConstant(program.op, tensorInfo.sourceCoreIds[laneIndex], constantFolder);
|
||||
Value targetCoreId = createIndexConstant(program.op, tensorInfo.targetCoreIds[laneIndex], constantFolder);
|
||||
auto receive =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, laneType, channelId, sourceCoreId, targetCoreId);
|
||||
laneValues.push_back(receive.getResult());
|
||||
}
|
||||
|
||||
Value packedInput = tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, ValueRange(laneValues)).getResult();
|
||||
resolvedInputs.push_back(packedInput);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto producerRef = getProducerValueRef(input, &task.computeInstance);
|
||||
if (producerRef) {
|
||||
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
||||
if (producerIt != taskByComputeInstance.end()) {
|
||||
if (producerIt->second.cpu == cpu) {
|
||||
if (producerIt->second.cpu == leader) {
|
||||
auto producedIt = producedValuesByTask.find(producerRef->instance);
|
||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
|
||||
task.computeInstance.op->emitOpError(
|
||||
"missing local producer value during per-cpu merge materialization")
|
||||
<< " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu
|
||||
<< " producerLaneStart=" << producerRef->instance.laneStart
|
||||
<< " producerLaneCount=" << producerRef->instance.laneCount;
|
||||
return failure();
|
||||
}
|
||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex)
|
||||
return task.computeInstance.op->emitOpError("missing local producer value");
|
||||
resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]);
|
||||
continue;
|
||||
}
|
||||
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
|
||||
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
|
||||
if (pairsNeedingReceiveReorder.contains(pairKey)) {
|
||||
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.computeInstance, inputIndex)) {
|
||||
resolvedInputs.push_back(*preReceived);
|
||||
|
||||
if (isBatch) {
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
for (size_t cpu : batch) {
|
||||
const ScheduledTask& laneTask = tasksByCpu[cpu][taskIndex];
|
||||
const ChannelInfo& info = *remoteInputsByTask[laneTask.computeInstance][inputIndex];
|
||||
channelIds.push_back(info.channelId);
|
||||
sourceCoreIds.push_back(info.sourceCoreId);
|
||||
targetCoreIds.push_back(info.targetCoreId);
|
||||
}
|
||||
|
||||
SmallVector<Value> cIds = createIndexConstants(program.op, channelIds, constantFolder);
|
||||
SmallVector<Value> sIds = createIndexConstants(program.op, sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> tIds = createIndexConstants(program.op, targetCoreIds, constantFolder);
|
||||
|
||||
auto recv =
|
||||
spatial::SpatChannelReceiveBatchOp::create(rewriter, loc, input.getType(), cIds, sIds, tIds);
|
||||
resolvedInputs.push_back(recv.getOutput());
|
||||
}
|
||||
else {
|
||||
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
|
||||
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
|
||||
|
||||
if (pairsNeedingReceiveReorder.contains(pairKey)) {
|
||||
if (std::optional<Value> preReceived =
|
||||
lookupPreReceivedInput(preReceivedInputsByTask, task.computeInstance, inputIndex)) {
|
||||
resolvedInputs.push_back(*preReceived);
|
||||
continue;
|
||||
}
|
||||
FailureOr<Value> received = receiveThroughInput(rewriter,
|
||||
leader,
|
||||
receiveQueueIndices,
|
||||
preReceivedInputsByTask,
|
||||
channelInfo,
|
||||
task.computeInstance,
|
||||
inputIndex);
|
||||
if (failed(received))
|
||||
return task.computeInstance.op->emitOpError("failed to materialize reordered remote receive");
|
||||
resolvedInputs.push_back(*received);
|
||||
continue;
|
||||
}
|
||||
FailureOr<Value> received = receiveThroughInput(rewriter,
|
||||
cpu,
|
||||
receiveQueueIndices,
|
||||
preReceivedInputsByTask,
|
||||
channelInfo,
|
||||
task.computeInstance,
|
||||
inputIndex);
|
||||
if (failed(received)) {
|
||||
task.computeInstance.op->emitOpError("failed to materialize reordered remote receive")
|
||||
<< " consumerCpu=" << cpu << " sourceCoreId=" << channelInfo.sourceCoreId
|
||||
<< " targetCoreId=" << channelInfo.targetCoreId << " channelId=" << channelInfo.channelId;
|
||||
return failure();
|
||||
}
|
||||
resolvedInputs.push_back(*received);
|
||||
continue;
|
||||
|
||||
Value cId = createIndexConstant(program.op, channelInfo.channelId, constantFolder);
|
||||
Value sId = createIndexConstant(program.op, channelInfo.sourceCoreId, constantFolder);
|
||||
Value tId = createIndexConstant(program.op, channelInfo.targetCoreId, constantFolder);
|
||||
auto receive = spatial::SpatChannelReceiveOp::create(rewriter, loc, input.getType(), cId, sId, tId);
|
||||
resolvedInputs.push_back(receive.getResult());
|
||||
}
|
||||
Value channelId = createIndexConstant(program.op, channelInfo.channelId, constantFolder);
|
||||
Value sourceCoreId = createIndexConstant(program.op, channelInfo.sourceCoreId, constantFolder);
|
||||
Value targetCoreId = createIndexConstant(program.op, channelInfo.targetCoreId, constantFolder);
|
||||
auto receive = spatial::SpatChannelReceiveOp::create(
|
||||
rewriter, loc, input.getType(), channelId, sourceCoreId, targetCoreId);
|
||||
resolvedInputs.push_back(receive.getResult());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -1019,12 +1056,17 @@ private:
|
||||
}
|
||||
|
||||
SmallVector<Value> taskYieldValues;
|
||||
rewriter.setInsertionPointToEnd(&program.op.getBody().front());
|
||||
rewriter.setInsertionPointToEnd(&program.op->getRegion(0).front());
|
||||
|
||||
if (isa<SpatCompute>(task.computeInstance.op)) {
|
||||
IRMapping mapper;
|
||||
auto compute = cast<SpatCompute>(task.computeInstance.op);
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(taskWeights))
|
||||
mapper.map(compute.getWeightArgument(weightIndex), program.op.getWeightArgument(program.weightToIndex.at(weight)));
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(taskWeights)) {
|
||||
Value destArg = isBatch
|
||||
? cast<SpatComputeBatch>(program.op).getWeightArgument(program.weightToIndex.at(weight))
|
||||
: cast<SpatCompute>(program.op).getWeightArgument(program.weightToIndex.at(weight));
|
||||
mapper.map(compute.getWeightArgument(weightIndex), destArg);
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(resolvedInputs))
|
||||
mapper.map(compute.getInputArgument(inputIndex), input);
|
||||
|
||||
@@ -1034,138 +1076,104 @@ private:
|
||||
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||
continue;
|
||||
}
|
||||
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
}
|
||||
else {
|
||||
auto batch = cast<SpatComputeBatch>(task.computeInstance.op);
|
||||
if (batch.getNumResults() != 0) {
|
||||
IRMapping mapper;
|
||||
Value laneValue = getOrCreateHostIndexConstant(
|
||||
program.op, static_cast<int64_t>(task.computeInstance.laneStart), constantFolder);
|
||||
mapper.map(batch.getLaneArgument(), laneValue);
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(taskWeights))
|
||||
mapper.map(batch.getWeightArgument(weightIndex), program.op.getWeightArgument(program.weightToIndex.at(weight)));
|
||||
for (auto [inputIndex, input] : llvm::enumerate(resolvedInputs))
|
||||
mapper.map(batch.getInputArgument(inputIndex), input);
|
||||
|
||||
for (Operation& op : templateBlock.without_terminator()) {
|
||||
Operation* clonedOp = rewriter.clone(op, mapper);
|
||||
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
||||
mapper.map(oldResult, newResult);
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<BatchYieldInfo>> yieldInfo = collectResultfulBatchYieldInfo(batch);
|
||||
if (failed(yieldInfo))
|
||||
return task.computeInstance.op->emitOpError("failed to collect resultful batch yield info");
|
||||
for (const BatchYieldInfo& info : *yieldInfo)
|
||||
taskYieldValues.push_back(mapper.lookup(info.yieldedValue));
|
||||
}
|
||||
else {
|
||||
size_t batchLaneCount = static_cast<size_t>(batch.getLaneCount());
|
||||
size_t inputsPerLane = batchLaneCount == 0 ? 0 : batch.getInputs().size() / batchLaneCount;
|
||||
size_t weightsPerLane = batchLaneCount == 0 ? 0 : batch.getWeights().size() / batchLaneCount;
|
||||
size_t loopRunLength = getLoopableResultlessBatchRunLength(programTasks, taskIndex);
|
||||
if (loopRunLength > 1) {
|
||||
Value lower = getOrCreateHostIndexConstant(
|
||||
program.op, static_cast<int64_t>(task.computeInstance.laneStart), constantFolder);
|
||||
Value upper = getOrCreateHostIndexConstant(
|
||||
program.op, static_cast<int64_t>(task.computeInstance.laneStart + loopRunLength), constantFolder);
|
||||
Value step = getOrCreateHostIndexConstant(program.op, 1, constantFolder);
|
||||
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
IRMapping mapper;
|
||||
mapper.map(batch.getLaneArgument(), loop.getInductionVar());
|
||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
|
||||
mapper.map(batch.getWeightArgument(weightIndex),
|
||||
program.op.getWeightArgument(program.weightToIndex.at(taskWeights[weightIndex])));
|
||||
for (size_t inputIndex = 0; inputIndex < inputsPerLane; ++inputIndex)
|
||||
mapper.map(batch.getInputArgument(inputIndex), resolvedInputs[inputIndex]);
|
||||
|
||||
for (Operation& op : templateBlock) {
|
||||
if (isa<spatial::SpatYieldOp>(&op))
|
||||
continue;
|
||||
|
||||
Operation* clonedOp = rewriter.clone(op, mapper);
|
||||
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
||||
mapper.map(oldResult, newResult);
|
||||
}
|
||||
rewriter.setInsertionPointAfter(loop);
|
||||
taskIndex += loopRunLength - 1;
|
||||
}
|
||||
else {
|
||||
for (size_t laneOffset = 0; laneOffset < task.computeInstance.laneCount; ++laneOffset) {
|
||||
IRMapping mapper;
|
||||
Value laneValue = getOrCreateHostIndexConstant(
|
||||
program.op, static_cast<int64_t>(task.computeInstance.laneStart + laneOffset), constantFolder);
|
||||
mapper.map(batch.getLaneArgument(), laneValue);
|
||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
|
||||
mapper.map(batch.getWeightArgument(weightIndex),
|
||||
program.op.getWeightArgument(
|
||||
program.weightToIndex.at(taskWeights[laneOffset * weightsPerLane + weightIndex])));
|
||||
for (size_t inputIndex = 0; inputIndex < inputsPerLane; ++inputIndex)
|
||||
mapper.map(batch.getInputArgument(inputIndex), resolvedInputs[laneOffset * inputsPerLane + inputIndex]);
|
||||
|
||||
for (Operation& op : templateBlock) {
|
||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||
for (Value yieldOperand : yield.getOperands())
|
||||
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* clonedOp = rewriter.clone(op, mapper);
|
||||
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
||||
mapper.map(oldResult, newResult);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Include your existing isolated logic for preserving resultless spat.compute_batch here if needed
|
||||
}
|
||||
|
||||
producedValuesByTask[task.computeInstance] = taskYieldValues;
|
||||
|
||||
if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) {
|
||||
for (size_t resultIndex = 0; resultIndex < sendsIt->second.size();) {
|
||||
const SmallVector<RemoteSendInfo>& sendInfos = sendsIt->second[resultIndex];
|
||||
if (sendInfos.empty())
|
||||
{
|
||||
if (sendInfos.empty()) {
|
||||
++resultIndex;
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t nextResultIndex = resultIndex + 1;
|
||||
if (tryEmitCompactSendLoops(
|
||||
program.op, rewriter, sendsIt->second, taskYieldValues, resultIndex, nextResultIndex)) {
|
||||
resultIndex = nextResultIndex;
|
||||
continue;
|
||||
}
|
||||
if (isBatch) {
|
||||
size_t numSends = sendInfos.size();
|
||||
for (size_t s = 0; s < numSends; ++s) {
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
Value producedValue = taskYieldValues[resultIndex];
|
||||
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||
Value channelId = createIndexConstant(program.op, sendInfo.channelInfo.channelId, constantFolder);
|
||||
Value sourceCoreId = createIndexConstant(program.op, sendInfo.channelInfo.sourceCoreId, constantFolder);
|
||||
Value targetCoreId = createIndexConstant(program.op, sendInfo.channelInfo.targetCoreId, constantFolder);
|
||||
spatial::SpatChannelSendOp::create(rewriter, loc, channelId, sourceCoreId, targetCoreId, producedValue);
|
||||
for (size_t cpu : batch) {
|
||||
const ScheduledTask& laneTask = tasksByCpu[cpu][taskIndex];
|
||||
const RemoteSendInfo& send = remoteSendsByTask[laneTask.computeInstance][resultIndex][s];
|
||||
channelIds.push_back(send.channelInfo.channelId);
|
||||
sourceCoreIds.push_back(send.channelInfo.sourceCoreId);
|
||||
targetCoreIds.push_back(send.channelInfo.targetCoreId);
|
||||
}
|
||||
|
||||
SmallVector<Value> cIds = createIndexConstants(program.op, channelIds, constantFolder);
|
||||
SmallVector<Value> sIds = createIndexConstants(program.op, sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> tIds = createIndexConstants(program.op, targetCoreIds, constantFolder);
|
||||
|
||||
spatial::SpatChannelSendBatchOp::create(rewriter, loc, cIds, sIds, tIds, taskYieldValues[resultIndex]);
|
||||
}
|
||||
++resultIndex;
|
||||
}
|
||||
else {
|
||||
size_t nextResultIndex = resultIndex + 1;
|
||||
if (tryEmitCompactSendLoops(
|
||||
program.op, rewriter, sendsIt->second, taskYieldValues, resultIndex, nextResultIndex)) {
|
||||
resultIndex = nextResultIndex;
|
||||
continue;
|
||||
}
|
||||
|
||||
Value producedValue = taskYieldValues[resultIndex];
|
||||
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||
Value cId = createIndexConstant(program.op, sendInfo.channelInfo.channelId, constantFolder);
|
||||
Value sId = createIndexConstant(program.op, sendInfo.channelInfo.sourceCoreId, constantFolder);
|
||||
Value tId = createIndexConstant(program.op, sendInfo.channelInfo.targetCoreId, constantFolder);
|
||||
spatial::SpatChannelSendOp::create(rewriter, loc, cId, sId, tId, producedValue);
|
||||
}
|
||||
++resultIndex;
|
||||
}
|
||||
++resultIndex;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Value> yieldValues;
|
||||
yieldValues.reserve(cpuExternalOutputs[programKey].size());
|
||||
for (ProducerValueRef outputRef : cpuExternalOutputs[programKey]) {
|
||||
yieldValues.reserve(cpuExternalOutputs[leader].size());
|
||||
for (ProducerValueRef outputRef : cpuExternalOutputs[leader]) {
|
||||
auto producedIt = producedValuesByTask.find(outputRef.instance);
|
||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) {
|
||||
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||
task.computeInstance.op->emitOpError("missing yielded external value during per-cpu merge materialization")
|
||||
<< " cpu=" << cpu << " laneStart=" << outputRef.instance.laneStart;
|
||||
return failure();
|
||||
}
|
||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex)
|
||||
return func.emitError("missing yielded external value during materialization");
|
||||
yieldValues.push_back(producedIt->second[outputRef.resultIndex]);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, loc, ValueRange(yieldValues));
|
||||
|
||||
if (isBatch) {
|
||||
auto batchOp = cast<SpatComputeBatch>(program.op);
|
||||
auto inParallel = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||
Block* parallelBlock = rewriter.createBlock(&inParallel.getRegion());
|
||||
rewriter.setInsertionPointToEnd(parallelBlock);
|
||||
|
||||
for (auto [resultIndex, yieldedVal] : llvm::enumerate(yieldValues)) {
|
||||
auto destArg = batchOp.getOutputArgument(resultIndex);
|
||||
auto destType = cast<RankedTensorType>(destArg.getType());
|
||||
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
offsets.push_back(batchOp.getLaneArgument());
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.push_back(rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> strides;
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
|
||||
for (int64_t dim : destType.getShape().drop_front()) {
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
strides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
tensor::ParallelInsertSliceOp::create(rewriter, loc, yieldedVal, destArg, offsets, sizes, strides);
|
||||
}
|
||||
}
|
||||
else {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, ValueRange(yieldValues));
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
@@ -1245,7 +1253,6 @@ private:
|
||||
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
|
||||
DenseMap<ProgramKey, SmallVector<ScheduledTask>> tasksByProgram;
|
||||
SmallVector<size_t> orderedCpus;
|
||||
SmallVector<ProgramKey> orderedPrograms;
|
||||
DenseSet<size_t> seenCpus;
|
||||
DenseSet<ProgramKey> seenPrograms;
|
||||
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
|
||||
|
||||
@@ -19,6 +19,7 @@ struct MergeScheduleResult {
|
||||
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
||||
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
||||
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
||||
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "mlir/IR/Threading.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
@@ -274,7 +275,64 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder;
|
||||
});
|
||||
|
||||
// 5. Populate Final Result
|
||||
// 5. Check if equal schedule in two level
|
||||
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
|
||||
for (size_t currentProcessor = 0; currentProcessor < processorCount - 1; ++currentProcessor) {
|
||||
for (size_t controlProcessor = currentProcessor + 1; controlProcessor < processorCount; ++controlProcessor) {
|
||||
if (tasksByProcessor[currentProcessor].size() != tasksByProcessor[controlProcessor].size())
|
||||
continue;
|
||||
auto& currentTasks = tasksByProcessor[currentProcessor];
|
||||
auto& controlTasks = tasksByProcessor[controlProcessor];
|
||||
bool equalSchedule = true;
|
||||
|
||||
for (auto [currentTask, controlTask] : llvm::zip(currentTasks, controlTasks)) {
|
||||
const ComputeInstance currentComputeInstance = graph.nodes[currentTask].instance;
|
||||
const ComputeInstance controlComputeInstance = graph.nodes[controlTask].instance;
|
||||
if (currentComputeInstance.op != controlComputeInstance.op) {
|
||||
equalSchedule = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (equalSchedule) {
|
||||
equivalentClass[currentProcessor].push_back(controlProcessor);
|
||||
equivalentClass[controlProcessor].push_back(currentProcessor);
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
llvm::dbgs() << "--- Scheduling Equivalence Classes ---\n";
|
||||
std::vector<bool> visited(processorCount, false);
|
||||
size_t uniqueClassCount = 0;
|
||||
|
||||
for (size_t i = 0; i < processorCount; ++i) {
|
||||
if (visited[i])
|
||||
continue;
|
||||
|
||||
// We found a new unique schedule (equivalence class)
|
||||
++uniqueClassCount;
|
||||
visited[i] = true;
|
||||
|
||||
llvm::dbgs() << "Class " << uniqueClassCount << ": CPUs { " << i;
|
||||
|
||||
// Find and mark all identical companions
|
||||
auto it = equivalentClass.find(i);
|
||||
if (it != equivalentClass.end()) {
|
||||
for (size_t eqCpu : it->second) {
|
||||
if (!visited[eqCpu]) {
|
||||
llvm::dbgs() << ", " << eqCpu;
|
||||
visited[eqCpu] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
llvm::dbgs() << " }\n";
|
||||
}
|
||||
|
||||
llvm::dbgs() << "Total unique CPU nodes to emit: " << uniqueClassCount << "\n";
|
||||
llvm::dbgs() << "--------------------------------------\n";
|
||||
}
|
||||
|
||||
// 6. Populate Final Result
|
||||
MergeScheduleResult result;
|
||||
result.dominanceOrderCompute.reserve(nodeCount);
|
||||
|
||||
@@ -296,8 +354,9 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
}
|
||||
}
|
||||
|
||||
result.equivalentClass = equivalentClass;
|
||||
|
||||
return result;
|
||||
}
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
|
||||
Reference in New Issue
Block a user