Topological initialization
Validate Operations / validate-operations (push) Failing after 27m23s

This commit is contained in:
ilgeco
2026-04-16 16:50:49 +02:00
parent ae93d1c563
commit 933e138012
8 changed files with 577 additions and 126 deletions
+318
View File
@@ -0,0 +1,318 @@
#pragma once
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ilist_node.h"
#include "llvm/ADT/simple_ilist.h"
#include <cassert>
#include <iterator>
#include <limits>
#include <type_traits>
namespace onnx_mlir {
template <typename NodeT>
class LabeledList;
template <typename NodeT>
class LabeledListNode : public llvm::ilist_node<NodeT> {
friend class LabeledList<NodeT>;
public:
using Label = uint64_t;
LabeledListNode() = default;
LabeledListNode(const LabeledListNode&) = delete;
LabeledListNode(LabeledListNode&&) = default;
LabeledListNode& operator=(LabeledListNode&&) = delete;
~LabeledListNode() { assert(owner_ == nullptr && "destroying a linked LabeledListNode"); }
bool isLinked() const { return owner_ != nullptr; }
Label getOrderLabel() const { return label; }
friend bool operator<(const LabeledListNode& lft, const LabeledListNode& rgt){
return lft.label < rgt.label;
}
private:
const void* owner_ = nullptr;
Label label = 0;
};
template <typename NodeT>
class LabeledList {
using Label = typename NodeT::Label;
static constexpr Label kLowerSentinel = 0;
static constexpr Label kUpperSentinel = std::numeric_limits<Label>::max();
static constexpr Label kRelabelGap = 2;
public:
using List = llvm::simple_ilist<NodeT>;
using Iterator = typename List::iterator;
using RIterator = typename List::reverse_iterator;
using ConstIterator = typename List::const_iterator;
LabeledList() = default;
LabeledList(const LabeledList&) = delete;
LabeledList& operator=(const LabeledList&) = delete;
LabeledList(LabeledList&&) = delete;
LabeledList& operator=(LabeledList&&) = delete;
~LabeledList() { clear(); }
bool empty() const { return size_ == 0; }
size_t size() const { return size_; }
NodeT* front() { return empty() ? nullptr : &nodes_.front(); }
const NodeT* front() const { return empty() ? nullptr : &nodes_.front(); }
NodeT* back() { return empty() ? nullptr : &nodes_.back(); }
const NodeT* back() const { return empty() ? nullptr : &nodes_.back(); }
static NodeT* previous(NodeT* node) {
if (!node || !owner(node))
return nullptr;
auto* list = owner(node);
auto it = node->getIterator();
if (it == list->nodes_.begin())
return nullptr;
return *std::prev(it);
}
static const NodeT* previous(const NodeT* node) {
if (!node || !owner(node))
return nullptr;
const auto* list = owner(node);
auto it = node->getIterator();
if (it == list->nodes_.begin())
return nullptr;
return *std::prev(it);
}
static NodeT* next(NodeT* node) {
if (!node || !owner(node))
return nullptr;
auto* list = owner(node);
auto it = std::next(node->getIterator());
if (it == list->nodes_.end())
return nullptr;
return *it;
}
static const NodeT* next(const NodeT* node) {
if (!node || !owner(node))
return nullptr;
const auto* list = owner(node);
auto it = std::next(node->getIterator());
if (it == list->nodes_.end())
return nullptr;
return *it;
}
bool contains(const NodeT* node) const { return node && node->owner_ == this; }
Label getOrderLabel(const NodeT* node) const {
assert(contains(node) && "node must belong to this list");
return node->label_;
}
bool comesBefore(const NodeT* lhs, const NodeT* rhs) const {
assert(contains(lhs) && contains(rhs) && "nodes must belong to this list");
return lhs->label_ < rhs->label_;
}
void pushFront(NodeT* node) { insertBefore(front(), node); }
void pushBack(NodeT* node) { insertBefore(nullptr, node); }
void insertBefore(NodeT* nextNode, NodeT* node) {
assert(node && "cannot insert a null node");
assert(!node->owner_ && "node is already linked");
assert(nextNode == nullptr || contains(nextNode));
Iterator nextIt = nextNode ? getIteratorFor(nextNode) : nodes_.end();
nodes_.insert(nextIt, *node);
node->owner_ = this;
++size_;
assignLabel(getIteratorFor(node));
}
void insertAfter(NodeT* prevNode, NodeT* node) {
assert(prevNode == nullptr || contains(prevNode));
if (prevNode == nullptr)
insertBefore(front(), node);
else
insertBefore(next(prevNode), node);
}
void remove(NodeT* node) {
assert(contains(node) && "node must belong to this list");
nodes_.remove(*node);
node->owner_ = nullptr;
node->label_ = 0;
--size_;
}
void moveBefore(NodeT* node, NodeT* nextNode) {
assert(contains(node) && "node must belong to this list");
assert(nextNode == nullptr || contains(nextNode));
Iterator nodeIt = getIteratorFor(node);
Iterator nextIt = nextNode ? getIteratorFor(nextNode) : nodes_.end();
if (nodeIt == nextIt || std::next(nodeIt) == nextIt)
return;
nodes_.splice(nextIt, nodes_, nodeIt);
assignLabel(getIteratorFor(node));
}
void moveAfter(NodeT* node, NodeT* prevNode) {
assert(contains(node) && "node must belong to this list");
assert(prevNode == nullptr || contains(prevNode));
Iterator nextIt = prevNode ? std::next(getIteratorFor(prevNode)) : nodes_.begin();
if (getIteratorFor(node) == nextIt)
return;
moveBefore(node, nextIt == nodes_.end() ? nullptr : &*nextIt);
}
void clear() {
while (!nodes_.empty()) {
NodeT* node = &nodes_.front();
node->owner_ = nullptr;
node->label = 0;
nodes_.remove(*node);
}
size_ = 0;
}
Iterator begin() { return nodes_.begin(); }
Iterator end() { return nodes_.end(); }
RIterator rbegin() { return nodes_.rbegin(); }
RIterator rend() { return nodes_.rend(); }
private:
static const LabeledList* owner(const NodeT* node) { return node->owner_; }
static LabeledList* owner(NodeT* node) { return node->owner_; }
static Label lowerLabel(const NodeT* node) { return node ? node->label : kLowerSentinel; }
static Label upperLabel(const NodeT* node) { return node ? node->label : kUpperSentinel; }
static Label labelGap(Label lower, Label upper) {
assert(lower < upper && "labels must be strictly ordered");
return upper - lower;
}
static bool hasMidpoint(Label lower, Label upper) { return labelGap(lower, upper) > 1; }
static bool hasRelabelSlack(Label lower, Label upper, size_t nodeCount) {
Label gap = labelGap(lower, upper);
return gap / static_cast<Label>(nodeCount + 1) >= kRelabelGap;
}
Iterator getIteratorFor(NodeT* node) { return node->getIterator(); }
ConstIterator getiteratorFor(const NodeT* node) const { return node->getIterator(); }
NodeT* previousNode(Iterator it) {
if (it == nodes_.begin())
return nullptr;
return &*std::prev(it);
}
const NodeT* previousNode(ConstIterator it) const {
if (it == nodes_.begin())
return nullptr;
return &*std::prev(it);
}
NodeT* nextNode(Iterator it) {
++it;
if (it == nodes_.end())
return nullptr;
return &*it;
}
const NodeT* nextNode(ConstIterator it) const {
++it;
if (it == nodes_.end())
return nullptr;
return &*it;
}
void assignLabel(Iterator it) {
Label lower = lowerLabel(previousNode(it));
Label upper = upperLabel(nextNode(it));
if (hasMidpoint(lower, upper)) {
(*it).label = lower + static_cast<Label>(labelGap(lower, upper) / 2);
return;
}
relabelAround(it);
}
void relabelAround(Iterator center) {
size_t targetCount = 1;
while (true) {
Iterator left = center;
Iterator right = center;
size_t actualCount = 1;
expandWindow(center, targetCount, left, right, actualCount);
Label lower = lowerLabel(previousNode(left));
Label upper = upperLabel(nextNode(right));
if (hasRelabelSlack(lower, upper, actualCount)) {
relabelWindow(left, actualCount, lower, upper);
return;
}
if (left == nodes_.begin() && nextNode(right) == nullptr) {
assert(hasRelabelSlack(lower, upper, actualCount) && "label space exhausted");
relabelWindow(left, actualCount, lower, upper);
return;
}
targetCount *= 2;
}
}
void expandWindow(Iterator center, size_t targetCount, Iterator& left, Iterator& right, size_t& actualCount) {
left = center;
right = center;
actualCount = 1;
while (actualCount < targetCount && (left != nodes_.begin() || nextNode(right) != nullptr)) {
if (left != nodes_.begin()) {
--left;
++actualCount;
if (actualCount == targetCount)
break;
}
if (nextNode(right) != nullptr) {
++right;
++actualCount;
}
}
}
void relabelWindow(Iterator left, size_t nodeCount, Label lower, Label upper) {
assert(nodeCount > 0 && "relabel window must not be empty");
Label step = labelGap(lower, upper) / static_cast<Label>(nodeCount + 1);
assert(step >= 1 && "relabel step must be positive");
Iterator it = left;
for (size_t index = 1; index <= nodeCount; ++index) {
(*it).label = lower + step * index;
++it;
}
}
List nodes_;
size_t size_ = 0;
};
} // namespace onnx_mlir