This commit is contained in:
@@ -0,0 +1,318 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/ilist_node.h"
|
||||
#include "llvm/ADT/simple_ilist.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <iterator>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
template <typename NodeT>
|
||||
class LabeledList;
|
||||
|
||||
template <typename NodeT>
|
||||
class LabeledListNode : public llvm::ilist_node<NodeT> {
|
||||
friend class LabeledList<NodeT>;
|
||||
|
||||
public:
|
||||
using Label = uint64_t;
|
||||
|
||||
LabeledListNode() = default;
|
||||
LabeledListNode(const LabeledListNode&) = delete;
|
||||
LabeledListNode(LabeledListNode&&) = default;
|
||||
LabeledListNode& operator=(LabeledListNode&&) = delete;
|
||||
|
||||
~LabeledListNode() { assert(owner_ == nullptr && "destroying a linked LabeledListNode"); }
|
||||
|
||||
bool isLinked() const { return owner_ != nullptr; }
|
||||
Label getOrderLabel() const { return label; }
|
||||
|
||||
friend bool operator<(const LabeledListNode& lft, const LabeledListNode& rgt){
|
||||
return lft.label < rgt.label;
|
||||
}
|
||||
|
||||
private:
|
||||
const void* owner_ = nullptr;
|
||||
Label label = 0;
|
||||
};
|
||||
|
||||
template <typename NodeT>
|
||||
class LabeledList {
|
||||
|
||||
using Label = typename NodeT::Label;
|
||||
|
||||
static constexpr Label kLowerSentinel = 0;
|
||||
static constexpr Label kUpperSentinel = std::numeric_limits<Label>::max();
|
||||
static constexpr Label kRelabelGap = 2;
|
||||
|
||||
public:
|
||||
using List = llvm::simple_ilist<NodeT>;
|
||||
using Iterator = typename List::iterator;
|
||||
using RIterator = typename List::reverse_iterator;
|
||||
using ConstIterator = typename List::const_iterator;
|
||||
|
||||
LabeledList() = default;
|
||||
LabeledList(const LabeledList&) = delete;
|
||||
LabeledList& operator=(const LabeledList&) = delete;
|
||||
LabeledList(LabeledList&&) = delete;
|
||||
LabeledList& operator=(LabeledList&&) = delete;
|
||||
|
||||
~LabeledList() { clear(); }
|
||||
|
||||
bool empty() const { return size_ == 0; }
|
||||
size_t size() const { return size_; }
|
||||
|
||||
NodeT* front() { return empty() ? nullptr : &nodes_.front(); }
|
||||
const NodeT* front() const { return empty() ? nullptr : &nodes_.front(); }
|
||||
|
||||
NodeT* back() { return empty() ? nullptr : &nodes_.back(); }
|
||||
const NodeT* back() const { return empty() ? nullptr : &nodes_.back(); }
|
||||
|
||||
static NodeT* previous(NodeT* node) {
|
||||
if (!node || !owner(node))
|
||||
return nullptr;
|
||||
auto* list = owner(node);
|
||||
auto it = node->getIterator();
|
||||
if (it == list->nodes_.begin())
|
||||
return nullptr;
|
||||
return *std::prev(it);
|
||||
}
|
||||
|
||||
static const NodeT* previous(const NodeT* node) {
|
||||
if (!node || !owner(node))
|
||||
return nullptr;
|
||||
const auto* list = owner(node);
|
||||
auto it = node->getIterator();
|
||||
if (it == list->nodes_.begin())
|
||||
return nullptr;
|
||||
return *std::prev(it);
|
||||
}
|
||||
|
||||
static NodeT* next(NodeT* node) {
|
||||
if (!node || !owner(node))
|
||||
return nullptr;
|
||||
auto* list = owner(node);
|
||||
auto it = std::next(node->getIterator());
|
||||
if (it == list->nodes_.end())
|
||||
return nullptr;
|
||||
return *it;
|
||||
}
|
||||
|
||||
static const NodeT* next(const NodeT* node) {
|
||||
if (!node || !owner(node))
|
||||
return nullptr;
|
||||
const auto* list = owner(node);
|
||||
auto it = std::next(node->getIterator());
|
||||
if (it == list->nodes_.end())
|
||||
return nullptr;
|
||||
return *it;
|
||||
}
|
||||
|
||||
bool contains(const NodeT* node) const { return node && node->owner_ == this; }
|
||||
|
||||
Label getOrderLabel(const NodeT* node) const {
|
||||
assert(contains(node) && "node must belong to this list");
|
||||
return node->label_;
|
||||
}
|
||||
|
||||
bool comesBefore(const NodeT* lhs, const NodeT* rhs) const {
|
||||
assert(contains(lhs) && contains(rhs) && "nodes must belong to this list");
|
||||
return lhs->label_ < rhs->label_;
|
||||
}
|
||||
|
||||
void pushFront(NodeT* node) { insertBefore(front(), node); }
|
||||
|
||||
void pushBack(NodeT* node) { insertBefore(nullptr, node); }
|
||||
|
||||
void insertBefore(NodeT* nextNode, NodeT* node) {
|
||||
assert(node && "cannot insert a null node");
|
||||
assert(!node->owner_ && "node is already linked");
|
||||
assert(nextNode == nullptr || contains(nextNode));
|
||||
|
||||
Iterator nextIt = nextNode ? getIteratorFor(nextNode) : nodes_.end();
|
||||
nodes_.insert(nextIt, *node);
|
||||
node->owner_ = this;
|
||||
++size_;
|
||||
assignLabel(getIteratorFor(node));
|
||||
}
|
||||
|
||||
void insertAfter(NodeT* prevNode, NodeT* node) {
|
||||
assert(prevNode == nullptr || contains(prevNode));
|
||||
if (prevNode == nullptr)
|
||||
insertBefore(front(), node);
|
||||
else
|
||||
insertBefore(next(prevNode), node);
|
||||
}
|
||||
|
||||
void remove(NodeT* node) {
|
||||
assert(contains(node) && "node must belong to this list");
|
||||
nodes_.remove(*node);
|
||||
node->owner_ = nullptr;
|
||||
node->label_ = 0;
|
||||
--size_;
|
||||
}
|
||||
|
||||
void moveBefore(NodeT* node, NodeT* nextNode) {
|
||||
assert(contains(node) && "node must belong to this list");
|
||||
assert(nextNode == nullptr || contains(nextNode));
|
||||
|
||||
Iterator nodeIt = getIteratorFor(node);
|
||||
Iterator nextIt = nextNode ? getIteratorFor(nextNode) : nodes_.end();
|
||||
if (nodeIt == nextIt || std::next(nodeIt) == nextIt)
|
||||
return;
|
||||
|
||||
nodes_.splice(nextIt, nodes_, nodeIt);
|
||||
assignLabel(getIteratorFor(node));
|
||||
}
|
||||
|
||||
void moveAfter(NodeT* node, NodeT* prevNode) {
|
||||
assert(contains(node) && "node must belong to this list");
|
||||
assert(prevNode == nullptr || contains(prevNode));
|
||||
|
||||
Iterator nextIt = prevNode ? std::next(getIteratorFor(prevNode)) : nodes_.begin();
|
||||
if (getIteratorFor(node) == nextIt)
|
||||
return;
|
||||
moveBefore(node, nextIt == nodes_.end() ? nullptr : &*nextIt);
|
||||
}
|
||||
|
||||
void clear() {
|
||||
while (!nodes_.empty()) {
|
||||
NodeT* node = &nodes_.front();
|
||||
node->owner_ = nullptr;
|
||||
node->label = 0;
|
||||
nodes_.remove(*node);
|
||||
}
|
||||
size_ = 0;
|
||||
}
|
||||
|
||||
Iterator begin() { return nodes_.begin(); }
|
||||
|
||||
Iterator end() { return nodes_.end(); }
|
||||
|
||||
RIterator rbegin() { return nodes_.rbegin(); }
|
||||
RIterator rend() { return nodes_.rend(); }
|
||||
|
||||
private:
|
||||
static const LabeledList* owner(const NodeT* node) { return node->owner_; }
|
||||
static LabeledList* owner(NodeT* node) { return node->owner_; }
|
||||
|
||||
static Label lowerLabel(const NodeT* node) { return node ? node->label : kLowerSentinel; }
|
||||
static Label upperLabel(const NodeT* node) { return node ? node->label : kUpperSentinel; }
|
||||
|
||||
static Label labelGap(Label lower, Label upper) {
|
||||
assert(lower < upper && "labels must be strictly ordered");
|
||||
return upper - lower;
|
||||
}
|
||||
|
||||
static bool hasMidpoint(Label lower, Label upper) { return labelGap(lower, upper) > 1; }
|
||||
|
||||
static bool hasRelabelSlack(Label lower, Label upper, size_t nodeCount) {
|
||||
Label gap = labelGap(lower, upper);
|
||||
return gap / static_cast<Label>(nodeCount + 1) >= kRelabelGap;
|
||||
}
|
||||
|
||||
Iterator getIteratorFor(NodeT* node) { return node->getIterator(); }
|
||||
ConstIterator getiteratorFor(const NodeT* node) const { return node->getIterator(); }
|
||||
|
||||
NodeT* previousNode(Iterator it) {
|
||||
if (it == nodes_.begin())
|
||||
return nullptr;
|
||||
return &*std::prev(it);
|
||||
}
|
||||
|
||||
const NodeT* previousNode(ConstIterator it) const {
|
||||
if (it == nodes_.begin())
|
||||
return nullptr;
|
||||
return &*std::prev(it);
|
||||
}
|
||||
|
||||
NodeT* nextNode(Iterator it) {
|
||||
++it;
|
||||
if (it == nodes_.end())
|
||||
return nullptr;
|
||||
return &*it;
|
||||
}
|
||||
|
||||
const NodeT* nextNode(ConstIterator it) const {
|
||||
++it;
|
||||
if (it == nodes_.end())
|
||||
return nullptr;
|
||||
return &*it;
|
||||
}
|
||||
|
||||
void assignLabel(Iterator it) {
|
||||
Label lower = lowerLabel(previousNode(it));
|
||||
Label upper = upperLabel(nextNode(it));
|
||||
if (hasMidpoint(lower, upper)) {
|
||||
(*it).label = lower + static_cast<Label>(labelGap(lower, upper) / 2);
|
||||
return;
|
||||
}
|
||||
|
||||
relabelAround(it);
|
||||
}
|
||||
|
||||
void relabelAround(Iterator center) {
|
||||
size_t targetCount = 1;
|
||||
while (true) {
|
||||
Iterator left = center;
|
||||
Iterator right = center;
|
||||
size_t actualCount = 1;
|
||||
expandWindow(center, targetCount, left, right, actualCount);
|
||||
|
||||
Label lower = lowerLabel(previousNode(left));
|
||||
Label upper = upperLabel(nextNode(right));
|
||||
if (hasRelabelSlack(lower, upper, actualCount)) {
|
||||
relabelWindow(left, actualCount, lower, upper);
|
||||
return;
|
||||
}
|
||||
|
||||
if (left == nodes_.begin() && nextNode(right) == nullptr) {
|
||||
assert(hasRelabelSlack(lower, upper, actualCount) && "label space exhausted");
|
||||
relabelWindow(left, actualCount, lower, upper);
|
||||
return;
|
||||
}
|
||||
|
||||
targetCount *= 2;
|
||||
}
|
||||
}
|
||||
|
||||
void expandWindow(Iterator center, size_t targetCount, Iterator& left, Iterator& right, size_t& actualCount) {
|
||||
left = center;
|
||||
right = center;
|
||||
actualCount = 1;
|
||||
|
||||
while (actualCount < targetCount && (left != nodes_.begin() || nextNode(right) != nullptr)) {
|
||||
if (left != nodes_.begin()) {
|
||||
--left;
|
||||
++actualCount;
|
||||
if (actualCount == targetCount)
|
||||
break;
|
||||
}
|
||||
if (nextNode(right) != nullptr) {
|
||||
++right;
|
||||
++actualCount;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void relabelWindow(Iterator left, size_t nodeCount, Label lower, Label upper) {
|
||||
assert(nodeCount > 0 && "relabel window must not be empty");
|
||||
Label step = labelGap(lower, upper) / static_cast<Label>(nodeCount + 1);
|
||||
assert(step >= 1 && "relabel step must be positive");
|
||||
|
||||
Iterator it = left;
|
||||
for (size_t index = 1; index <= nodeCount; ++index) {
|
||||
(*it).label = lower + step * index;
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
List nodes_;
|
||||
size_t size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user