diff --git a/src/PIM/Common/PimCommon.cpp b/src/PIM/Common/PimCommon.cpp index 2de9861..ebfd18e 100644 --- a/src/PIM/Common/PimCommon.cpp +++ b/src/PIM/Common/PimCommon.cpp @@ -1,4 +1,5 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "llvm/Support/raw_os_ostream.h" diff --git a/src/PIM/Compiler/PimCompilerUtils.cpp b/src/PIM/Compiler/PimCompilerUtils.cpp index dd6b418..880c37b 100644 --- a/src/PIM/Compiler/PimCompilerUtils.cpp +++ b/src/PIM/Compiler/PimCompilerUtils.cpp @@ -33,6 +33,7 @@ void addPassesPim(OwningOpRef& module, } if (pimEmissionTarget >= EmitPim) { + pm.addPass(createMergeComputeNodePass()); pm.addPass(createSpatialToPimPass()); // pm.addPass(createCountInstructionPass()); pm.addPass(createMessagePass("Spatial lowered to Pim")); diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 759f4e7..d74ae85 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -14,6 +14,7 @@ #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Compiler/CompilerOptions.hpp" diff --git a/src/PIM/Dialect/Spatial/CMakeLists.txt b/src/PIM/Dialect/Spatial/CMakeLists.txt index ba07c16..1a412a6 100644 --- a/src/PIM/Dialect/Spatial/CMakeLists.txt +++ b/src/PIM/Dialect/Spatial/CMakeLists.txt @@ -4,6 +4,10 @@ add_onnx_mlir_dialect_doc(spat Spatial.td) add_pim_library(SpatialOps SpatialOps.cpp Transforms/SpatialBufferizableOpInterface.cpp + Transforms/MergeComputeNode/MergeComputeNodePass.cpp + DCPGraph/Graph.cpp + DCPGraph/Task.cpp + DCPGraph/DCPAnalysis.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.cpp b/src/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.cpp new file mode 100644 index 0000000..fa078a4 --- /dev/null +++ b/src/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.cpp @@ -0,0 +1,52 @@ +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" + +#include + +#include "../SpatialOps.hpp" +#include "DCPAnalysis.hpp" +#include "Graph.hpp" +#include "src/Support/TypeUtilities.hpp" + +namespace onnx_mlir { +namespace spatial { + +using namespace mlir; + +DCPAnalysisResult DCPAnalysis::runAnalysis() { + using EdgesIndex = std::tuple; + llvm::SmallVector spatWeightedComputes; + llvm::SmallVector edges; + for (auto& regions : entryOp->getRegions()) + for (SpatWeightedCompute spatWeightedCompute : regions.getOps()) + spatWeightedComputes.push_back(spatWeightedCompute); + + for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) { + for (Value input : spatWeightedCompute.getInputs()) { + if (auto spatWeightedComputeArgOp = llvm::dyn_cast_if_present(input.getDefiningOp()); + spatWeightedComputeArgOp) { + auto elemIter = llvm::find(spatWeightedComputes, spatWeightedComputeArgOp); + assert(elemIter != spatWeightedComputes.end()); + auto indexStartEdge = std::distance(spatWeightedComputes.begin(), elemIter); + ResultRange outputs = spatWeightedComputeArgOp.getResults(); + int64_t totalSize = 0; + for (auto output : outputs) { + ShapedType result = cast(output.getType()); + totalSize += getSizeInBytes(result); + } + edges.push_back({indexStartEdge, indexEndEdge, totalSize}); + } + } + } + + GraphDCP graphDCP(spatWeightedComputes, edges); + graphDCP.DCP(); + return graphDCP.getResult(); +} + +} // namespace spatial +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp b/src/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp new file mode 100644 index 0000000..eee0dd8 --- /dev/null +++ b/src/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include "mlir/IR/Operation.h" + +#include "llvm/ADT/DenseMap.h" + +#include + +#include "../SpatialOps.hpp" + +struct DCPAnalysisResult { + std::vector dominanceOrderCompute; + llvm::DenseMap computeToCPUMap; + llvm::DenseSet isLastComputeOfACpu; + llvm::DenseMap cpuToLastComputeMap; +}; + +namespace onnx_mlir { +namespace spatial { +struct DCPAnalysis { +private: + DCPAnalysisResult result; + mlir::Operation* entryOp; + DCPAnalysisResult runAnalysis(); + +public: + DCPAnalysis(mlir::Operation* op) + : entryOp(op) { + result = runAnalysis(); + } + DCPAnalysisResult& getResult() { return result; } +}; + +} // namespace spatial +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Graph.cpp b/src/PIM/Dialect/Spatial/DCPGraph/Graph.cpp new file mode 100644 index 0000000..15325f1 --- /dev/null +++ b/src/PIM/Dialect/Spatial/DCPGraph/Graph.cpp @@ -0,0 +1,496 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../../Common/PimCommon.hpp" +#include "DCPAnalysis.hpp" +#include "Graph.hpp" +#include "Task.hpp" +#include "Uniqueworklist.hpp" +#include "Utils.hpp" + +std::optional addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight) { + auto old_child = parent->addChild(child, weight); + auto old_parent = child->addParent(parent, weight); + assert(old_child.has_value() == old_parent.has_value() && "The edge must be present in both element"); + if (old_child.has_value()) { + + return { + {*old_parent, *old_child} + }; + } + return {}; +} + +void removeEdge(TaskDCP* parent, TaskDCP* child) { + parent->removeChild(child); + child->removeParent(parent); +} + +int getTranferCost(TaskDCP* parent, TaskDCP* child) { + if (parent->scheduledCPU.has_value() && child->scheduledCPU.has_value() + && *parent->scheduledCPU == *child->scheduledCPU) { + return 0; + } + auto child_position = + std::find_if(parent->childs.begin(), parent->childs.end(), [&child](Edge_t elem) { return elem.first == child; }); + assert(child_position != parent->childs.end()); + return child_position->second; +} + +std::array, 2> GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position) { + std::array, 2> ret; + task->setCPU(cpu); + task->setWeight(task->computeWeight(this, cpu)); + auto& list = mapCPUTasks[cpu]; + unsigned int total_size = list.size(); + assert(position <= total_size && "Inserting in a not valid position"); + auto inserted_point = list.begin(); + inserted_point = list.insert(std::next(list.begin(), position), task); + + if (inserted_point != list.begin()) { + auto precedent_point = std::prev(inserted_point, 1); + auto old_edge = addEdge(*precedent_point, *inserted_point, 0); + ret[0] = old_edge; + } + + if (std::next(inserted_point) != list.end()) { + auto next_point = std::next(inserted_point, 1); + auto old_edge = addEdge(*inserted_point, *next_point, 0); + ret[1] = old_edge; + } + return ret; +} + +void GraphDCP::removeTaskFromCPU(CPU cpu, TaskDCP* task) { + task->resetCPU(); + task->resetWeight(); + auto& list = mapCPUTasks[cpu]; + auto task_position = std::find(list.begin(), list.end(), task); + assert(task_position != list.end() && "Removing a not present task"); + if (task_position != list.begin()) { + auto precedent_point = std::prev(task_position, 1); + removeEdge(*precedent_point, *task_position); + } + + if (std::next(task_position) != list.end()) { + auto next_point = std::next(task_position, 1); + removeEdge(*task_position, *next_point); + } + list.erase(task_position); +} + +std::vector GraphDCP::getRoots() { + std::vector tmp; + for (auto& task : nodes) + if (!task.hasParent()) + tmp.push_back(&task); + return tmp; +} + +void GraphDCP::initAEST() { + UniqueWorkList> worklists(getRoots()); + + while (!worklists.empty()) { + TaskDCP& task = *worklists.front(); + bool modified = true; + while (modified) { + modified = false; + for (auto& child : task.childs) { + if (worklists.allElementContained( + child.first->parents.begin(), child.first->parents.end(), [](Edge_t edge) { return edge.first; })) { + modified |= worklists.push_back(child.first); + } + } + } + int max_parent_aest = 0; + for (auto& parent : task.parents) { + max_parent_aest = std::max( + parent.first->getAEST() + parent.first->getWeight() + getTranferCost(parent.first, &task), max_parent_aest); + } + task.setAEST(max_parent_aest); + worklists.pop_front(); + } +} + +int GraphDCP::computeAEST(TaskDCP* task, CPU cpu) { + int max_parent_aest = 0; + for (auto& parent : task->parents) { + int transfer_cost = 0; + if (!(parent.first->isScheduled() && cpu == *parent.first->getCPU())) + transfer_cost = getTranferCost(parent.first, task); + + max_parent_aest = std::max(parent.first->getAEST() + parent.first->getWeight() + transfer_cost, max_parent_aest); + } + return max_parent_aest; +} + +int GraphDCP::initDCPL() { + int max_aest = 0; + for (auto& node : nodes) + max_aest = std::max(node.getAEST() + node.getWeight(), max_aest); + return max_aest; +} + +int GraphDCP::computeDCPL(TaskDCP* task, CPU cpu) { + int max_aest = 0; + for (auto& node : nodes) + if (&node != task) + max_aest = std::max(node.getAEST() + node.getWeight(), max_aest); + else + max_aest = std::max(computeAEST(task, cpu) + node.computeWeight(this, cpu), max_aest); + return max_aest; +} + +void GraphDCP::initALST() { + int dcpl = initDCPL(); + std::vector roots = getRoots(); + UniqueWorkList> worklists(roots); + worklists.reserve(nodes.size()); + size_t i = 0; + while (i != worklists.size()) { + bool modified = true; + while (modified) { + modified = false; + for (auto& child : worklists.at(i)->childs) { + if (worklists.allElementContained( + child.first->parents.begin(), child.first->parents.end(), [](Edge_t edge) { return edge.first; })) { + modified |= worklists.push_back(child.first); + } + } + } + i++; + } + + while (!worklists.empty()) { + TaskDCP& node = *worklists.back(); + int min_alst = INT_MAX; + if (!node.hasChilds()) + min_alst = dcpl - node.getWeight(); + for (auto child : node.childs) + min_alst = std::min(min_alst, child.first->getALST() - node.getWeight() - getTranferCost(&node, child.first)); + node.setALST(min_alst); + worklists.pop_back(); + } +} + +std::unordered_map GraphDCP::computeALST(TaskDCP* task, CPU cpu) { + int dcpl = computeDCPL(task, cpu); + std::unordered_map temp_ALST; + UniqueWorkList> worklists(getRoots()); + size_t i = 0; + while (i != worklists.size()) { + + bool modified = true; + while (modified) { + modified = false; + for (auto& child : worklists.at(i)->childs) { + if (worklists.allElementContained( + child.first->parents.begin(), child.first->parents.end(), [](Edge_t edge) { return edge.first; })) { + modified |= worklists.push_back(child.first); + } + } + } + i++; + } + + while (!worklists.empty()) { + TaskDCP& node = *worklists.back(); + int min_alst = INT_MAX; + if (!node.hasChilds()) { + if (&node != task) + min_alst = dcpl - node.getWeight(); + else + min_alst = dcpl - node.computeWeight(this, cpu); + } + + for (auto child : node.childs) { + int transfer_cost = getTranferCost(&node, child.first); + if (&node == task && child.first->isScheduled() && cpu == *child.first->getCPU()) + transfer_cost = 0; + min_alst = std::min(min_alst, temp_ALST[child.first] - node.getWeight() - transfer_cost); + } + temp_ALST[&node] = min_alst; + worklists.pop_back(); + } + return temp_ALST; +} + +TaskDCP* GraphDCP::findCandidate(std::vector nodes) { + auto hasNoCPParentUnsecheduled = [](TaskDCP* node) { + return std::all_of( + node->parents.begin(), node->parents.end(), [](Edge_t element) { return element.first->isScheduled() == true; }); + }; + + auto findBestNode = [](auto lft, auto rgt) { + int lft_difference = (*lft)->getALST() - (*lft)->getAEST(); + int rgt_difference = (*rgt)->getALST() - (*rgt)->getAEST(); + if (lft_difference < rgt_difference) + return lft; + if (rgt_difference < lft_difference) + return rgt; + if ((*lft)->getAEST() < (*rgt)->getAEST()) + return lft; + return rgt; + }; + + auto valid_node = std::find_if(nodes.begin(), nodes.end(), hasNoCPParentUnsecheduled); + auto best_node = valid_node; + + while (valid_node != nodes.end()) { + if (!hasNoCPParentUnsecheduled(*valid_node)) { + std::advance(valid_node, 1); + continue; + } + best_node = findBestNode(valid_node, best_node); + std::advance(valid_node, 1); + } + return *best_node; +} + +GraphDCP::FindSlot GraphDCP::findSlot(TaskDCP* candidate, CPU cpu, bool push) { + int aest_on_cpu = computeAEST(candidate, cpu); + auto tmp_ALST = computeALST(candidate, cpu); + int final_time = tmp_ALST[candidate] + candidate->computeWeight(this, cpu); + std::list& scheduledTasks = mapCPUTasks[cpu]; + + // Search last non ancestor + auto after_last_anc = scheduledTasks.end(); + while (after_last_anc != scheduledTasks.begin()) { + if (after_last_anc != scheduledTasks.end() && (*after_last_anc)->hasDescendent(candidate)) + break; + after_last_anc = std::prev(after_last_anc, 1); + } + + if (after_last_anc != scheduledTasks.end() && (*after_last_anc)->hasDescendent(candidate)) + std::advance(after_last_anc, 1); + + auto first_descendent_index = scheduledTasks.begin(); + while (first_descendent_index != scheduledTasks.end()) { + if (first_descendent_index != scheduledTasks.end() && candidate->hasDescendent(*first_descendent_index)) + break; + first_descendent_index = std::next(first_descendent_index, 1); + } + + auto iter_index = after_last_anc; + auto best_index = scheduledTasks.end(); + int best_max; + assert(std::distance(scheduledTasks.begin(), after_last_anc) + <= std::distance(scheduledTasks.begin(), first_descendent_index)); + + bool keep = true; + while (keep && iter_index != scheduledTasks.end()) { + if (iter_index == first_descendent_index) + keep = false; + int min = INT_MAX; + if (!push) + min = std::min(final_time, (*iter_index)->getAEST()); + else if (tmp_ALST.count(*iter_index) == 1) + min = std::min(final_time, tmp_ALST[*iter_index]); + else + min = std::min(final_time, (*iter_index)->getALST()); + int max = aest_on_cpu; + if (iter_index != scheduledTasks.begin()) { + auto prev_iter = std::prev(iter_index); + max = std::max(aest_on_cpu, (*prev_iter)->getAEST() + (*prev_iter)->getWeight()); + } + if (min - max >= candidate->computeWeight(this, cpu)) { + best_index = iter_index; + best_max = max; + break; + } + + std::advance(iter_index, 1); + if (iter_index == scheduledTasks.end()) + keep = false; + } + + if (best_index != scheduledTasks.end()) + return FindSlot {best_max, (int) std::distance(scheduledTasks.begin(), best_index)}; + + if (iter_index == scheduledTasks.end()) { + best_max = aest_on_cpu; + if (iter_index != scheduledTasks.begin()) { + auto prev_iter = std::prev(iter_index); + best_max = std::max(aest_on_cpu, (*prev_iter)->getAEST() + (*prev_iter)->getWeight()); + } + return FindSlot {best_max, (int) std::distance(scheduledTasks.begin(), scheduledTasks.end())}; + } + return FindSlot {INT_MAX, 0}; +} + +void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) { + + std::vector processors; + processors.reserve(lastCPU()); + for (CPU c = push ? lastCPU() : lastCPU(); c >= 0; c--) + processors.push_back(c); + + CPU best_process = -1; + int best_composite = INT_MAX; + FindSlot best_slot; + + while (!processors.empty()) { + CPU current_cpu = processors.back(); + processors.pop_back(); + auto slot = findSlot(candidate, current_cpu, 0); + if (slot.aest == INT_MAX && push) + slot = findSlot(candidate, current_cpu, 1); + if (slot.aest == INT_MAX) + continue; + if (std::all_of(candidate->childs.begin(), candidate->childs.end(), [](Edge_t child) { + return child.first->isScheduled(); + })) { + if (slot.aest < best_composite) { + best_process = current_cpu; + best_composite = slot.aest; + best_slot = slot; + } + } + else if (candidate->hasChilds()) { + auto dcpl = initDCPL(); + auto old_edges = insertTaskInCPU(current_cpu, candidate, slot.index); + initAEST(); + initALST(); + Edge_t smallest_child {nullptr, 0}; + for (auto child : candidate->childs) { + if (child.first->isScheduled()) + continue; + if (smallest_child.first == nullptr) { + smallest_child = child; + continue; + } + if (smallest_child.first->getALST() - smallest_child.first->getAEST() + > child.first->getALST() - child.first->getAEST()) { + smallest_child = child; + } + } + auto child_slot = findSlot(smallest_child.first, current_cpu, false); + auto dcpl_with_child = computeDCPL(smallest_child.first, current_cpu); + if (child_slot.aest != INT_MAX and child_slot.aest + slot.aest < best_composite and dcpl_with_child <= dcpl) { + best_process = current_cpu; + best_composite = slot.aest + child_slot.aest; + best_slot = slot; + } + removeTaskFromCPU(current_cpu, candidate); + initAEST(); + initALST(); + for (auto opt_edge : old_edges) { + if (opt_edge.has_value()) { + auto double_edge = *opt_edge; + addEdge(double_edge.first.first, double_edge.second.first, double_edge.first.second); + } + } + } + } + if (best_process == -1) { + best_process = lastCPU(); + incLastCPU(); + } + if (best_process == lastCPU()) + incLastCPU(); + insertTaskInCPU(best_process, candidate, best_slot.index); +} + +void GraphDCP::DCP() { + initAEST(); + initALST(); + to_dot(); + std::vector worklists; + worklists.reserve(nodes.size()); + for (auto& node : nodes) + worklists.push_back(&node); + while (!worklists.empty()) { + auto candidate = findCandidate(worklists); + selectProcessor(candidate, candidate->isCP()); + initAEST(); + initALST(); + fastRemove(worklists, candidate); + } + to_dot(); +} + +void GraphDCP::to_dot() { + static int index = 0; + std::string outputDir = onnx_mlir::getOutputDir(); + if (outputDir.empty()) + return; + std::string graphDir = outputDir + "/DCPGraph"; + onnx_mlir::createDirectory(graphDir); + std::fstream file(graphDir + "/graph_" + std::to_string(index++) + ".dot", std::ios::out); + file << "digraph G {\n"; + if (mapCPUTasks.size() != 0) { + for (CPU c = 0; c < lastCPU(); c++) { + file << "subgraph cluster_" << c << "{\nstyle=filled;\ncolor=lightgrey;\n"; + for (auto node : mapCPUTasks[c]) { + file << node->Id() << " [label=\""; + file << "n:" << node->Id() << "\n"; + file << "aest:" << node->getAEST() << "\n"; + file << "alst:" << node->getALST() << "\n"; + file << "weight:" << node->getWeight() << "\"]\n"; + } + file << " }\n"; + } + } + else { + for (auto& node : nodes) { + file << node.Id() << " [label=\""; + file << "n:" << node.Id() << "\n"; + file << "aest:" << node.getAEST() << "\n"; + file << "alst:" << node.getALST() << "\n"; + file << "weight:" << node.getWeight() << "\"]\n"; + } + } + for (auto& node : nodes) { + for (auto& child : node.childs) { + file << node.Id() << " -> " << child.first->Id(); + file << " [label=\"" << child.second << "\"]\n"; + } + } + file << "}\n"; + file.flush(); + file.close(); +} + +DCPAnalysisResult GraphDCP::getResult() { + DCPAnalysisResult ret; + + std::vector roots = getRoots(); + UniqueWorkList> worklists(roots); + worklists.reserve(nodes.size()); + size_t i = 0; + while (i != worklists.size()) { + bool modified = true; + while (modified) { + modified = false; + for (auto& child : worklists.at(i)->childs) { + if (worklists.allElementContained( + child.first->parents.begin(), child.first->parents.end(), [](Edge_t edge) { return edge.first; })) { + modified |= worklists.push_back(child.first); + } + } + } + i++; + } + ret.dominanceOrderCompute.reserve(worklists.size()); + for (auto elem : worklists) + ret.dominanceOrderCompute.push_back(elem->getSpatWeightedCompute()); + + for (auto [cpu, nodes] : mapCPUTasks) { + size_t i = 0; + for (auto node : nodes) { + ret.computeToCPUMap[node->getSpatWeightedCompute()] = cpu; + if (i++ == nodes.size() - 1){ + ret.isLastComputeOfACpu.insert(node->getSpatWeightedCompute()); + ret.cpuToLastComputeMap[cpu] = node->getSpatWeightedCompute(); + } + } + } + + return ret; +} diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Graph.hpp b/src/PIM/Dialect/Spatial/DCPGraph/Graph.hpp new file mode 100644 index 0000000..ad115ba --- /dev/null +++ b/src/PIM/Dialect/Spatial/DCPGraph/Graph.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include "llvm/ADT/ArrayRef.h" + +#include +#include +#include +#include + +#include "DCPAnalysis.hpp" +#include "Task.hpp" +#include "Utils.hpp" + +std::optional addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight); +void removeEdge(TaskDCP* parent, TaskDCP* child); +int getTranferCost(TaskDCP* parent, TaskDCP* child); + +class GraphDCP { + + struct FindSlot { + int aest; + int index; + }; + + std::vector nodes; + std::unordered_map> mapCPUTasks; + CPU last_cpu = 0; + + std::array, 2> insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position); + void removeTaskFromCPU(CPU cpu, TaskDCP* task); + + std::vector getRoots(); + + void initAEST(); + int computeAEST(TaskDCP* task, CPU cpu); + int initDCPL(); + int computeDCPL(TaskDCP* task, CPU cpu); + void initALST(); + std::unordered_map computeALST(TaskDCP* task, CPU cpu); + + TaskDCP* findCandidate(std::vector nodes); + void selectProcessor(TaskDCP* candidate, bool push); + CPU lastCPU() const { return last_cpu; } + void incLastCPU() { last_cpu++; } + FindSlot findSlot(TaskDCP* candidate, CPU cpu, bool push); + void to_dot(); + +public: + void DCP(); + GraphDCP(llvm::ArrayRef spatWeightedComputes, + llvm::ArrayRef edges) + : nodes(), mapCPUTasks() { + for (auto spatWeightedCompute : spatWeightedComputes) + nodes.emplace_back(spatWeightedCompute); + + for (auto [start, end, weight] : edges) + makeEdge(start, end, weight); + } + + DCPAnalysisResult getResult(); + + void makeEdge(size_t parent_index, size_t child_index, Weight_t weight) { + addEdge(&nodes[parent_index], &nodes[child_index], weight); + } + + size_t taskInCPU(CPU cpu) { return mapCPUTasks[cpu].size(); } +}; diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Task.cpp b/src/PIM/Dialect/Spatial/DCPGraph/Task.cpp new file mode 100644 index 0000000..724475b --- /dev/null +++ b/src/PIM/Dialect/Spatial/DCPGraph/Task.cpp @@ -0,0 +1,49 @@ +#include + +#include "Graph.hpp" +#include "Task.hpp" +#include "Uniqueworklist.hpp" + +std::optional TaskDCP::addChild(TaskDCP* child, Weight_t weight) { + std::optional oldEdge = std::nullopt; + auto founded_element = + std::find_if(childs.begin(), childs.end(), [child](Edge_t element) { return child == element.first; }); + if (founded_element != childs.end()) { + oldEdge = *founded_element; + fastRemove(childs, founded_element); + } + childs.emplace_back(child, weight); + return oldEdge; +} + +std::optional TaskDCP::addParent(TaskDCP* parent, Weight_t weight) { + std::optional oldEdge = std::nullopt; + auto founded_element = + std::find_if(parents.begin(), parents.end(), [parent](Edge_t element) { return parent == element.first; }); + if (founded_element != parents.end()) { + oldEdge = *founded_element; + fastRemove(parents, founded_element); + } + parents.emplace_back(parent, weight); + return oldEdge; +} + +bool TaskDCP::hasDescendent(TaskDCP* child) { + UniqueWorkList> worklist; + worklist.reserve(32); + worklist.push_back(this); + while (!worklist.empty()) { + TaskDCP* task = worklist.back(); + worklist.pop_back(); + if (task == child) + return true; + for (auto c : task->childs) + worklist.push_back(c.first); + } + return false; +} + +//TODO fare qualcosa di sensato +int TaskDCP::computeWeight(GraphDCP* graph, CPU cpu) { + return orig_weight; +} diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Task.hpp b/src/PIM/Dialect/Spatial/DCPGraph/Task.hpp new file mode 100644 index 0000000..748df3f --- /dev/null +++ b/src/PIM/Dialect/Spatial/DCPGraph/Task.hpp @@ -0,0 +1,87 @@ +#pragma once + +#include +#include +#include +#include + +#include "../SpatialOps.hpp" +#include "Utils.hpp" + +std::optional addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight); +void removeEdge(TaskDCP* parent, TaskDCP* child); + +class TaskDCP { + onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute; + int aest; + int alst; + std::optional scheduledCPU; + int weight; + int orig_weight; + + std::optional addChild(TaskDCP* child, Weight_t weight); + std::optional addChild(TaskDCP& child, Weight_t weight) { return addChild(&child, weight); } + + void removeChild(TaskDCP* to_remove) { fastRemove(childs, to_remove); } + void removeChild(TaskDCP& to_remove) { fastRemove(childs, &to_remove); } + + std::optional addParent(TaskDCP* parent, Weight_t weight); + std::optional addParent(TaskDCP& parent, Weight_t weight) { return addParent(&parent, weight); } + + void removeParent(TaskDCP* to_remove) { fastRemove(parents, to_remove); } + void removeParent(TaskDCP& to_remove) { fastRemove(parents, &to_remove); } + +public: + std::vector parents; + std::vector childs; + TaskDCP() = default; + TaskDCP(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) + : spatWeightedCompute(spatWeightedCompute), + aest(0), + alst(0), + scheduledCPU(), + weight(getSpatWeightCompute(spatWeightedCompute)), + orig_weight(weight), + parents(), + childs() {} + + TaskDCP(const TaskDCP& node) = delete; + TaskDCP(TaskDCP&& node) = default; + + void setCPU(CPU cpu) { scheduledCPU = cpu; } + std::optional getCPU() const { return scheduledCPU; } + void resetCPU() { scheduledCPU = std::nullopt; } + int getWeight() { + if (isScheduled()) + return weight; + else + return orig_weight; + } + void setWeight(int val) { weight = val; } + void resetWeight() { weight = orig_weight; } + int computeWeight(GraphDCP* graph, CPU cpu); + + bool hasParent() { return parents.size() != 0; } + bool hasChilds() { return childs.size() != 0; } + + int getAEST() { return aest; } + int getALST() { return alst; } + void setAEST(int val) { + assert(val >= 0); + aest = val; + } + void setALST(int val) { + assert(val >= 0 && val >= aest); + alst = val; + } + bool hasDescendent(TaskDCP* child); + int64_t Id() const { return (int64_t)spatWeightedCompute.getAsOpaquePointer(); } + + bool isCP() const { return alst == aest; } + bool isScheduled() const { return scheduledCPU.has_value(); } + onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute(){return spatWeightedCompute;} + + friend std::optional addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight); + friend void removeEdge(TaskDCP* parent, TaskDCP* child); + friend int getTranferCost(TaskDCP* parent, TaskDCP* child); +}; diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Uniqueworklist.hpp b/src/PIM/Dialect/Spatial/DCPGraph/Uniqueworklist.hpp new file mode 100644 index 0000000..59d6bbe --- /dev/null +++ b/src/PIM/Dialect/Spatial/DCPGraph/Uniqueworklist.hpp @@ -0,0 +1,82 @@ +#pragma once + +#include +#include +#include + +template +struct has_pop_front : std::false_type {}; + +template +struct has_pop_front().pop_front())>> : std::true_type {}; + +template +class UniqueWorkList { + + using V = typename T::value_type; + T storage; + std::unordered_set set; + +public: + UniqueWorkList() = default; + + template + UniqueWorkList(const arg_ty& from) + : storage() { + for (auto& element : from) { + if (set.count(element) == 0) { + storage.push_back(element); + set.insert(element); + } + } + } + + bool empty() const { return storage.empty(); } + void reserve(size_t val) { return storage.reserve(val); } + size_t size() const { return storage.size(); } + V& at(size_t i) { return storage.at(i); } + const V& at(size_t i) const { return storage.at(i); } + + V& front() { return storage.front(); } + V& back() { return storage.back(); } + + bool push_back(const V& val) { + if (set.count(val) == 0) { + storage.push_back(val); + set.insert(val); + return true; + } + return false; + } + + void pop_front() { + if constexpr (has_pop_front::value) + storage.pop_front(); + else + assert(false && "Underlying storage type does not support pop_front()"); + } + + auto cbegin() const { return storage.cbegin(); } + auto cend() const { return storage.cend(); } + + void pop_back() { storage.pop_back(); } + + template + bool allElementContained(Iterator start, Iterator end, Mapper map) { + while (start != end) { + if (set.count(map(*start)) == 0) + return false; + std::advance(start, 1); + } + return true; + } + + auto begin() { + return storage.begin(); + } + + auto end() { + return storage.end(); + } + +}; diff --git a/src/PIM/Dialect/Spatial/DCPGraph/Utils.hpp b/src/PIM/Dialect/Spatial/DCPGraph/Utils.hpp new file mode 100644 index 0000000..3ca2b95 --- /dev/null +++ b/src/PIM/Dialect/Spatial/DCPGraph/Utils.hpp @@ -0,0 +1,60 @@ +#pragma once + +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "llvm/Support/Casting.h" +#include +#include +#include +#include + +#include "../SpatialOps.hpp" +#include "src/Support/TypeUtilities.hpp" + +using CPU = int; +using Weight_t = int; +class TaskDCP; +class GraphDCP; +using Edge_t = std::pair; +using Edge_pair = std::pair; +using EdgesIndex = std::tuple; + +template +void fastRemove(std::vector>& vector, T* to_remove) { + auto position = + std::find_if(vector.begin(), vector.end(), [to_remove](Edge_t edge) { return edge.first == to_remove; }); + if (position != vector.end()) { + std::swap(*(vector.end() - 1), *position); + vector.pop_back(); + } +} + +inline void fastRemove(std::vector& vector, TaskDCP* to_remove) { + auto position = + std::find_if(vector.begin(), vector.end(), [to_remove](TaskDCP* element) { return element == to_remove; }); + if (position != vector.end()) { + std::swap(*(vector.end() - 1), *position); + vector.pop_back(); + } +} + +template +void fastRemove(std::vector>& vector, P position) { + if (position != vector.end()) { + std::swap(*(vector.end() - 1), *position); + vector.pop_back(); + } +} + +// TODO Fare qualcosa di sensato +inline int64_t getSpatWeightCompute(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) { + int64_t tot = 0; + for (auto& region : spatWeightedCompute.getBody()) { + for (auto& inst : region) { + for(auto result : inst.getResults()){ + if(auto element = llvm::dyn_cast(result.getType())) + tot += onnx_mlir::getSizeInBytes(element); + } + } + } + return tot; +} diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index a970bdf..12fdd87 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -47,6 +47,7 @@ def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegmen let regions = (region SizedRegion<1>:$body); let hasVerifier = 1; + let hasFolder = 1; let assemblyFormat = [{ `[` $weights `]` `(` $inputs `)` attr-dict `:` `[` type($weights) `]` `(` type($inputs) `)` `->` type($outputs) $body diff --git a/src/PIM/Dialect/Spatial/SpatialOps.cpp b/src/PIM/Dialect/Spatial/SpatialOps.cpp index 7023b35..9c1ee13 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.cpp @@ -204,45 +204,47 @@ LogicalResult SpatWeightedCompute::verify() { // Check that it has a terminator, it is a yieldOp, and it has a single // operand with the same type as the result auto& block = getBody().front(); - auto yieldOp = dyn_cast_or_null(block.getTerminator()); - if (!yieldOp) - return emitError("ComputeOp must have a single yield operation"); + if (block.mightHaveTerminator()) { + auto yieldOp = dyn_cast_or_null(block.getTerminator()); + if (!yieldOp) + return emitError("ComputeOp must have a single yield operation"); - auto resultTypes = getResultTypes(); - auto yieldTypes = yieldOp->getOperandTypes(); - if (resultTypes.size() != yieldTypes.size()) { - return emitError("ComputeOp must have same number of results as yieldOp " - "operands"); - } - - for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) { - auto resultType = std::get<0>(it); - auto yieldType = std::get<1>(it); - - // Same type and compatible shape - if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) { - return emitError("ComputeOp output must be of the same type as yieldOp " - "operand"); + auto resultTypes = getResultTypes(); + auto yieldTypes = yieldOp->getOperandTypes(); + if (resultTypes.size() != yieldTypes.size()) { + return emitError("ComputeOp must have same number of results as yieldOp " + "operands"); } - // Same encoding - if (auto resultRankedType = dyn_cast(resultType)) { - if (auto yieldRankedType = dyn_cast(yieldType)) { - if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) { - return emitError("ComputeOp output must have the same encoding as " - "yieldOp operand"); + for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) { + auto resultType = std::get<0>(it); + auto yieldType = std::get<1>(it); + + // Same type and compatible shape + if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) { + return emitError("ComputeOp output must be of the same type as yieldOp " + "operand"); + } + + // Same encoding + if (auto resultRankedType = dyn_cast(resultType)) { + if (auto yieldRankedType = dyn_cast(yieldType)) { + if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) { + return emitError("ComputeOp output must have the same encoding as " + "yieldOp operand"); + } + } + else { + return emitError("ComputeOp output has an encoding while yieldOp " + "operand does not have one"); } } else { - return emitError("ComputeOp output has an encoding while yieldOp " - "operand does not have one"); - } - } - else { - // If result does not have an encoding, yield shouldn't either - if (auto yieldRankedType = dyn_cast(yieldType)) { - return emitError("ComputeOp output must not have an encoding if " - "yieldOp operand has one"); + // If result does not have an encoding, yield shouldn't either + if (auto yieldRankedType = dyn_cast(yieldType)) { + return emitError("ComputeOp output must not have an encoding if " + "yieldOp operand has one"); + } } } } @@ -255,6 +257,27 @@ LogicalResult SpatWeightedCompute::verify() { return success(); } +LogicalResult SpatWeightedCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) { + Block& block = getBody().front(); + if (!llvm::hasSingleElement(block)) + return failure(); + + auto yieldOp = dyn_cast(block.front()); + if (!yieldOp) + return failure(); + + for (Value yieldedValue : yieldOp.getOperands()) { + if (auto blockArg = dyn_cast(yieldedValue)) { + if (blockArg.getOwner() == &block) { + results.push_back(getOperand(blockArg.getArgNumber())); + continue; + } + } + results.push_back(yieldedValue); + } + return success(); +} + } // namespace spatial } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp new file mode 100644 index 0000000..3d5da4d --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp @@ -0,0 +1,318 @@ +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" + +#include +#include +#include + +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { +using SpatWeightedCompute = spatial::SpatWeightedCompute; + +struct ComputeValueResults { + // Value yielded by the yieldOp + Value innerValue; +}; + +class LazyInsertComputeResult { + using InsertPoint = mlir::IRRewriter::InsertPoint; + ComputeValueResults computeResults; + Value channelNewOpVal; + bool onlyChannel; + std::function channelSendInserter; + InsertPoint insertPointSend; + std::function>()> channelNewInserter; + +public: + LazyInsertComputeResult(ComputeValueResults computeValueResults, + std::function>()> channelNewInserter, + bool isOnlyChannel) + : computeResults(computeValueResults), + onlyChannel(isOnlyChannel), + channelSendInserter(nullptr), + insertPointSend({}), + channelNewInserter(channelNewInserter) {} + + struct ChannelOrLocalOp { + Value data; + bool isChannel; + }; + + bool onlyChanneled() const { return onlyChannel; } + + ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatWeightedCompute spatWeightedCompute) { + + if (channelSendInserter == nullptr) { + auto [first, second] = channelNewInserter(); + channelNewOpVal = first; + channelSendInserter = second; + auto op = computeResults.innerValue.getDefiningOp(); + if (op) { + insertPointSend = InsertPoint(op->getBlock(), ++Block::iterator(op)); + } + else { + auto BB = computeResults.innerValue.getParentBlock(); + insertPointSend = InsertPoint(BB, BB->begin()); + } + } + if (spatWeightedCompute) { + for (auto& BB : spatWeightedCompute.getBody()) + if (&BB == insertPointSend.getBlock()) + return {computeResults.innerValue, false}; + } + channelSendInserter(insertPointSend); + return {channelNewOpVal, true}; + } + + ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); } +}; + +struct MergeComputeNodePass : PassWrapper> { + +private: + DenseMap newComputeNodeResults; + DenseMap oldToNewComputeMap; + DenseMap cputToNewComputeMap; + +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodePass) + + StringRef getArgument() const override { return "pim-merge-node-pass"; } + StringRef getDescription() const override { + return "Merge Spatial-Weighted-Compute-Node in order to reduce the total " + "execution time"; + } + + LogicalResult initialize(MLIRContext* context) override { return success(); } + + void runOnOperation() override { + DCPAnalysisResult& analysisResult = getAnalysis().getResult(); + auto& lastComputeOfCpu = analysisResult.isLastComputeOfACpu; + auto& cpuToLastComputeMap = analysisResult.cpuToLastComputeMap; + IRRewriter rewriter(&getContext()); + + for (auto currentComputeNode : analysisResult.dominanceOrderCompute) { + size_t cpu = analysisResult.computeToCPUMap.at(currentComputeNode); + if (!cputToNewComputeMap.contains(cpu)) { + ValueTypeRange newWeightedComputeType = cpuToLastComputeMap.at(cpu).getResultTypes(); + auto [newWeightedCompute, computeValueResult] = createNewComputeNode( + currentComputeNode, newWeightedComputeType, lastComputeOfCpu.contains(currentComputeNode)); + cputToNewComputeMap[cpu] = newWeightedCompute; + newComputeNodeResults.insert( + std::make_pair(currentComputeNode, + createLazyComputeResult( + newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode)))); + } + else { + auto [newWeightedCompute, computeValueResult] = mergeIntoComputeNode( + cputToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode)); + newComputeNodeResults.insert( + std::make_pair(currentComputeNode, + createLazyComputeResult( + newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode)))); + } + } + + for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) + computeNodetoRemove.erase(); + func::FuncOp func = getOperation(); + dumpModule(cast(func->getParentOp()), "SpatialDCPMerged"); + } + +private: + std::pair createNewComputeNode( + SpatWeightedCompute oldWeightedCompute, ValueTypeRange newWeightedComputeType, bool lastCompute) { + func::FuncOp func = getOperation(); + auto loc = func.getLoc(); + IRRewriter rewriter(&getContext()); + rewriter.setInsertionPoint(&*std::prev(func.getBody().front().end(), 1)); + + ComputeValueResults computeValueResults; + + IRMapping mapper; + llvm::SmallVector newComputeOperand; + llvm::SmallVector newBBOperandType; + llvm::SmallVector newBBLocations; + + for (auto arg : oldWeightedCompute.getWeights()) + newComputeOperand.push_back(arg); + + for (auto arg : oldWeightedCompute.getInputs()) + if (!llvm::isa(arg.getDefiningOp())) { + newComputeOperand.push_back(arg); + newBBOperandType.push_back(arg.getType()); + newBBLocations.push_back(loc); + } + + auto newWeightedCompute = SpatWeightedCompute::create(rewriter, loc, newWeightedComputeType, newComputeOperand); + rewriter.createBlock( + &newWeightedCompute.getBody(), newWeightedCompute.getBody().end(), newBBOperandType, newBBLocations); + newWeightedCompute.getProperties().setOperandSegmentSizes( + {(int) oldWeightedCompute.getWeights().size(), (int) newBBOperandType.size()}); + rewriter.setInsertionPointToEnd(&newWeightedCompute.getBody().front()); + + int indexNew = 0; + int indexOld = oldWeightedCompute.getWeights().size(); + int indexOldStart = oldWeightedCompute.getWeights().size(); + for (; indexOld < oldWeightedCompute.getNumOperands(); ++indexOld) { + if (!llvm::isa(oldWeightedCompute.getOperand(indexOld).getDefiningOp())) { + mapper.map(oldWeightedCompute.getBody().front().getArgument(indexOld - indexOldStart), + newWeightedCompute.getBody().front().getArgument(indexNew++)); + } + else { + auto argWeightCompute = + llvm::dyn_cast_if_present(oldWeightedCompute.getOperand(indexOld).getDefiningOp()); + + LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute); + auto [channelVal, _] = lazyArgWeight.getAsChannelValueAndInsertSender(); + spatial::SpatChannelReceiveOp reciveOp = + spatial::SpatChannelReceiveOp::create(rewriter, loc, channelVal.getType(), channelVal); + mapper.map(oldWeightedCompute.getBody().front().getArgument(indexOld - indexOldStart), reciveOp); + } + } + + for (auto& op : oldWeightedCompute.getOps()) { + if (auto yield = dyn_cast(&op)) { + computeValueResults.innerValue = mapper.lookup(yield.getOperand(0)); + if (lastCompute) + rewriter.clone(op, mapper); + } + else + rewriter.clone(op, mapper); + } + + for (auto users : oldWeightedCompute->getUsers()) + if (auto funcRet = dyn_cast(users)) + funcRet.setOperand(0, newWeightedCompute.getResult(0)); + + oldToNewComputeMap.insert({oldWeightedCompute, newWeightedCompute}); + return {cast(newWeightedCompute), computeValueResults}; + } + + std::pair + mergeIntoComputeNode(SpatWeightedCompute toCompute, SpatWeightedCompute fromCompute, bool lastCompute) { + func::FuncOp func = getOperation(); + auto loc = func.getLoc(); + IRRewriter rewriter(&getContext()); + IRMapping mapper; + + auto weightMutableIter = toCompute.getWeightsMutable(); + for (auto weight : fromCompute.getWeights()) { + int sizeW = toCompute.getWeights().size(); + int sizeI = toCompute.getInputs().size(); + weightMutableIter.append(weight); + assert(sizeW + 1 == toCompute.getWeights().size()); + assert(sizeI == toCompute.getInputs().size()); + assert(sizeW + sizeI + 1 == toCompute.getOperands().size()); + } + + + auto inputeArgMutable = toCompute.getInputsMutable(); + // Insert reciveOp + rewriter.setInsertionPointToEnd(&toCompute.getBody().front()); + int newBBindex = toCompute.getBody().front().getArguments().size(); + for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) { + if (auto argWeightCompute = llvm::dyn_cast_if_present(arg.getDefiningOp())) { + LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute); + + LazyInsertComputeResult::ChannelOrLocalOp channelOrLocal = + lazyArgWeight.getAsChannelValueAndInsertSender(toCompute); + if (channelOrLocal.isChannel) { + spatial::SpatChannelReceiveOp reciveOp = + spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelOrLocal.data); + mapper.map(fromCompute.getBody().front().getArgument(bbIndex), reciveOp.getResult()); + } + else { + mapper.map(fromCompute.getBody().front().getArgument(bbIndex), channelOrLocal.data); + } + } + else { + + int sizeW = toCompute.getWeights().size(); + int sizeI = toCompute.getInputs().size(); + inputeArgMutable.append(arg); + assert(sizeW == toCompute.getWeights().size()); + assert(sizeI + 1 == toCompute.getInputs().size()); + assert(sizeW + sizeI + 1 == toCompute.getOperands().size()); + + toCompute.getBody().front().addArgument( + fromCompute.getBody().front().getArgument(bbIndex).getType(),loc); + + mapper.map(fromCompute.getBody().front().getArgument(bbIndex), + toCompute.getBody().front().getArgument(newBBindex++)); + } + } + + for (auto oldBBarg : fromCompute.getBody().front().getArguments()) + assert(mapper.contains(oldBBarg)); + + ComputeValueResults computeValueResults; + for (auto& op : fromCompute.getOps()) { + if (auto yield = dyn_cast(&op)) { + computeValueResults.innerValue = mapper.lookup(yield.getOperand(0)); + if (lastCompute) + rewriter.clone(op, mapper); + } + else { + rewriter.clone(op, mapper); + } + } + + for (auto users : fromCompute->getUsers()) + if (auto funcRet = dyn_cast(users)) + funcRet.setOperand(0, toCompute.getResult(0)); + + oldToNewComputeMap.insert({fromCompute, toCompute}); + return {cast(toCompute), computeValueResults}; + } + + LazyInsertComputeResult createLazyComputeResult(SpatWeightedCompute weightedCompute, + ComputeValueResults computeValueResults, + bool lastCompute) { + func::FuncOp funcOp = cast(weightedCompute->getParentOp()); + auto* context = &getContext(); + auto loc = funcOp.getLoc(); + IRRewriter rewriter(context); + + rewriter.setInsertionPointToStart(&funcOp.front()); + auto saveInsertionPointChnNew = rewriter.saveInsertionPoint(); + auto insertNew = [saveInsertionPointChnNew, context, loc, computeValueResults]() { + IRRewriter rewriter(context); + rewriter.restoreInsertionPoint(saveInsertionPointChnNew); + auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(context)); + auto channelVal = channelOp.getResult(); + auto insertVal = + [&context, loc, computeValueResults, channelVal](mlir::IRRewriter::InsertPoint insertPointChnSend) { + IRRewriter rewriter(context); + rewriter.restoreInsertionPoint(insertPointChnSend); + auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.innerValue); + return spatSend; + }; + std::pair> ret {channelVal, insertVal}; + return ret; + }; + return LazyInsertComputeResult(computeValueResults, insertNew, false); + } +}; + +} // namespace + +std::unique_ptr createMergeComputeNodePass() { return std::make_unique(); } + +} // namespace onnx_mlir diff --git a/src/PIM/Pass/PIMPasses.h b/src/PIM/Pass/PIMPasses.h index ba0724c..4417f50 100644 --- a/src/PIM/Pass/PIMPasses.h +++ b/src/PIM/Pass/PIMPasses.h @@ -15,6 +15,8 @@ std::unique_ptr createSpatialToPimPass(); std::unique_ptr createPimBufferizationPass(); +std::unique_ptr createMergeComputeNodePass(); + std::unique_ptr createPimConstantFoldingPass(); std::unique_ptr createPimMaterializeConstantsPass(); diff --git a/src/PIM/PimAccelerator.cpp b/src/PIM/PimAccelerator.cpp index 657d98e..21bd62c 100644 --- a/src/PIM/PimAccelerator.cpp +++ b/src/PIM/PimAccelerator.cpp @@ -74,6 +74,7 @@ void PimAccelerator::registerPasses(int optLevel) const { registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToPimPass); registerPass(createPimBufferizationPass); + registerPass(createMergeComputeNodePass); registerPass(createPimConstantFoldingPass); registerPass(createPimMaterializeConstantsPass); registerPass(createPimVerificationPass);