All checks were successful
Validate Operations / validate-operations (push) Successful in 2h16m17s
163 lines
4.2 KiB
C++
163 lines
4.2 KiB
C++
#include <cassert>
|
|
#include <cstdlib>
|
|
#include <initializer_list>
|
|
#include <iostream>
|
|
#include <vector>
|
|
|
|
#include "src/Accelerators/PIM/Common/LabeledList.hpp"
|
|
|
|
using onnx_mlir::LabeledList;
|
|
using onnx_mlir::LabeledListNode;
|
|
|
|
namespace {
|
|
|
|
struct TestNode : public LabeledListNode<TestNode> {
|
|
explicit TestNode(int id)
|
|
: id(id) {}
|
|
|
|
int id;
|
|
};
|
|
|
|
void assertOrder(LabeledList<TestNode>& list, std::initializer_list<int> expectedOrder) {
|
|
auto expectedIt = expectedOrder.begin();
|
|
for (auto& node : list) {
|
|
assert(expectedIt != expectedOrder.end());
|
|
assert(node.id == *expectedIt);
|
|
++expectedIt;
|
|
}
|
|
assert(expectedIt == expectedOrder.end());
|
|
}
|
|
|
|
void assertStrictlyIncreasingLabels(LabeledList<TestNode>& list) {
|
|
auto it = list.begin();
|
|
if (it == list.end())
|
|
return;
|
|
|
|
auto previousLabel = it->getOrderLabel();
|
|
++it;
|
|
for (; it != list.end(); ++it) {
|
|
assert(previousLabel < it->getOrderLabel());
|
|
previousLabel = it->getOrderLabel();
|
|
}
|
|
}
|
|
|
|
int testLabeledListBasicMutation() {
|
|
std::cout << "testLabeledListBasicMutation:" << std::endl;
|
|
|
|
LabeledList<TestNode> list;
|
|
TestNode n1(1);
|
|
TestNode n2(2);
|
|
TestNode n3(3);
|
|
TestNode n4(4);
|
|
TestNode n5(5);
|
|
|
|
assert(list.empty());
|
|
assert(list.front() == nullptr);
|
|
assert(list.back() == nullptr);
|
|
assert(!list.contains(&n1));
|
|
assert(LabeledList<TestNode>::previous(&n1) == nullptr);
|
|
assert(LabeledList<TestNode>::next(&n1) == nullptr);
|
|
|
|
list.pushBack(&n1);
|
|
list.pushBack(&n3);
|
|
list.insertAfter(&n1, &n2);
|
|
list.pushFront(&n4);
|
|
list.insertBefore(nullptr, &n5);
|
|
|
|
assert(list.size() == 5);
|
|
assert(list.front() == &n4);
|
|
assert(list.back() == &n5);
|
|
assert(list.contains(&n2));
|
|
assertOrder(list, {4, 1, 2, 3, 5});
|
|
assert(LabeledList<TestNode>::next(&n4) == &n1);
|
|
assert(LabeledList<TestNode>::previous(&n1) == &n4);
|
|
assert(LabeledList<TestNode>::next(&n5) == nullptr);
|
|
assert(list.comesBefore(&n1, &n3));
|
|
assert(list.getOrderLabel(&n1) < list.getOrderLabel(&n3));
|
|
|
|
list.moveBefore(&n5, &n2);
|
|
assertOrder(list, {4, 1, 5, 2, 3});
|
|
|
|
list.moveAfter(&n4, &n3);
|
|
assertOrder(list, {1, 5, 2, 3, 4});
|
|
|
|
list.remove(&n2);
|
|
assert(!n2.isLinked());
|
|
assert(!list.contains(&n2));
|
|
assertOrder(list, {1, 5, 3, 4});
|
|
|
|
list.clear();
|
|
assert(list.empty());
|
|
assert(list.size() == 0);
|
|
assert(list.front() == nullptr);
|
|
assert(list.back() == nullptr);
|
|
assert(!n1.isLinked());
|
|
assert(!n3.isLinked());
|
|
assert(!n4.isLinked());
|
|
assert(!n5.isLinked());
|
|
|
|
return 0;
|
|
}
|
|
|
|
int testLabeledListRelabelingAndNoopMoves() {
|
|
std::cout << "testLabeledListRelabelingAndNoopMoves:" << std::endl;
|
|
|
|
constexpr int insertedNodeCount = 80;
|
|
LabeledList<TestNode> list;
|
|
TestNode head(0);
|
|
TestNode tail(999);
|
|
std::vector<TestNode> insertedNodes;
|
|
insertedNodes.reserve(insertedNodeCount);
|
|
for (int i = 0; i < insertedNodeCount; ++i)
|
|
insertedNodes.emplace_back(i + 1);
|
|
|
|
list.pushBack(&head);
|
|
list.pushBack(&tail);
|
|
for (auto& node : insertedNodes)
|
|
list.insertAfter(&head, &node);
|
|
|
|
assert(list.size() == insertedNodeCount + 2);
|
|
assert(list.front() == &head);
|
|
assert(list.back() == &tail);
|
|
assert(LabeledList<TestNode>::previous(&head) == nullptr);
|
|
assert(LabeledList<TestNode>::next(&tail) == nullptr);
|
|
assertStrictlyIncreasingLabels(list);
|
|
|
|
auto* firstInserted = LabeledList<TestNode>::next(&head);
|
|
auto* secondInserted = LabeledList<TestNode>::next(firstInserted);
|
|
list.moveBefore(firstInserted, secondInserted);
|
|
list.moveAfter(&head, nullptr);
|
|
list.moveAfter(&tail, LabeledList<TestNode>::previous(&tail));
|
|
|
|
assert(list.front() == &head);
|
|
assert(list.back() == &tail);
|
|
assert(firstInserted == &insertedNodes.back());
|
|
assert(secondInserted == &insertedNodes[insertedNodeCount - 2]);
|
|
assertStrictlyIncreasingLabels(list);
|
|
|
|
int expectedId = insertedNodeCount;
|
|
auto it = std::next(list.begin());
|
|
for (; it != list.end() && &*it != &tail; ++it, --expectedId)
|
|
assert(it->id == expectedId);
|
|
assert(expectedId == 0);
|
|
list.clear();
|
|
|
|
return 0;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
int main(int argc, char* argv[]) {
|
|
(void) argc;
|
|
(void) argv;
|
|
|
|
int failures = 0;
|
|
failures += testLabeledListBasicMutation();
|
|
failures += testLabeledListRelabelingAndNoopMoves();
|
|
if (failures != 0) {
|
|
std::cerr << failures << " test failures\n";
|
|
return EXIT_FAILURE;
|
|
}
|
|
return EXIT_SUCCESS;
|
|
}
|