faster (and refactored) DCP analysis
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h16m17s

This commit is contained in:
NiccoloN
2026-04-21 12:33:44 +02:00
parent f4c6da8f10
commit 85e2750d6c
20 changed files with 2525 additions and 858 deletions

View File

@@ -661,9 +661,8 @@ void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) {
broadcastSendOp = op;
continue;
}
if (auto op = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user)) {
if (auto op = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user))
continue;
}
llvm_unreachable("Unexpected user of spat.channel_new during Spatial-to-PIM lowering");
}
@@ -719,7 +718,8 @@ void SpatialToPimPass::lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter&
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, receiveOp.getChannel());
Value receivedValue =
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
PimReceiveOp::create(
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
rewriter.replaceOp(receiveOp, receivedValue);
}

View File

@@ -5,6 +5,8 @@ add_pim_library(SpatialOps
SpatialOps.cpp
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp
Transforms/MergeComputeNodes/DCPGraph/Task.cpp
Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp

View File

@@ -17,7 +17,7 @@ namespace spatial {
using namespace mlir;
SpatWeightedCompute getOriginalSpatWeightCompute(Operation* op) {
SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) {
if (!op)
return {};
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) {
@@ -30,32 +30,32 @@ SpatWeightedCompute getOriginalSpatWeightCompute(Operation* op) {
return {};
}
DCPAnalysisResult DCPAnalysis::runAnalysis() {
using EdgesIndex = std::tuple<int64_t, int64_t, int64_t>;
DCPAnalysisResult DCPAnalysis::run() {
llvm::SmallVector<SpatWeightedCompute, 10> spatWeightedComputes;
llvm::SmallVector<EdgesIndex, 10> edges;
for (auto& regions : entryOp->getRegions())
for (SpatWeightedCompute spatWeightedCompute : regions.getOps<SpatWeightedCompute>())
llvm::SmallVector<IndexedEdge, 10> edges;
for (auto& region : entryOp->getRegions())
for (SpatWeightedCompute spatWeightedCompute : region.getOps<SpatWeightedCompute>())
spatWeightedComputes.push_back(spatWeightedCompute);
for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) {
for (Value input : spatWeightedCompute.getInputs()) {
if (auto spatWeightedComputeArgOp = getOriginalSpatWeightCompute(input.getDefiningOp())) {
auto elemIter = llvm::find(spatWeightedComputes, spatWeightedComputeArgOp);
assert(elemIter != spatWeightedComputes.end());
auto indexStartEdge = std::distance(spatWeightedComputes.begin(), elemIter);
ResultRange outputs = spatWeightedComputeArgOp.getResults();
if (auto producerCompute = getOriginalSpatWeightedCompute(input.getDefiningOp())) {
auto producerIt = llvm::find(spatWeightedComputes, producerCompute);
assert(producerIt != spatWeightedComputes.end());
auto indexStartEdge = std::distance(spatWeightedComputes.begin(), producerIt);
ResultRange outputs = producerCompute.getResults();
int64_t totalSize = 0;
for (auto output : outputs) {
ShapedType result = cast<ShapedType>(output.getType());
totalSize += getSizeInBytes(result);
ShapedType resultType = cast<ShapedType>(output.getType());
totalSize += getSizeInBytes(resultType);
}
edges.push_back({indexStartEdge, indexEndEdge, totalSize});
}
}
}
GraphDCP graphDCP(spatWeightedComputes, edges);
graphDCP.DCP();
graphDCP.setContext(entryOp->getContext());
graphDCP.runDcp();
return graphDCP.getResult();
}

View File

@@ -3,6 +3,7 @@
#include "mlir/IR/Operation.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include <vector>
@@ -10,8 +11,8 @@
struct DCPAnalysisResult {
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;
llvm::DenseMap<onnx_mlir::spatial::SpatWeightedCompute, size_t> computeToCPUMap;
llvm::DenseSet<onnx_mlir::spatial::SpatWeightedCompute> isLastComputeOfACpu;
llvm::DenseMap<onnx_mlir::spatial::SpatWeightedCompute, size_t> computeToCpuMap;
llvm::DenseSet<onnx_mlir::spatial::SpatWeightedCompute> isLastComputeOfCpu;
llvm::DenseMap<size_t, onnx_mlir::spatial::SpatWeightedCompute> cpuToLastComputeMap;
};
@@ -21,12 +22,12 @@ struct DCPAnalysis {
private:
DCPAnalysisResult result;
mlir::Operation* entryOp;
DCPAnalysisResult runAnalysis();
DCPAnalysisResult run();
public:
DCPAnalysis(mlir::Operation* op)
: entryOp(op) {
result = runAnalysis();
result = run();
}
DCPAnalysisResult& getResult() { return result; }
};

View File

@@ -2,6 +2,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include <list>
#include <optional>
@@ -12,90 +13,144 @@
#include "Task.hpp"
#include "Utils.hpp"
std::optional<DoubleEdge> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
void removeEdge(TaskDCP* parent, TaskDCP* child);
int getTranferCost(TaskDCP* parent, TaskDCP* child);
namespace mlir {
class MLIRContext;
} // namespace mlir
std::optional<EdgePair> addEdge(TaskDCP* parent, TaskDCP* child, Weight weight, bool isScheduling = false);
void removeEdge(TaskDCP* parent, TaskDCP* child, bool isScheduling = false);
Weight getTransferCost(TaskDCP* parent, TaskDCP* child);
class GraphDCP {
public:
struct CandidateRelations {
llvm::DenseSet<TaskDCP*> ancestors;
llvm::DenseSet<TaskDCP*> descendants;
// descendants ordered by position in the graph's topological order;
// iterating this avoids walking non-descendant tail tasks on hot paths.
llvm::SmallVector<TaskDCP*, 32> descendantsTopoOrder;
};
struct ScheduledTaskInfo {
size_t nodeIndex;
int aest;
int alst;
int weight;
Time aest;
Time alst;
Weight weight;
};
private:
using CpuTaskList = std::list<TaskDCP*>;
struct FindSlot {
int aest;
Time aest;
int index;
};
std::vector<TaskDCP> nodes;
onnx_mlir::LabeledList<TaskDCP> topologicalOrder;
std::unordered_map<CPU, std::list<TaskDCP*>> mapCPUTasks;
CPU last_cpu = 0;
std::vector<CpuTaskList> cpuTasks;
std::unordered_map<CPU, CrossbarUsage> cpuCrossbarUsage;
CPU lastCpu = 0;
long long flag = 1;
int DCPL;
Time dcpl = 0;
Time maxCompletion = 0;
Time secondMaxCompletion = 0;
TaskDCP* maxCompletionTask = nullptr;
int maxCpuCount = 1000;
mlir::MLIRContext* context = nullptr;
TaskInsertion insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position);
void removeTaskFromCPU(CPU cpu, TaskDCP* task);
CpuTaskList& getOrCreateCpuTasks(CPU cpu);
const CpuTaskList* findCpuTasks(CPU cpu) const;
std::vector<TaskDCP*> getRoots();
long long getUniqueFlag() { return flag++; }
void initAEST();
int initDCPL();
void initALST();
void initAest();
void initAlst();
int computeAEST(TaskDCP* task, CPU cpu);
int computeDCPL(TaskDCP* task, CPU cpu);
int getDCPL() { return DCPL; }
Time computeAestOnCpu(TaskDCP* task, CPU cpu);
Time computeDcplOnCpu(TaskDCP* task, CPU cpu);
Time getDcpl() const { return dcpl; }
Time computeTaskAlstOnCpu(TaskDCP* task, CPU cpu, Time scheduleDcpl);
void updateAestFromTask(TaskDCP* task);
void updateAestFromTaskWithDescendants(TaskDCP* task, const llvm::DenseSet<TaskDCP*>& descendants);
void updateAestFromTaskWithDescendants(TaskDCP* task, llvm::ArrayRef<TaskDCP*> descendantsTopoOrder);
// Propagates AEST like the overload above but returns early (before touching
// the remaining descendants) as soon as a task's completion exceeds
// `dcplBudget`, signalling that the new DCPL would exceed the budget.
// Returns true iff the full propagation completed without exceeding the
// budget. Uses the caller's snapshot to restore AEST on the aborted tail.
bool tryUpdateAestWithinBudget(TaskDCP* task,
llvm::ArrayRef<TaskDCP*> descendantsTopoOrder,
Time dcplBudget);
void initTopological();
void topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint);
void topologicalMoveBefore(TaskDCP* task, TaskDCP* pivotPoint);
void topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint, TaskInsertion* insertion = nullptr);
void topologicalMoveBefore(TaskDCP* task, TaskDCP* pivotPoint, TaskInsertion* insertion = nullptr);
llvm::DenseMap<TaskDCP*, int> computeALST(TaskDCP* task, CPU cpu);
llvm::DenseMap<TaskDCP*, Time> computeAlst(TaskDCP* task, CPU cpu, const CandidateRelations& relations);
size_t getNodeIndex(const TaskDCP* task) const;
TaskDCP* findCandidate(std::vector<TaskDCP*> nodes);
TaskDCP* findCandidate(const std::vector<TaskDCP*>& readyNodes);
void selectProcessor(TaskDCP* candidate, bool push);
CPU lastCPU() const { return last_cpu; }
void incLastCPU() { last_cpu++; }
FindSlot findSlot(TaskDCP* candidate, CPU cpu, bool push);
void to_dot();
CPU getLastCpu() const { return lastCpu; }
void incrementLastCpu() { lastCpu++; }
FindSlot findSlot(TaskDCP* candidate, CPU cpu, bool push, const CandidateRelations& relations);
FindSlot findSlotWithFixedFinalTime(
TaskDCP* candidate, CPU cpu, const CandidateRelations& relations, Time finalTime, Time aestOnCpu);
void dumpDot();
friend TaskInsertion;
friend class TaskDCP;
CrossbarUsage getCpuCrossbarUsage(CPU cpu) const;
CrossbarUsage getCpuCrossbarCapacity() const;
CrossbarUsage getTaskCrossbarFootprint(const TaskDCP* task) const;
void reserveTaskCrossbars(CPU cpu, const TaskDCP* task);
void releaseTaskCrossbars(CPU cpu, const TaskDCP* task);
bool wouldExhaustCrossbarCapacity(CPU cpu, const TaskDCP* task) const;
public:
void DCP();
void runDcp();
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatWeightedCompute> spatWeightedComputes,
llvm::ArrayRef<EdgesIndex> edges)
: nodes(), mapCPUTasks() {
llvm::ArrayRef<IndexedEdge> edges)
: nodes(), cpuTasks(), cpuCrossbarUsage() {
for (auto spatWeightedCompute : spatWeightedComputes)
nodes.emplace_back(spatWeightedCompute);
for (auto [start, end, weight] : edges)
makeEdge(start, end, weight);
}
GraphDCP(llvm::ArrayRef<Weight_t> nodeWeights, llvm::ArrayRef<EdgesIndex> edges)
: nodes(), mapCPUTasks() {
GraphDCP(llvm::ArrayRef<Weight> nodeWeights,
llvm::ArrayRef<IndexedEdge> edges,
llvm::ArrayRef<CrossbarUsage> nodeCrossbarUsage = {})
: nodes(), cpuTasks(), cpuCrossbarUsage() {
assert((nodeCrossbarUsage.empty() || nodeCrossbarUsage.size() == nodeWeights.size())
&& "synthetic crossbar usage must match synthetic node weights");
nodes.reserve(nodeWeights.size());
for (auto [index, weight] : llvm::enumerate(nodeWeights))
nodes.emplace_back(index, weight);
nodes.emplace_back(index, weight, nodeCrossbarUsage.empty() ? 0 : nodeCrossbarUsage[index]);
for (auto [start, end, weight] : edges)
makeEdge(start, end, weight);
}
DCPAnalysisResult getResult();
std::vector<ScheduledTaskInfo> getScheduledTasks(CPU cpu) const;
CPU cpuCount() const { return last_cpu; }
CPU cpuCount() const { return lastCpu; }
void makeEdge(size_t parent_index, size_t child_index, Weight_t weight) {
addEdge(&nodes[parent_index], &nodes[child_index], weight);
void makeEdge(size_t parentIndex, size_t childIndex, Weight weight) {
addEdge(&nodes[parentIndex], &nodes[childIndex], weight);
}
size_t taskInCPU(CPU cpu) { return mapCPUTasks[cpu].size(); }
size_t taskInCpu(CPU cpu) { return getOrCreateCpuTasks(cpu).size(); }
void setMaxCpuCount(int value) { maxCpuCount = value; }
int getMaxCpuCount() const { return maxCpuCount; }
// Optional MLIR context used to drive mlir::parallelFor inside runDcp. If
// null the scheduler runs single-threaded (tests use this path).
void setContext(mlir::MLIRContext* ctx) { context = ctx; }
};

View File

@@ -0,0 +1,152 @@
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <fstream>
#include <string>
#include "GraphDebug.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
namespace dcp_graph {
#ifdef DCP_DEBUG_ENABLED
DcpProgressLogger::DcpProgressLogger(size_t totalTasks)
: logProgress(totalTasks >= 200),
totalTasks(totalTasks),
startTime(std::chrono::steady_clock::now()),
lastProgressPrint(startTime) {}
std::string DcpProgressLogger::formatDuration(double seconds) {
if (seconds < 0)
seconds = 0;
long totalSeconds = static_cast<long>(seconds + 0.5);
long hours = totalSeconds / 3600;
long minutes = (totalSeconds % 3600) / 60;
long secs = totalSeconds % 60;
if (hours > 0)
return llvm::formatv("{0}:{1:02}:{2:02}", hours, minutes, secs).str();
return llvm::formatv("{0}:{1:02}", minutes, secs).str();
}
void DcpProgressLogger::recordFindDuration(double seconds) { findCandidateSeconds += seconds; }
void DcpProgressLogger::recordSelectDuration(double seconds) { selectProcessorSeconds += seconds; }
void DcpProgressLogger::recordUpdateDuration(double seconds) { updateTimingSeconds += seconds; }
void DcpProgressLogger::advanceCompleted(size_t taskCount) { completedTasks += taskCount; }
void DcpProgressLogger::printStart(size_t readyCount) const {
if (!logProgress)
return;
llvm::errs() << llvm::formatv("[DCP] start: tasks={0} ready={1}\n", totalTasks, readyCount);
}
void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
double elapsedSeconds,
size_t readyCount,
CPU cpuCount) const {
if (!logProgress || elapsedSeconds < 1.0)
return;
llvm::errs() << llvm::formatv("[DCP] slow candidate node={0} elapsed={1} ready={2} cpus={3}\n",
nodeIndex,
formatDuration(elapsedSeconds),
readyCount,
cpuCount);
}
void DcpProgressLogger::printProgress(size_t readyCount, CPU cpuCount, llvm::StringRef stage, bool force) {
if (!logProgress)
return;
auto now = std::chrono::steady_clock::now();
if (!force && now - lastProgressPrint < std::chrono::seconds(1) && completedTasks != totalTasks)
return;
double elapsedSeconds = std::chrono::duration<double>(now - startTime).count();
double rate = elapsedSeconds > 0.0 ? static_cast<double>(completedTasks) / elapsedSeconds : 0.0;
double etaSeconds = rate > 0.0 ? static_cast<double>(totalTasks - completedTasks) / rate : 0.0;
double percent = totalTasks == 0 ? 100.0 : (100.0 * static_cast<double>(completedTasks) / totalTasks);
llvm::errs() << llvm::formatv("[DCP] {0}/{1} ({2:F1}%) ready={3} cpus={4} stage={5} elapsed={6} eta={7}\n",
completedTasks,
totalTasks,
percent,
readyCount,
cpuCount,
stage,
formatDuration(elapsedSeconds),
completedTasks == totalTasks ? "0:00" : formatDuration(etaSeconds));
llvm::errs() << llvm::formatv(" time(find={0}, select={1}, update={2})\n",
formatDuration(findCandidateSeconds),
formatDuration(selectProcessorSeconds),
formatDuration(updateTimingSeconds));
lastProgressPrint = now;
}
#else
DcpProgressLogger::DcpProgressLogger(size_t) {}
void DcpProgressLogger::recordFindDuration(double) {}
void DcpProgressLogger::recordSelectDuration(double) {}
void DcpProgressLogger::recordUpdateDuration(double) {}
void DcpProgressLogger::advanceCompleted(size_t) {}
void DcpProgressLogger::printStart(size_t) const {}
void DcpProgressLogger::maybePrintSlowCandidate(size_t, double, size_t, CPU) const {}
void DcpProgressLogger::printProgress(size_t, CPU, llvm::StringRef, bool) {}
#endif
void dumpGraphDot(const std::vector<TaskDCP>& nodes,
const std::vector<std::list<TaskDCP*>>& cpuTasks,
CPU lastCpu) {
static int dumpIndex = 0;
std::string outputDir = onnx_mlir::getOutputDir();
if (outputDir.empty())
return;
std::string graphDir = outputDir + "/dcp_graph";
onnx_mlir::createDirectory(graphDir);
std::fstream file(graphDir + "/graph_" + std::to_string(dumpIndex++) + ".dot", std::ios::out);
file << "digraph G {\n";
if (!cpuTasks.empty()) {
for (CPU cpu = 0; cpu < lastCpu; cpu++) {
file << "subgraph cluster_" << cpu << "{\nstyle=filled;\ncolor=lightgrey;\n";
size_t cpuIndex = static_cast<size_t>(cpu);
if (cpuIndex >= cpuTasks.size()) {
file << " }\n";
continue;
}
for (auto node : cpuTasks[cpuIndex]) {
file << node->Id() << " [label=\"";
file << "n:" << node->Id() << "\n";
file << "aest:" << node->getAest() << "\n";
file << "alst:" << node->getAlst() << "\n";
file << "weight:" << node->getWeight() << "\"]\n";
}
file << " }\n";
}
}
else {
for (const auto& node : nodes) {
file << node.Id() << " [label=\"";
file << "n:" << node.Id() << "\n";
file << "aest:" << node.getAest() << "\n";
file << "alst:" << node.getAlst() << "\n";
file << "weight:" << node.getWeight() << "\"]\n";
}
}
for (const auto& node : nodes)
for (const auto& child : node.children) {
file << node.Id() << " -> " << child.first->Id();
file << " [label=\"" << child.second << "\"]\n";
}
file << "}\n";
file.flush();
file.close();
}
} // namespace dcp_graph

View File

@@ -0,0 +1,57 @@
#pragma once
#include "llvm/ADT/StringRef.h"
#include <chrono>
#include <list>
#include <vector>
#include "Task.hpp"
#include "Utils.hpp"
// Uncomment to enable DCP progress logging and per-phase profiling during
// development. When disabled the logger methods are no-ops and the helpers
// compile away.
#define DCP_DEBUG_ENABLED
#ifdef DCP_DEBUG_ENABLED
#define DCP_DEBUG_IF(...) __VA_ARGS__
#else
#define DCP_DEBUG_IF(...)
#endif
namespace dcp_graph {
class DcpProgressLogger {
public:
explicit DcpProgressLogger(size_t totalTasks);
void recordFindDuration(double seconds);
void recordSelectDuration(double seconds);
void recordUpdateDuration(double seconds);
void advanceCompleted(size_t taskCount = 1);
void printStart(size_t readyCount) const;
void maybePrintSlowCandidate(size_t nodeIndex, double elapsedSeconds, size_t readyCount, CPU cpuCount) const;
void printProgress(size_t readyCount, CPU cpuCount, llvm::StringRef stage, bool force);
#ifdef DCP_DEBUG_ENABLED
private:
static std::string formatDuration(double seconds);
bool logProgress = false;
size_t totalTasks = 0;
size_t completedTasks = 0;
std::chrono::steady_clock::time_point startTime;
std::chrono::steady_clock::time_point lastProgressPrint;
double findCandidateSeconds = 0.0;
double selectProcessorSeconds = 0.0;
double updateTimingSeconds = 0.0;
#endif
};
void dumpGraphDot(const std::vector<TaskDCP>& nodes,
const std::vector<std::list<TaskDCP*>>& cpuTasks,
CPU lastCpu);
} // namespace dcp_graph

View File

@@ -0,0 +1,105 @@
#include "llvm/ADT/STLExtras.h"
#include <algorithm>
#include <vector>
#include "GraphSupport.hpp"
#include "Task.hpp"
#include "UniqueWorklist.hpp"
namespace dcp_graph {
llvm::DenseSet<TaskDCP*> collectReachableTasks(TaskDCP* root, bool followParents) {
llvm::DenseSet<TaskDCP*> reachable;
std::vector<TaskDCP*> worklist;
worklist.reserve(32);
auto enqueueEdges = [&](TaskDCP* task) {
const auto& edges = followParents ? task->parents : task->children;
for (const auto& edge : edges)
if (reachable.insert(edge.first).second)
worklist.push_back(edge.first);
};
enqueueEdges(root);
while (!worklist.empty()) {
TaskDCP* task = worklist.back();
worklist.pop_back();
enqueueEdges(task);
}
return reachable;
}
GraphDCP::CandidateRelations computeCandidateRelations(TaskDCP* candidate) {
return {collectReachableTasks(candidate, true), collectReachableTasks(candidate, false)};
}
LocalScheduleSnapshot captureLocalScheduleState(TaskDCP* task,
const llvm::DenseSet<TaskDCP*>& descendants,
Time dcpl,
Time maxCompletion,
Time secondMaxCompletion,
TaskDCP* maxCompletionTask) {
LocalScheduleSnapshot snapshot;
snapshot.aestBackup.reserve(descendants.size() + 1);
snapshot.aestBackup.emplace_back(task, task->getAest());
for (TaskDCP* descendant : descendants)
snapshot.aestBackup.emplace_back(descendant, descendant->getAest());
snapshot.dcpl = dcpl;
snapshot.maxCompletion = maxCompletion;
snapshot.secondMaxCompletion = secondMaxCompletion;
snapshot.maxCompletionTask = maxCompletionTask;
return snapshot;
}
void restoreLocalScheduleState(const LocalScheduleSnapshot& snapshot,
Time& dcpl,
Time& maxCompletion,
Time& secondMaxCompletion,
TaskDCP*& maxCompletionTask) {
for (const auto& [task, aest] : snapshot.aestBackup)
task->setAest(aest);
dcpl = snapshot.dcpl;
maxCompletion = snapshot.maxCompletion;
secondMaxCompletion = snapshot.secondMaxCompletion;
maxCompletionTask = snapshot.maxCompletionTask;
}
int countDependencyParents(const TaskDCP* task) {
return static_cast<int>(llvm::count_if(task->parents, [](const Edge& edge) { return !edge.isScheduling; }));
}
void recordTopologicalMove(TaskDCP* task, TaskInsertion* insertion) {
if (insertion == nullptr)
return;
auto alreadyRecorded =
llvm::any_of(insertion->topologicalMoves,
[task](const TaskInsertion::TopologicalMoveRecord& move) { return move.task == task; });
if (alreadyRecorded)
return;
insertion->topologicalMoves.push_back({task, onnx_mlir::LabeledList<TaskDCP>::next(task)});
}
std::vector<TaskDCP*> collectDominanceOrder(llvm::ArrayRef<TaskDCP*> roots, size_t nodeCount) {
UniqueWorkList<std::vector<TaskDCP*>> worklist(roots);
worklist.reserve(nodeCount);
size_t index = 0;
while (index != worklist.size()) {
bool modified = true;
while (modified) {
modified = false;
for (const auto& child : worklist.at(index)->children)
if (worklist.allElementsContained(
child.first->parents.begin(), child.first->parents.end(), [](Edge edge) { return edge.first; }))
modified |= worklist.pushBack(child.first);
}
index++;
}
return {worklist.begin(), worklist.end()};
}
} // namespace dcp_graph

View File

@@ -0,0 +1,41 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallVector.h"
#include <utility>
#include <vector>
#include "Graph.hpp"
namespace dcp_graph {
struct LocalScheduleSnapshot {
llvm::SmallVector<std::pair<TaskDCP*, Time>, 64> aestBackup;
Time dcpl = 0;
Time maxCompletion = 0;
Time secondMaxCompletion = 0;
TaskDCP* maxCompletionTask = nullptr;
};
llvm::DenseSet<TaskDCP*> collectReachableTasks(TaskDCP* root, bool followParents);
GraphDCP::CandidateRelations computeCandidateRelations(TaskDCP* candidate);
LocalScheduleSnapshot captureLocalScheduleState(TaskDCP* task,
const llvm::DenseSet<TaskDCP*>& descendants,
Time dcpl,
Time maxCompletion,
Time secondMaxCompletion,
TaskDCP* maxCompletionTask);
void restoreLocalScheduleState(const LocalScheduleSnapshot& snapshot,
Time& dcpl,
Time& maxCompletion,
Time& secondMaxCompletion,
TaskDCP*& maxCompletionTask);
int countDependencyParents(const TaskDCP* task);
void recordTopologicalMove(TaskDCP* task, TaskInsertion* insertion);
std::vector<TaskDCP*> collectDominanceOrder(llvm::ArrayRef<TaskDCP*> roots, size_t nodeCount);
} // namespace dcp_graph

View File

@@ -4,57 +4,63 @@
#include "Task.hpp"
#include "UniqueWorklist.hpp"
std::optional<Edge_t> TaskDCP::addChild(TaskDCP* child, Weight_t weight) {
std::optional<Edge_t> oldEdge = std::nullopt;
auto founded_element =
std::find_if(childs.begin(), childs.end(), [child](Edge_t element) { return child == element.first; });
if (founded_element != childs.end()) {
oldEdge = *founded_element;
fastRemove(childs, founded_element);
std::optional<Edge> TaskDCP::addChild(TaskDCP* child, Weight weight, bool isScheduling) {
std::optional<Edge> oldEdge = std::nullopt;
auto foundElement = std::find_if(children.begin(), children.end(), [child, isScheduling](Edge element) {
return child == element.first && isScheduling == element.isScheduling;
});
if (foundElement != children.end()) {
oldEdge = *foundElement;
fastRemove(children, foundElement);
}
childs.emplace_back(child, weight);
children.emplace_back(Edge {child, weight, isScheduling});
return oldEdge;
}
std::optional<Edge_t> TaskDCP::addParent(TaskDCP* parent, Weight_t weight) {
std::optional<Edge_t> oldEdge = std::nullopt;
auto founded_element =
std::find_if(parents.begin(), parents.end(), [parent](Edge_t element) { return parent == element.first; });
if (founded_element != parents.end()) {
oldEdge = *founded_element;
fastRemove(parents, founded_element);
std::optional<Edge> TaskDCP::addParent(TaskDCP* parent, Weight weight, bool isScheduling) {
std::optional<Edge> oldEdge = std::nullopt;
auto foundElement = std::find_if(parents.begin(), parents.end(), [parent, isScheduling](Edge element) {
return parent == element.first && isScheduling == element.isScheduling;
});
if (foundElement != parents.end()) {
oldEdge = *foundElement;
fastRemove(parents, foundElement);
}
parents.emplace_back(parent, weight);
parents.emplace_back(Edge {parent, weight, isScheduling});
return oldEdge;
}
bool TaskDCP::hasDescendent(TaskDCP* child) {
bool TaskDCP::hasDescendant(TaskDCP* child) {
UniqueWorkList<std::vector<TaskDCP*>> worklist;
worklist.reserve(32);
worklist.push_back(this);
worklist.pushBack(this);
while (!worklist.empty()) {
TaskDCP* task = worklist.back();
worklist.pop_back();
worklist.popBack();
if (task == child)
return true;
for (auto c : task->childs)
worklist.push_back(c.first);
for (auto edge : task->children)
worklist.pushBack(edge.first);
}
return false;
}
// TODO fare qualcosa di sensato
int TaskDCP::computeWeight(GraphDCP* graph, CPU cpu) { return origWeight; }
Weight TaskDCP::computeWeightOnCpu(GraphDCP* graph, CPU cpu) {
if (crossbarUsage != 0 && graph->wouldExhaustCrossbarCapacity(cpu, this))
return std::numeric_limits<Weight>::max();
return baseWeight;
}
void TaskInsertion::rollBack() {
graph->removeTaskFromCPU(cpuModified, taskInserted);
if (beforeNode.has_value()) {
auto double_edge = *beforeNode;
addEdge(double_edge.first.first, double_edge.second.first, double_edge.first.second);
auto edgePair = *beforeNode;
addEdge(edgePair.first.first, edgePair.second.first, edgePair.first.second, edgePair.first.isScheduling);
}
if (afterNode.has_value()) {
auto double_edge = *afterNode;
addEdge(double_edge.first.first, double_edge.second.first, double_edge.first.second);
auto edgePair = *afterNode;
addEdge(edgePair.first.first, edgePair.second.first, edgePair.first.second, edgePair.first.isScheduling);
}
graph->topologicalOrder.moveBefore( taskInserted,&*oldTopologicalPosition );
// for (auto it = topologicalMoves.rbegin(); it != topologicalMoves.rend(); ++it)
// graph->topologicalOrder.moveBefore(it->task, it->nextTask);
}

View File

@@ -7,110 +7,117 @@
#include "Utils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
std::optional<DoubleEdge> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
void removeEdge(TaskDCP* parent, TaskDCP* child);
class TaskDCP : public onnx_mlir::LabeledListNode<TaskDCP> {
onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute;
int aest;
int alst;
std::optional<CPU> scheduledCPU;
int weight;
int origWeight;
Time aest;
Time alst;
std::optional<CPU> scheduledCpu;
Weight weight;
Weight baseWeight;
CrossbarUsage crossbarUsage;
long long flag = 0;
int64_t syntheticId = -1;
std::optional<Edge_t> addChild(TaskDCP* child, Weight_t weight);
std::optional<Edge_t> addChild(TaskDCP& child, Weight_t weight) { return addChild(&child, weight); }
std::optional<Edge> addChild(TaskDCP* child, Weight weight, bool isScheduling);
std::optional<Edge> addChild(TaskDCP& child, Weight weight, bool isScheduling) {
return addChild(&child, weight, isScheduling);
}
void removeChild(TaskDCP* to_remove) { fastRemove(childs, to_remove); }
void removeChild(TaskDCP& to_remove) { fastRemove(childs, &to_remove); }
void removeChild(TaskDCP* toRemove, bool isScheduling) { fastRemove(children, toRemove, isScheduling); }
void removeChild(TaskDCP& toRemove, bool isScheduling) { fastRemove(children, &toRemove, isScheduling); }
std::optional<Edge_t> addParent(TaskDCP* parent, Weight_t weight);
std::optional<Edge_t> addParent(TaskDCP& parent, Weight_t weight) { return addParent(&parent, weight); }
std::optional<Edge> addParent(TaskDCP* parent, Weight weight, bool isScheduling);
std::optional<Edge> addParent(TaskDCP& parent, Weight weight, bool isScheduling) {
return addParent(&parent, weight, isScheduling);
}
void removeParent(TaskDCP* to_remove) { fastRemove(parents, to_remove); }
void removeParent(TaskDCP& to_remove) { fastRemove(parents, &to_remove); }
void removeParent(TaskDCP* toRemove, bool isScheduling) { fastRemove(parents, toRemove, isScheduling); }
void removeParent(TaskDCP& toRemove, bool isScheduling) { fastRemove(parents, &toRemove, isScheduling); }
public:
std::vector<Edge_t> parents;
std::vector<Edge_t> childs;
std::vector<Edge> parents;
std::vector<Edge> children;
TaskDCP() = default;
TaskDCP(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute)
: onnx_mlir::LabeledListNode<TaskDCP>(),
spatWeightedCompute(spatWeightedCompute),
aest(0),
alst(0),
scheduledCPU(),
weight(getSpatWeightCompute(spatWeightedCompute)),
origWeight(weight),
scheduledCpu(),
weight(getSpatComputeWeight(spatWeightedCompute)),
baseWeight(weight),
crossbarUsage(getSpatComputeCrossbarUsage(spatWeightedCompute)),
syntheticId(-1),
parents(),
childs() {}
children() {}
TaskDCP(int64_t id, int weight)
TaskDCP(int64_t id, Weight weight, CrossbarUsage crossbarUsage = 0)
: onnx_mlir::LabeledListNode<TaskDCP>(),
spatWeightedCompute(),
aest(0),
alst(0),
scheduledCPU(),
scheduledCpu(),
weight(weight),
origWeight(weight),
baseWeight(weight),
crossbarUsage(crossbarUsage),
flag(0),
syntheticId(id),
parents(),
childs() {}
children() {}
TaskDCP(const TaskDCP& node) = delete;
TaskDCP(TaskDCP&& node) = default;
void setCPU(CPU cpu) { scheduledCPU = cpu; }
std::optional<CPU> getCPU() const { return scheduledCPU; }
void resetCPU() { scheduledCPU = std::nullopt; }
int getWeight() const {
void setCpu(CPU cpu) { scheduledCpu = cpu; }
std::optional<CPU> getCpu() const { return scheduledCpu; }
void resetCpu() { scheduledCpu = std::nullopt; }
Weight getWeight() const {
if (isScheduled())
return weight;
return origWeight;
return baseWeight;
}
void setWeight(int val) { weight = val; }
void resetWeight() { weight = origWeight; }
int computeWeight(GraphDCP* graph, CPU cpu);
void setWeight(Weight value) { weight = value; }
void resetWeight() { weight = baseWeight; }
Weight computeWeightOnCpu(GraphDCP* graph, CPU cpu);
CrossbarUsage getCrossbarUsage() const { return crossbarUsage; }
bool hasParents() const { return parents.size() != 0; }
bool hasChilds() const { return childs.size() != 0; }
bool hasChildren() const { return children.size() != 0; }
int getAEST() const { return aest; }
int getALST() const { return alst; }
void setAEST(int val) {
assert(val >= 0);
aest = val;
}
void setALST(int val) { alst = val; }
bool hasDescendent(TaskDCP* child);
Time getAest() const { return aest; }
Time getAlst() const { return alst; }
void setAest(Time value) { aest = value; }
void setAlst(Time value) { alst = value; }
bool hasDescendant(TaskDCP* child);
int64_t Id() const {
if (spatWeightedCompute)
return reinterpret_cast<int64_t>(spatWeightedCompute.getAsOpaquePointer());
return syntheticId;
}
bool isCP() const { return alst == aest; }
bool isScheduled() const { return scheduledCPU.has_value(); }
bool isCriticalPath() const { return alst == aest; }
bool isScheduled() const { return scheduledCpu.has_value(); }
onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute() const { return spatWeightedCompute; }
void setFlag(long long val) { flag = val; }
long long getFlag() const { return flag; }
onnx_mlir::LabeledList<TaskDCP>::Iterator getTopologicalPosition() { return getIterator(); }
onnx_mlir::LabeledList<TaskDCP>::Iterator getTopologicalIterator() { return getIterator(); }
friend std::optional<DoubleEdge> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
friend void removeEdge(TaskDCP* parent, TaskDCP* child);
friend int getTranferCost(TaskDCP* parent, TaskDCP* child);
friend std::optional<EdgePair> addEdge(TaskDCP* parent, TaskDCP* child, Weight weight, bool isScheduling);
friend void removeEdge(TaskDCP* parent, TaskDCP* child, bool isScheduling);
friend Weight getTransferCost(TaskDCP* parent, TaskDCP* child);
};
struct TaskInsertion {
std::optional<DoubleEdge> beforeNode;
std::optional<DoubleEdge> afterNode;
onnx_mlir::LabeledList<TaskDCP>::Iterator oldTopologicalPosition;
struct TopologicalMoveRecord {
TaskDCP* task;
TaskDCP* nextTask;
};
std::optional<EdgePair> beforeNode;
std::optional<EdgePair> afterNode;
std::vector<TopologicalMoveRecord> topologicalMoves;
CPU cpuModified;
TaskDCP* taskInserted;
GraphDCP* graph;

View File

@@ -1,58 +1,57 @@
#pragma once
#include "llvm/ADT/DenseSet.h"
#include <cassert>
#include <type_traits>
#include <iostream>
#include <unordered_set>
template <typename T, typename = void>
struct has_pop_front : std::false_type {};
struct HasPopFront : std::false_type {};
template <typename T>
struct has_pop_front<T, std::void_t<decltype(std::declval<T>().pop_front())>> : std::true_type {};
struct HasPopFront<T, std::void_t<decltype(std::declval<T>().pop_front())>> : std::true_type {};
template <typename T>
class UniqueWorkList {
using V = typename T::value_type;
using ValueType = typename T::value_type;
T storage;
llvm::DenseSet<V> set;
llvm::DenseSet<ValueType> uniqueElements;
public:
UniqueWorkList() = default;
template <typename arg_ty>
UniqueWorkList(const arg_ty& from)
template <typename RangeT>
UniqueWorkList(const RangeT& from)
: storage() {
for (auto& element : from) {
if (!set.contains(element)) {
if (!uniqueElements.contains(element)) {
storage.push_back(element);
set.insert(element);
uniqueElements.insert(element);
}
}
}
bool empty() const { return storage.empty(); }
void reserve(size_t val) { return storage.reserve(val); }
void reserve(size_t value) { return storage.reserve(value); }
size_t size() const { return storage.size(); }
V& at(size_t i) { return storage.at(i); }
const V& at(size_t i) const { return storage.at(i); }
ValueType& at(size_t index) { return storage.at(index); }
const ValueType& at(size_t index) const { return storage.at(index); }
V& front() { return storage.front(); }
V& back() { return storage.back(); }
ValueType& front() { return storage.front(); }
ValueType& back() { return storage.back(); }
bool push_back(const V& val) {
if (!set.contains(val)) {
storage.push_back(val);
set.insert(val);
bool pushBack(const ValueType& value) {
if (!uniqueElements.contains(value)) {
storage.push_back(value);
uniqueElements.insert(value);
return true;
}
return false;
}
void pop_front() {
if constexpr (has_pop_front<T>::value)
void popFront() {
if constexpr (HasPopFront<T>::value)
storage.pop_front();
else
assert(false && "Underlying storage type does not support pop_front()");
@@ -61,15 +60,15 @@ public:
auto cbegin() const { return storage.cbegin(); }
auto cend() const { return storage.cend(); }
void pop_back() { storage.pop_back(); }
void popBack() { storage.pop_back(); }
template <typename Iterator, typename Mapper>
bool allElementContained(Iterator start, Iterator end, Mapper map) {
while (start != end) {
if (!set.contains(map(*start)))
bool allElementsContained(Iterator begin, Iterator end, Mapper map) const {
auto it = begin;
while (it != end) {
if (!uniqueElements.contains(map(*it)))
return false;
std::advance(start, 1);
std::advance(it, 1);
}
return true;
}
@@ -77,4 +76,8 @@ public:
auto begin() { return storage.begin(); }
auto end() { return storage.end(); }
auto begin() const { return storage.begin(); }
auto end() const { return storage.end(); }
};

View File

@@ -6,60 +6,106 @@
#include <algorithm>
#include <cstdint>
#include <limits>
#include <list>
#include <type_traits>
#include <utility>
#include <vector>
#include "src/Accelerators/PIM/Common/LabeledList.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Support/TypeUtilities.hpp"
using CPU = int;
using Weight_t = int;
using Weight = unsigned long long;
using Time = unsigned long long;
using CrossbarUsage = unsigned long long;
class TaskDCP;
class GraphDCP;
using Edge_t = std::pair<TaskDCP*, Weight_t>;
using DoubleEdge = std::pair<Edge_t, Edge_t>;
using EdgesIndex = std::tuple<int64_t, int64_t, int64_t>;
struct Edge {
TaskDCP* first;
Weight second;
bool isScheduling = false;
};
using EdgePair = std::pair<Edge, Edge>;
using IndexedEdge = std::tuple<int64_t, int64_t, int64_t>;
inline void fastRemove(std::vector<Edge>& vector, TaskDCP* toRemove, bool isScheduling) {
auto position = std::find_if(vector.begin(), vector.end(), [toRemove, isScheduling](Edge edge) {
return edge.first == toRemove && edge.isScheduling == isScheduling;
});
if (position != vector.end()) {
std::swap(*(vector.end() - 1), *position);
vector.pop_back();
}
}
inline void fastRemove(std::vector<TaskDCP*>& vector, TaskDCP* toRemove) {
auto position =
std::find_if(vector.begin(), vector.end(), [toRemove](TaskDCP* element) { return element == toRemove; });
if (position != vector.end()) {
std::swap(*(vector.end() - 1), *position);
vector.pop_back();
}
}
template <typename P>
void fastRemove(std::vector<Edge>& vector, P position) {
if (position != vector.end()) {
std::swap(*(vector.end() - 1), *position);
vector.pop_back();
}
}
template <typename T>
void fastRemove(std::vector<std::pair<T*, Weight_t>>& vector, T* to_remove) {
auto position =
std::find_if(vector.begin(), vector.end(), [to_remove](Edge_t edge) { return edge.first == to_remove; });
if (position != vector.end()) {
std::swap(*(vector.end() - 1), *position);
vector.pop_back();
}
inline T checkedAdd(T lhs, T rhs) {
static_assert(std::is_unsigned_v<T>, "checkedAdd only supports unsigned types");
assert(lhs <= std::numeric_limits<T>::max() - rhs && "unsigned addition overflow");
return lhs + rhs;
}
inline void fastRemove(std::vector<TaskDCP*>& vector, TaskDCP* to_remove) {
auto position =
std::find_if(vector.begin(), vector.end(), [to_remove](TaskDCP* element) { return element == to_remove; });
if (position != vector.end()) {
std::swap(*(vector.end() - 1), *position);
vector.pop_back();
}
template <typename T>
inline T checkedMultiply(T lhs, T rhs) {
static_assert(std::is_unsigned_v<T>, "checkedMultiply only supports unsigned types");
if (lhs == 0 || rhs == 0)
return 0;
assert(lhs <= std::numeric_limits<T>::max() / rhs && "unsigned multiplication overflow");
return lhs * rhs;
}
template <typename T, typename P>
void fastRemove(std::vector<std::pair<T*, Weight_t>>& vector, P position) {
if (position != vector.end()) {
std::swap(*(vector.end() - 1), *position);
vector.pop_back();
}
template <typename T>
inline T addOrMax(T lhs, T rhs) {
static_assert(std::is_unsigned_v<T>, "addOrMax only supports unsigned types");
if (lhs == std::numeric_limits<T>::max() || rhs == std::numeric_limits<T>::max())
return std::numeric_limits<T>::max();
return checkedAdd(lhs, rhs);
}
// TODO Fare qualcosa di sensato
inline int64_t getSpatWeightCompute(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) {
int64_t tot = 0;
for (auto& region : spatWeightedCompute.getBody()) {
for (auto& inst : region) {
for (auto result : inst.getResults())
if (auto element = llvm::dyn_cast<mlir::ShapedType>(result.getType()))
tot += onnx_mlir::getSizeInBytes(element);
}
}
return tot;
template <typename T>
inline T subtractOrZero(T lhs, T rhs) {
static_assert(std::is_unsigned_v<T>, "subtractOrZero only supports unsigned types");
if (lhs == std::numeric_limits<T>::max())
return lhs;
if (rhs == std::numeric_limits<T>::max() || lhs <= rhs)
return 0;
return lhs - rhs;
}
inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); }
inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) {
constexpr Weight kOperationWeight = 100;
Weight numOperations = 0;
for (auto& block : spatWeightedCompute.getBody())
for ([[maybe_unused]] auto& op : block)
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
return checkedMultiply(numOperations, kOperationWeight);
}
inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) {
CrossbarUsage crossbarUsage = 0;
for (auto& region : spatWeightedCompute.getBody())
for (auto& inst : region)
if (llvm::isa<onnx_mlir::spatial::SpatWeightedVMMOp>(inst))
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
return crossbarUsage;
}

View File

@@ -5,7 +5,6 @@
#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
@@ -14,13 +13,12 @@
#include "llvm/Support/Debug.h"
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iterator>
#include <memory>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "DCPGraph/DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
@@ -36,10 +34,10 @@ struct ComputeValueResults {
class LazyInsertComputeResult {
using InsertPoint = mlir::IRRewriter::InsertPoint;
ComputeValueResults computeResults;
Value channelNewOpVal;
Value channelValue;
bool onlyChannel;
std::function<void(InsertPoint insertPoint)> channelSendInserter;
InsertPoint insertPointSend;
InsertPoint sendInsertPoint;
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter;
public:
@@ -49,7 +47,7 @@ public:
: computeResults(computeValueResults),
onlyChannel(isOnlyChannel),
channelSendInserter(nullptr),
insertPointSend({}),
sendInsertPoint({}),
channelNewInserter(channelNewInserter) {}
struct ChannelOrLocalOp {
@@ -59,23 +57,23 @@ public:
bool onlyChanneled() const { return onlyChannel; }
ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatWeightedCompute spatWeightedCompute) {
ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatWeightedCompute currentCompute) {
auto [first, second] = channelNewInserter();
channelNewOpVal = first;
channelSendInserter = second;
auto BB = computeResults.innerValue.getParentBlock();
if (!BB->empty() && isa<spatial::SpatYieldOp>(BB->back()))
insertPointSend = InsertPoint(BB, --BB->end());
auto [newChannelValue, senderInserter] = channelNewInserter();
channelValue = newChannelValue;
channelSendInserter = senderInserter;
auto* block = computeResults.innerValue.getParentBlock();
if (!block->empty() && isa<spatial::SpatYieldOp>(block->back()))
sendInsertPoint = InsertPoint(block, --block->end());
else
insertPointSend = InsertPoint(BB, BB->end());
if (spatWeightedCompute) {
for (auto& BB : spatWeightedCompute.getBody())
if (&BB == insertPointSend.getBlock())
sendInsertPoint = InsertPoint(block, block->end());
if (currentCompute) {
for (auto& block : currentCompute.getBody())
if (&block == sendInsertPoint.getBlock())
return {computeResults.innerValue, false};
}
channelSendInserter(insertPointSend);
return {channelNewOpVal, true};
channelSendInserter(sendInsertPoint);
return {channelValue, true};
}
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
@@ -86,7 +84,7 @@ struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<
private:
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
DenseMap<SpatWeightedCompute, SpatWeightedCompute> oldToNewComputeMap;
DenseMap<int64_t, SpatWeightedCompute> cputToNewComputeMap;
DenseMap<int64_t, SpatWeightedCompute> cpuToNewComputeMap;
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass)
@@ -101,17 +99,16 @@ public:
void runOnOperation() override {
DCPAnalysisResult& analysisResult = getAnalysis<spatial::DCPAnalysis>().getResult();
auto& lastComputeOfCpu = analysisResult.isLastComputeOfACpu;
auto& lastComputeOfCpu = analysisResult.isLastComputeOfCpu;
auto& cpuToLastComputeMap = analysisResult.cpuToLastComputeMap;
IRRewriter rewriter(&getContext());
for (auto currentComputeNode : analysisResult.dominanceOrderCompute) {
size_t cpu = analysisResult.computeToCPUMap.at(currentComputeNode);
if (!cputToNewComputeMap.contains(cpu)) {
size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode);
if (!cpuToNewComputeMap.contains(cpu)) {
ValueTypeRange<ResultRange> newWeightedComputeType = cpuToLastComputeMap.at(cpu).getResultTypes();
auto [newWeightedCompute, computeValueResult] = createNewComputeNode(
currentComputeNode, newWeightedComputeType, lastComputeOfCpu.contains(currentComputeNode));
cputToNewComputeMap[cpu] = newWeightedCompute;
cpuToNewComputeMap[cpu] = newWeightedCompute;
newComputeNodeResults.insert(
std::make_pair(currentComputeNode,
createLazyComputeResult(
@@ -119,7 +116,7 @@ public:
}
else {
auto [newWeightedCompute, computeValueResult] = mergeIntoComputeNode(
cputToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode));
cpuToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode));
newComputeNodeResults.insert(
std::make_pair(currentComputeNode,
createLazyComputeResult(
@@ -127,10 +124,10 @@ public:
}
}
for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) {
for (auto users : computeNodetoRemove->getUsers())
for (auto computeNodeToRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) {
for (auto users : computeNodeToRemove->getUsers())
users->dump();
computeNodetoRemove.erase();
computeNodeToRemove.erase();
}
func::FuncOp func = getOperation();
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
@@ -186,9 +183,9 @@ private:
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender();
assert(isChannel == true);
spatial::SpatChannelReceiveOp reciveOp =
spatial::SpatChannelReceiveOp receiveOp =
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelVal);
mapper.map(oldBB.getArgument(indexOld - indexOldStart), reciveOp);
mapper.map(oldBB.getArgument(indexOld - indexOldStart), receiveOp);
}
}
@@ -238,8 +235,8 @@ private:
auto& toBB = toCompute.getBody().front();
auto& fromBB = fromCompute.getBody().front();
auto inputeArgMutable = toCompute.getInputsMutable();
// Insert reciveOp
auto inputArgMutable = toCompute.getInputsMutable();
// Insert receiveOp
rewriter.setInsertionPointToEnd(&toBB);
for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) {
if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatWeightedCompute>(arg.getDefiningOp())) {
@@ -248,9 +245,9 @@ private:
LazyInsertComputeResult::ChannelOrLocalOp channelOrLocal =
lazyArgWeight.getAsChannelValueAndInsertSender(toCompute);
if (channelOrLocal.isChannel) {
spatial::SpatChannelReceiveOp reciveOp =
spatial::SpatChannelReceiveOp receiveOp =
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelOrLocal.data);
mapper.map(fromBB.getArgument(bbIndex), reciveOp.getResult());
mapper.map(fromBB.getArgument(bbIndex), receiveOp.getResult());
}
else {
mapper.map(fromBB.getArgument(bbIndex), channelOrLocal.data);
@@ -262,7 +259,7 @@ private:
if (founded == toCompute.getInputs().end()) {
size_t sizeW = toCompute.getWeights().size();
size_t sizeI = toCompute.getInputs().size();
inputeArgMutable.append(arg);
inputArgMutable.append(arg);
assert(sizeW == toCompute.getWeights().size());
assert(sizeI + 1 == toCompute.getInputs().size());
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
@@ -281,6 +278,12 @@ private:
assert(mapper.contains(oldBBarg));
ComputeValueResults computeValueResults;
auto remapWeightIndex = [&](auto weightedOp) {
auto oldIndex = weightedOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
weightedOp.setWeightIndex(newIndex);
};
for (auto& op : fromCompute.getOps()) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0));
@@ -289,20 +292,10 @@ private:
}
else {
auto newInst = rewriter.clone(op, mapper);
// TODO Refactor in a lambda? same code just different cast, but templated lambda are C++20 and a free function
// is a bit too much
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
if (auto weightedMvmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst))
remapWeightIndex(weightedMvmOp);
if (auto weightedVmmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst))
remapWeightIndex(weightedVmmOp);
}
}
@@ -323,19 +316,18 @@ private:
IRRewriter rewriter(context);
rewriter.setInsertionPointToStart(&funcOp.front());
auto saveInsertionPointChnNew = rewriter.saveInsertionPoint();
auto insertNew = [saveInsertionPointChnNew, context, loc, computeValueResults]() {
auto savedChannelInsertPoint = rewriter.saveInsertionPoint();
auto insertNew = [savedChannelInsertPoint, context, loc, computeValueResults]() {
IRRewriter rewriter(context);
rewriter.restoreInsertionPoint(saveInsertionPointChnNew);
rewriter.restoreInsertionPoint(savedChannelInsertPoint);
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(context));
auto channelVal = channelOp.getResult();
auto insertVal =
[&context, loc, computeValueResults, channelVal](mlir::IRRewriter::InsertPoint insertPointChnSend) {
IRRewriter rewriter(context);
rewriter.restoreInsertionPoint(insertPointChnSend);
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.innerValue);
return spatSend;
};
auto insertVal = [&context, loc, computeValueResults, channelVal](mlir::IRRewriter::InsertPoint sendInsertPoint) {
IRRewriter rewriter(context);
rewriter.restoreInsertionPoint(sendInsertPoint);
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.innerValue);
return spatSend;
};
std::pair<Value, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {channelVal, insertVal};
return ret;
};