/* * SPDX-License-Identifier: Apache-2.0 */ #include "src/Accelerators/PIM/Common/LabeledList.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.hpp" #include #include #include #include #include using onnx_mlir::LabeledList; using onnx_mlir::LabeledListNode; namespace { struct TestNode : public LabeledListNode { explicit TestNode(int id) : id(id) {} int id; }; void assertOrder(LabeledList& list, std::initializer_list expectedOrder) { auto expectedIt = expectedOrder.begin(); for (auto& node : list) { assert(expectedIt != expectedOrder.end()); assert(node.id == *expectedIt); ++expectedIt; } assert(expectedIt == expectedOrder.end()); } int testLabeledList() { std::cout << "testLabeledList:" << std::endl; LabeledList list; TestNode n1(1); TestNode n2(2); TestNode n3(3); TestNode n4(4); TestNode n5(5); list.pushBack(&n1); list.pushBack(&n3); list.insertAfter(&n1, &n2); list.pushFront(&n4); list.insertBefore(nullptr, &n5); assertOrder(list, {4, 1, 2, 3, 5}); assert(LabeledList::next(&n4) == &n1); assert(LabeledList::previous(&n1) == &n4); assert(LabeledList::next(&n5) == nullptr); assert(list.comesBefore(&n1, &n3)); assert(list.getOrderLabel(&n1) < list.getOrderLabel(&n3)); list.moveBefore(&n5, &n2); assertOrder(list, {4, 1, 5, 2, 3}); list.moveAfter(&n4, &n3); assertOrder(list, {1, 5, 2, 3, 4}); list.remove(&n2); assert(!n2.isLinked()); assertOrder(list, {1, 5, 3, 4}); list.clear(); assert(list.empty()); assert(!n1.isLinked()); assert(!n3.isLinked()); assert(!n4.isLinked()); assert(!n5.isLinked()); return 0; } struct ExpectedScheduledTask { size_t nodeIndex; int aest; int alst; int weight; }; void assertScheduledTasks(GraphDCP& graph, CPU cpu, std::initializer_list expectedTasks) { auto actualTasks = graph.getScheduledTasks(cpu); assert(actualTasks.size() == expectedTasks.size()); auto expectedIt = expectedTasks.begin(); for (const auto& actualTask : actualTasks) { assert(expectedIt != expectedTasks.end()); if (actualTask.nodeIndex != expectedIt->nodeIndex || actualTask.aest != expectedIt->aest || actualTask.alst != expectedIt->alst || actualTask.weight != expectedIt->weight) { std::cerr << "CPU " << cpu << " actual schedule:\n"; for (const auto& task : actualTasks) { std::cerr << " " << task.nodeIndex << ") aest: " << task.aest << " alst: " << task.alst << " weight: " << task.weight << '\n'; } } assert(actualTask.nodeIndex == expectedIt->nodeIndex); assert(actualTask.aest == expectedIt->aest); assert(actualTask.alst == expectedIt->alst); assert(actualTask.weight == expectedIt->weight); ++expectedIt; } } int testDCPGraphFixture() { std::cout << "testDCPGraphFixture:" << std::endl; const std::vector nodeWeights = { 80, 40, 40, 40, 40, 40, 60, 30, 30, 30, 30, 40, 20, 20, 20, 20, 10, 10, }; GraphDCP graph(nodeWeights, {}); graph.makeEdge(0, 1, 3); graph.makeEdge(0, 1, 120); graph.makeEdge(0, 2, 120); graph.makeEdge(0, 3, 120); graph.makeEdge(0, 4, 120); graph.makeEdge(0, 5, 120); graph.makeEdge(0, 6, 120); graph.makeEdge(2, 6, 80); graph.makeEdge(2, 7, 80); graph.makeEdge(3, 8, 80); graph.makeEdge(4, 9, 80); graph.makeEdge(5, 10, 80); graph.makeEdge(6, 7, 120); graph.makeEdge(6, 8, 120); graph.makeEdge(6, 9, 120); graph.makeEdge(6, 10, 120); graph.makeEdge(6, 11, 120); graph.makeEdge(8, 11, 80); graph.makeEdge(8, 12, 80); graph.makeEdge(9, 13, 80); graph.makeEdge(10, 14, 80); graph.makeEdge(11, 12, 120); graph.makeEdge(11, 13, 120); graph.makeEdge(11, 14, 120); graph.makeEdge(11, 15, 120); graph.makeEdge(13, 15, 80); graph.makeEdge(13, 16, 80); graph.makeEdge(14, 17, 80); graph.makeEdge(15, 16, 120); graph.makeEdge(15, 17, 120); graph.DCP(); for (CPU cpu = 0; cpu < graph.cpuCount(); ++cpu) { auto scheduledTasks = graph.getScheduledTasks(cpu); std::cerr << "CPU " << cpu << " computed schedule:\n"; for (const auto& task : scheduledTasks) { std::cerr << " " << task.nodeIndex << ") aest: " << task.aest << " alst: " << task.alst << " weight: " << task.weight << '\n'; } } assert(graph.cpuCount() == 4); assertScheduledTasks(graph, 3, { {1, 200, 370, 40}, }); assertScheduledTasks(graph, 2, { {5, 200, 260, 40}, {10, 300, 300, 30}, }); assertScheduledTasks(graph, 1, { {4, 200, 210, 40}, {7, 300, 380, 30}, }); assertScheduledTasks(graph, 0, { {0, 0, 0, 80}, {2, 80, 80, 40}, {6, 120, 120, 60}, {3, 180, 200, 40}, {8, 220, 240, 30}, {11, 250, 270, 40}, {12, 290, 310, 20}, {9, 320, 330, 30}, {13, 350, 360, 20}, {15, 370, 380, 20}, {16, 390, 400, 10}, {14, 410, 410, 20}, {17, 430, 430, 10}, }); return 0; } } // namespace int main(int argc, char* argv[]) { (void) argc; (void) argv; int failures = 0; failures += testLabeledList(); failures += testDCPGraphFixture(); if (failures != 0) { std::cerr << failures << " test failures\n"; return EXIT_FAILURE; } return EXIT_SUCCESS; }