add 2 unittests
Some checks failed
Validate Operations / validate-operations (push) Failing after 8m9s
Some checks failed
Validate Operations / validate-operations (push) Failing after 8m9s
fix bugs
This commit is contained in:
202
test/PIM/TestPIM.cpp
Normal file
202
test/PIM/TestPIM.cpp
Normal file
@@ -0,0 +1,202 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
#include "src/Accelerators/PIM/Common/LabeledList.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <initializer_list>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
using onnx_mlir::LabeledList;
|
||||
using onnx_mlir::LabeledListNode;
|
||||
|
||||
namespace {
|
||||
|
||||
struct TestNode : public LabeledListNode<TestNode> {
|
||||
explicit TestNode(int id)
|
||||
: id(id) {}
|
||||
|
||||
int id;
|
||||
};
|
||||
|
||||
void assertOrder(LabeledList<TestNode>& list, std::initializer_list<int> expectedOrder) {
|
||||
auto expectedIt = expectedOrder.begin();
|
||||
for (auto& node : list) {
|
||||
assert(expectedIt != expectedOrder.end());
|
||||
assert(node.id == *expectedIt);
|
||||
++expectedIt;
|
||||
}
|
||||
assert(expectedIt == expectedOrder.end());
|
||||
}
|
||||
|
||||
int testLabeledList() {
|
||||
std::cout << "testLabeledList:" << std::endl;
|
||||
|
||||
LabeledList<TestNode> list;
|
||||
TestNode n1(1);
|
||||
TestNode n2(2);
|
||||
TestNode n3(3);
|
||||
TestNode n4(4);
|
||||
TestNode n5(5);
|
||||
|
||||
list.pushBack(&n1);
|
||||
list.pushBack(&n3);
|
||||
list.insertAfter(&n1, &n2);
|
||||
list.pushFront(&n4);
|
||||
list.insertBefore(nullptr, &n5);
|
||||
|
||||
assertOrder(list, {4, 1, 2, 3, 5});
|
||||
assert(LabeledList<TestNode>::next(&n4) == &n1);
|
||||
assert(LabeledList<TestNode>::previous(&n1) == &n4);
|
||||
assert(LabeledList<TestNode>::next(&n5) == nullptr);
|
||||
assert(list.comesBefore(&n1, &n3));
|
||||
assert(list.getOrderLabel(&n1) < list.getOrderLabel(&n3));
|
||||
|
||||
list.moveBefore(&n5, &n2);
|
||||
assertOrder(list, {4, 1, 5, 2, 3});
|
||||
|
||||
list.moveAfter(&n4, &n3);
|
||||
assertOrder(list, {1, 5, 2, 3, 4});
|
||||
|
||||
list.remove(&n2);
|
||||
assert(!n2.isLinked());
|
||||
assertOrder(list, {1, 5, 3, 4});
|
||||
|
||||
list.clear();
|
||||
assert(list.empty());
|
||||
assert(!n1.isLinked());
|
||||
assert(!n3.isLinked());
|
||||
assert(!n4.isLinked());
|
||||
assert(!n5.isLinked());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct ExpectedScheduledTask {
|
||||
size_t nodeIndex;
|
||||
int aest;
|
||||
int alst;
|
||||
int weight;
|
||||
};
|
||||
|
||||
void assertScheduledTasks(GraphDCP& graph, CPU cpu, std::initializer_list<ExpectedScheduledTask> expectedTasks) {
|
||||
auto actualTasks = graph.getScheduledTasks(cpu);
|
||||
assert(actualTasks.size() == expectedTasks.size());
|
||||
|
||||
auto expectedIt = expectedTasks.begin();
|
||||
for (const auto& actualTask : actualTasks) {
|
||||
assert(expectedIt != expectedTasks.end());
|
||||
if (actualTask.nodeIndex != expectedIt->nodeIndex || actualTask.aest != expectedIt->aest
|
||||
|| actualTask.alst != expectedIt->alst || actualTask.weight != expectedIt->weight) {
|
||||
std::cerr << "CPU " << cpu << " actual schedule:\n";
|
||||
for (const auto& task : actualTasks) {
|
||||
std::cerr << " " << task.nodeIndex << ") aest: " << task.aest << " alst: " << task.alst
|
||||
<< " weight: " << task.weight << '\n';
|
||||
}
|
||||
}
|
||||
assert(actualTask.nodeIndex == expectedIt->nodeIndex);
|
||||
assert(actualTask.aest == expectedIt->aest);
|
||||
assert(actualTask.alst == expectedIt->alst);
|
||||
assert(actualTask.weight == expectedIt->weight);
|
||||
++expectedIt;
|
||||
}
|
||||
}
|
||||
|
||||
int testDCPGraphFixture() {
|
||||
std::cout << "testDCPGraphFixture:" << std::endl;
|
||||
|
||||
const std::vector<Weight_t> nodeWeights = {
|
||||
80, 40, 40, 40, 40, 40, 60, 30, 30, 30,
|
||||
30, 40, 20, 20, 20, 20, 10, 10,
|
||||
};
|
||||
|
||||
GraphDCP graph(nodeWeights, {});
|
||||
graph.makeEdge(0, 1, 3);
|
||||
graph.makeEdge(0, 1, 120);
|
||||
graph.makeEdge(0, 2, 120);
|
||||
graph.makeEdge(0, 3, 120);
|
||||
graph.makeEdge(0, 4, 120);
|
||||
graph.makeEdge(0, 5, 120);
|
||||
graph.makeEdge(0, 6, 120);
|
||||
graph.makeEdge(2, 6, 80);
|
||||
graph.makeEdge(2, 7, 80);
|
||||
graph.makeEdge(3, 8, 80);
|
||||
graph.makeEdge(4, 9, 80);
|
||||
graph.makeEdge(5, 10, 80);
|
||||
graph.makeEdge(6, 7, 120);
|
||||
graph.makeEdge(6, 8, 120);
|
||||
graph.makeEdge(6, 9, 120);
|
||||
graph.makeEdge(6, 10, 120);
|
||||
graph.makeEdge(6, 11, 120);
|
||||
graph.makeEdge(8, 11, 80);
|
||||
graph.makeEdge(8, 12, 80);
|
||||
graph.makeEdge(9, 13, 80);
|
||||
graph.makeEdge(10, 14, 80);
|
||||
graph.makeEdge(11, 12, 120);
|
||||
graph.makeEdge(11, 13, 120);
|
||||
graph.makeEdge(11, 14, 120);
|
||||
graph.makeEdge(11, 15, 120);
|
||||
graph.makeEdge(13, 15, 80);
|
||||
graph.makeEdge(13, 16, 80);
|
||||
graph.makeEdge(14, 17, 80);
|
||||
graph.makeEdge(15, 16, 120);
|
||||
graph.makeEdge(15, 17, 120);
|
||||
|
||||
graph.DCP();
|
||||
for (CPU cpu = 0; cpu < graph.cpuCount(); ++cpu) {
|
||||
auto scheduledTasks = graph.getScheduledTasks(cpu);
|
||||
std::cerr << "CPU " << cpu << " computed schedule:\n";
|
||||
for (const auto& task : scheduledTasks) {
|
||||
std::cerr << " " << task.nodeIndex << ") aest: " << task.aest << " alst: " << task.alst
|
||||
<< " weight: " << task.weight << '\n';
|
||||
}
|
||||
}
|
||||
assert(graph.cpuCount() == 4);
|
||||
assertScheduledTasks(graph, 3, {
|
||||
{1, 200, 370, 40},
|
||||
});
|
||||
assertScheduledTasks(graph, 2, {
|
||||
{5, 200, 260, 40},
|
||||
{10, 300, 300, 30},
|
||||
});
|
||||
assertScheduledTasks(graph, 1, {
|
||||
{4, 200, 210, 40},
|
||||
{7, 300, 380, 30},
|
||||
});
|
||||
assertScheduledTasks(graph, 0, {
|
||||
{0, 0, 0, 80},
|
||||
{2, 80, 80, 40},
|
||||
{6, 120, 120, 60},
|
||||
{3, 180, 200, 40},
|
||||
{8, 220, 240, 30},
|
||||
{11, 250, 270, 40},
|
||||
{12, 290, 310, 20},
|
||||
{9, 320, 330, 30},
|
||||
{13, 350, 360, 20},
|
||||
{15, 370, 380, 20},
|
||||
{16, 390, 400, 10},
|
||||
{14, 410, 410, 20},
|
||||
{17, 430, 430, 10},
|
||||
});
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
(void) argc;
|
||||
(void) argv;
|
||||
|
||||
int failures = 0;
|
||||
failures += testLabeledList();
|
||||
failures += testDCPGraphFixture();
|
||||
if (failures != 0) {
|
||||
std::cerr << failures << " test failures\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
Reference in New Issue
Block a user