#include #include #include #include #include #include "src/Accelerators/PIM/Common/LabeledList.hpp" using onnx_mlir::LabeledList; using onnx_mlir::LabeledListNode; namespace { struct TestNode : public LabeledListNode { explicit TestNode(int id) : id(id) {} int id; }; void assertOrder(LabeledList& list, std::initializer_list expectedOrder) { auto expectedIt = expectedOrder.begin(); for (auto& node : list) { assert(expectedIt != expectedOrder.end()); assert(node.id == *expectedIt); ++expectedIt; } assert(expectedIt == expectedOrder.end()); } void assertStrictlyIncreasingLabels(LabeledList& list) { auto it = list.begin(); if (it == list.end()) return; auto previousLabel = it->getOrderLabel(); ++it; for (; it != list.end(); ++it) { assert(previousLabel < it->getOrderLabel()); previousLabel = it->getOrderLabel(); } } int testLabeledListBasicMutation() { std::cout << "testLabeledListBasicMutation:" << std::endl; LabeledList list; TestNode n1(1); TestNode n2(2); TestNode n3(3); TestNode n4(4); TestNode n5(5); assert(list.empty()); assert(list.front() == nullptr); assert(list.back() == nullptr); assert(!list.contains(&n1)); assert(LabeledList::previous(&n1) == nullptr); assert(LabeledList::next(&n1) == nullptr); list.pushBack(&n1); list.pushBack(&n3); list.insertAfter(&n1, &n2); list.pushFront(&n4); list.insertBefore(nullptr, &n5); assert(list.size() == 5); assert(list.front() == &n4); assert(list.back() == &n5); assert(list.contains(&n2)); assertOrder(list, {4, 1, 2, 3, 5}); assert(LabeledList::next(&n4) == &n1); assert(LabeledList::previous(&n1) == &n4); assert(LabeledList::next(&n5) == nullptr); assert(list.comesBefore(&n1, &n3)); assert(list.getOrderLabel(&n1) < list.getOrderLabel(&n3)); list.moveBefore(&n5, &n2); assertOrder(list, {4, 1, 5, 2, 3}); list.moveAfter(&n4, &n3); assertOrder(list, {1, 5, 2, 3, 4}); list.remove(&n2); assert(!n2.isLinked()); assert(!list.contains(&n2)); assertOrder(list, {1, 5, 3, 4}); list.clear(); assert(list.empty()); assert(list.size() == 0); assert(list.front() == nullptr); assert(list.back() == nullptr); assert(!n1.isLinked()); assert(!n3.isLinked()); assert(!n4.isLinked()); assert(!n5.isLinked()); return 0; } int testLabeledListRelabelingAndNoopMoves() { std::cout << "testLabeledListRelabelingAndNoopMoves:" << std::endl; constexpr int insertedNodeCount = 80; LabeledList list; TestNode head(0); TestNode tail(999); std::vector insertedNodes; insertedNodes.reserve(insertedNodeCount); for (int i = 0; i < insertedNodeCount; ++i) insertedNodes.emplace_back(i + 1); list.pushBack(&head); list.pushBack(&tail); for (auto& node : insertedNodes) list.insertAfter(&head, &node); assert(list.size() == insertedNodeCount + 2); assert(list.front() == &head); assert(list.back() == &tail); assert(LabeledList::previous(&head) == nullptr); assert(LabeledList::next(&tail) == nullptr); assertStrictlyIncreasingLabels(list); auto* firstInserted = LabeledList::next(&head); auto* secondInserted = LabeledList::next(firstInserted); list.moveBefore(firstInserted, secondInserted); list.moveAfter(&head, nullptr); list.moveAfter(&tail, LabeledList::previous(&tail)); assert(list.front() == &head); assert(list.back() == &tail); assert(firstInserted == &insertedNodes.back()); assert(secondInserted == &insertedNodes[insertedNodeCount - 2]); assertStrictlyIncreasingLabels(list); int expectedId = insertedNodeCount; auto it = std::next(list.begin()); for (; it != list.end() && &*it != &tail; ++it, --expectedId) assert(it->id == expectedId); assert(expectedId == 0); list.clear(); return 0; } } // namespace int main(int argc, char* argv[]) { (void) argc; (void) argv; int failures = 0; failures += testLabeledListBasicMutation(); failures += testLabeledListRelabelingAndNoopMoves(); if (failures != 0) { std::cerr << failures << " test failures\n"; return EXIT_FAILURE; } return EXIT_SUCCESS; }