diff --git a/src/PIM/Common/LabeledList.hpp b/src/PIM/Common/LabeledList.hpp new file mode 100644 index 0000000..07d3437 --- /dev/null +++ b/src/PIM/Common/LabeledList.hpp @@ -0,0 +1,318 @@ +#pragma once + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ilist_node.h" +#include "llvm/ADT/simple_ilist.h" + +#include +#include +#include +#include + +namespace onnx_mlir { + +template +class LabeledList; + +template +class LabeledListNode : public llvm::ilist_node { + friend class LabeledList; + +public: + using Label = uint64_t; + + LabeledListNode() = default; + LabeledListNode(const LabeledListNode&) = delete; + LabeledListNode(LabeledListNode&&) = default; + LabeledListNode& operator=(LabeledListNode&&) = delete; + + ~LabeledListNode() { assert(owner_ == nullptr && "destroying a linked LabeledListNode"); } + + bool isLinked() const { return owner_ != nullptr; } + Label getOrderLabel() const { return label; } + + friend bool operator<(const LabeledListNode& lft, const LabeledListNode& rgt){ + return lft.label < rgt.label; + } + +private: + const void* owner_ = nullptr; + Label label = 0; +}; + +template +class LabeledList { + + using Label = typename NodeT::Label; + + static constexpr Label kLowerSentinel = 0; + static constexpr Label kUpperSentinel = std::numeric_limits