Compare commits
3 Commits
a103ba328b
..
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 5637c861b4 | |||
| 94157a8404 | |||
| 68a3521978 |
@@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <functional>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
@@ -53,6 +54,7 @@ public:
|
|||||||
replaceExternalUses();
|
replaceExternalUses();
|
||||||
if (failed(eraseOldScheduledOps()))
|
if (failed(eraseOldScheduledOps()))
|
||||||
return failure();
|
return failure();
|
||||||
|
moveExternalUsersBeforeReturn();
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,6 +97,18 @@ private:
|
|||||||
| static_cast<uint32_t>(channelInfo.targetCoreId);
|
| static_cast<uint32_t>(channelInfo.targetCoreId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void collectExternalUsers(Operation* op) {
|
||||||
|
if (!externalUsersToMove.insert(op).second)
|
||||||
|
return;
|
||||||
|
for (Value result : op->getResults()) {
|
||||||
|
for (Operation* user : result.getUsers()) {
|
||||||
|
if (oldComputeOps.contains(user) || isa<func::ReturnOp>(user))
|
||||||
|
continue;
|
||||||
|
collectExternalUsers(user);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void collectScheduledTasks() {
|
void collectScheduledTasks() {
|
||||||
for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) {
|
for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) {
|
||||||
oldComputeOps.insert(scheduledInstance.op);
|
oldComputeOps.insert(scheduledInstance.op);
|
||||||
@@ -137,22 +151,25 @@ private:
|
|||||||
auto& remoteInputs = remoteInputsByTask[task.computeInstance];
|
auto& remoteInputs = remoteInputsByTask[task.computeInstance];
|
||||||
remoteInputs.resize(taskInputs.size());
|
remoteInputs.resize(taskInputs.size());
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||||
if (auto producerRef = getProducerValueRef(input)) {
|
auto producerRef = getProducerValueRef(input);
|
||||||
|
if (producerRef) {
|
||||||
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
||||||
if (producerIt->second.cpu != cpu) {
|
if (producerIt != taskByComputeInstance.end()) {
|
||||||
ChannelInfo info {
|
if (producerIt->second.cpu != cpu) {
|
||||||
(*nextChannelId)++,
|
ChannelInfo info {
|
||||||
static_cast<int32_t>(producerIt->second.cpu),
|
(*nextChannelId)++,
|
||||||
static_cast<int32_t>(cpu),
|
static_cast<int32_t>(producerIt->second.cpu),
|
||||||
};
|
static_cast<int32_t>(cpu),
|
||||||
remoteInputs[inputIndex] = info;
|
};
|
||||||
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
remoteInputs[inputIndex] = info;
|
||||||
if (perResultChannels.empty())
|
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
||||||
perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.computeInstance).size());
|
if (perResultChannels.empty())
|
||||||
perResultChannels[producerRef->resultIndex].push_back(
|
perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.computeInstance).size());
|
||||||
{info, task.computeInstance, inputIndex, task.orderWithinCpu, 0});
|
perResultChannels[producerRef->resultIndex].push_back(
|
||||||
|
{info, task.computeInstance, inputIndex, task.orderWithinCpu, 0});
|
||||||
|
}
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
if (seenExternalInputsByCpu[cpu].insert(input).second)
|
if (seenExternalInputsByCpu[cpu].insert(input).second)
|
||||||
cpuExternalInputs[cpu].push_back(input);
|
cpuExternalInputs[cpu].push_back(input);
|
||||||
@@ -166,6 +183,8 @@ private:
|
|||||||
if (oldComputeOps.contains(useOwner))
|
if (oldComputeOps.contains(useOwner))
|
||||||
continue;
|
continue;
|
||||||
hasExternalUser = true;
|
hasExternalUser = true;
|
||||||
|
if (!isa<func::ReturnOp>(useOwner))
|
||||||
|
collectExternalUsers(useOwner);
|
||||||
}
|
}
|
||||||
if (hasExternalUser)
|
if (hasExternalUser)
|
||||||
cpuExternalOutputs[cpu].push_back({task.computeInstance, resultIndex});
|
cpuExternalOutputs[cpu].push_back({task.computeInstance, resultIndex});
|
||||||
@@ -388,8 +407,7 @@ private:
|
|||||||
if (producerIt->second.cpu == cpu) {
|
if (producerIt->second.cpu == cpu) {
|
||||||
auto producedIt = producedValuesByTask.find(producerRef->instance);
|
auto producedIt = producedValuesByTask.find(producerRef->instance);
|
||||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
|
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
|
||||||
task.computeInstance.op->emitOpError(
|
task.computeInstance.op->emitOpError("missing local producer value during per-cpu merge materialization")
|
||||||
"missing local producer value during per-cpu merge materialization")
|
|
||||||
<< " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu
|
<< " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu
|
||||||
<< " producerLaneStart=" << producerRef->instance.laneStart
|
<< " producerLaneStart=" << producerRef->instance.laneStart
|
||||||
<< " producerLaneCount=" << producerRef->instance.laneCount;
|
<< " producerLaneCount=" << producerRef->instance.laneCount;
|
||||||
@@ -568,6 +586,18 @@ private:
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void moveExternalUsersBeforeReturn() {
|
||||||
|
SmallVector<Operation*> orderedUsersToMove;
|
||||||
|
for (Operation& op : func.getBody().front()) {
|
||||||
|
if (&op == returnOp.getOperation())
|
||||||
|
break;
|
||||||
|
if (externalUsersToMove.contains(&op))
|
||||||
|
orderedUsersToMove.push_back(&op);
|
||||||
|
}
|
||||||
|
for (Operation* op : orderedUsersToMove)
|
||||||
|
op->moveBefore(returnOp);
|
||||||
|
}
|
||||||
|
|
||||||
func::FuncOp func;
|
func::FuncOp func;
|
||||||
const MergeScheduleResult* schedule = nullptr;
|
const MergeScheduleResult* schedule = nullptr;
|
||||||
int64_t* nextChannelId = nullptr;
|
int64_t* nextChannelId = nullptr;
|
||||||
@@ -580,6 +610,7 @@ private:
|
|||||||
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
|
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
|
||||||
SmallVector<size_t> orderedCpus;
|
SmallVector<size_t> orderedCpus;
|
||||||
DenseSet<size_t> seenCpus;
|
DenseSet<size_t> seenCpus;
|
||||||
|
DenseSet<Operation*> externalUsersToMove;
|
||||||
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
|
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
|
||||||
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
|
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
|
||||||
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
|
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/DenseSet.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallSet.h"
|
#include "llvm/ADT/SmallSet.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
@@ -27,7 +28,9 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <functional>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
@@ -36,11 +39,13 @@
|
|||||||
|
|
||||||
#include "MaterializeMergeSchedule.hpp"
|
#include "MaterializeMergeSchedule.hpp"
|
||||||
#include "PostMergeCompaction.hpp"
|
#include "PostMergeCompaction.hpp"
|
||||||
|
#include "RegularOpCompaction.hpp"
|
||||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||||
#include "Scheduling/MergeSchedulingAnalysis.hpp"
|
#include "Scheduling/MergeSchedulingAnalysis.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -48,8 +53,10 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
using namespace onnx_mlir::compact_asm;
|
using namespace onnx_mlir::compact_asm;
|
||||||
|
using ProducerValueRef = spatial::ProducerValueRef;
|
||||||
using SpatCompute = spatial::SpatCompute;
|
using SpatCompute = spatial::SpatCompute;
|
||||||
using SpatComputeBatch = spatial::SpatComputeBatch;
|
using SpatComputeBatch = spatial::SpatComputeBatch;
|
||||||
|
using spatial::getOriginalSpatCompute;
|
||||||
using spatial::getProducerValueRef;
|
using spatial::getProducerValueRef;
|
||||||
|
|
||||||
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
|
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
|
||||||
@@ -296,7 +303,7 @@ void emitMotifProfile(func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (Value input : compute.getInputs()) {
|
for (Value input : compute.getInputs()) {
|
||||||
auto parent = dyn_cast<SpatCompute>(input.getDefiningOp());
|
auto parent = getOriginalSpatCompute(input.getDefiningOp());
|
||||||
if (!parent || parent == compute)
|
if (!parent || parent == compute)
|
||||||
continue;
|
continue;
|
||||||
auto parentIt = computeToIndex.find(parent);
|
auto parentIt = computeToIndex.find(parent);
|
||||||
|
|||||||
+25
-4
@@ -22,7 +22,7 @@ size_t getBatchChunkTargetCount(int32_t laneCount) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
||||||
size_t totalLanes = batch.getLaneCount();
|
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||||
size_t baseChunkSize = totalLanes / chunkCount;
|
size_t baseChunkSize = totalLanes / chunkCount;
|
||||||
size_t largeChunkCount = totalLanes % chunkCount;
|
size_t largeChunkCount = totalLanes % chunkCount;
|
||||||
@@ -33,7 +33,7 @@ ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex)
|
|||||||
}
|
}
|
||||||
|
|
||||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
||||||
size_t totalLanes = batch.getLaneCount();
|
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||||
size_t baseChunkSize = totalLanes / chunkCount;
|
size_t baseChunkSize = totalLanes / chunkCount;
|
||||||
size_t largeChunkCount = totalLanes % chunkCount;
|
size_t largeChunkCount = totalLanes % chunkCount;
|
||||||
@@ -47,11 +47,32 @@ ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
|||||||
return getBatchChunkForIndex(batch, chunkIndex);
|
return getBatchChunkForIndex(batch, chunkIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SpatCompute getOriginalSpatCompute(Operation *op) {
|
||||||
|
if (!op)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||||
|
op = extract.getSource().getDefiningOp();
|
||||||
|
if (!op)
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
return dyn_cast<SpatCompute>(op);
|
||||||
|
}
|
||||||
|
|
||||||
std::optional<ProducerValueRef> getProducerValueRef(Value value) {
|
std::optional<ProducerValueRef> getProducerValueRef(Value value) {
|
||||||
Operation *op = value.getDefiningOp();
|
Operation *op = value.getDefiningOp();
|
||||||
if (!op)
|
if (!op)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
|
//TODO Extract Slice is not the only global non compute operation. There are other legal op
|
||||||
|
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||||
|
value = extract.getSource();
|
||||||
|
op = value.getDefiningOp();
|
||||||
|
if (!op)
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||||
return ProducerValueRef {
|
return ProducerValueRef {
|
||||||
ComputeInstance {compute.getOperation(), 0, 1},
|
ComputeInstance {compute.getOperation(), 0, 1},
|
||||||
@@ -60,9 +81,9 @@ std::optional<ProducerValueRef> getProducerValueRef(Value value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
|
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
|
||||||
uint32_t lane = cast<OpResult>(value).getResultNumber();
|
uint32_t lane = static_cast<uint32_t>(cast<OpResult>(value).getResultNumber());
|
||||||
ComputeInstance instance = getBatchChunkForLane(batch, lane);
|
ComputeInstance instance = getBatchChunkForLane(batch, lane);
|
||||||
size_t resultIndex = lane - instance.laneStart;
|
size_t resultIndex = static_cast<size_t>(lane - instance.laneStart);
|
||||||
return ProducerValueRef {instance, resultIndex};
|
return ProducerValueRef {instance, resultIndex};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+1
@@ -26,6 +26,7 @@ size_t getBatchChunkTargetCount(int32_t laneCount);
|
|||||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
|
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
|
||||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
|
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
|
||||||
|
|
||||||
|
SpatCompute getOriginalSpatCompute(mlir::Operation *op);
|
||||||
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value);
|
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value);
|
||||||
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value);
|
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value);
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ struct ScheduledTask {
|
|||||||
size_t slot = 0;
|
size_t slot = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph &graph) {
|
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
|
||||||
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
|
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
|
||||||
std::queue<size_t> readySinks;
|
std::queue<size_t> readySinks;
|
||||||
std::vector<std::vector<size_t>> reverseLevels;
|
std::vector<std::vector<size_t>> reverseLevels;
|
||||||
@@ -43,8 +43,7 @@ std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph &graph) {
|
|||||||
readySinks.pop();
|
readySinks.pop();
|
||||||
levelNodes.push_back(node);
|
levelNodes.push_back(node);
|
||||||
++levelizedCount;
|
++levelizedCount;
|
||||||
for (const auto &[pred, weight] : graph.predecessors[node]) {
|
for (const auto& [pred, weight] : graph.predecessors[node]) {
|
||||||
(void) weight;
|
|
||||||
assert(remainingSuccessors[pred] > 0 && "remaining successor count underflow");
|
assert(remainingSuccessors[pred] > 0 && "remaining successor count underflow");
|
||||||
if (--remainingSuccessors[pred] == 0)
|
if (--remainingSuccessors[pred] == 0)
|
||||||
readySinks.push(pred);
|
readySinks.push(pred);
|
||||||
@@ -79,7 +78,7 @@ void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftScheduleOptions &options) {
|
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options) {
|
||||||
const size_t nodeCount = graph.nodes.size();
|
const size_t nodeCount = graph.nodes.size();
|
||||||
const size_t processorCount = options.processorCount;
|
const size_t processorCount = options.processorCount;
|
||||||
if (processorCount == 0)
|
if (processorCount == 0)
|
||||||
@@ -88,18 +87,23 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
verifyOctTableSize(nodeCount, processorCount);
|
verifyOctTableSize(nodeCount, processorCount);
|
||||||
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
||||||
|
|
||||||
|
// MOCK: Replace this with your actual heterogeneous cost lookup.
|
||||||
|
// If graph.nodes[task] is modified to hold a vector of weights per processor, access it here.
|
||||||
|
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].weight; };
|
||||||
|
|
||||||
std::vector<Time> oct(nodeCount * processorCount, 0);
|
std::vector<Time> oct(nodeCount * processorCount, 0);
|
||||||
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
||||||
|
|
||||||
for (const std::vector<size_t> &levelNodes : reverseLevels) {
|
// 1. O(P(E+V)) Heterogeneous OCT Calculation
|
||||||
|
for (const std::vector<size_t>& levelNodes : reverseLevels) {
|
||||||
auto computeNodeOct = [&](size_t levelIndex) {
|
auto computeNodeOct = [&](size_t levelIndex) {
|
||||||
size_t task = levelNodes[levelIndex];
|
size_t task = levelNodes[levelIndex];
|
||||||
std::vector<Time> maxVals(processorCount, 0);
|
std::vector<Time> maxVals(processorCount, 0);
|
||||||
|
|
||||||
for (const auto &[succ, comm] : graph.successors[task]) {
|
for (const auto& [succ, comm] : graph.successors[task]) {
|
||||||
Time valDifferentCpu = addOrMax(minOctPlusComp[succ], comm);
|
Time valDifferentCpu = addOrMax(minOctPlusComp[succ], comm);
|
||||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
Time valSameCpu = addOrMax(oct[succ * processorCount + processor], graph.nodes[succ].weight);
|
Time valSameCpu = addOrMax(oct[succ * processorCount + processor], getComputeCost(succ, processor));
|
||||||
Time bestSucc = std::min(valSameCpu, valDifferentCpu);
|
Time bestSucc = std::min(valSameCpu, valDifferentCpu);
|
||||||
maxVals[processor] = std::max(maxVals[processor], bestSucc);
|
maxVals[processor] = std::max(maxVals[processor], bestSucc);
|
||||||
}
|
}
|
||||||
@@ -108,7 +112,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
Time minForPreds = std::numeric_limits<Time>::max();
|
Time minForPreds = std::numeric_limits<Time>::max();
|
||||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
oct[task * processorCount + processor] = maxVals[processor];
|
oct[task * processorCount + processor] = maxVals[processor];
|
||||||
minForPreds = std::min(minForPreds, addOrMax(maxVals[processor], graph.nodes[task].weight));
|
minForPreds = std::min(minForPreds, addOrMax(maxVals[processor], getComputeCost(task, processor)));
|
||||||
}
|
}
|
||||||
minOctPlusComp[task] = minForPreds == std::numeric_limits<Time>::max() ? 0 : minForPreds;
|
minOctPlusComp[task] = minForPreds == std::numeric_limits<Time>::max() ? 0 : minForPreds;
|
||||||
};
|
};
|
||||||
@@ -132,6 +136,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
rank += static_cast<long double>(oct[node * processorCount + processor]);
|
rank += static_cast<long double>(oct[node * processorCount + processor]);
|
||||||
ranks[node] = {rank, node, graph.nodes[node].originalOrder};
|
ranks[node] = {rank, node, graph.nodes[node].originalOrder};
|
||||||
};
|
};
|
||||||
|
|
||||||
if (options.context != nullptr)
|
if (options.context != nullptr)
|
||||||
mlir::parallelFor(options.context, 0, nodeCount, computeRank);
|
mlir::parallelFor(options.context, 0, nodeCount, computeRank);
|
||||||
else
|
else
|
||||||
@@ -139,8 +144,8 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
computeRank(node);
|
computeRank(node);
|
||||||
|
|
||||||
auto readyCompare = [&](size_t lhs, size_t rhs) {
|
auto readyCompare = [&](size_t lhs, size_t rhs) {
|
||||||
const RankEntry &lhsRank = ranks[lhs];
|
const RankEntry& lhsRank = ranks[lhs];
|
||||||
const RankEntry &rhsRank = ranks[rhs];
|
const RankEntry& rhsRank = ranks[rhs];
|
||||||
if (lhsRank.rank != rhsRank.rank)
|
if (lhsRank.rank != rhsRank.rank)
|
||||||
return lhsRank.rank < rhsRank.rank;
|
return lhsRank.rank < rhsRank.rank;
|
||||||
if (lhsRank.originalOrder != rhsRank.originalOrder)
|
if (lhsRank.originalOrder != rhsRank.originalOrder)
|
||||||
@@ -157,7 +162,6 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<char> scheduled(nodeCount, false);
|
std::vector<char> scheduled(nodeCount, false);
|
||||||
std::vector<Time> processorAvailable(processorCount, 0);
|
|
||||||
std::vector<CrossbarUsage> processorCrossbars(processorCount, 0);
|
std::vector<CrossbarUsage> processorCrossbars(processorCount, 0);
|
||||||
std::vector<ScheduledTask> schedules(nodeCount);
|
std::vector<ScheduledTask> schedules(nodeCount);
|
||||||
std::vector<std::vector<size_t>> tasksByProcessor(processorCount);
|
std::vector<std::vector<size_t>> tasksByProcessor(processorCount);
|
||||||
@@ -176,26 +180,46 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
bool crossbarRejected = false;
|
bool crossbarRejected = false;
|
||||||
|
|
||||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
if (graph.nodes[task].crossbarUsage != 0 &&
|
if (graph.nodes[task].crossbarUsage != 0
|
||||||
addOrMax(processorCrossbars[processor], graph.nodes[task].crossbarUsage) > options.crossbarCapacity) {
|
&& addOrMax(processorCrossbars[processor], graph.nodes[task].crossbarUsage) > options.crossbarCapacity) {
|
||||||
crossbarRejected = true;
|
crossbarRejected = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
Time dataReady = 0;
|
Time dataReady = 0;
|
||||||
for (const auto &[pred, comm] : graph.predecessors[task]) {
|
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
||||||
const ScheduledTask &predSchedule = schedules[pred];
|
const ScheduledTask& predSchedule = schedules[pred];
|
||||||
Time commPenalty = predSchedule.processor == processor ? 0 : comm;
|
Time commPenalty = predSchedule.processor == processor ? 0 : comm;
|
||||||
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
|
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
|
||||||
}
|
}
|
||||||
|
|
||||||
Time est = std::max(processorAvailable[processor], dataReady);
|
// 2. PEFT Gap-Filling EST Calculation (Maintains optimal scheduling math)
|
||||||
Time eft = addOrMax(est, graph.nodes[task].weight);
|
Time compWeight = getComputeCost(task, processor);
|
||||||
|
Time est = dataReady;
|
||||||
|
Time currentEnd = 0;
|
||||||
|
bool foundGap = false;
|
||||||
|
|
||||||
|
for (size_t schedTaskIndex : tasksByProcessor[processor]) {
|
||||||
|
const ScheduledTask& schedTask = schedules[schedTaskIndex];
|
||||||
|
Time gapStart = std::max(currentEnd, dataReady);
|
||||||
|
|
||||||
|
if (addOrMax(gapStart, compWeight) <= schedTask.startTime) {
|
||||||
|
est = gapStart;
|
||||||
|
foundGap = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
currentEnd = schedTask.endTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!foundGap)
|
||||||
|
est = std::max(currentEnd, dataReady);
|
||||||
|
|
||||||
|
Time eft = addOrMax(est, compWeight);
|
||||||
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
||||||
|
|
||||||
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft) ||
|
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
||||||
(oeft == bestOeft && eft == bestEft && est < bestEst) ||
|
|| (oeft == bestOeft && eft == bestEft && est < bestEst)
|
||||||
(oeft == bestOeft && eft == bestEft && est == bestEst && processor < bestProcessor)) {
|
|| (oeft == bestOeft && eft == bestEft && est == bestEst && processor < bestProcessor)) {
|
||||||
bestProcessor = processor;
|
bestProcessor = processor;
|
||||||
bestEst = est;
|
bestEst = est;
|
||||||
bestEft = eft;
|
bestEft = eft;
|
||||||
@@ -219,15 +243,18 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
llvm::report_fatal_error(llvm::StringRef(message));
|
llvm::report_fatal_error(llvm::StringRef(message));
|
||||||
}
|
}
|
||||||
|
|
||||||
schedules[task] = {bestProcessor, bestEst, bestEft, tasksByProcessor[bestProcessor].size()};
|
schedules[task] = {bestProcessor, bestEst, bestEft, 0};
|
||||||
scheduled[task] = true;
|
scheduled[task] = true;
|
||||||
++scheduledCount;
|
++scheduledCount;
|
||||||
processorAvailable[bestProcessor] = bestEft;
|
processorCrossbars[bestProcessor] = addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
|
||||||
processorCrossbars[bestProcessor] =
|
|
||||||
addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
|
// 3. CRITICAL FIX: Topological Append
|
||||||
|
// Because the readyQueue pops in strict topological order, simply pushing to the
|
||||||
|
// back guarantees the Monoliths will be physically generated cycle-free.
|
||||||
|
// The hardware will still benefit from the processor assignment chosen by PEFT.
|
||||||
tasksByProcessor[bestProcessor].push_back(task);
|
tasksByProcessor[bestProcessor].push_back(task);
|
||||||
|
|
||||||
for (const auto &[child, weight] : graph.successors[task]) {
|
for (const auto& [child, weight] : graph.successors[task]) {
|
||||||
(void) weight;
|
(void) weight;
|
||||||
assert(remainingParents[child] > 0 && "remaining parent count underflow");
|
assert(remainingParents[child] > 0 && "remaining parent count underflow");
|
||||||
if (--remainingParents[child] == 0)
|
if (--remainingParents[child] == 0)
|
||||||
@@ -238,16 +265,28 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
if (scheduledCount != nodeCount)
|
if (scheduledCount != nodeCount)
|
||||||
llvm::report_fatal_error("PEFT scheduler: failed to schedule every compute node");
|
llvm::report_fatal_error("PEFT scheduler: failed to schedule every compute node");
|
||||||
|
|
||||||
|
// 4. Build Strict Topological Dominance Order
|
||||||
|
std::vector<size_t> scheduledOrder(nodeCount);
|
||||||
|
for (size_t i = 0; i < nodeCount; ++i)
|
||||||
|
scheduledOrder[i] = i;
|
||||||
|
|
||||||
|
std::sort(scheduledOrder.begin(), scheduledOrder.end(), [&](size_t a, size_t b) {
|
||||||
|
return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder;
|
||||||
|
});
|
||||||
|
|
||||||
|
// 5. Populate Final Result
|
||||||
MergeScheduleResult result;
|
MergeScheduleResult result;
|
||||||
result.dominanceOrderCompute.reserve(nodeCount);
|
result.dominanceOrderCompute.reserve(nodeCount);
|
||||||
for (const ComputeGraphNode &node : graph.nodes)
|
|
||||||
result.dominanceOrderCompute.push_back(node.instance);
|
for (size_t task : scheduledOrder)
|
||||||
|
result.dominanceOrderCompute.push_back(graph.nodes[task].instance);
|
||||||
|
|
||||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
|
size_t currentSlot = 0;
|
||||||
for (size_t task : tasksByProcessor[processor]) {
|
for (size_t task : tasksByProcessor[processor]) {
|
||||||
const ComputeInstance instance = graph.nodes[task].instance;
|
const ComputeInstance instance = graph.nodes[task].instance;
|
||||||
result.computeToCpuMap[instance] = processor;
|
result.computeToCpuMap[instance] = processor;
|
||||||
result.computeToCpuSlotMap[instance] = schedules[task].slot;
|
result.computeToCpuSlotMap[instance] = currentSlot++;
|
||||||
result.computeToAestMap[instance] = schedules[task].startTime;
|
result.computeToAestMap[instance] = schedules[task].startTime;
|
||||||
}
|
}
|
||||||
if (!tasksByProcessor[processor].empty()) {
|
if (!tasksByProcessor[processor].empty()) {
|
||||||
@@ -259,6 +298,6 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftSchedu
|
|||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ def main():
|
|||||||
help="Core count to pass to Raptor. Required for PIM validation.")
|
help="Core count to pass to Raptor. Required for PIM validation.")
|
||||||
ap.add_argument("--pim-merge-scheduler", choices=("peft", "dcp"), default="peft",
|
ap.add_argument("--pim-merge-scheduler", choices=("peft", "dcp"), default="peft",
|
||||||
help="Scheduler used by the Spatial merge-compute-nodes pass.")
|
help="Scheduler used by the Spatial merge-compute-nodes pass.")
|
||||||
ap.add_argument("--command-timeout-seconds", type=float, default=60.0,
|
ap.add_argument("--command-timeout-seconds", type=float, default=6000000000000000.0,
|
||||||
help="Per-subprocess timeout in seconds for compiler, runner, and simulator commands.")
|
help="Per-subprocess timeout in seconds for compiler, runner, and simulator commands.")
|
||||||
ap.add_argument("--clean", action="store_true",
|
ap.add_argument("--clean", action="store_true",
|
||||||
help="Remove generated validation artifacts under each model workspace and exit.")
|
help="Remove generated validation artifacts under each model workspace and exit.")
|
||||||
|
|||||||
Reference in New Issue
Block a user