faster (and refactored) DCP analysis
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h16m17s
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h16m17s
This commit is contained in:
162
test/PIM/LabeledListTest.cpp
Normal file
162
test/PIM/LabeledListTest.cpp
Normal file
@@ -0,0 +1,162 @@
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <initializer_list>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/LabeledList.hpp"
|
||||
|
||||
using onnx_mlir::LabeledList;
|
||||
using onnx_mlir::LabeledListNode;
|
||||
|
||||
namespace {
|
||||
|
||||
struct TestNode : public LabeledListNode<TestNode> {
|
||||
explicit TestNode(int id)
|
||||
: id(id) {}
|
||||
|
||||
int id;
|
||||
};
|
||||
|
||||
void assertOrder(LabeledList<TestNode>& list, std::initializer_list<int> expectedOrder) {
|
||||
auto expectedIt = expectedOrder.begin();
|
||||
for (auto& node : list) {
|
||||
assert(expectedIt != expectedOrder.end());
|
||||
assert(node.id == *expectedIt);
|
||||
++expectedIt;
|
||||
}
|
||||
assert(expectedIt == expectedOrder.end());
|
||||
}
|
||||
|
||||
void assertStrictlyIncreasingLabels(LabeledList<TestNode>& list) {
|
||||
auto it = list.begin();
|
||||
if (it == list.end())
|
||||
return;
|
||||
|
||||
auto previousLabel = it->getOrderLabel();
|
||||
++it;
|
||||
for (; it != list.end(); ++it) {
|
||||
assert(previousLabel < it->getOrderLabel());
|
||||
previousLabel = it->getOrderLabel();
|
||||
}
|
||||
}
|
||||
|
||||
int testLabeledListBasicMutation() {
|
||||
std::cout << "testLabeledListBasicMutation:" << std::endl;
|
||||
|
||||
LabeledList<TestNode> list;
|
||||
TestNode n1(1);
|
||||
TestNode n2(2);
|
||||
TestNode n3(3);
|
||||
TestNode n4(4);
|
||||
TestNode n5(5);
|
||||
|
||||
assert(list.empty());
|
||||
assert(list.front() == nullptr);
|
||||
assert(list.back() == nullptr);
|
||||
assert(!list.contains(&n1));
|
||||
assert(LabeledList<TestNode>::previous(&n1) == nullptr);
|
||||
assert(LabeledList<TestNode>::next(&n1) == nullptr);
|
||||
|
||||
list.pushBack(&n1);
|
||||
list.pushBack(&n3);
|
||||
list.insertAfter(&n1, &n2);
|
||||
list.pushFront(&n4);
|
||||
list.insertBefore(nullptr, &n5);
|
||||
|
||||
assert(list.size() == 5);
|
||||
assert(list.front() == &n4);
|
||||
assert(list.back() == &n5);
|
||||
assert(list.contains(&n2));
|
||||
assertOrder(list, {4, 1, 2, 3, 5});
|
||||
assert(LabeledList<TestNode>::next(&n4) == &n1);
|
||||
assert(LabeledList<TestNode>::previous(&n1) == &n4);
|
||||
assert(LabeledList<TestNode>::next(&n5) == nullptr);
|
||||
assert(list.comesBefore(&n1, &n3));
|
||||
assert(list.getOrderLabel(&n1) < list.getOrderLabel(&n3));
|
||||
|
||||
list.moveBefore(&n5, &n2);
|
||||
assertOrder(list, {4, 1, 5, 2, 3});
|
||||
|
||||
list.moveAfter(&n4, &n3);
|
||||
assertOrder(list, {1, 5, 2, 3, 4});
|
||||
|
||||
list.remove(&n2);
|
||||
assert(!n2.isLinked());
|
||||
assert(!list.contains(&n2));
|
||||
assertOrder(list, {1, 5, 3, 4});
|
||||
|
||||
list.clear();
|
||||
assert(list.empty());
|
||||
assert(list.size() == 0);
|
||||
assert(list.front() == nullptr);
|
||||
assert(list.back() == nullptr);
|
||||
assert(!n1.isLinked());
|
||||
assert(!n3.isLinked());
|
||||
assert(!n4.isLinked());
|
||||
assert(!n5.isLinked());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int testLabeledListRelabelingAndNoopMoves() {
|
||||
std::cout << "testLabeledListRelabelingAndNoopMoves:" << std::endl;
|
||||
|
||||
constexpr int insertedNodeCount = 80;
|
||||
LabeledList<TestNode> list;
|
||||
TestNode head(0);
|
||||
TestNode tail(999);
|
||||
std::vector<TestNode> insertedNodes;
|
||||
insertedNodes.reserve(insertedNodeCount);
|
||||
for (int i = 0; i < insertedNodeCount; ++i)
|
||||
insertedNodes.emplace_back(i + 1);
|
||||
|
||||
list.pushBack(&head);
|
||||
list.pushBack(&tail);
|
||||
for (auto& node : insertedNodes)
|
||||
list.insertAfter(&head, &node);
|
||||
|
||||
assert(list.size() == insertedNodeCount + 2);
|
||||
assert(list.front() == &head);
|
||||
assert(list.back() == &tail);
|
||||
assert(LabeledList<TestNode>::previous(&head) == nullptr);
|
||||
assert(LabeledList<TestNode>::next(&tail) == nullptr);
|
||||
assertStrictlyIncreasingLabels(list);
|
||||
|
||||
auto* firstInserted = LabeledList<TestNode>::next(&head);
|
||||
auto* secondInserted = LabeledList<TestNode>::next(firstInserted);
|
||||
list.moveBefore(firstInserted, secondInserted);
|
||||
list.moveAfter(&head, nullptr);
|
||||
list.moveAfter(&tail, LabeledList<TestNode>::previous(&tail));
|
||||
|
||||
assert(list.front() == &head);
|
||||
assert(list.back() == &tail);
|
||||
assert(firstInserted == &insertedNodes.back());
|
||||
assert(secondInserted == &insertedNodes[insertedNodeCount - 2]);
|
||||
assertStrictlyIncreasingLabels(list);
|
||||
|
||||
int expectedId = insertedNodeCount;
|
||||
auto it = std::next(list.begin());
|
||||
for (; it != list.end() && &*it != &tail; ++it, --expectedId)
|
||||
assert(it->id == expectedId);
|
||||
assert(expectedId == 0);
|
||||
list.clear();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
(void) argc;
|
||||
(void) argv;
|
||||
|
||||
int failures = 0;
|
||||
failures += testLabeledListBasicMutation();
|
||||
failures += testLabeledListRelabelingAndNoopMoves();
|
||||
if (failures != 0) {
|
||||
std::cerr << failures << " test failures\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
Reference in New Issue
Block a user