Add DCP alghoritm, partial working test

This commit is contained in:
ilgeco
2026-04-07 22:05:39 +02:00
parent ef4743c986
commit ca56e3d4f1
17 changed files with 1313 additions and 33 deletions

View File

@@ -1,4 +1,5 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/Support/raw_os_ostream.h" #include "llvm/Support/raw_os_ostream.h"

View File

@@ -33,6 +33,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
} }
if (pimEmissionTarget >= EmitPim) { if (pimEmissionTarget >= EmitPim) {
pm.addPass(createMergeComputeNodePass());
pm.addPass(createSpatialToPimPass()); pm.addPass(createSpatialToPimPass());
// pm.addPass(createCountInstructionPass()); // pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Spatial lowered to Pim")); pm.addPass(createMessagePass("Spatial lowered to Pim"));

View File

@@ -14,6 +14,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"

View File

@@ -4,6 +4,10 @@ add_onnx_mlir_dialect_doc(spat Spatial.td)
add_pim_library(SpatialOps add_pim_library(SpatialOps
SpatialOps.cpp SpatialOps.cpp
Transforms/SpatialBufferizableOpInterface.cpp Transforms/SpatialBufferizableOpInterface.cpp
Transforms/MergeComputeNode/MergeComputeNodePass.cpp
DCPGraph/Graph.cpp
DCPGraph/Task.cpp
DCPGraph/DCPAnalysis.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS

View File

@@ -0,0 +1,52 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include <iterator>
#include "../SpatialOps.hpp"
#include "DCPAnalysis.hpp"
#include "Graph.hpp"
#include "src/Support/TypeUtilities.hpp"
namespace onnx_mlir {
namespace spatial {
using namespace mlir;
DCPAnalysisResult DCPAnalysis::runAnalysis() {
using EdgesIndex = std::tuple<int64_t, int64_t, int64_t>;
llvm::SmallVector<SpatWeightedCompute, 10> spatWeightedComputes;
llvm::SmallVector<EdgesIndex, 10> edges;
for (auto& regions : entryOp->getRegions())
for (SpatWeightedCompute spatWeightedCompute : regions.getOps<SpatWeightedCompute>())
spatWeightedComputes.push_back(spatWeightedCompute);
for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) {
for (Value input : spatWeightedCompute.getInputs()) {
if (auto spatWeightedComputeArgOp = llvm::dyn_cast_if_present<SpatWeightedCompute>(input.getDefiningOp());
spatWeightedComputeArgOp) {
auto elemIter = llvm::find(spatWeightedComputes, spatWeightedComputeArgOp);
assert(elemIter != spatWeightedComputes.end());
auto indexStartEdge = std::distance(spatWeightedComputes.begin(), elemIter);
ResultRange outputs = spatWeightedComputeArgOp.getResults();
int64_t totalSize = 0;
for (auto output : outputs) {
ShapedType result = cast<ShapedType>(output.getType());
totalSize += getSizeInBytes(result);
}
edges.push_back({indexStartEdge, indexEndEdge, totalSize});
}
}
}
GraphDCP graphDCP(spatWeightedComputes, edges);
graphDCP.DCP();
return graphDCP.getResult();
}
} // namespace spatial
} // namespace onnx_mlir

View File

@@ -0,0 +1,35 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "llvm/ADT/DenseMap.h"
#include <vector>
#include "../SpatialOps.hpp"
struct DCPAnalysisResult {
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;
llvm::DenseMap<onnx_mlir::spatial::SpatWeightedCompute, size_t> computeToCPUMap;
llvm::DenseSet<onnx_mlir::spatial::SpatWeightedCompute> isLastComputeOfACpu;
llvm::DenseMap<size_t, onnx_mlir::spatial::SpatWeightedCompute> cpuToLastComputeMap;
};
namespace onnx_mlir {
namespace spatial {
struct DCPAnalysis {
private:
DCPAnalysisResult result;
mlir::Operation* entryOp;
DCPAnalysisResult runAnalysis();
public:
DCPAnalysis(mlir::Operation* op)
: entryOp(op) {
result = runAnalysis();
}
DCPAnalysisResult& getResult() { return result; }
};
} // namespace spatial
} // namespace onnx_mlir

View File

@@ -0,0 +1,496 @@
#include <algorithm>
#include <array>
#include <cassert>
#include <climits>
#include <cstddef>
#include <deque>
#include <fstream>
#include <vector>
#include "../../../Common/PimCommon.hpp"
#include "DCPAnalysis.hpp"
#include "Graph.hpp"
#include "Task.hpp"
#include "Uniqueworklist.hpp"
#include "Utils.hpp"
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight) {
auto old_child = parent->addChild(child, weight);
auto old_parent = child->addParent(parent, weight);
assert(old_child.has_value() == old_parent.has_value() && "The edge must be present in both element");
if (old_child.has_value()) {
return {
{*old_parent, *old_child}
};
}
return {};
}
void removeEdge(TaskDCP* parent, TaskDCP* child) {
parent->removeChild(child);
child->removeParent(parent);
}
int getTranferCost(TaskDCP* parent, TaskDCP* child) {
if (parent->scheduledCPU.has_value() && child->scheduledCPU.has_value()
&& *parent->scheduledCPU == *child->scheduledCPU) {
return 0;
}
auto child_position =
std::find_if(parent->childs.begin(), parent->childs.end(), [&child](Edge_t elem) { return elem.first == child; });
assert(child_position != parent->childs.end());
return child_position->second;
}
std::array<std::optional<Edge_pair>, 2> GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position) {
std::array<std::optional<Edge_pair>, 2> ret;
task->setCPU(cpu);
task->setWeight(task->computeWeight(this, cpu));
auto& list = mapCPUTasks[cpu];
unsigned int total_size = list.size();
assert(position <= total_size && "Inserting in a not valid position");
auto inserted_point = list.begin();
inserted_point = list.insert(std::next(list.begin(), position), task);
if (inserted_point != list.begin()) {
auto precedent_point = std::prev(inserted_point, 1);
auto old_edge = addEdge(*precedent_point, *inserted_point, 0);
ret[0] = old_edge;
}
if (std::next(inserted_point) != list.end()) {
auto next_point = std::next(inserted_point, 1);
auto old_edge = addEdge(*inserted_point, *next_point, 0);
ret[1] = old_edge;
}
return ret;
}
void GraphDCP::removeTaskFromCPU(CPU cpu, TaskDCP* task) {
task->resetCPU();
task->resetWeight();
auto& list = mapCPUTasks[cpu];
auto task_position = std::find(list.begin(), list.end(), task);
assert(task_position != list.end() && "Removing a not present task");
if (task_position != list.begin()) {
auto precedent_point = std::prev(task_position, 1);
removeEdge(*precedent_point, *task_position);
}
if (std::next(task_position) != list.end()) {
auto next_point = std::next(task_position, 1);
removeEdge(*task_position, *next_point);
}
list.erase(task_position);
}
std::vector<TaskDCP*> GraphDCP::getRoots() {
std::vector<TaskDCP*> tmp;
for (auto& task : nodes)
if (!task.hasParent())
tmp.push_back(&task);
return tmp;
}
void GraphDCP::initAEST() {
UniqueWorkList<std::deque<TaskDCP*>> worklists(getRoots());
while (!worklists.empty()) {
TaskDCP& task = *worklists.front();
bool modified = true;
while (modified) {
modified = false;
for (auto& child : task.childs) {
if (worklists.allElementContained(
child.first->parents.begin(), child.first->parents.end(), [](Edge_t edge) { return edge.first; })) {
modified |= worklists.push_back(child.first);
}
}
}
int max_parent_aest = 0;
for (auto& parent : task.parents) {
max_parent_aest = std::max(
parent.first->getAEST() + parent.first->getWeight() + getTranferCost(parent.first, &task), max_parent_aest);
}
task.setAEST(max_parent_aest);
worklists.pop_front();
}
}
int GraphDCP::computeAEST(TaskDCP* task, CPU cpu) {
int max_parent_aest = 0;
for (auto& parent : task->parents) {
int transfer_cost = 0;
if (!(parent.first->isScheduled() && cpu == *parent.first->getCPU()))
transfer_cost = getTranferCost(parent.first, task);
max_parent_aest = std::max(parent.first->getAEST() + parent.first->getWeight() + transfer_cost, max_parent_aest);
}
return max_parent_aest;
}
int GraphDCP::initDCPL() {
int max_aest = 0;
for (auto& node : nodes)
max_aest = std::max(node.getAEST() + node.getWeight(), max_aest);
return max_aest;
}
int GraphDCP::computeDCPL(TaskDCP* task, CPU cpu) {
int max_aest = 0;
for (auto& node : nodes)
if (&node != task)
max_aest = std::max(node.getAEST() + node.getWeight(), max_aest);
else
max_aest = std::max(computeAEST(task, cpu) + node.computeWeight(this, cpu), max_aest);
return max_aest;
}
void GraphDCP::initALST() {
int dcpl = initDCPL();
std::vector<TaskDCP*> roots = getRoots();
UniqueWorkList<std::vector<TaskDCP*>> worklists(roots);
worklists.reserve(nodes.size());
size_t i = 0;
while (i != worklists.size()) {
bool modified = true;
while (modified) {
modified = false;
for (auto& child : worklists.at(i)->childs) {
if (worklists.allElementContained(
child.first->parents.begin(), child.first->parents.end(), [](Edge_t edge) { return edge.first; })) {
modified |= worklists.push_back(child.first);
}
}
}
i++;
}
while (!worklists.empty()) {
TaskDCP& node = *worklists.back();
int min_alst = INT_MAX;
if (!node.hasChilds())
min_alst = dcpl - node.getWeight();
for (auto child : node.childs)
min_alst = std::min(min_alst, child.first->getALST() - node.getWeight() - getTranferCost(&node, child.first));
node.setALST(min_alst);
worklists.pop_back();
}
}
std::unordered_map<TaskDCP*, int> GraphDCP::computeALST(TaskDCP* task, CPU cpu) {
int dcpl = computeDCPL(task, cpu);
std::unordered_map<TaskDCP*, int> temp_ALST;
UniqueWorkList<std::vector<TaskDCP*>> worklists(getRoots());
size_t i = 0;
while (i != worklists.size()) {
bool modified = true;
while (modified) {
modified = false;
for (auto& child : worklists.at(i)->childs) {
if (worklists.allElementContained(
child.first->parents.begin(), child.first->parents.end(), [](Edge_t edge) { return edge.first; })) {
modified |= worklists.push_back(child.first);
}
}
}
i++;
}
while (!worklists.empty()) {
TaskDCP& node = *worklists.back();
int min_alst = INT_MAX;
if (!node.hasChilds()) {
if (&node != task)
min_alst = dcpl - node.getWeight();
else
min_alst = dcpl - node.computeWeight(this, cpu);
}
for (auto child : node.childs) {
int transfer_cost = getTranferCost(&node, child.first);
if (&node == task && child.first->isScheduled() && cpu == *child.first->getCPU())
transfer_cost = 0;
min_alst = std::min(min_alst, temp_ALST[child.first] - node.getWeight() - transfer_cost);
}
temp_ALST[&node] = min_alst;
worklists.pop_back();
}
return temp_ALST;
}
TaskDCP* GraphDCP::findCandidate(std::vector<TaskDCP*> nodes) {
auto hasNoCPParentUnsecheduled = [](TaskDCP* node) {
return std::all_of(
node->parents.begin(), node->parents.end(), [](Edge_t element) { return element.first->isScheduled() == true; });
};
auto findBestNode = [](auto lft, auto rgt) {
int lft_difference = (*lft)->getALST() - (*lft)->getAEST();
int rgt_difference = (*rgt)->getALST() - (*rgt)->getAEST();
if (lft_difference < rgt_difference)
return lft;
if (rgt_difference < lft_difference)
return rgt;
if ((*lft)->getAEST() < (*rgt)->getAEST())
return lft;
return rgt;
};
auto valid_node = std::find_if(nodes.begin(), nodes.end(), hasNoCPParentUnsecheduled);
auto best_node = valid_node;
while (valid_node != nodes.end()) {
if (!hasNoCPParentUnsecheduled(*valid_node)) {
std::advance(valid_node, 1);
continue;
}
best_node = findBestNode(valid_node, best_node);
std::advance(valid_node, 1);
}
return *best_node;
}
GraphDCP::FindSlot GraphDCP::findSlot(TaskDCP* candidate, CPU cpu, bool push) {
int aest_on_cpu = computeAEST(candidate, cpu);
auto tmp_ALST = computeALST(candidate, cpu);
int final_time = tmp_ALST[candidate] + candidate->computeWeight(this, cpu);
std::list<TaskDCP*>& scheduledTasks = mapCPUTasks[cpu];
// Search last non ancestor
auto after_last_anc = scheduledTasks.end();
while (after_last_anc != scheduledTasks.begin()) {
if (after_last_anc != scheduledTasks.end() && (*after_last_anc)->hasDescendent(candidate))
break;
after_last_anc = std::prev(after_last_anc, 1);
}
if (after_last_anc != scheduledTasks.end() && (*after_last_anc)->hasDescendent(candidate))
std::advance(after_last_anc, 1);
auto first_descendent_index = scheduledTasks.begin();
while (first_descendent_index != scheduledTasks.end()) {
if (first_descendent_index != scheduledTasks.end() && candidate->hasDescendent(*first_descendent_index))
break;
first_descendent_index = std::next(first_descendent_index, 1);
}
auto iter_index = after_last_anc;
auto best_index = scheduledTasks.end();
int best_max;
assert(std::distance(scheduledTasks.begin(), after_last_anc)
<= std::distance(scheduledTasks.begin(), first_descendent_index));
bool keep = true;
while (keep && iter_index != scheduledTasks.end()) {
if (iter_index == first_descendent_index)
keep = false;
int min = INT_MAX;
if (!push)
min = std::min(final_time, (*iter_index)->getAEST());
else if (tmp_ALST.count(*iter_index) == 1)
min = std::min(final_time, tmp_ALST[*iter_index]);
else
min = std::min(final_time, (*iter_index)->getALST());
int max = aest_on_cpu;
if (iter_index != scheduledTasks.begin()) {
auto prev_iter = std::prev(iter_index);
max = std::max(aest_on_cpu, (*prev_iter)->getAEST() + (*prev_iter)->getWeight());
}
if (min - max >= candidate->computeWeight(this, cpu)) {
best_index = iter_index;
best_max = max;
break;
}
std::advance(iter_index, 1);
if (iter_index == scheduledTasks.end())
keep = false;
}
if (best_index != scheduledTasks.end())
return FindSlot {best_max, (int) std::distance(scheduledTasks.begin(), best_index)};
if (iter_index == scheduledTasks.end()) {
best_max = aest_on_cpu;
if (iter_index != scheduledTasks.begin()) {
auto prev_iter = std::prev(iter_index);
best_max = std::max(aest_on_cpu, (*prev_iter)->getAEST() + (*prev_iter)->getWeight());
}
return FindSlot {best_max, (int) std::distance(scheduledTasks.begin(), scheduledTasks.end())};
}
return FindSlot {INT_MAX, 0};
}
void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
std::vector<CPU> processors;
processors.reserve(lastCPU());
for (CPU c = push ? lastCPU() : lastCPU(); c >= 0; c--)
processors.push_back(c);
CPU best_process = -1;
int best_composite = INT_MAX;
FindSlot best_slot;
while (!processors.empty()) {
CPU current_cpu = processors.back();
processors.pop_back();
auto slot = findSlot(candidate, current_cpu, 0);
if (slot.aest == INT_MAX && push)
slot = findSlot(candidate, current_cpu, 1);
if (slot.aest == INT_MAX)
continue;
if (std::all_of(candidate->childs.begin(), candidate->childs.end(), [](Edge_t child) {
return child.first->isScheduled();
})) {
if (slot.aest < best_composite) {
best_process = current_cpu;
best_composite = slot.aest;
best_slot = slot;
}
}
else if (candidate->hasChilds()) {
auto dcpl = initDCPL();
auto old_edges = insertTaskInCPU(current_cpu, candidate, slot.index);
initAEST();
initALST();
Edge_t smallest_child {nullptr, 0};
for (auto child : candidate->childs) {
if (child.first->isScheduled())
continue;
if (smallest_child.first == nullptr) {
smallest_child = child;
continue;
}
if (smallest_child.first->getALST() - smallest_child.first->getAEST()
> child.first->getALST() - child.first->getAEST()) {
smallest_child = child;
}
}
auto child_slot = findSlot(smallest_child.first, current_cpu, false);
auto dcpl_with_child = computeDCPL(smallest_child.first, current_cpu);
if (child_slot.aest != INT_MAX and child_slot.aest + slot.aest < best_composite and dcpl_with_child <= dcpl) {
best_process = current_cpu;
best_composite = slot.aest + child_slot.aest;
best_slot = slot;
}
removeTaskFromCPU(current_cpu, candidate);
initAEST();
initALST();
for (auto opt_edge : old_edges) {
if (opt_edge.has_value()) {
auto double_edge = *opt_edge;
addEdge(double_edge.first.first, double_edge.second.first, double_edge.first.second);
}
}
}
}
if (best_process == -1) {
best_process = lastCPU();
incLastCPU();
}
if (best_process == lastCPU())
incLastCPU();
insertTaskInCPU(best_process, candidate, best_slot.index);
}
void GraphDCP::DCP() {
initAEST();
initALST();
to_dot();
std::vector<TaskDCP*> worklists;
worklists.reserve(nodes.size());
for (auto& node : nodes)
worklists.push_back(&node);
while (!worklists.empty()) {
auto candidate = findCandidate(worklists);
selectProcessor(candidate, candidate->isCP());
initAEST();
initALST();
fastRemove(worklists, candidate);
}
to_dot();
}
void GraphDCP::to_dot() {
static int index = 0;
std::string outputDir = onnx_mlir::getOutputDir();
if (outputDir.empty())
return;
std::string graphDir = outputDir + "/DCPGraph";
onnx_mlir::createDirectory(graphDir);
std::fstream file(graphDir + "/graph_" + std::to_string(index++) + ".dot", std::ios::out);
file << "digraph G {\n";
if (mapCPUTasks.size() != 0) {
for (CPU c = 0; c < lastCPU(); c++) {
file << "subgraph cluster_" << c << "{\nstyle=filled;\ncolor=lightgrey;\n";
for (auto node : mapCPUTasks[c]) {
file << node->Id() << " [label=\"";
file << "n:" << node->Id() << "\n";
file << "aest:" << node->getAEST() << "\n";
file << "alst:" << node->getALST() << "\n";
file << "weight:" << node->getWeight() << "\"]\n";
}
file << " }\n";
}
}
else {
for (auto& node : nodes) {
file << node.Id() << " [label=\"";
file << "n:" << node.Id() << "\n";
file << "aest:" << node.getAEST() << "\n";
file << "alst:" << node.getALST() << "\n";
file << "weight:" << node.getWeight() << "\"]\n";
}
}
for (auto& node : nodes) {
for (auto& child : node.childs) {
file << node.Id() << " -> " << child.first->Id();
file << " [label=\"" << child.second << "\"]\n";
}
}
file << "}\n";
file.flush();
file.close();
}
DCPAnalysisResult GraphDCP::getResult() {
DCPAnalysisResult ret;
std::vector<TaskDCP*> roots = getRoots();
UniqueWorkList<std::vector<TaskDCP*>> worklists(roots);
worklists.reserve(nodes.size());
size_t i = 0;
while (i != worklists.size()) {
bool modified = true;
while (modified) {
modified = false;
for (auto& child : worklists.at(i)->childs) {
if (worklists.allElementContained(
child.first->parents.begin(), child.first->parents.end(), [](Edge_t edge) { return edge.first; })) {
modified |= worklists.push_back(child.first);
}
}
}
i++;
}
ret.dominanceOrderCompute.reserve(worklists.size());
for (auto elem : worklists)
ret.dominanceOrderCompute.push_back(elem->getSpatWeightedCompute());
for (auto [cpu, nodes] : mapCPUTasks) {
size_t i = 0;
for (auto node : nodes) {
ret.computeToCPUMap[node->getSpatWeightedCompute()] = cpu;
if (i++ == nodes.size() - 1){
ret.isLastComputeOfACpu.insert(node->getSpatWeightedCompute());
ret.cpuToLastComputeMap[cpu] = node->getSpatWeightedCompute();
}
}
}
return ret;
}

View File

@@ -0,0 +1,67 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include <list>
#include <optional>
#include <unordered_map>
#include <vector>
#include "DCPAnalysis.hpp"
#include "Task.hpp"
#include "Utils.hpp"
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
void removeEdge(TaskDCP* parent, TaskDCP* child);
int getTranferCost(TaskDCP* parent, TaskDCP* child);
class GraphDCP {
struct FindSlot {
int aest;
int index;
};
std::vector<TaskDCP> nodes;
std::unordered_map<CPU, std::list<TaskDCP*>> mapCPUTasks;
CPU last_cpu = 0;
std::array<std::optional<Edge_pair>, 2> insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position);
void removeTaskFromCPU(CPU cpu, TaskDCP* task);
std::vector<TaskDCP*> getRoots();
void initAEST();
int computeAEST(TaskDCP* task, CPU cpu);
int initDCPL();
int computeDCPL(TaskDCP* task, CPU cpu);
void initALST();
std::unordered_map<TaskDCP*, int> computeALST(TaskDCP* task, CPU cpu);
TaskDCP* findCandidate(std::vector<TaskDCP*> nodes);
void selectProcessor(TaskDCP* candidate, bool push);
CPU lastCPU() const { return last_cpu; }
void incLastCPU() { last_cpu++; }
FindSlot findSlot(TaskDCP* candidate, CPU cpu, bool push);
void to_dot();
public:
void DCP();
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatWeightedCompute> spatWeightedComputes,
llvm::ArrayRef<EdgesIndex> edges)
: nodes(), mapCPUTasks() {
for (auto spatWeightedCompute : spatWeightedComputes)
nodes.emplace_back(spatWeightedCompute);
for (auto [start, end, weight] : edges)
makeEdge(start, end, weight);
}
DCPAnalysisResult getResult();
void makeEdge(size_t parent_index, size_t child_index, Weight_t weight) {
addEdge(&nodes[parent_index], &nodes[child_index], weight);
}
size_t taskInCPU(CPU cpu) { return mapCPUTasks[cpu].size(); }
};

View File

@@ -0,0 +1,49 @@
#include <optional>
#include "Graph.hpp"
#include "Task.hpp"
#include "Uniqueworklist.hpp"
std::optional<Edge_t> TaskDCP::addChild(TaskDCP* child, Weight_t weight) {
std::optional<Edge_t> oldEdge = std::nullopt;
auto founded_element =
std::find_if(childs.begin(), childs.end(), [child](Edge_t element) { return child == element.first; });
if (founded_element != childs.end()) {
oldEdge = *founded_element;
fastRemove(childs, founded_element);
}
childs.emplace_back(child, weight);
return oldEdge;
}
std::optional<Edge_t> TaskDCP::addParent(TaskDCP* parent, Weight_t weight) {
std::optional<Edge_t> oldEdge = std::nullopt;
auto founded_element =
std::find_if(parents.begin(), parents.end(), [parent](Edge_t element) { return parent == element.first; });
if (founded_element != parents.end()) {
oldEdge = *founded_element;
fastRemove(parents, founded_element);
}
parents.emplace_back(parent, weight);
return oldEdge;
}
bool TaskDCP::hasDescendent(TaskDCP* child) {
UniqueWorkList<std::vector<TaskDCP*>> worklist;
worklist.reserve(32);
worklist.push_back(this);
while (!worklist.empty()) {
TaskDCP* task = worklist.back();
worklist.pop_back();
if (task == child)
return true;
for (auto c : task->childs)
worklist.push_back(c.first);
}
return false;
}
//TODO fare qualcosa di sensato
int TaskDCP::computeWeight(GraphDCP* graph, CPU cpu) {
return orig_weight;
}

View File

@@ -0,0 +1,87 @@
#pragma once
#include <cassert>
#include <cstdint>
#include <optional>
#include <vector>
#include "../SpatialOps.hpp"
#include "Utils.hpp"
std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
void removeEdge(TaskDCP* parent, TaskDCP* child);
class TaskDCP {
onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute;
int aest;
int alst;
std::optional<CPU> scheduledCPU;
int weight;
int orig_weight;
std::optional<Edge_t> addChild(TaskDCP* child, Weight_t weight);
std::optional<Edge_t> addChild(TaskDCP& child, Weight_t weight) { return addChild(&child, weight); }
void removeChild(TaskDCP* to_remove) { fastRemove(childs, to_remove); }
void removeChild(TaskDCP& to_remove) { fastRemove(childs, &to_remove); }
std::optional<Edge_t> addParent(TaskDCP* parent, Weight_t weight);
std::optional<Edge_t> addParent(TaskDCP& parent, Weight_t weight) { return addParent(&parent, weight); }
void removeParent(TaskDCP* to_remove) { fastRemove(parents, to_remove); }
void removeParent(TaskDCP& to_remove) { fastRemove(parents, &to_remove); }
public:
std::vector<Edge_t> parents;
std::vector<Edge_t> childs;
TaskDCP() = default;
TaskDCP(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute)
: spatWeightedCompute(spatWeightedCompute),
aest(0),
alst(0),
scheduledCPU(),
weight(getSpatWeightCompute(spatWeightedCompute)),
orig_weight(weight),
parents(),
childs() {}
TaskDCP(const TaskDCP& node) = delete;
TaskDCP(TaskDCP&& node) = default;
void setCPU(CPU cpu) { scheduledCPU = cpu; }
std::optional<CPU> getCPU() const { return scheduledCPU; }
void resetCPU() { scheduledCPU = std::nullopt; }
int getWeight() {
if (isScheduled())
return weight;
else
return orig_weight;
}
void setWeight(int val) { weight = val; }
void resetWeight() { weight = orig_weight; }
int computeWeight(GraphDCP* graph, CPU cpu);
bool hasParent() { return parents.size() != 0; }
bool hasChilds() { return childs.size() != 0; }
int getAEST() { return aest; }
int getALST() { return alst; }
void setAEST(int val) {
assert(val >= 0);
aest = val;
}
void setALST(int val) {
assert(val >= 0 && val >= aest);
alst = val;
}
bool hasDescendent(TaskDCP* child);
int64_t Id() const { return (int64_t)spatWeightedCompute.getAsOpaquePointer(); }
bool isCP() const { return alst == aest; }
bool isScheduled() const { return scheduledCPU.has_value(); }
onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute(){return spatWeightedCompute;}
friend std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
friend void removeEdge(TaskDCP* parent, TaskDCP* child);
friend int getTranferCost(TaskDCP* parent, TaskDCP* child);
};

View File

@@ -0,0 +1,82 @@
#pragma once
#include <cassert>
#include <type_traits>
#include <unordered_set>
template <typename T, typename = void>
struct has_pop_front : std::false_type {};
template <typename T>
struct has_pop_front<T, std::void_t<decltype(std::declval<T>().pop_front())>> : std::true_type {};
template <typename T>
class UniqueWorkList {
using V = typename T::value_type;
T storage;
std::unordered_set<V> set;
public:
UniqueWorkList() = default;
template <typename arg_ty>
UniqueWorkList(const arg_ty& from)
: storage() {
for (auto& element : from) {
if (set.count(element) == 0) {
storage.push_back(element);
set.insert(element);
}
}
}
bool empty() const { return storage.empty(); }
void reserve(size_t val) { return storage.reserve(val); }
size_t size() const { return storage.size(); }
V& at(size_t i) { return storage.at(i); }
const V& at(size_t i) const { return storage.at(i); }
V& front() { return storage.front(); }
V& back() { return storage.back(); }
bool push_back(const V& val) {
if (set.count(val) == 0) {
storage.push_back(val);
set.insert(val);
return true;
}
return false;
}
void pop_front() {
if constexpr (has_pop_front<T>::value)
storage.pop_front();
else
assert(false && "Underlying storage type does not support pop_front()");
}
auto cbegin() const { return storage.cbegin(); }
auto cend() const { return storage.cend(); }
void pop_back() { storage.pop_back(); }
template <typename Iterator, typename Mapper>
bool allElementContained(Iterator start, Iterator end, Mapper map) {
while (start != end) {
if (set.count(map(*start)) == 0)
return false;
std::advance(start, 1);
}
return true;
}
auto begin() {
return storage.begin();
}
auto end() {
return storage.end();
}
};

View File

@@ -0,0 +1,60 @@
#pragma once
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "llvm/Support/Casting.h"
#include <algorithm>
#include <cstdint>
#include <utility>
#include <vector>
#include "../SpatialOps.hpp"
#include "src/Support/TypeUtilities.hpp"
using CPU = int;
using Weight_t = int;
class TaskDCP;
class GraphDCP;
using Edge_t = std::pair<TaskDCP*, Weight_t>;
using Edge_pair = std::pair<Edge_t, Edge_t>;
using EdgesIndex = std::tuple<int64_t, int64_t, int64_t>;
template <typename T>
void fastRemove(std::vector<std::pair<T*, Weight_t>>& vector, T* to_remove) {
auto position =
std::find_if(vector.begin(), vector.end(), [to_remove](Edge_t edge) { return edge.first == to_remove; });
if (position != vector.end()) {
std::swap(*(vector.end() - 1), *position);
vector.pop_back();
}
}
inline void fastRemove(std::vector<TaskDCP*>& vector, TaskDCP* to_remove) {
auto position =
std::find_if(vector.begin(), vector.end(), [to_remove](TaskDCP* element) { return element == to_remove; });
if (position != vector.end()) {
std::swap(*(vector.end() - 1), *position);
vector.pop_back();
}
}
template <typename T, typename P>
void fastRemove(std::vector<std::pair<T*, Weight_t>>& vector, P position) {
if (position != vector.end()) {
std::swap(*(vector.end() - 1), *position);
vector.pop_back();
}
}
// TODO Fare qualcosa di sensato
inline int64_t getSpatWeightCompute(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) {
int64_t tot = 0;
for (auto& region : spatWeightedCompute.getBody()) {
for (auto& inst : region) {
for(auto result : inst.getResults()){
if(auto element = llvm::dyn_cast<mlir::ShapedType>(result.getType()))
tot += onnx_mlir::getSizeInBytes(element);
}
}
}
return tot;
}

View File

@@ -47,6 +47,7 @@ def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegmen
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);
let hasVerifier = 1; let hasVerifier = 1;
let hasFolder = 1;
let assemblyFormat = [{ let assemblyFormat = [{
`[` $weights `]` `(` $inputs `)` attr-dict `:` `[` type($weights) `]` `(` type($inputs) `)` `->` type($outputs) $body `[` $weights `]` `(` $inputs `)` attr-dict `:` `[` type($weights) `]` `(` type($inputs) `)` `->` type($outputs) $body

View File

@@ -204,45 +204,47 @@ LogicalResult SpatWeightedCompute::verify() {
// Check that it has a terminator, it is a yieldOp, and it has a single // Check that it has a terminator, it is a yieldOp, and it has a single
// operand with the same type as the result // operand with the same type as the result
auto& block = getBody().front(); auto& block = getBody().front();
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator()); if (block.mightHaveTerminator()) {
if (!yieldOp) auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
return emitError("ComputeOp must have a single yield operation"); if (!yieldOp)
return emitError("ComputeOp must have a single yield operation");
auto resultTypes = getResultTypes(); auto resultTypes = getResultTypes();
auto yieldTypes = yieldOp->getOperandTypes(); auto yieldTypes = yieldOp->getOperandTypes();
if (resultTypes.size() != yieldTypes.size()) { if (resultTypes.size() != yieldTypes.size()) {
return emitError("ComputeOp must have same number of results as yieldOp " return emitError("ComputeOp must have same number of results as yieldOp "
"operands"); "operands");
}
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
auto resultType = std::get<0>(it);
auto yieldType = std::get<1>(it);
// Same type and compatible shape
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) {
return emitError("ComputeOp output must be of the same type as yieldOp "
"operand");
} }
// Same encoding for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) { auto resultType = std::get<0>(it);
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) { auto yieldType = std::get<1>(it);
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) {
return emitError("ComputeOp output must have the same encoding as " // Same type and compatible shape
"yieldOp operand"); if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) {
return emitError("ComputeOp output must be of the same type as yieldOp "
"operand");
}
// Same encoding
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) {
return emitError("ComputeOp output must have the same encoding as "
"yieldOp operand");
}
}
else {
return emitError("ComputeOp output has an encoding while yieldOp "
"operand does not have one");
} }
} }
else { else {
return emitError("ComputeOp output has an encoding while yieldOp " // If result does not have an encoding, yield shouldn't either
"operand does not have one"); if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
} return emitError("ComputeOp output must not have an encoding if "
} "yieldOp operand has one");
else { }
// If result does not have an encoding, yield shouldn't either
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
return emitError("ComputeOp output must not have an encoding if "
"yieldOp operand has one");
} }
} }
} }
@@ -255,6 +257,27 @@ LogicalResult SpatWeightedCompute::verify() {
return success(); return success();
} }
LogicalResult SpatWeightedCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
Block& block = getBody().front();
if (!llvm::hasSingleElement(block))
return failure();
auto yieldOp = dyn_cast<SpatYieldOp>(block.front());
if (!yieldOp)
return failure();
for (Value yieldedValue : yieldOp.getOperands()) {
if (auto blockArg = dyn_cast<BlockArgument>(yieldedValue)) {
if (blockArg.getOwner() == &block) {
results.push_back(getOperand(blockArg.getArgNumber()));
continue;
}
}
results.push_back(yieldedValue);
}
return success();
}
} // namespace spatial } // namespace spatial
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -0,0 +1,318 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <cstdint>
#include <functional>
#include <memory>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
using SpatWeightedCompute = spatial::SpatWeightedCompute;
struct ComputeValueResults {
// Value yielded by the yieldOp
Value innerValue;
};
class LazyInsertComputeResult {
using InsertPoint = mlir::IRRewriter::InsertPoint;
ComputeValueResults computeResults;
Value channelNewOpVal;
bool onlyChannel;
std::function<void(InsertPoint insertPoint)> channelSendInserter;
InsertPoint insertPointSend;
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter;
public:
LazyInsertComputeResult(ComputeValueResults computeValueResults,
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter,
bool isOnlyChannel)
: computeResults(computeValueResults),
onlyChannel(isOnlyChannel),
channelSendInserter(nullptr),
insertPointSend({}),
channelNewInserter(channelNewInserter) {}
struct ChannelOrLocalOp {
Value data;
bool isChannel;
};
bool onlyChanneled() const { return onlyChannel; }
ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatWeightedCompute spatWeightedCompute) {
if (channelSendInserter == nullptr) {
auto [first, second] = channelNewInserter();
channelNewOpVal = first;
channelSendInserter = second;
auto op = computeResults.innerValue.getDefiningOp();
if (op) {
insertPointSend = InsertPoint(op->getBlock(), ++Block::iterator(op));
}
else {
auto BB = computeResults.innerValue.getParentBlock();
insertPointSend = InsertPoint(BB, BB->begin());
}
}
if (spatWeightedCompute) {
for (auto& BB : spatWeightedCompute.getBody())
if (&BB == insertPointSend.getBlock())
return {computeResults.innerValue, false};
}
channelSendInserter(insertPointSend);
return {channelNewOpVal, true};
}
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
};
struct MergeComputeNodePass : PassWrapper<MergeComputeNodePass, OperationPass<func::FuncOp>> {
private:
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
DenseMap<SpatWeightedCompute, SpatWeightedCompute> oldToNewComputeMap;
DenseMap<int64_t, SpatWeightedCompute> cputToNewComputeMap;
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodePass)
StringRef getArgument() const override { return "pim-merge-node-pass"; }
StringRef getDescription() const override {
return "Merge Spatial-Weighted-Compute-Node in order to reduce the total "
"execution time";
}
LogicalResult initialize(MLIRContext* context) override { return success(); }
void runOnOperation() override {
DCPAnalysisResult& analysisResult = getAnalysis<spatial::DCPAnalysis>().getResult();
auto& lastComputeOfCpu = analysisResult.isLastComputeOfACpu;
auto& cpuToLastComputeMap = analysisResult.cpuToLastComputeMap;
IRRewriter rewriter(&getContext());
for (auto currentComputeNode : analysisResult.dominanceOrderCompute) {
size_t cpu = analysisResult.computeToCPUMap.at(currentComputeNode);
if (!cputToNewComputeMap.contains(cpu)) {
ValueTypeRange<ResultRange> newWeightedComputeType = cpuToLastComputeMap.at(cpu).getResultTypes();
auto [newWeightedCompute, computeValueResult] = createNewComputeNode(
currentComputeNode, newWeightedComputeType, lastComputeOfCpu.contains(currentComputeNode));
cputToNewComputeMap[cpu] = newWeightedCompute;
newComputeNodeResults.insert(
std::make_pair(currentComputeNode,
createLazyComputeResult(
newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
}
else {
auto [newWeightedCompute, computeValueResult] = mergeIntoComputeNode(
cputToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode));
newComputeNodeResults.insert(
std::make_pair(currentComputeNode,
createLazyComputeResult(
newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
}
}
for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute)))
computeNodetoRemove.erase();
func::FuncOp func = getOperation();
dumpModule(cast<ModuleOp>(func->getParentOp()), "SpatialDCPMerged");
}
private:
std::pair<SpatWeightedCompute, ComputeValueResults> createNewComputeNode(
SpatWeightedCompute oldWeightedCompute, ValueTypeRange<ResultRange> newWeightedComputeType, bool lastCompute) {
func::FuncOp func = getOperation();
auto loc = func.getLoc();
IRRewriter rewriter(&getContext());
rewriter.setInsertionPoint(&*std::prev(func.getBody().front().end(), 1));
ComputeValueResults computeValueResults;
IRMapping mapper;
llvm::SmallVector<Value> newComputeOperand;
llvm::SmallVector<Type> newBBOperandType;
llvm::SmallVector<Location> newBBLocations;
for (auto arg : oldWeightedCompute.getWeights())
newComputeOperand.push_back(arg);
for (auto arg : oldWeightedCompute.getInputs())
if (!llvm::isa<SpatWeightedCompute>(arg.getDefiningOp())) {
newComputeOperand.push_back(arg);
newBBOperandType.push_back(arg.getType());
newBBLocations.push_back(loc);
}
auto newWeightedCompute = SpatWeightedCompute::create(rewriter, loc, newWeightedComputeType, newComputeOperand);
rewriter.createBlock(
&newWeightedCompute.getBody(), newWeightedCompute.getBody().end(), newBBOperandType, newBBLocations);
newWeightedCompute.getProperties().setOperandSegmentSizes(
{(int) oldWeightedCompute.getWeights().size(), (int) newBBOperandType.size()});
rewriter.setInsertionPointToEnd(&newWeightedCompute.getBody().front());
int indexNew = 0;
int indexOld = oldWeightedCompute.getWeights().size();
int indexOldStart = oldWeightedCompute.getWeights().size();
for (; indexOld < oldWeightedCompute.getNumOperands(); ++indexOld) {
if (!llvm::isa<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp())) {
mapper.map(oldWeightedCompute.getBody().front().getArgument(indexOld - indexOldStart),
newWeightedCompute.getBody().front().getArgument(indexNew++));
}
else {
auto argWeightCompute =
llvm::dyn_cast_if_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp());
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
auto [channelVal, _] = lazyArgWeight.getAsChannelValueAndInsertSender();
spatial::SpatChannelReceiveOp reciveOp =
spatial::SpatChannelReceiveOp::create(rewriter, loc, channelVal.getType(), channelVal);
mapper.map(oldWeightedCompute.getBody().front().getArgument(indexOld - indexOldStart), reciveOp);
}
}
for (auto& op : oldWeightedCompute.getOps()) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0));
if (lastCompute)
rewriter.clone(op, mapper);
}
else
rewriter.clone(op, mapper);
}
for (auto users : oldWeightedCompute->getUsers())
if (auto funcRet = dyn_cast<func::ReturnOp>(users))
funcRet.setOperand(0, newWeightedCompute.getResult(0));
oldToNewComputeMap.insert({oldWeightedCompute, newWeightedCompute});
return {cast<SpatWeightedCompute>(newWeightedCompute), computeValueResults};
}
std::pair<SpatWeightedCompute, ComputeValueResults>
mergeIntoComputeNode(SpatWeightedCompute toCompute, SpatWeightedCompute fromCompute, bool lastCompute) {
func::FuncOp func = getOperation();
auto loc = func.getLoc();
IRRewriter rewriter(&getContext());
IRMapping mapper;
auto weightMutableIter = toCompute.getWeightsMutable();
for (auto weight : fromCompute.getWeights()) {
int sizeW = toCompute.getWeights().size();
int sizeI = toCompute.getInputs().size();
weightMutableIter.append(weight);
assert(sizeW + 1 == toCompute.getWeights().size());
assert(sizeI == toCompute.getInputs().size());
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
}
auto inputeArgMutable = toCompute.getInputsMutable();
// Insert reciveOp
rewriter.setInsertionPointToEnd(&toCompute.getBody().front());
int newBBindex = toCompute.getBody().front().getArguments().size();
for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) {
if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatWeightedCompute>(arg.getDefiningOp())) {
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
LazyInsertComputeResult::ChannelOrLocalOp channelOrLocal =
lazyArgWeight.getAsChannelValueAndInsertSender(toCompute);
if (channelOrLocal.isChannel) {
spatial::SpatChannelReceiveOp reciveOp =
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelOrLocal.data);
mapper.map(fromCompute.getBody().front().getArgument(bbIndex), reciveOp.getResult());
}
else {
mapper.map(fromCompute.getBody().front().getArgument(bbIndex), channelOrLocal.data);
}
}
else {
int sizeW = toCompute.getWeights().size();
int sizeI = toCompute.getInputs().size();
inputeArgMutable.append(arg);
assert(sizeW == toCompute.getWeights().size());
assert(sizeI + 1 == toCompute.getInputs().size());
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
toCompute.getBody().front().addArgument(
fromCompute.getBody().front().getArgument(bbIndex).getType(),loc);
mapper.map(fromCompute.getBody().front().getArgument(bbIndex),
toCompute.getBody().front().getArgument(newBBindex++));
}
}
for (auto oldBBarg : fromCompute.getBody().front().getArguments())
assert(mapper.contains(oldBBarg));
ComputeValueResults computeValueResults;
for (auto& op : fromCompute.getOps()) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0));
if (lastCompute)
rewriter.clone(op, mapper);
}
else {
rewriter.clone(op, mapper);
}
}
for (auto users : fromCompute->getUsers())
if (auto funcRet = dyn_cast<func::ReturnOp>(users))
funcRet.setOperand(0, toCompute.getResult(0));
oldToNewComputeMap.insert({fromCompute, toCompute});
return {cast<SpatWeightedCompute>(toCompute), computeValueResults};
}
LazyInsertComputeResult createLazyComputeResult(SpatWeightedCompute weightedCompute,
ComputeValueResults computeValueResults,
bool lastCompute) {
func::FuncOp funcOp = cast<func::FuncOp>(weightedCompute->getParentOp());
auto* context = &getContext();
auto loc = funcOp.getLoc();
IRRewriter rewriter(context);
rewriter.setInsertionPointToStart(&funcOp.front());
auto saveInsertionPointChnNew = rewriter.saveInsertionPoint();
auto insertNew = [saveInsertionPointChnNew, context, loc, computeValueResults]() {
IRRewriter rewriter(context);
rewriter.restoreInsertionPoint(saveInsertionPointChnNew);
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(context));
auto channelVal = channelOp.getResult();
auto insertVal =
[&context, loc, computeValueResults, channelVal](mlir::IRRewriter::InsertPoint insertPointChnSend) {
IRRewriter rewriter(context);
rewriter.restoreInsertionPoint(insertPointChnSend);
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.innerValue);
return spatSend;
};
std::pair<Value, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {channelVal, insertVal};
return ret;
};
return LazyInsertComputeResult(computeValueResults, insertNew, false);
}
};
} // namespace
std::unique_ptr<Pass> createMergeComputeNodePass() { return std::make_unique<MergeComputeNodePass>(); }
} // namespace onnx_mlir

View File

@@ -15,6 +15,8 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
std::unique_ptr<mlir::Pass> createPimBufferizationPass(); std::unique_ptr<mlir::Pass> createPimBufferizationPass();
std::unique_ptr<mlir::Pass> createMergeComputeNodePass();
std::unique_ptr<mlir::Pass> createPimConstantFoldingPass(); std::unique_ptr<mlir::Pass> createPimConstantFoldingPass();
std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass(); std::unique_ptr<mlir::Pass> createPimMaterializeConstantsPass();

View File

@@ -74,6 +74,7 @@ void PimAccelerator::registerPasses(int optLevel) const {
registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPimPass); registerPass(createSpatialToPimPass);
registerPass(createPimBufferizationPass); registerPass(createPimBufferizationPass);
registerPass(createMergeComputeNodePass);
registerPass(createPimConstantFoldingPass); registerPass(createPimConstantFoldingPass);
registerPass(createPimMaterializeConstantsPass); registerPass(createPimMaterializeConstantsPass);
registerPass(createPimVerificationPass); registerPass(createPimVerificationPass);