faster (and refactored) DCP analysis
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h16m17s
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h16m17s
This commit is contained in:
@@ -661,9 +661,8 @@ void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) {
|
||||
broadcastSendOp = op;
|
||||
continue;
|
||||
}
|
||||
if (auto op = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user)) {
|
||||
if (auto op = dyn_cast<spatial::SpatChannelBroadcastReceiveOp>(user))
|
||||
continue;
|
||||
}
|
||||
llvm_unreachable("Unexpected user of spat.channel_new during Spatial-to-PIM lowering");
|
||||
}
|
||||
|
||||
@@ -719,7 +718,8 @@ void SpatialToPimPass::lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter&
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
|
||||
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, receiveOp.getChannel());
|
||||
Value receivedValue =
|
||||
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
PimReceiveOp::create(
|
||||
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
.getOutput();
|
||||
rewriter.replaceOp(receiveOp, receivedValue);
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ add_pim_library(SpatialOps
|
||||
SpatialOps.cpp
|
||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/Task.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ namespace spatial {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
SpatWeightedCompute getOriginalSpatWeightCompute(Operation* op) {
|
||||
SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) {
|
||||
if (!op)
|
||||
return {};
|
||||
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
@@ -30,32 +30,32 @@ SpatWeightedCompute getOriginalSpatWeightCompute(Operation* op) {
|
||||
return {};
|
||||
}
|
||||
|
||||
DCPAnalysisResult DCPAnalysis::runAnalysis() {
|
||||
using EdgesIndex = std::tuple<int64_t, int64_t, int64_t>;
|
||||
DCPAnalysisResult DCPAnalysis::run() {
|
||||
llvm::SmallVector<SpatWeightedCompute, 10> spatWeightedComputes;
|
||||
llvm::SmallVector<EdgesIndex, 10> edges;
|
||||
for (auto& regions : entryOp->getRegions())
|
||||
for (SpatWeightedCompute spatWeightedCompute : regions.getOps<SpatWeightedCompute>())
|
||||
llvm::SmallVector<IndexedEdge, 10> edges;
|
||||
for (auto& region : entryOp->getRegions())
|
||||
for (SpatWeightedCompute spatWeightedCompute : region.getOps<SpatWeightedCompute>())
|
||||
spatWeightedComputes.push_back(spatWeightedCompute);
|
||||
|
||||
for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) {
|
||||
for (Value input : spatWeightedCompute.getInputs()) {
|
||||
if (auto spatWeightedComputeArgOp = getOriginalSpatWeightCompute(input.getDefiningOp())) {
|
||||
auto elemIter = llvm::find(spatWeightedComputes, spatWeightedComputeArgOp);
|
||||
assert(elemIter != spatWeightedComputes.end());
|
||||
auto indexStartEdge = std::distance(spatWeightedComputes.begin(), elemIter);
|
||||
ResultRange outputs = spatWeightedComputeArgOp.getResults();
|
||||
if (auto producerCompute = getOriginalSpatWeightedCompute(input.getDefiningOp())) {
|
||||
auto producerIt = llvm::find(spatWeightedComputes, producerCompute);
|
||||
assert(producerIt != spatWeightedComputes.end());
|
||||
auto indexStartEdge = std::distance(spatWeightedComputes.begin(), producerIt);
|
||||
ResultRange outputs = producerCompute.getResults();
|
||||
int64_t totalSize = 0;
|
||||
for (auto output : outputs) {
|
||||
ShapedType result = cast<ShapedType>(output.getType());
|
||||
totalSize += getSizeInBytes(result);
|
||||
ShapedType resultType = cast<ShapedType>(output.getType());
|
||||
totalSize += getSizeInBytes(resultType);
|
||||
}
|
||||
edges.push_back({indexStartEdge, indexEndEdge, totalSize});
|
||||
}
|
||||
}
|
||||
}
|
||||
GraphDCP graphDCP(spatWeightedComputes, edges);
|
||||
graphDCP.DCP();
|
||||
graphDCP.setContext(entryOp->getContext());
|
||||
graphDCP.runDcp();
|
||||
return graphDCP.getResult();
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
@@ -10,8 +11,8 @@
|
||||
|
||||
struct DCPAnalysisResult {
|
||||
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;
|
||||
llvm::DenseMap<onnx_mlir::spatial::SpatWeightedCompute, size_t> computeToCPUMap;
|
||||
llvm::DenseSet<onnx_mlir::spatial::SpatWeightedCompute> isLastComputeOfACpu;
|
||||
llvm::DenseMap<onnx_mlir::spatial::SpatWeightedCompute, size_t> computeToCpuMap;
|
||||
llvm::DenseSet<onnx_mlir::spatial::SpatWeightedCompute> isLastComputeOfCpu;
|
||||
llvm::DenseMap<size_t, onnx_mlir::spatial::SpatWeightedCompute> cpuToLastComputeMap;
|
||||
};
|
||||
|
||||
@@ -21,12 +22,12 @@ struct DCPAnalysis {
|
||||
private:
|
||||
DCPAnalysisResult result;
|
||||
mlir::Operation* entryOp;
|
||||
DCPAnalysisResult runAnalysis();
|
||||
DCPAnalysisResult run();
|
||||
|
||||
public:
|
||||
DCPAnalysis(mlir::Operation* op)
|
||||
: entryOp(op) {
|
||||
result = runAnalysis();
|
||||
result = run();
|
||||
}
|
||||
DCPAnalysisResult& getResult() { return result; }
|
||||
};
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,7 @@
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
#include <list>
|
||||
#include <optional>
|
||||
@@ -12,90 +13,144 @@
|
||||
#include "Task.hpp"
|
||||
#include "Utils.hpp"
|
||||
|
||||
std::optional<DoubleEdge> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
|
||||
void removeEdge(TaskDCP* parent, TaskDCP* child);
|
||||
int getTranferCost(TaskDCP* parent, TaskDCP* child);
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
} // namespace mlir
|
||||
|
||||
std::optional<EdgePair> addEdge(TaskDCP* parent, TaskDCP* child, Weight weight, bool isScheduling = false);
|
||||
void removeEdge(TaskDCP* parent, TaskDCP* child, bool isScheduling = false);
|
||||
Weight getTransferCost(TaskDCP* parent, TaskDCP* child);
|
||||
|
||||
class GraphDCP {
|
||||
public:
|
||||
struct CandidateRelations {
|
||||
llvm::DenseSet<TaskDCP*> ancestors;
|
||||
llvm::DenseSet<TaskDCP*> descendants;
|
||||
// descendants ordered by position in the graph's topological order;
|
||||
// iterating this avoids walking non-descendant tail tasks on hot paths.
|
||||
llvm::SmallVector<TaskDCP*, 32> descendantsTopoOrder;
|
||||
};
|
||||
|
||||
struct ScheduledTaskInfo {
|
||||
size_t nodeIndex;
|
||||
int aest;
|
||||
int alst;
|
||||
int weight;
|
||||
Time aest;
|
||||
Time alst;
|
||||
Weight weight;
|
||||
};
|
||||
|
||||
private:
|
||||
using CpuTaskList = std::list<TaskDCP*>;
|
||||
|
||||
struct FindSlot {
|
||||
int aest;
|
||||
Time aest;
|
||||
int index;
|
||||
};
|
||||
|
||||
std::vector<TaskDCP> nodes;
|
||||
onnx_mlir::LabeledList<TaskDCP> topologicalOrder;
|
||||
std::unordered_map<CPU, std::list<TaskDCP*>> mapCPUTasks;
|
||||
CPU last_cpu = 0;
|
||||
std::vector<CpuTaskList> cpuTasks;
|
||||
std::unordered_map<CPU, CrossbarUsage> cpuCrossbarUsage;
|
||||
CPU lastCpu = 0;
|
||||
long long flag = 1;
|
||||
int DCPL;
|
||||
Time dcpl = 0;
|
||||
Time maxCompletion = 0;
|
||||
Time secondMaxCompletion = 0;
|
||||
TaskDCP* maxCompletionTask = nullptr;
|
||||
int maxCpuCount = 1000;
|
||||
mlir::MLIRContext* context = nullptr;
|
||||
|
||||
TaskInsertion insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position);
|
||||
void removeTaskFromCPU(CPU cpu, TaskDCP* task);
|
||||
CpuTaskList& getOrCreateCpuTasks(CPU cpu);
|
||||
const CpuTaskList* findCpuTasks(CPU cpu) const;
|
||||
|
||||
std::vector<TaskDCP*> getRoots();
|
||||
|
||||
long long getUniqueFlag() { return flag++; }
|
||||
|
||||
void initAEST();
|
||||
int initDCPL();
|
||||
void initALST();
|
||||
void initAest();
|
||||
void initAlst();
|
||||
|
||||
int computeAEST(TaskDCP* task, CPU cpu);
|
||||
int computeDCPL(TaskDCP* task, CPU cpu);
|
||||
int getDCPL() { return DCPL; }
|
||||
Time computeAestOnCpu(TaskDCP* task, CPU cpu);
|
||||
Time computeDcplOnCpu(TaskDCP* task, CPU cpu);
|
||||
Time getDcpl() const { return dcpl; }
|
||||
Time computeTaskAlstOnCpu(TaskDCP* task, CPU cpu, Time scheduleDcpl);
|
||||
void updateAestFromTask(TaskDCP* task);
|
||||
void updateAestFromTaskWithDescendants(TaskDCP* task, const llvm::DenseSet<TaskDCP*>& descendants);
|
||||
void updateAestFromTaskWithDescendants(TaskDCP* task, llvm::ArrayRef<TaskDCP*> descendantsTopoOrder);
|
||||
// Propagates AEST like the overload above but returns early (before touching
|
||||
// the remaining descendants) as soon as a task's completion exceeds
|
||||
// `dcplBudget`, signalling that the new DCPL would exceed the budget.
|
||||
// Returns true iff the full propagation completed without exceeding the
|
||||
// budget. Uses the caller's snapshot to restore AEST on the aborted tail.
|
||||
bool tryUpdateAestWithinBudget(TaskDCP* task,
|
||||
llvm::ArrayRef<TaskDCP*> descendantsTopoOrder,
|
||||
Time dcplBudget);
|
||||
|
||||
void initTopological();
|
||||
void topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint);
|
||||
void topologicalMoveBefore(TaskDCP* task, TaskDCP* pivotPoint);
|
||||
void topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint, TaskInsertion* insertion = nullptr);
|
||||
void topologicalMoveBefore(TaskDCP* task, TaskDCP* pivotPoint, TaskInsertion* insertion = nullptr);
|
||||
|
||||
llvm::DenseMap<TaskDCP*, int> computeALST(TaskDCP* task, CPU cpu);
|
||||
llvm::DenseMap<TaskDCP*, Time> computeAlst(TaskDCP* task, CPU cpu, const CandidateRelations& relations);
|
||||
size_t getNodeIndex(const TaskDCP* task) const;
|
||||
|
||||
TaskDCP* findCandidate(std::vector<TaskDCP*> nodes);
|
||||
TaskDCP* findCandidate(const std::vector<TaskDCP*>& readyNodes);
|
||||
void selectProcessor(TaskDCP* candidate, bool push);
|
||||
CPU lastCPU() const { return last_cpu; }
|
||||
void incLastCPU() { last_cpu++; }
|
||||
FindSlot findSlot(TaskDCP* candidate, CPU cpu, bool push);
|
||||
void to_dot();
|
||||
CPU getLastCpu() const { return lastCpu; }
|
||||
void incrementLastCpu() { lastCpu++; }
|
||||
FindSlot findSlot(TaskDCP* candidate, CPU cpu, bool push, const CandidateRelations& relations);
|
||||
FindSlot findSlotWithFixedFinalTime(
|
||||
TaskDCP* candidate, CPU cpu, const CandidateRelations& relations, Time finalTime, Time aestOnCpu);
|
||||
void dumpDot();
|
||||
|
||||
friend TaskInsertion;
|
||||
friend class TaskDCP;
|
||||
|
||||
CrossbarUsage getCpuCrossbarUsage(CPU cpu) const;
|
||||
CrossbarUsage getCpuCrossbarCapacity() const;
|
||||
CrossbarUsage getTaskCrossbarFootprint(const TaskDCP* task) const;
|
||||
void reserveTaskCrossbars(CPU cpu, const TaskDCP* task);
|
||||
void releaseTaskCrossbars(CPU cpu, const TaskDCP* task);
|
||||
bool wouldExhaustCrossbarCapacity(CPU cpu, const TaskDCP* task) const;
|
||||
|
||||
public:
|
||||
void DCP();
|
||||
void runDcp();
|
||||
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatWeightedCompute> spatWeightedComputes,
|
||||
llvm::ArrayRef<EdgesIndex> edges)
|
||||
: nodes(), mapCPUTasks() {
|
||||
llvm::ArrayRef<IndexedEdge> edges)
|
||||
: nodes(), cpuTasks(), cpuCrossbarUsage() {
|
||||
for (auto spatWeightedCompute : spatWeightedComputes)
|
||||
nodes.emplace_back(spatWeightedCompute);
|
||||
for (auto [start, end, weight] : edges)
|
||||
makeEdge(start, end, weight);
|
||||
}
|
||||
|
||||
GraphDCP(llvm::ArrayRef<Weight_t> nodeWeights, llvm::ArrayRef<EdgesIndex> edges)
|
||||
: nodes(), mapCPUTasks() {
|
||||
GraphDCP(llvm::ArrayRef<Weight> nodeWeights,
|
||||
llvm::ArrayRef<IndexedEdge> edges,
|
||||
llvm::ArrayRef<CrossbarUsage> nodeCrossbarUsage = {})
|
||||
: nodes(), cpuTasks(), cpuCrossbarUsage() {
|
||||
assert((nodeCrossbarUsage.empty() || nodeCrossbarUsage.size() == nodeWeights.size())
|
||||
&& "synthetic crossbar usage must match synthetic node weights");
|
||||
nodes.reserve(nodeWeights.size());
|
||||
for (auto [index, weight] : llvm::enumerate(nodeWeights))
|
||||
nodes.emplace_back(index, weight);
|
||||
nodes.emplace_back(index, weight, nodeCrossbarUsage.empty() ? 0 : nodeCrossbarUsage[index]);
|
||||
for (auto [start, end, weight] : edges)
|
||||
makeEdge(start, end, weight);
|
||||
}
|
||||
|
||||
DCPAnalysisResult getResult();
|
||||
std::vector<ScheduledTaskInfo> getScheduledTasks(CPU cpu) const;
|
||||
CPU cpuCount() const { return last_cpu; }
|
||||
CPU cpuCount() const { return lastCpu; }
|
||||
|
||||
void makeEdge(size_t parent_index, size_t child_index, Weight_t weight) {
|
||||
addEdge(&nodes[parent_index], &nodes[child_index], weight);
|
||||
void makeEdge(size_t parentIndex, size_t childIndex, Weight weight) {
|
||||
addEdge(&nodes[parentIndex], &nodes[childIndex], weight);
|
||||
}
|
||||
|
||||
size_t taskInCPU(CPU cpu) { return mapCPUTasks[cpu].size(); }
|
||||
size_t taskInCpu(CPU cpu) { return getOrCreateCpuTasks(cpu).size(); }
|
||||
|
||||
void setMaxCpuCount(int value) { maxCpuCount = value; }
|
||||
int getMaxCpuCount() const { return maxCpuCount; }
|
||||
|
||||
// Optional MLIR context used to drive mlir::parallelFor inside runDcp. If
|
||||
// null the scheduler runs single-threaded (tests use this path).
|
||||
void setContext(mlir::MLIRContext* ctx) { context = ctx; }
|
||||
};
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
|
||||
#include "GraphDebug.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
namespace dcp_graph {
|
||||
|
||||
#ifdef DCP_DEBUG_ENABLED
|
||||
|
||||
DcpProgressLogger::DcpProgressLogger(size_t totalTasks)
|
||||
: logProgress(totalTasks >= 200),
|
||||
totalTasks(totalTasks),
|
||||
startTime(std::chrono::steady_clock::now()),
|
||||
lastProgressPrint(startTime) {}
|
||||
|
||||
std::string DcpProgressLogger::formatDuration(double seconds) {
|
||||
if (seconds < 0)
|
||||
seconds = 0;
|
||||
|
||||
long totalSeconds = static_cast<long>(seconds + 0.5);
|
||||
long hours = totalSeconds / 3600;
|
||||
long minutes = (totalSeconds % 3600) / 60;
|
||||
long secs = totalSeconds % 60;
|
||||
if (hours > 0)
|
||||
return llvm::formatv("{0}:{1:02}:{2:02}", hours, minutes, secs).str();
|
||||
return llvm::formatv("{0}:{1:02}", minutes, secs).str();
|
||||
}
|
||||
|
||||
void DcpProgressLogger::recordFindDuration(double seconds) { findCandidateSeconds += seconds; }
|
||||
void DcpProgressLogger::recordSelectDuration(double seconds) { selectProcessorSeconds += seconds; }
|
||||
void DcpProgressLogger::recordUpdateDuration(double seconds) { updateTimingSeconds += seconds; }
|
||||
void DcpProgressLogger::advanceCompleted(size_t taskCount) { completedTasks += taskCount; }
|
||||
|
||||
void DcpProgressLogger::printStart(size_t readyCount) const {
|
||||
if (!logProgress)
|
||||
return;
|
||||
llvm::errs() << llvm::formatv("[DCP] start: tasks={0} ready={1}\n", totalTasks, readyCount);
|
||||
}
|
||||
|
||||
void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
|
||||
double elapsedSeconds,
|
||||
size_t readyCount,
|
||||
CPU cpuCount) const {
|
||||
if (!logProgress || elapsedSeconds < 1.0)
|
||||
return;
|
||||
|
||||
llvm::errs() << llvm::formatv("[DCP] slow candidate node={0} elapsed={1} ready={2} cpus={3}\n",
|
||||
nodeIndex,
|
||||
formatDuration(elapsedSeconds),
|
||||
readyCount,
|
||||
cpuCount);
|
||||
}
|
||||
|
||||
void DcpProgressLogger::printProgress(size_t readyCount, CPU cpuCount, llvm::StringRef stage, bool force) {
|
||||
if (!logProgress)
|
||||
return;
|
||||
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
if (!force && now - lastProgressPrint < std::chrono::seconds(1) && completedTasks != totalTasks)
|
||||
return;
|
||||
|
||||
double elapsedSeconds = std::chrono::duration<double>(now - startTime).count();
|
||||
double rate = elapsedSeconds > 0.0 ? static_cast<double>(completedTasks) / elapsedSeconds : 0.0;
|
||||
double etaSeconds = rate > 0.0 ? static_cast<double>(totalTasks - completedTasks) / rate : 0.0;
|
||||
double percent = totalTasks == 0 ? 100.0 : (100.0 * static_cast<double>(completedTasks) / totalTasks);
|
||||
|
||||
llvm::errs() << llvm::formatv("[DCP] {0}/{1} ({2:F1}%) ready={3} cpus={4} stage={5} elapsed={6} eta={7}\n",
|
||||
completedTasks,
|
||||
totalTasks,
|
||||
percent,
|
||||
readyCount,
|
||||
cpuCount,
|
||||
stage,
|
||||
formatDuration(elapsedSeconds),
|
||||
completedTasks == totalTasks ? "0:00" : formatDuration(etaSeconds));
|
||||
llvm::errs() << llvm::formatv(" time(find={0}, select={1}, update={2})\n",
|
||||
formatDuration(findCandidateSeconds),
|
||||
formatDuration(selectProcessorSeconds),
|
||||
formatDuration(updateTimingSeconds));
|
||||
lastProgressPrint = now;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
DcpProgressLogger::DcpProgressLogger(size_t) {}
|
||||
void DcpProgressLogger::recordFindDuration(double) {}
|
||||
void DcpProgressLogger::recordSelectDuration(double) {}
|
||||
void DcpProgressLogger::recordUpdateDuration(double) {}
|
||||
void DcpProgressLogger::advanceCompleted(size_t) {}
|
||||
void DcpProgressLogger::printStart(size_t) const {}
|
||||
void DcpProgressLogger::maybePrintSlowCandidate(size_t, double, size_t, CPU) const {}
|
||||
void DcpProgressLogger::printProgress(size_t, CPU, llvm::StringRef, bool) {}
|
||||
|
||||
#endif
|
||||
|
||||
void dumpGraphDot(const std::vector<TaskDCP>& nodes,
|
||||
const std::vector<std::list<TaskDCP*>>& cpuTasks,
|
||||
CPU lastCpu) {
|
||||
static int dumpIndex = 0;
|
||||
std::string outputDir = onnx_mlir::getOutputDir();
|
||||
if (outputDir.empty())
|
||||
return;
|
||||
|
||||
std::string graphDir = outputDir + "/dcp_graph";
|
||||
onnx_mlir::createDirectory(graphDir);
|
||||
std::fstream file(graphDir + "/graph_" + std::to_string(dumpIndex++) + ".dot", std::ios::out);
|
||||
file << "digraph G {\n";
|
||||
if (!cpuTasks.empty()) {
|
||||
for (CPU cpu = 0; cpu < lastCpu; cpu++) {
|
||||
file << "subgraph cluster_" << cpu << "{\nstyle=filled;\ncolor=lightgrey;\n";
|
||||
size_t cpuIndex = static_cast<size_t>(cpu);
|
||||
if (cpuIndex >= cpuTasks.size()) {
|
||||
file << " }\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto node : cpuTasks[cpuIndex]) {
|
||||
file << node->Id() << " [label=\"";
|
||||
file << "n:" << node->Id() << "\n";
|
||||
file << "aest:" << node->getAest() << "\n";
|
||||
file << "alst:" << node->getAlst() << "\n";
|
||||
file << "weight:" << node->getWeight() << "\"]\n";
|
||||
}
|
||||
file << " }\n";
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (const auto& node : nodes) {
|
||||
file << node.Id() << " [label=\"";
|
||||
file << "n:" << node.Id() << "\n";
|
||||
file << "aest:" << node.getAest() << "\n";
|
||||
file << "alst:" << node.getAlst() << "\n";
|
||||
file << "weight:" << node.getWeight() << "\"]\n";
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& node : nodes)
|
||||
for (const auto& child : node.children) {
|
||||
file << node.Id() << " -> " << child.first->Id();
|
||||
file << " [label=\"" << child.second << "\"]\n";
|
||||
}
|
||||
|
||||
file << "}\n";
|
||||
file.flush();
|
||||
file.close();
|
||||
}
|
||||
|
||||
} // namespace dcp_graph
|
||||
@@ -0,0 +1,57 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <list>
|
||||
#include <vector>
|
||||
|
||||
#include "Task.hpp"
|
||||
#include "Utils.hpp"
|
||||
|
||||
// Uncomment to enable DCP progress logging and per-phase profiling during
|
||||
// development. When disabled the logger methods are no-ops and the helpers
|
||||
// compile away.
|
||||
#define DCP_DEBUG_ENABLED
|
||||
|
||||
#ifdef DCP_DEBUG_ENABLED
|
||||
#define DCP_DEBUG_IF(...) __VA_ARGS__
|
||||
#else
|
||||
#define DCP_DEBUG_IF(...)
|
||||
#endif
|
||||
|
||||
namespace dcp_graph {
|
||||
|
||||
class DcpProgressLogger {
|
||||
public:
|
||||
explicit DcpProgressLogger(size_t totalTasks);
|
||||
|
||||
void recordFindDuration(double seconds);
|
||||
void recordSelectDuration(double seconds);
|
||||
void recordUpdateDuration(double seconds);
|
||||
void advanceCompleted(size_t taskCount = 1);
|
||||
|
||||
void printStart(size_t readyCount) const;
|
||||
void maybePrintSlowCandidate(size_t nodeIndex, double elapsedSeconds, size_t readyCount, CPU cpuCount) const;
|
||||
void printProgress(size_t readyCount, CPU cpuCount, llvm::StringRef stage, bool force);
|
||||
|
||||
#ifdef DCP_DEBUG_ENABLED
|
||||
private:
|
||||
static std::string formatDuration(double seconds);
|
||||
|
||||
bool logProgress = false;
|
||||
size_t totalTasks = 0;
|
||||
size_t completedTasks = 0;
|
||||
std::chrono::steady_clock::time_point startTime;
|
||||
std::chrono::steady_clock::time_point lastProgressPrint;
|
||||
double findCandidateSeconds = 0.0;
|
||||
double selectProcessorSeconds = 0.0;
|
||||
double updateTimingSeconds = 0.0;
|
||||
#endif
|
||||
};
|
||||
|
||||
void dumpGraphDot(const std::vector<TaskDCP>& nodes,
|
||||
const std::vector<std::list<TaskDCP*>>& cpuTasks,
|
||||
CPU lastCpu);
|
||||
|
||||
} // namespace dcp_graph
|
||||
@@ -0,0 +1,105 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "GraphSupport.hpp"
|
||||
#include "Task.hpp"
|
||||
#include "UniqueWorklist.hpp"
|
||||
|
||||
namespace dcp_graph {
|
||||
|
||||
llvm::DenseSet<TaskDCP*> collectReachableTasks(TaskDCP* root, bool followParents) {
|
||||
llvm::DenseSet<TaskDCP*> reachable;
|
||||
std::vector<TaskDCP*> worklist;
|
||||
worklist.reserve(32);
|
||||
|
||||
auto enqueueEdges = [&](TaskDCP* task) {
|
||||
const auto& edges = followParents ? task->parents : task->children;
|
||||
for (const auto& edge : edges)
|
||||
if (reachable.insert(edge.first).second)
|
||||
worklist.push_back(edge.first);
|
||||
};
|
||||
|
||||
enqueueEdges(root);
|
||||
while (!worklist.empty()) {
|
||||
TaskDCP* task = worklist.back();
|
||||
worklist.pop_back();
|
||||
enqueueEdges(task);
|
||||
}
|
||||
return reachable;
|
||||
}
|
||||
|
||||
GraphDCP::CandidateRelations computeCandidateRelations(TaskDCP* candidate) {
|
||||
return {collectReachableTasks(candidate, true), collectReachableTasks(candidate, false)};
|
||||
}
|
||||
|
||||
LocalScheduleSnapshot captureLocalScheduleState(TaskDCP* task,
|
||||
const llvm::DenseSet<TaskDCP*>& descendants,
|
||||
Time dcpl,
|
||||
Time maxCompletion,
|
||||
Time secondMaxCompletion,
|
||||
TaskDCP* maxCompletionTask) {
|
||||
LocalScheduleSnapshot snapshot;
|
||||
snapshot.aestBackup.reserve(descendants.size() + 1);
|
||||
snapshot.aestBackup.emplace_back(task, task->getAest());
|
||||
for (TaskDCP* descendant : descendants)
|
||||
snapshot.aestBackup.emplace_back(descendant, descendant->getAest());
|
||||
snapshot.dcpl = dcpl;
|
||||
snapshot.maxCompletion = maxCompletion;
|
||||
snapshot.secondMaxCompletion = secondMaxCompletion;
|
||||
snapshot.maxCompletionTask = maxCompletionTask;
|
||||
return snapshot;
|
||||
}
|
||||
|
||||
void restoreLocalScheduleState(const LocalScheduleSnapshot& snapshot,
|
||||
Time& dcpl,
|
||||
Time& maxCompletion,
|
||||
Time& secondMaxCompletion,
|
||||
TaskDCP*& maxCompletionTask) {
|
||||
for (const auto& [task, aest] : snapshot.aestBackup)
|
||||
task->setAest(aest);
|
||||
dcpl = snapshot.dcpl;
|
||||
maxCompletion = snapshot.maxCompletion;
|
||||
secondMaxCompletion = snapshot.secondMaxCompletion;
|
||||
maxCompletionTask = snapshot.maxCompletionTask;
|
||||
}
|
||||
|
||||
int countDependencyParents(const TaskDCP* task) {
|
||||
return static_cast<int>(llvm::count_if(task->parents, [](const Edge& edge) { return !edge.isScheduling; }));
|
||||
}
|
||||
|
||||
void recordTopologicalMove(TaskDCP* task, TaskInsertion* insertion) {
|
||||
if (insertion == nullptr)
|
||||
return;
|
||||
|
||||
auto alreadyRecorded =
|
||||
llvm::any_of(insertion->topologicalMoves,
|
||||
[task](const TaskInsertion::TopologicalMoveRecord& move) { return move.task == task; });
|
||||
if (alreadyRecorded)
|
||||
return;
|
||||
|
||||
insertion->topologicalMoves.push_back({task, onnx_mlir::LabeledList<TaskDCP>::next(task)});
|
||||
}
|
||||
|
||||
std::vector<TaskDCP*> collectDominanceOrder(llvm::ArrayRef<TaskDCP*> roots, size_t nodeCount) {
|
||||
UniqueWorkList<std::vector<TaskDCP*>> worklist(roots);
|
||||
worklist.reserve(nodeCount);
|
||||
|
||||
size_t index = 0;
|
||||
while (index != worklist.size()) {
|
||||
bool modified = true;
|
||||
while (modified) {
|
||||
modified = false;
|
||||
for (const auto& child : worklist.at(index)->children)
|
||||
if (worklist.allElementsContained(
|
||||
child.first->parents.begin(), child.first->parents.end(), [](Edge edge) { return edge.first; }))
|
||||
modified |= worklist.pushBack(child.first);
|
||||
}
|
||||
index++;
|
||||
}
|
||||
|
||||
return {worklist.begin(), worklist.end()};
|
||||
}
|
||||
|
||||
} // namespace dcp_graph
|
||||
@@ -0,0 +1,41 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "Graph.hpp"
|
||||
|
||||
namespace dcp_graph {
|
||||
|
||||
struct LocalScheduleSnapshot {
|
||||
llvm::SmallVector<std::pair<TaskDCP*, Time>, 64> aestBackup;
|
||||
Time dcpl = 0;
|
||||
Time maxCompletion = 0;
|
||||
Time secondMaxCompletion = 0;
|
||||
TaskDCP* maxCompletionTask = nullptr;
|
||||
};
|
||||
|
||||
llvm::DenseSet<TaskDCP*> collectReachableTasks(TaskDCP* root, bool followParents);
|
||||
GraphDCP::CandidateRelations computeCandidateRelations(TaskDCP* candidate);
|
||||
|
||||
LocalScheduleSnapshot captureLocalScheduleState(TaskDCP* task,
|
||||
const llvm::DenseSet<TaskDCP*>& descendants,
|
||||
Time dcpl,
|
||||
Time maxCompletion,
|
||||
Time secondMaxCompletion,
|
||||
TaskDCP* maxCompletionTask);
|
||||
void restoreLocalScheduleState(const LocalScheduleSnapshot& snapshot,
|
||||
Time& dcpl,
|
||||
Time& maxCompletion,
|
||||
Time& secondMaxCompletion,
|
||||
TaskDCP*& maxCompletionTask);
|
||||
|
||||
int countDependencyParents(const TaskDCP* task);
|
||||
void recordTopologicalMove(TaskDCP* task, TaskInsertion* insertion);
|
||||
std::vector<TaskDCP*> collectDominanceOrder(llvm::ArrayRef<TaskDCP*> roots, size_t nodeCount);
|
||||
|
||||
} // namespace dcp_graph
|
||||
@@ -4,57 +4,63 @@
|
||||
#include "Task.hpp"
|
||||
#include "UniqueWorklist.hpp"
|
||||
|
||||
std::optional<Edge_t> TaskDCP::addChild(TaskDCP* child, Weight_t weight) {
|
||||
std::optional<Edge_t> oldEdge = std::nullopt;
|
||||
auto founded_element =
|
||||
std::find_if(childs.begin(), childs.end(), [child](Edge_t element) { return child == element.first; });
|
||||
if (founded_element != childs.end()) {
|
||||
oldEdge = *founded_element;
|
||||
fastRemove(childs, founded_element);
|
||||
std::optional<Edge> TaskDCP::addChild(TaskDCP* child, Weight weight, bool isScheduling) {
|
||||
std::optional<Edge> oldEdge = std::nullopt;
|
||||
auto foundElement = std::find_if(children.begin(), children.end(), [child, isScheduling](Edge element) {
|
||||
return child == element.first && isScheduling == element.isScheduling;
|
||||
});
|
||||
if (foundElement != children.end()) {
|
||||
oldEdge = *foundElement;
|
||||
fastRemove(children, foundElement);
|
||||
}
|
||||
childs.emplace_back(child, weight);
|
||||
children.emplace_back(Edge {child, weight, isScheduling});
|
||||
return oldEdge;
|
||||
}
|
||||
|
||||
std::optional<Edge_t> TaskDCP::addParent(TaskDCP* parent, Weight_t weight) {
|
||||
std::optional<Edge_t> oldEdge = std::nullopt;
|
||||
auto founded_element =
|
||||
std::find_if(parents.begin(), parents.end(), [parent](Edge_t element) { return parent == element.first; });
|
||||
if (founded_element != parents.end()) {
|
||||
oldEdge = *founded_element;
|
||||
fastRemove(parents, founded_element);
|
||||
std::optional<Edge> TaskDCP::addParent(TaskDCP* parent, Weight weight, bool isScheduling) {
|
||||
std::optional<Edge> oldEdge = std::nullopt;
|
||||
auto foundElement = std::find_if(parents.begin(), parents.end(), [parent, isScheduling](Edge element) {
|
||||
return parent == element.first && isScheduling == element.isScheduling;
|
||||
});
|
||||
if (foundElement != parents.end()) {
|
||||
oldEdge = *foundElement;
|
||||
fastRemove(parents, foundElement);
|
||||
}
|
||||
parents.emplace_back(parent, weight);
|
||||
parents.emplace_back(Edge {parent, weight, isScheduling});
|
||||
return oldEdge;
|
||||
}
|
||||
|
||||
bool TaskDCP::hasDescendent(TaskDCP* child) {
|
||||
bool TaskDCP::hasDescendant(TaskDCP* child) {
|
||||
UniqueWorkList<std::vector<TaskDCP*>> worklist;
|
||||
worklist.reserve(32);
|
||||
worklist.push_back(this);
|
||||
worklist.pushBack(this);
|
||||
while (!worklist.empty()) {
|
||||
TaskDCP* task = worklist.back();
|
||||
worklist.pop_back();
|
||||
worklist.popBack();
|
||||
if (task == child)
|
||||
return true;
|
||||
for (auto c : task->childs)
|
||||
worklist.push_back(c.first);
|
||||
for (auto edge : task->children)
|
||||
worklist.pushBack(edge.first);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO fare qualcosa di sensato
|
||||
int TaskDCP::computeWeight(GraphDCP* graph, CPU cpu) { return origWeight; }
|
||||
Weight TaskDCP::computeWeightOnCpu(GraphDCP* graph, CPU cpu) {
|
||||
if (crossbarUsage != 0 && graph->wouldExhaustCrossbarCapacity(cpu, this))
|
||||
return std::numeric_limits<Weight>::max();
|
||||
return baseWeight;
|
||||
}
|
||||
|
||||
void TaskInsertion::rollBack() {
|
||||
graph->removeTaskFromCPU(cpuModified, taskInserted);
|
||||
if (beforeNode.has_value()) {
|
||||
auto double_edge = *beforeNode;
|
||||
addEdge(double_edge.first.first, double_edge.second.first, double_edge.first.second);
|
||||
auto edgePair = *beforeNode;
|
||||
addEdge(edgePair.first.first, edgePair.second.first, edgePair.first.second, edgePair.first.isScheduling);
|
||||
}
|
||||
if (afterNode.has_value()) {
|
||||
auto double_edge = *afterNode;
|
||||
addEdge(double_edge.first.first, double_edge.second.first, double_edge.first.second);
|
||||
auto edgePair = *afterNode;
|
||||
addEdge(edgePair.first.first, edgePair.second.first, edgePair.first.second, edgePair.first.isScheduling);
|
||||
}
|
||||
graph->topologicalOrder.moveBefore( taskInserted,&*oldTopologicalPosition );
|
||||
// for (auto it = topologicalMoves.rbegin(); it != topologicalMoves.rend(); ++it)
|
||||
// graph->topologicalOrder.moveBefore(it->task, it->nextTask);
|
||||
}
|
||||
|
||||
@@ -7,110 +7,117 @@
|
||||
#include "Utils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
std::optional<DoubleEdge> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
|
||||
void removeEdge(TaskDCP* parent, TaskDCP* child);
|
||||
|
||||
class TaskDCP : public onnx_mlir::LabeledListNode<TaskDCP> {
|
||||
onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute;
|
||||
int aest;
|
||||
int alst;
|
||||
std::optional<CPU> scheduledCPU;
|
||||
int weight;
|
||||
int origWeight;
|
||||
Time aest;
|
||||
Time alst;
|
||||
std::optional<CPU> scheduledCpu;
|
||||
Weight weight;
|
||||
Weight baseWeight;
|
||||
CrossbarUsage crossbarUsage;
|
||||
long long flag = 0;
|
||||
int64_t syntheticId = -1;
|
||||
|
||||
std::optional<Edge_t> addChild(TaskDCP* child, Weight_t weight);
|
||||
std::optional<Edge_t> addChild(TaskDCP& child, Weight_t weight) { return addChild(&child, weight); }
|
||||
std::optional<Edge> addChild(TaskDCP* child, Weight weight, bool isScheduling);
|
||||
std::optional<Edge> addChild(TaskDCP& child, Weight weight, bool isScheduling) {
|
||||
return addChild(&child, weight, isScheduling);
|
||||
}
|
||||
|
||||
void removeChild(TaskDCP* to_remove) { fastRemove(childs, to_remove); }
|
||||
void removeChild(TaskDCP& to_remove) { fastRemove(childs, &to_remove); }
|
||||
void removeChild(TaskDCP* toRemove, bool isScheduling) { fastRemove(children, toRemove, isScheduling); }
|
||||
void removeChild(TaskDCP& toRemove, bool isScheduling) { fastRemove(children, &toRemove, isScheduling); }
|
||||
|
||||
std::optional<Edge_t> addParent(TaskDCP* parent, Weight_t weight);
|
||||
std::optional<Edge_t> addParent(TaskDCP& parent, Weight_t weight) { return addParent(&parent, weight); }
|
||||
std::optional<Edge> addParent(TaskDCP* parent, Weight weight, bool isScheduling);
|
||||
std::optional<Edge> addParent(TaskDCP& parent, Weight weight, bool isScheduling) {
|
||||
return addParent(&parent, weight, isScheduling);
|
||||
}
|
||||
|
||||
void removeParent(TaskDCP* to_remove) { fastRemove(parents, to_remove); }
|
||||
void removeParent(TaskDCP& to_remove) { fastRemove(parents, &to_remove); }
|
||||
void removeParent(TaskDCP* toRemove, bool isScheduling) { fastRemove(parents, toRemove, isScheduling); }
|
||||
void removeParent(TaskDCP& toRemove, bool isScheduling) { fastRemove(parents, &toRemove, isScheduling); }
|
||||
|
||||
public:
|
||||
std::vector<Edge_t> parents;
|
||||
std::vector<Edge_t> childs;
|
||||
std::vector<Edge> parents;
|
||||
std::vector<Edge> children;
|
||||
TaskDCP() = default;
|
||||
TaskDCP(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute)
|
||||
: onnx_mlir::LabeledListNode<TaskDCP>(),
|
||||
spatWeightedCompute(spatWeightedCompute),
|
||||
aest(0),
|
||||
alst(0),
|
||||
scheduledCPU(),
|
||||
weight(getSpatWeightCompute(spatWeightedCompute)),
|
||||
origWeight(weight),
|
||||
scheduledCpu(),
|
||||
weight(getSpatComputeWeight(spatWeightedCompute)),
|
||||
baseWeight(weight),
|
||||
crossbarUsage(getSpatComputeCrossbarUsage(spatWeightedCompute)),
|
||||
syntheticId(-1),
|
||||
parents(),
|
||||
childs() {}
|
||||
children() {}
|
||||
|
||||
TaskDCP(int64_t id, int weight)
|
||||
TaskDCP(int64_t id, Weight weight, CrossbarUsage crossbarUsage = 0)
|
||||
: onnx_mlir::LabeledListNode<TaskDCP>(),
|
||||
spatWeightedCompute(),
|
||||
aest(0),
|
||||
alst(0),
|
||||
scheduledCPU(),
|
||||
scheduledCpu(),
|
||||
weight(weight),
|
||||
origWeight(weight),
|
||||
baseWeight(weight),
|
||||
crossbarUsage(crossbarUsage),
|
||||
flag(0),
|
||||
syntheticId(id),
|
||||
parents(),
|
||||
childs() {}
|
||||
children() {}
|
||||
|
||||
TaskDCP(const TaskDCP& node) = delete;
|
||||
TaskDCP(TaskDCP&& node) = default;
|
||||
|
||||
void setCPU(CPU cpu) { scheduledCPU = cpu; }
|
||||
std::optional<CPU> getCPU() const { return scheduledCPU; }
|
||||
void resetCPU() { scheduledCPU = std::nullopt; }
|
||||
int getWeight() const {
|
||||
void setCpu(CPU cpu) { scheduledCpu = cpu; }
|
||||
std::optional<CPU> getCpu() const { return scheduledCpu; }
|
||||
void resetCpu() { scheduledCpu = std::nullopt; }
|
||||
Weight getWeight() const {
|
||||
if (isScheduled())
|
||||
return weight;
|
||||
return origWeight;
|
||||
return baseWeight;
|
||||
}
|
||||
void setWeight(int val) { weight = val; }
|
||||
void resetWeight() { weight = origWeight; }
|
||||
int computeWeight(GraphDCP* graph, CPU cpu);
|
||||
void setWeight(Weight value) { weight = value; }
|
||||
void resetWeight() { weight = baseWeight; }
|
||||
Weight computeWeightOnCpu(GraphDCP* graph, CPU cpu);
|
||||
CrossbarUsage getCrossbarUsage() const { return crossbarUsage; }
|
||||
|
||||
bool hasParents() const { return parents.size() != 0; }
|
||||
bool hasChilds() const { return childs.size() != 0; }
|
||||
bool hasChildren() const { return children.size() != 0; }
|
||||
|
||||
int getAEST() const { return aest; }
|
||||
int getALST() const { return alst; }
|
||||
void setAEST(int val) {
|
||||
assert(val >= 0);
|
||||
aest = val;
|
||||
}
|
||||
void setALST(int val) { alst = val; }
|
||||
bool hasDescendent(TaskDCP* child);
|
||||
Time getAest() const { return aest; }
|
||||
Time getAlst() const { return alst; }
|
||||
void setAest(Time value) { aest = value; }
|
||||
void setAlst(Time value) { alst = value; }
|
||||
bool hasDescendant(TaskDCP* child);
|
||||
int64_t Id() const {
|
||||
if (spatWeightedCompute)
|
||||
return reinterpret_cast<int64_t>(spatWeightedCompute.getAsOpaquePointer());
|
||||
return syntheticId;
|
||||
}
|
||||
|
||||
bool isCP() const { return alst == aest; }
|
||||
bool isScheduled() const { return scheduledCPU.has_value(); }
|
||||
bool isCriticalPath() const { return alst == aest; }
|
||||
bool isScheduled() const { return scheduledCpu.has_value(); }
|
||||
onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute() const { return spatWeightedCompute; }
|
||||
|
||||
void setFlag(long long val) { flag = val; }
|
||||
long long getFlag() const { return flag; }
|
||||
|
||||
onnx_mlir::LabeledList<TaskDCP>::Iterator getTopologicalPosition() { return getIterator(); }
|
||||
onnx_mlir::LabeledList<TaskDCP>::Iterator getTopologicalIterator() { return getIterator(); }
|
||||
|
||||
friend std::optional<DoubleEdge> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
|
||||
friend void removeEdge(TaskDCP* parent, TaskDCP* child);
|
||||
friend int getTranferCost(TaskDCP* parent, TaskDCP* child);
|
||||
friend std::optional<EdgePair> addEdge(TaskDCP* parent, TaskDCP* child, Weight weight, bool isScheduling);
|
||||
friend void removeEdge(TaskDCP* parent, TaskDCP* child, bool isScheduling);
|
||||
friend Weight getTransferCost(TaskDCP* parent, TaskDCP* child);
|
||||
};
|
||||
|
||||
struct TaskInsertion {
|
||||
std::optional<DoubleEdge> beforeNode;
|
||||
std::optional<DoubleEdge> afterNode;
|
||||
onnx_mlir::LabeledList<TaskDCP>::Iterator oldTopologicalPosition;
|
||||
struct TopologicalMoveRecord {
|
||||
TaskDCP* task;
|
||||
TaskDCP* nextTask;
|
||||
};
|
||||
|
||||
std::optional<EdgePair> beforeNode;
|
||||
std::optional<EdgePair> afterNode;
|
||||
std::vector<TopologicalMoveRecord> topologicalMoves;
|
||||
CPU cpuModified;
|
||||
TaskDCP* taskInserted;
|
||||
GraphDCP* graph;
|
||||
|
||||
@@ -1,58 +1,57 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <type_traits>
|
||||
#include <iostream>
|
||||
#include <unordered_set>
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct has_pop_front : std::false_type {};
|
||||
struct HasPopFront : std::false_type {};
|
||||
|
||||
template <typename T>
|
||||
struct has_pop_front<T, std::void_t<decltype(std::declval<T>().pop_front())>> : std::true_type {};
|
||||
struct HasPopFront<T, std::void_t<decltype(std::declval<T>().pop_front())>> : std::true_type {};
|
||||
|
||||
template <typename T>
|
||||
class UniqueWorkList {
|
||||
|
||||
using V = typename T::value_type;
|
||||
using ValueType = typename T::value_type;
|
||||
T storage;
|
||||
llvm::DenseSet<V> set;
|
||||
llvm::DenseSet<ValueType> uniqueElements;
|
||||
|
||||
public:
|
||||
UniqueWorkList() = default;
|
||||
|
||||
template <typename arg_ty>
|
||||
UniqueWorkList(const arg_ty& from)
|
||||
template <typename RangeT>
|
||||
UniqueWorkList(const RangeT& from)
|
||||
: storage() {
|
||||
for (auto& element : from) {
|
||||
if (!set.contains(element)) {
|
||||
if (!uniqueElements.contains(element)) {
|
||||
storage.push_back(element);
|
||||
set.insert(element);
|
||||
uniqueElements.insert(element);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool empty() const { return storage.empty(); }
|
||||
void reserve(size_t val) { return storage.reserve(val); }
|
||||
void reserve(size_t value) { return storage.reserve(value); }
|
||||
size_t size() const { return storage.size(); }
|
||||
V& at(size_t i) { return storage.at(i); }
|
||||
const V& at(size_t i) const { return storage.at(i); }
|
||||
ValueType& at(size_t index) { return storage.at(index); }
|
||||
const ValueType& at(size_t index) const { return storage.at(index); }
|
||||
|
||||
V& front() { return storage.front(); }
|
||||
V& back() { return storage.back(); }
|
||||
ValueType& front() { return storage.front(); }
|
||||
ValueType& back() { return storage.back(); }
|
||||
|
||||
bool push_back(const V& val) {
|
||||
if (!set.contains(val)) {
|
||||
storage.push_back(val);
|
||||
set.insert(val);
|
||||
bool pushBack(const ValueType& value) {
|
||||
if (!uniqueElements.contains(value)) {
|
||||
storage.push_back(value);
|
||||
uniqueElements.insert(value);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void pop_front() {
|
||||
if constexpr (has_pop_front<T>::value)
|
||||
void popFront() {
|
||||
if constexpr (HasPopFront<T>::value)
|
||||
storage.pop_front();
|
||||
else
|
||||
assert(false && "Underlying storage type does not support pop_front()");
|
||||
@@ -61,15 +60,15 @@ public:
|
||||
auto cbegin() const { return storage.cbegin(); }
|
||||
auto cend() const { return storage.cend(); }
|
||||
|
||||
void pop_back() { storage.pop_back(); }
|
||||
|
||||
void popBack() { storage.pop_back(); }
|
||||
|
||||
template <typename Iterator, typename Mapper>
|
||||
bool allElementContained(Iterator start, Iterator end, Mapper map) {
|
||||
while (start != end) {
|
||||
if (!set.contains(map(*start)))
|
||||
bool allElementsContained(Iterator begin, Iterator end, Mapper map) const {
|
||||
auto it = begin;
|
||||
while (it != end) {
|
||||
if (!uniqueElements.contains(map(*it)))
|
||||
return false;
|
||||
std::advance(start, 1);
|
||||
std::advance(it, 1);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@@ -77,4 +76,8 @@ public:
|
||||
auto begin() { return storage.begin(); }
|
||||
|
||||
auto end() { return storage.end(); }
|
||||
|
||||
auto begin() const { return storage.begin(); }
|
||||
|
||||
auto end() const { return storage.end(); }
|
||||
};
|
||||
|
||||
@@ -6,60 +6,106 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <list>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/LabeledList.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Support/TypeUtilities.hpp"
|
||||
|
||||
|
||||
using CPU = int;
|
||||
using Weight_t = int;
|
||||
using Weight = unsigned long long;
|
||||
using Time = unsigned long long;
|
||||
using CrossbarUsage = unsigned long long;
|
||||
class TaskDCP;
|
||||
class GraphDCP;
|
||||
using Edge_t = std::pair<TaskDCP*, Weight_t>;
|
||||
using DoubleEdge = std::pair<Edge_t, Edge_t>;
|
||||
using EdgesIndex = std::tuple<int64_t, int64_t, int64_t>;
|
||||
struct Edge {
|
||||
TaskDCP* first;
|
||||
Weight second;
|
||||
bool isScheduling = false;
|
||||
};
|
||||
using EdgePair = std::pair<Edge, Edge>;
|
||||
using IndexedEdge = std::tuple<int64_t, int64_t, int64_t>;
|
||||
|
||||
inline void fastRemove(std::vector<Edge>& vector, TaskDCP* toRemove, bool isScheduling) {
|
||||
auto position = std::find_if(vector.begin(), vector.end(), [toRemove, isScheduling](Edge edge) {
|
||||
return edge.first == toRemove && edge.isScheduling == isScheduling;
|
||||
});
|
||||
if (position != vector.end()) {
|
||||
std::swap(*(vector.end() - 1), *position);
|
||||
vector.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
inline void fastRemove(std::vector<TaskDCP*>& vector, TaskDCP* toRemove) {
|
||||
auto position =
|
||||
std::find_if(vector.begin(), vector.end(), [toRemove](TaskDCP* element) { return element == toRemove; });
|
||||
if (position != vector.end()) {
|
||||
std::swap(*(vector.end() - 1), *position);
|
||||
vector.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
void fastRemove(std::vector<Edge>& vector, P position) {
|
||||
if (position != vector.end()) {
|
||||
std::swap(*(vector.end() - 1), *position);
|
||||
vector.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void fastRemove(std::vector<std::pair<T*, Weight_t>>& vector, T* to_remove) {
|
||||
auto position =
|
||||
std::find_if(vector.begin(), vector.end(), [to_remove](Edge_t edge) { return edge.first == to_remove; });
|
||||
if (position != vector.end()) {
|
||||
std::swap(*(vector.end() - 1), *position);
|
||||
vector.pop_back();
|
||||
}
|
||||
inline T checkedAdd(T lhs, T rhs) {
|
||||
static_assert(std::is_unsigned_v<T>, "checkedAdd only supports unsigned types");
|
||||
assert(lhs <= std::numeric_limits<T>::max() - rhs && "unsigned addition overflow");
|
||||
return lhs + rhs;
|
||||
}
|
||||
|
||||
inline void fastRemove(std::vector<TaskDCP*>& vector, TaskDCP* to_remove) {
|
||||
auto position =
|
||||
std::find_if(vector.begin(), vector.end(), [to_remove](TaskDCP* element) { return element == to_remove; });
|
||||
if (position != vector.end()) {
|
||||
std::swap(*(vector.end() - 1), *position);
|
||||
vector.pop_back();
|
||||
}
|
||||
template <typename T>
|
||||
inline T checkedMultiply(T lhs, T rhs) {
|
||||
static_assert(std::is_unsigned_v<T>, "checkedMultiply only supports unsigned types");
|
||||
if (lhs == 0 || rhs == 0)
|
||||
return 0;
|
||||
assert(lhs <= std::numeric_limits<T>::max() / rhs && "unsigned multiplication overflow");
|
||||
return lhs * rhs;
|
||||
}
|
||||
|
||||
template <typename T, typename P>
|
||||
void fastRemove(std::vector<std::pair<T*, Weight_t>>& vector, P position) {
|
||||
if (position != vector.end()) {
|
||||
std::swap(*(vector.end() - 1), *position);
|
||||
vector.pop_back();
|
||||
}
|
||||
template <typename T>
|
||||
inline T addOrMax(T lhs, T rhs) {
|
||||
static_assert(std::is_unsigned_v<T>, "addOrMax only supports unsigned types");
|
||||
if (lhs == std::numeric_limits<T>::max() || rhs == std::numeric_limits<T>::max())
|
||||
return std::numeric_limits<T>::max();
|
||||
return checkedAdd(lhs, rhs);
|
||||
}
|
||||
|
||||
// TODO Fare qualcosa di sensato
|
||||
inline int64_t getSpatWeightCompute(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) {
|
||||
int64_t tot = 0;
|
||||
for (auto& region : spatWeightedCompute.getBody()) {
|
||||
for (auto& inst : region) {
|
||||
for (auto result : inst.getResults())
|
||||
if (auto element = llvm::dyn_cast<mlir::ShapedType>(result.getType()))
|
||||
tot += onnx_mlir::getSizeInBytes(element);
|
||||
}
|
||||
}
|
||||
return tot;
|
||||
template <typename T>
|
||||
inline T subtractOrZero(T lhs, T rhs) {
|
||||
static_assert(std::is_unsigned_v<T>, "subtractOrZero only supports unsigned types");
|
||||
if (lhs == std::numeric_limits<T>::max())
|
||||
return lhs;
|
||||
if (rhs == std::numeric_limits<T>::max() || lhs <= rhs)
|
||||
return 0;
|
||||
return lhs - rhs;
|
||||
}
|
||||
|
||||
inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); }
|
||||
|
||||
inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) {
|
||||
constexpr Weight kOperationWeight = 100;
|
||||
Weight numOperations = 0;
|
||||
for (auto& block : spatWeightedCompute.getBody())
|
||||
for ([[maybe_unused]] auto& op : block)
|
||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
||||
return checkedMultiply(numOperations, kOperationWeight);
|
||||
}
|
||||
|
||||
inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) {
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
for (auto& region : spatWeightedCompute.getBody())
|
||||
for (auto& inst : region)
|
||||
if (llvm::isa<onnx_mlir::spatial::SpatWeightedVMMOp>(inst))
|
||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||
return crossbarUsage;
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
#include "mlir/IR/Region.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
@@ -14,13 +13,12 @@
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "DCPGraph/DCPAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -36,10 +34,10 @@ struct ComputeValueResults {
|
||||
class LazyInsertComputeResult {
|
||||
using InsertPoint = mlir::IRRewriter::InsertPoint;
|
||||
ComputeValueResults computeResults;
|
||||
Value channelNewOpVal;
|
||||
Value channelValue;
|
||||
bool onlyChannel;
|
||||
std::function<void(InsertPoint insertPoint)> channelSendInserter;
|
||||
InsertPoint insertPointSend;
|
||||
InsertPoint sendInsertPoint;
|
||||
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter;
|
||||
|
||||
public:
|
||||
@@ -49,7 +47,7 @@ public:
|
||||
: computeResults(computeValueResults),
|
||||
onlyChannel(isOnlyChannel),
|
||||
channelSendInserter(nullptr),
|
||||
insertPointSend({}),
|
||||
sendInsertPoint({}),
|
||||
channelNewInserter(channelNewInserter) {}
|
||||
|
||||
struct ChannelOrLocalOp {
|
||||
@@ -59,23 +57,23 @@ public:
|
||||
|
||||
bool onlyChanneled() const { return onlyChannel; }
|
||||
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatWeightedCompute spatWeightedCompute) {
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatWeightedCompute currentCompute) {
|
||||
|
||||
auto [first, second] = channelNewInserter();
|
||||
channelNewOpVal = first;
|
||||
channelSendInserter = second;
|
||||
auto BB = computeResults.innerValue.getParentBlock();
|
||||
if (!BB->empty() && isa<spatial::SpatYieldOp>(BB->back()))
|
||||
insertPointSend = InsertPoint(BB, --BB->end());
|
||||
auto [newChannelValue, senderInserter] = channelNewInserter();
|
||||
channelValue = newChannelValue;
|
||||
channelSendInserter = senderInserter;
|
||||
auto* block = computeResults.innerValue.getParentBlock();
|
||||
if (!block->empty() && isa<spatial::SpatYieldOp>(block->back()))
|
||||
sendInsertPoint = InsertPoint(block, --block->end());
|
||||
else
|
||||
insertPointSend = InsertPoint(BB, BB->end());
|
||||
if (spatWeightedCompute) {
|
||||
for (auto& BB : spatWeightedCompute.getBody())
|
||||
if (&BB == insertPointSend.getBlock())
|
||||
sendInsertPoint = InsertPoint(block, block->end());
|
||||
if (currentCompute) {
|
||||
for (auto& block : currentCompute.getBody())
|
||||
if (&block == sendInsertPoint.getBlock())
|
||||
return {computeResults.innerValue, false};
|
||||
}
|
||||
channelSendInserter(insertPointSend);
|
||||
return {channelNewOpVal, true};
|
||||
channelSendInserter(sendInsertPoint);
|
||||
return {channelValue, true};
|
||||
}
|
||||
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
|
||||
@@ -86,7 +84,7 @@ struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<
|
||||
private:
|
||||
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
|
||||
DenseMap<SpatWeightedCompute, SpatWeightedCompute> oldToNewComputeMap;
|
||||
DenseMap<int64_t, SpatWeightedCompute> cputToNewComputeMap;
|
||||
DenseMap<int64_t, SpatWeightedCompute> cpuToNewComputeMap;
|
||||
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass)
|
||||
@@ -101,17 +99,16 @@ public:
|
||||
|
||||
void runOnOperation() override {
|
||||
DCPAnalysisResult& analysisResult = getAnalysis<spatial::DCPAnalysis>().getResult();
|
||||
auto& lastComputeOfCpu = analysisResult.isLastComputeOfACpu;
|
||||
auto& lastComputeOfCpu = analysisResult.isLastComputeOfCpu;
|
||||
auto& cpuToLastComputeMap = analysisResult.cpuToLastComputeMap;
|
||||
IRRewriter rewriter(&getContext());
|
||||
|
||||
for (auto currentComputeNode : analysisResult.dominanceOrderCompute) {
|
||||
size_t cpu = analysisResult.computeToCPUMap.at(currentComputeNode);
|
||||
if (!cputToNewComputeMap.contains(cpu)) {
|
||||
size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode);
|
||||
if (!cpuToNewComputeMap.contains(cpu)) {
|
||||
ValueTypeRange<ResultRange> newWeightedComputeType = cpuToLastComputeMap.at(cpu).getResultTypes();
|
||||
auto [newWeightedCompute, computeValueResult] = createNewComputeNode(
|
||||
currentComputeNode, newWeightedComputeType, lastComputeOfCpu.contains(currentComputeNode));
|
||||
cputToNewComputeMap[cpu] = newWeightedCompute;
|
||||
cpuToNewComputeMap[cpu] = newWeightedCompute;
|
||||
newComputeNodeResults.insert(
|
||||
std::make_pair(currentComputeNode,
|
||||
createLazyComputeResult(
|
||||
@@ -119,7 +116,7 @@ public:
|
||||
}
|
||||
else {
|
||||
auto [newWeightedCompute, computeValueResult] = mergeIntoComputeNode(
|
||||
cputToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode));
|
||||
cpuToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode));
|
||||
newComputeNodeResults.insert(
|
||||
std::make_pair(currentComputeNode,
|
||||
createLazyComputeResult(
|
||||
@@ -127,10 +124,10 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) {
|
||||
for (auto users : computeNodetoRemove->getUsers())
|
||||
for (auto computeNodeToRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) {
|
||||
for (auto users : computeNodeToRemove->getUsers())
|
||||
users->dump();
|
||||
computeNodetoRemove.erase();
|
||||
computeNodeToRemove.erase();
|
||||
}
|
||||
func::FuncOp func = getOperation();
|
||||
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
|
||||
@@ -186,9 +183,9 @@ private:
|
||||
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
|
||||
auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender();
|
||||
assert(isChannel == true);
|
||||
spatial::SpatChannelReceiveOp reciveOp =
|
||||
spatial::SpatChannelReceiveOp receiveOp =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelVal);
|
||||
mapper.map(oldBB.getArgument(indexOld - indexOldStart), reciveOp);
|
||||
mapper.map(oldBB.getArgument(indexOld - indexOldStart), receiveOp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,8 +235,8 @@ private:
|
||||
|
||||
auto& toBB = toCompute.getBody().front();
|
||||
auto& fromBB = fromCompute.getBody().front();
|
||||
auto inputeArgMutable = toCompute.getInputsMutable();
|
||||
// Insert reciveOp
|
||||
auto inputArgMutable = toCompute.getInputsMutable();
|
||||
// Insert receiveOp
|
||||
rewriter.setInsertionPointToEnd(&toBB);
|
||||
for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) {
|
||||
if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatWeightedCompute>(arg.getDefiningOp())) {
|
||||
@@ -248,9 +245,9 @@ private:
|
||||
LazyInsertComputeResult::ChannelOrLocalOp channelOrLocal =
|
||||
lazyArgWeight.getAsChannelValueAndInsertSender(toCompute);
|
||||
if (channelOrLocal.isChannel) {
|
||||
spatial::SpatChannelReceiveOp reciveOp =
|
||||
spatial::SpatChannelReceiveOp receiveOp =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelOrLocal.data);
|
||||
mapper.map(fromBB.getArgument(bbIndex), reciveOp.getResult());
|
||||
mapper.map(fromBB.getArgument(bbIndex), receiveOp.getResult());
|
||||
}
|
||||
else {
|
||||
mapper.map(fromBB.getArgument(bbIndex), channelOrLocal.data);
|
||||
@@ -262,7 +259,7 @@ private:
|
||||
if (founded == toCompute.getInputs().end()) {
|
||||
size_t sizeW = toCompute.getWeights().size();
|
||||
size_t sizeI = toCompute.getInputs().size();
|
||||
inputeArgMutable.append(arg);
|
||||
inputArgMutable.append(arg);
|
||||
assert(sizeW == toCompute.getWeights().size());
|
||||
assert(sizeI + 1 == toCompute.getInputs().size());
|
||||
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
|
||||
@@ -281,6 +278,12 @@ private:
|
||||
assert(mapper.contains(oldBBarg));
|
||||
|
||||
ComputeValueResults computeValueResults;
|
||||
auto remapWeightIndex = [&](auto weightedOp) {
|
||||
auto oldIndex = weightedOp.getWeightIndex();
|
||||
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
|
||||
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
|
||||
weightedOp.setWeightIndex(newIndex);
|
||||
};
|
||||
for (auto& op : fromCompute.getOps()) {
|
||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0));
|
||||
@@ -289,20 +292,10 @@ private:
|
||||
}
|
||||
else {
|
||||
auto newInst = rewriter.clone(op, mapper);
|
||||
// TODO Refactor in a lambda? same code just different cast, but templated lambda are C++20 and a free function
|
||||
// is a bit too much
|
||||
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
|
||||
auto oldIndex = vmOp.getWeightIndex();
|
||||
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
|
||||
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
|
||||
vmOp.setWeightIndex(newIndex);
|
||||
}
|
||||
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
|
||||
auto oldIndex = vmOp.getWeightIndex();
|
||||
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
|
||||
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
|
||||
vmOp.setWeightIndex(newIndex);
|
||||
}
|
||||
if (auto weightedMvmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst))
|
||||
remapWeightIndex(weightedMvmOp);
|
||||
if (auto weightedVmmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst))
|
||||
remapWeightIndex(weightedVmmOp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,19 +316,18 @@ private:
|
||||
IRRewriter rewriter(context);
|
||||
|
||||
rewriter.setInsertionPointToStart(&funcOp.front());
|
||||
auto saveInsertionPointChnNew = rewriter.saveInsertionPoint();
|
||||
auto insertNew = [saveInsertionPointChnNew, context, loc, computeValueResults]() {
|
||||
auto savedChannelInsertPoint = rewriter.saveInsertionPoint();
|
||||
auto insertNew = [savedChannelInsertPoint, context, loc, computeValueResults]() {
|
||||
IRRewriter rewriter(context);
|
||||
rewriter.restoreInsertionPoint(saveInsertionPointChnNew);
|
||||
rewriter.restoreInsertionPoint(savedChannelInsertPoint);
|
||||
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(context));
|
||||
auto channelVal = channelOp.getResult();
|
||||
auto insertVal =
|
||||
[&context, loc, computeValueResults, channelVal](mlir::IRRewriter::InsertPoint insertPointChnSend) {
|
||||
IRRewriter rewriter(context);
|
||||
rewriter.restoreInsertionPoint(insertPointChnSend);
|
||||
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.innerValue);
|
||||
return spatSend;
|
||||
};
|
||||
auto insertVal = [&context, loc, computeValueResults, channelVal](mlir::IRRewriter::InsertPoint sendInsertPoint) {
|
||||
IRRewriter rewriter(context);
|
||||
rewriter.restoreInsertionPoint(sendInsertPoint);
|
||||
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.innerValue);
|
||||
return spatSend;
|
||||
};
|
||||
std::pair<Value, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {channelVal, insertVal};
|
||||
return ret;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user