faster (and refactored) DCP analysis
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h16m17s

This commit is contained in:
NiccoloN
2026-04-21 12:33:44 +02:00
parent f4c6da8f10
commit 85e2750d6c
20 changed files with 2525 additions and 858 deletions

View File

@@ -0,0 +1,162 @@
#include <cassert>
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <vector>
#include "src/Accelerators/PIM/Common/LabeledList.hpp"
using onnx_mlir::LabeledList;
using onnx_mlir::LabeledListNode;
namespace {
struct TestNode : public LabeledListNode<TestNode> {
explicit TestNode(int id)
: id(id) {}
int id;
};
void assertOrder(LabeledList<TestNode>& list, std::initializer_list<int> expectedOrder) {
auto expectedIt = expectedOrder.begin();
for (auto& node : list) {
assert(expectedIt != expectedOrder.end());
assert(node.id == *expectedIt);
++expectedIt;
}
assert(expectedIt == expectedOrder.end());
}
void assertStrictlyIncreasingLabels(LabeledList<TestNode>& list) {
auto it = list.begin();
if (it == list.end())
return;
auto previousLabel = it->getOrderLabel();
++it;
for (; it != list.end(); ++it) {
assert(previousLabel < it->getOrderLabel());
previousLabel = it->getOrderLabel();
}
}
int testLabeledListBasicMutation() {
std::cout << "testLabeledListBasicMutation:" << std::endl;
LabeledList<TestNode> list;
TestNode n1(1);
TestNode n2(2);
TestNode n3(3);
TestNode n4(4);
TestNode n5(5);
assert(list.empty());
assert(list.front() == nullptr);
assert(list.back() == nullptr);
assert(!list.contains(&n1));
assert(LabeledList<TestNode>::previous(&n1) == nullptr);
assert(LabeledList<TestNode>::next(&n1) == nullptr);
list.pushBack(&n1);
list.pushBack(&n3);
list.insertAfter(&n1, &n2);
list.pushFront(&n4);
list.insertBefore(nullptr, &n5);
assert(list.size() == 5);
assert(list.front() == &n4);
assert(list.back() == &n5);
assert(list.contains(&n2));
assertOrder(list, {4, 1, 2, 3, 5});
assert(LabeledList<TestNode>::next(&n4) == &n1);
assert(LabeledList<TestNode>::previous(&n1) == &n4);
assert(LabeledList<TestNode>::next(&n5) == nullptr);
assert(list.comesBefore(&n1, &n3));
assert(list.getOrderLabel(&n1) < list.getOrderLabel(&n3));
list.moveBefore(&n5, &n2);
assertOrder(list, {4, 1, 5, 2, 3});
list.moveAfter(&n4, &n3);
assertOrder(list, {1, 5, 2, 3, 4});
list.remove(&n2);
assert(!n2.isLinked());
assert(!list.contains(&n2));
assertOrder(list, {1, 5, 3, 4});
list.clear();
assert(list.empty());
assert(list.size() == 0);
assert(list.front() == nullptr);
assert(list.back() == nullptr);
assert(!n1.isLinked());
assert(!n3.isLinked());
assert(!n4.isLinked());
assert(!n5.isLinked());
return 0;
}
int testLabeledListRelabelingAndNoopMoves() {
std::cout << "testLabeledListRelabelingAndNoopMoves:" << std::endl;
constexpr int insertedNodeCount = 80;
LabeledList<TestNode> list;
TestNode head(0);
TestNode tail(999);
std::vector<TestNode> insertedNodes;
insertedNodes.reserve(insertedNodeCount);
for (int i = 0; i < insertedNodeCount; ++i)
insertedNodes.emplace_back(i + 1);
list.pushBack(&head);
list.pushBack(&tail);
for (auto& node : insertedNodes)
list.insertAfter(&head, &node);
assert(list.size() == insertedNodeCount + 2);
assert(list.front() == &head);
assert(list.back() == &tail);
assert(LabeledList<TestNode>::previous(&head) == nullptr);
assert(LabeledList<TestNode>::next(&tail) == nullptr);
assertStrictlyIncreasingLabels(list);
auto* firstInserted = LabeledList<TestNode>::next(&head);
auto* secondInserted = LabeledList<TestNode>::next(firstInserted);
list.moveBefore(firstInserted, secondInserted);
list.moveAfter(&head, nullptr);
list.moveAfter(&tail, LabeledList<TestNode>::previous(&tail));
assert(list.front() == &head);
assert(list.back() == &tail);
assert(firstInserted == &insertedNodes.back());
assert(secondInserted == &insertedNodes[insertedNodeCount - 2]);
assertStrictlyIncreasingLabels(list);
int expectedId = insertedNodeCount;
auto it = std::next(list.begin());
for (; it != list.end() && &*it != &tail; ++it, --expectedId)
assert(it->id == expectedId);
assert(expectedId == 0);
list.clear();
return 0;
}
} // namespace
int main(int argc, char* argv[]) {
(void) argc;
(void) argv;
int failures = 0;
failures += testLabeledListBasicMutation();
failures += testLabeledListRelabelingAndNoopMoves();
if (failures != 0) {
std::cerr << failures << " test failures\n";
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}