add 2 unittests
Some checks failed
Validate Operations / validate-operations (push) Failing after 8m9s

fix bugs
This commit is contained in:
NiccoloN
2026-04-16 18:01:38 +02:00
parent 197c38f9ca
commit a903e30859
9 changed files with 336 additions and 58 deletions

View File

@@ -31,9 +31,7 @@ public:
bool isLinked() const { return owner_ != nullptr; } bool isLinked() const { return owner_ != nullptr; }
Label getOrderLabel() const { return label; } Label getOrderLabel() const { return label; }
friend bool operator<(const LabeledListNode& lft, const LabeledListNode& rgt){ friend bool operator<(const LabeledListNode& lft, const LabeledListNode& rgt) { return lft.label < rgt.label; }
return lft.label < rgt.label;
}
private: private:
const void* owner_ = nullptr; const void* owner_ = nullptr;
@@ -79,17 +77,17 @@ public:
auto it = node->getIterator(); auto it = node->getIterator();
if (it == list->nodes_.begin()) if (it == list->nodes_.begin())
return nullptr; return nullptr;
return *std::prev(it); return &*std::prev(it);
} }
static const NodeT* previous(const NodeT* node) { static const NodeT* previous(const NodeT* node) {
if (!node || !owner(node)) if (!node || !owner(node))
return nullptr; return nullptr;
const auto* list = owner(node); const auto* list = owner(node);
auto it = node->getIterator(); auto it = const_cast<NodeT*>(node)->getIterator();
if (it == list->nodes_.begin()) if (it == list->nodes_.begin())
return nullptr; return nullptr;
return *std::prev(it); return &*std::prev(it);
} }
static NodeT* next(NodeT* node) { static NodeT* next(NodeT* node) {
@@ -99,29 +97,29 @@ public:
auto it = std::next(node->getIterator()); auto it = std::next(node->getIterator());
if (it == list->nodes_.end()) if (it == list->nodes_.end())
return nullptr; return nullptr;
return *it; return &*it;
} }
static const NodeT* next(const NodeT* node) { static const NodeT* next(const NodeT* node) {
if (!node || !owner(node)) if (!node || !owner(node))
return nullptr; return nullptr;
const auto* list = owner(node); const auto* list = owner(node);
auto it = std::next(node->getIterator()); auto it = std::next(const_cast<NodeT*>(node)->getIterator());
if (it == list->nodes_.end()) if (it == list->nodes_.end())
return nullptr; return nullptr;
return *it; return &*it;
} }
bool contains(const NodeT* node) const { return node && node->owner_ == this; } bool contains(const NodeT* node) const { return node && node->owner_ == this; }
Label getOrderLabel(const NodeT* node) const { Label getOrderLabel(const NodeT* node) const {
assert(contains(node) && "node must belong to this list"); assert(contains(node) && "node must belong to this list");
return node->label_; return node->label;
} }
bool comesBefore(const NodeT* lhs, const NodeT* rhs) const { bool comesBefore(const NodeT* lhs, const NodeT* rhs) const {
assert(contains(lhs) && contains(rhs) && "nodes must belong to this list"); assert(contains(lhs) && contains(rhs) && "nodes must belong to this list");
return lhs->label_ < rhs->label_; return lhs->label < rhs->label;
} }
void pushFront(NodeT* node) { insertBefore(front(), node); } void pushFront(NodeT* node) { insertBefore(front(), node); }
@@ -152,7 +150,7 @@ public:
assert(contains(node) && "node must belong to this list"); assert(contains(node) && "node must belong to this list");
nodes_.remove(*node); nodes_.remove(*node);
node->owner_ = nullptr; node->owner_ = nullptr;
node->label_ = 0; node->label = 0;
--size_; --size_;
} }
@@ -190,15 +188,14 @@ public:
} }
Iterator begin() { return nodes_.begin(); } Iterator begin() { return nodes_.begin(); }
Iterator end() { return nodes_.end(); } Iterator end() { return nodes_.end(); }
RIterator rbegin() { return nodes_.rbegin(); } RIterator rbegin() { return nodes_.rbegin(); }
RIterator rend() { return nodes_.rend(); } RIterator rend() { return nodes_.rend(); }
private: private:
static const LabeledList* owner(const NodeT* node) { return node->owner_; } static const LabeledList* owner(const NodeT* node) { return static_cast<const LabeledList*>(node->owner_); }
static LabeledList* owner(NodeT* node) { return node->owner_; } static LabeledList* owner(NodeT* node) { return static_cast<LabeledList*>(const_cast<void*>(node->owner_)); }
static Label lowerLabel(const NodeT* node) { return node ? node->label : kLowerSentinel; } static Label lowerLabel(const NodeT* node) { return node ? node->label : kLowerSentinel; }
static Label upperLabel(const NodeT* node) { return node ? node->label : kUpperSentinel; } static Label upperLabel(const NodeT* node) { return node ? node->label : kUpperSentinel; }

View File

@@ -28,12 +28,12 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitSpatial) { if (pimEmissionTarget >= EmitSpatial) {
pm.addPass(createONNXToSpatialPass()); pm.addPass(createONNXToSpatialPass());
pm.addPass(createMergeComputeNodesPass());
// pm.addPass(createCountInstructionPass()); // pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Onnx lowered to Spatial")); pm.addPass(createMessagePass("Onnx lowered to Spatial"));
} }
if (pimEmissionTarget >= EmitPim) { if (pimEmissionTarget >= EmitPim) {
pm.addPass(createMergeComputeNodesPass());
pm.addPass(createSpatialToPimPass()); pm.addPass(createSpatialToPimPass());
// pm.addPass(createCountInstructionPass()); // pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Spatial lowered to Pim")); pm.addPass(createMessagePass("Spatial lowered to Pim"));

View File

@@ -47,6 +47,11 @@ int getTranferCost(TaskDCP* parent, TaskDCP* child) {
return child_position->second; return child_position->second;
} }
size_t GraphDCP::getNodeIndex(const TaskDCP* task) const {
assert(task >= nodes.data() && task < nodes.data() + nodes.size() && "task must belong to graph");
return static_cast<size_t>(task - nodes.data());
}
TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position) { TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position) {
TaskInsertion ret; TaskInsertion ret;
task->setCPU(cpu); task->setCPU(cpu);
@@ -245,12 +250,12 @@ void GraphDCP::topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint) {
auto moveChildAfterMe = [this](TaskDCP* origTask) -> void { auto moveChildAfterMe = [this](TaskDCP* origTask) -> void {
auto cmp = [](Edge_t lft, Edge_t rgt) { return *rgt.first < *lft.first; }; auto cmp = [](Edge_t lft, Edge_t rgt) { return *rgt.first < *lft.first; };
TaskDCP* insertionPoint = origTask; TaskDCP* insertionPoint = origTask;
std::vector<Edge_t>& childEdges = origTask->childs;
std::vector<TaskDCP*> worklist; std::vector<TaskDCP*> worklist;
worklist.push_back(origTask); worklist.push_back(origTask);
size_t i = 0; size_t i = 0;
while (i < worklist.size()) { while (i < worklist.size()) {
auto task = worklist[i]; auto task = worklist[i];
std::vector<Edge_t>& childEdges = task->childs;
// build min heap Complexity 3N // build min heap Complexity 3N
std::make_heap(childEdges.begin(), childEdges.end(), cmp); std::make_heap(childEdges.begin(), childEdges.end(), cmp);
auto lastPoppedIter = childEdges.end(); auto lastPoppedIter = childEdges.end();
@@ -279,12 +284,12 @@ void GraphDCP::topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint) {
} }
}; };
if (!(*task < *pivotPoint)) { if (!(*task < *pivotPoint))
return; return;
topologicalOrder.moveAfter(task, pivotPoint);
if (task->hasChilds()) topologicalOrder.moveAfter(task, pivotPoint);
moveChildAfterMe(task); if (task->hasChilds())
} moveChildAfterMe(task);
} }
void GraphDCP::topologicalMoveBefore(TaskDCP* task, TaskDCP* pivotPoint) { void GraphDCP::topologicalMoveBefore(TaskDCP* task, TaskDCP* pivotPoint) {
@@ -292,12 +297,12 @@ void GraphDCP::topologicalMoveBefore(TaskDCP* task, TaskDCP* pivotPoint) {
auto moveParentBeforeMe = [this](TaskDCP* origTask) -> void { auto moveParentBeforeMe = [this](TaskDCP* origTask) -> void {
auto cmp = [](Edge_t lft, Edge_t rgt) { return *lft.first < *rgt.first; }; auto cmp = [](Edge_t lft, Edge_t rgt) { return *lft.first < *rgt.first; };
TaskDCP* insertionPoint = origTask; TaskDCP* insertionPoint = origTask;
std::vector<Edge_t>& parentEdges = origTask->parents;
std::vector<TaskDCP*> worklist; std::vector<TaskDCP*> worklist;
worklist.push_back(origTask); worklist.push_back(origTask);
size_t i = 0; size_t i = 0;
while (i < worklist.size()) { while (i < worklist.size()) {
auto task = worklist[i]; auto task = worklist[i];
std::vector<Edge_t>& parentEdges = task->parents;
// build max heap Complexity 3N // build max heap Complexity 3N
std::make_heap(parentEdges.begin(), parentEdges.end(), cmp); std::make_heap(parentEdges.begin(), parentEdges.end(), cmp);
auto lastPoppedIter = parentEdges.end(); auto lastPoppedIter = parentEdges.end();
@@ -326,14 +331,13 @@ void GraphDCP::topologicalMoveBefore(TaskDCP* task, TaskDCP* pivotPoint) {
} }
}; };
if (!(*task < *pivotPoint)) { if (!(*pivotPoint < *task))
return; return;
topologicalOrder.moveBefore(task, pivotPoint);
if (task->hasParents())
moveParentBeforeMe(task);
}
}
topologicalOrder.moveBefore(task, pivotPoint);
if (task->hasParents())
moveParentBeforeMe(task);
}
GraphDCP::FindSlot GraphDCP::findSlot(TaskDCP* candidate, CPU cpu, bool push) { GraphDCP::FindSlot GraphDCP::findSlot(TaskDCP* candidate, CPU cpu, bool push) {
int aest_on_cpu = computeAEST(candidate, cpu); int aest_on_cpu = computeAEST(candidate, cpu);
@@ -407,7 +411,6 @@ GraphDCP::FindSlot GraphDCP::findSlot(TaskDCP* candidate, CPU cpu, bool push) {
} }
void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) { void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
std::vector<CPU> processors; std::vector<CPU> processors;
processors.reserve(lastCPU()); processors.reserve(lastCPU());
for (CPU c = push ? lastCPU() : lastCPU(); c >= 0; c--) for (CPU c = push ? lastCPU() : lastCPU(); c >= 0; c--)
@@ -571,3 +574,15 @@ DCPAnalysisResult GraphDCP::getResult() {
return ret; return ret;
} }
std::vector<GraphDCP::ScheduledTaskInfo> GraphDCP::getScheduledTasks(CPU cpu) const {
std::vector<ScheduledTaskInfo> scheduledTasks;
auto cpuIt = mapCPUTasks.find(cpu);
if (cpuIt == mapCPUTasks.end())
return scheduledTasks;
scheduledTasks.reserve(cpuIt->second.size());
for (auto* task : cpuIt->second)
scheduledTasks.push_back({getNodeIndex(task), task->getAEST(), task->getALST(), task->getWeight()});
return scheduledTasks;
}

View File

@@ -17,7 +17,15 @@ void removeEdge(TaskDCP* parent, TaskDCP* child);
int getTranferCost(TaskDCP* parent, TaskDCP* child); int getTranferCost(TaskDCP* parent, TaskDCP* child);
class GraphDCP { class GraphDCP {
public:
struct ScheduledTaskInfo {
size_t nodeIndex;
int aest;
int alst;
int weight;
};
private:
struct FindSlot { struct FindSlot {
int aest; int aest;
int index; int index;
@@ -35,7 +43,7 @@ class GraphDCP {
std::vector<TaskDCP*> getRoots(); std::vector<TaskDCP*> getRoots();
long long getUniqueFlag() {return flag++;}; long long getUniqueFlag() { return flag++; }
void initAEST(); void initAEST();
int initDCPL(); int initDCPL();
@@ -43,13 +51,14 @@ class GraphDCP {
int computeAEST(TaskDCP* task, CPU cpu); int computeAEST(TaskDCP* task, CPU cpu);
int computeDCPL(TaskDCP* task, CPU cpu); int computeDCPL(TaskDCP* task, CPU cpu);
int getDCPL() {return DCPL;}; int getDCPL() { return DCPL; }
void initTopological(); void initTopological();
void topologicalMoveAfter(TaskDCP* task, TaskDCP * pivotPoint); void topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint);
void topologicalMoveBefore(TaskDCP* task, TaskDCP * pivotPoint); void topologicalMoveBefore(TaskDCP* task, TaskDCP* pivotPoint);
llvm::DenseMap<TaskDCP*, int> computeALST(TaskDCP* task, CPU cpu); llvm::DenseMap<TaskDCP*, int> computeALST(TaskDCP* task, CPU cpu);
size_t getNodeIndex(const TaskDCP* task) const;
TaskDCP* findCandidate(std::vector<TaskDCP*> nodes); TaskDCP* findCandidate(std::vector<TaskDCP*> nodes);
void selectProcessor(TaskDCP* candidate, bool push); void selectProcessor(TaskDCP* candidate, bool push);
@@ -60,21 +69,29 @@ class GraphDCP {
friend TaskInsertion; friend TaskInsertion;
public: public:
void DCP(); void DCP();
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatWeightedCompute> spatWeightedComputes, GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatWeightedCompute> spatWeightedComputes,
llvm::ArrayRef<EdgesIndex> edges) llvm::ArrayRef<EdgesIndex> edges)
: nodes(), mapCPUTasks() { : nodes(), mapCPUTasks() {
for (auto spatWeightedCompute : spatWeightedComputes){ for (auto spatWeightedCompute : spatWeightedComputes)
nodes.emplace_back(spatWeightedCompute); nodes.emplace_back(spatWeightedCompute);
} for (auto [start, end, weight] : edges)
makeEdge(start, end, weight);
}
GraphDCP(llvm::ArrayRef<Weight_t> nodeWeights, llvm::ArrayRef<EdgesIndex> edges)
: nodes(), mapCPUTasks() {
nodes.reserve(nodeWeights.size());
for (auto [index, weight] : llvm::enumerate(nodeWeights))
nodes.emplace_back(index, weight);
for (auto [start, end, weight] : edges) for (auto [start, end, weight] : edges)
makeEdge(start, end, weight); makeEdge(start, end, weight);
} }
DCPAnalysisResult getResult(); DCPAnalysisResult getResult();
std::vector<ScheduledTaskInfo> getScheduledTasks(CPU cpu) const;
CPU cpuCount() const { return last_cpu; }
void makeEdge(size_t parent_index, size_t child_index, Weight_t weight) { void makeEdge(size_t parent_index, size_t child_index, Weight_t weight) {
addEdge(&nodes[parent_index], &nodes[child_index], weight); addEdge(&nodes[parent_index], &nodes[child_index], weight);

View File

@@ -53,7 +53,7 @@ void TaskInsertion::rollBack() {
addEdge(double_edge.first.first, double_edge.second.first, double_edge.first.second); addEdge(double_edge.first.first, double_edge.second.first, double_edge.first.second);
} }
if (afterNode.has_value()) { if (afterNode.has_value()) {
auto double_edge = *beforeNode; auto double_edge = *afterNode;
addEdge(double_edge.first.first, double_edge.second.first, double_edge.first.second); addEdge(double_edge.first.first, double_edge.second.first, double_edge.first.second);
} }
graph->topologicalOrder.moveBefore( taskInserted,&*oldTopologicalPosition ); graph->topologicalOrder.moveBefore( taskInserted,&*oldTopologicalPosition );

View File

@@ -1,9 +1,6 @@
#pragma once #pragma once
#include <cassert> #include <cassert>
#include <cstdint>
#include <iterator>
#include <list>
#include <optional> #include <optional>
#include <vector> #include <vector>
@@ -21,6 +18,7 @@ class TaskDCP : public onnx_mlir::LabeledListNode<TaskDCP> {
int weight; int weight;
int origWeight; int origWeight;
long long flag = 0; long long flag = 0;
int64_t syntheticId = -1;
std::optional<Edge_t> addChild(TaskDCP* child, Weight_t weight); std::optional<Edge_t> addChild(TaskDCP* child, Weight_t weight);
std::optional<Edge_t> addChild(TaskDCP& child, Weight_t weight) { return addChild(&child, weight); } std::optional<Edge_t> addChild(TaskDCP& child, Weight_t weight) { return addChild(&child, weight); }
@@ -39,12 +37,27 @@ public:
std::vector<Edge_t> childs; std::vector<Edge_t> childs;
TaskDCP() = default; TaskDCP() = default;
TaskDCP(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) TaskDCP(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute)
: onnx_mlir::LabeledListNode<TaskDCP>(), spatWeightedCompute(spatWeightedCompute), : onnx_mlir::LabeledListNode<TaskDCP>(),
spatWeightedCompute(spatWeightedCompute),
aest(0), aest(0),
alst(0), alst(0),
scheduledCPU(), scheduledCPU(),
weight(getSpatWeightCompute(spatWeightedCompute)), weight(getSpatWeightCompute(spatWeightedCompute)),
origWeight(weight), origWeight(weight),
syntheticId(-1),
parents(),
childs() {}
TaskDCP(int64_t id, int weight)
: onnx_mlir::LabeledListNode<TaskDCP>(),
spatWeightedCompute(),
aest(0),
alst(0),
scheduledCPU(),
weight(weight),
origWeight(weight),
flag(0),
syntheticId(id),
parents(), parents(),
childs() {} childs() {}
@@ -54,35 +67,35 @@ public:
void setCPU(CPU cpu) { scheduledCPU = cpu; } void setCPU(CPU cpu) { scheduledCPU = cpu; }
std::optional<CPU> getCPU() const { return scheduledCPU; } std::optional<CPU> getCPU() const { return scheduledCPU; }
void resetCPU() { scheduledCPU = std::nullopt; } void resetCPU() { scheduledCPU = std::nullopt; }
int getWeight() { int getWeight() const {
if (isScheduled()) if (isScheduled())
return weight; return weight;
else return origWeight;
return origWeight;
} }
void setWeight(int val) { weight = val; } void setWeight(int val) { weight = val; }
void resetWeight() { weight = origWeight; } void resetWeight() { weight = origWeight; }
int computeWeight(GraphDCP* graph, CPU cpu); int computeWeight(GraphDCP* graph, CPU cpu);
bool hasParents() { return parents.size() != 0; } bool hasParents() const { return parents.size() != 0; }
bool hasChilds() { return childs.size() != 0; } bool hasChilds() const { return childs.size() != 0; }
int getAEST() { return aest; } int getAEST() const { return aest; }
int getALST() { return alst; } int getALST() const { return alst; }
void setAEST(int val) { void setAEST(int val) {
assert(val >= 0); assert(val >= 0);
aest = val; aest = val;
} }
void setALST(int val) { void setALST(int val) { alst = val; }
assert(val >= 0 && val >= aest);
alst = val;
}
bool hasDescendent(TaskDCP* child); bool hasDescendent(TaskDCP* child);
int64_t Id() const { return (int64_t) spatWeightedCompute.getAsOpaquePointer(); } int64_t Id() const {
if (spatWeightedCompute)
return reinterpret_cast<int64_t>(spatWeightedCompute.getAsOpaquePointer());
return syntheticId;
}
bool isCP() const { return alst == aest; } bool isCP() const { return alst == aest; }
bool isScheduled() const { return scheduledCPU.has_value(); } bool isScheduled() const { return scheduledCPU.has_value(); }
onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute() { return spatWeightedCompute; } onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute() const { return spatWeightedCompute; }
void setFlag(long long val) { flag = val; } void setFlag(long long val) { flag = val; }
long long getFlag() const { return flag; } long long getFlag() const { return flag; }
@@ -94,7 +107,6 @@ public:
friend int getTranferCost(TaskDCP* parent, TaskDCP* child); friend int getTranferCost(TaskDCP* parent, TaskDCP* child);
}; };
struct TaskInsertion { struct TaskInsertion {
std::optional<DoubleEdge> beforeNode; std::optional<DoubleEdge> beforeNode;
std::optional<DoubleEdge> afterNode; std::optional<DoubleEdge> afterNode;

View File

@@ -0,0 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
add_custom_target(pim-unittest)
set_target_properties(pim-unittest PROPERTIES FOLDER "Tests")
add_custom_target(check-pim-unittest
COMMENT "Running the PIM unit tests"
COMMAND "${CMAKE_CTEST_COMMAND}" -L pim-unittest --output-on-failure -C $<CONFIG> --force-new-ctest-process
USES_TERMINAL
DEPENDS pim-unittest
)
set_target_properties(check-pim-unittest PROPERTIES FOLDER "Tests")
set_target_properties(check-pim-unittest PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD ON)
function(add_pim_unittest test_name)
add_onnx_mlir_executable(${test_name} NO_INSTALL ${ARGN})
add_dependencies(pim-unittest ${test_name})
get_target_property(test_suite_folder pim-unittest FOLDER)
if (test_suite_folder)
set_property(TARGET ${test_name} PROPERTY FOLDER "${test_suite_folder}")
endif ()
add_test(NAME ${test_name} COMMAND ${test_name} WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
set_tests_properties(${test_name} PROPERTIES LABELS pim-unittest)
endfunction()
add_pim_unittest(TestPIM
TestPIM.cpp
LINK_LIBS PRIVATE
OMPimCommon
SpatialOps
)

202
test/PIM/TestPIM.cpp Normal file
View File

@@ -0,0 +1,202 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "src/Accelerators/PIM/Common/LabeledList.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.hpp"
#include <cassert>
#include <cstdlib>
#include <initializer_list>
#include <iostream>
#include <vector>
using onnx_mlir::LabeledList;
using onnx_mlir::LabeledListNode;
namespace {
struct TestNode : public LabeledListNode<TestNode> {
explicit TestNode(int id)
: id(id) {}
int id;
};
void assertOrder(LabeledList<TestNode>& list, std::initializer_list<int> expectedOrder) {
auto expectedIt = expectedOrder.begin();
for (auto& node : list) {
assert(expectedIt != expectedOrder.end());
assert(node.id == *expectedIt);
++expectedIt;
}
assert(expectedIt == expectedOrder.end());
}
int testLabeledList() {
std::cout << "testLabeledList:" << std::endl;
LabeledList<TestNode> list;
TestNode n1(1);
TestNode n2(2);
TestNode n3(3);
TestNode n4(4);
TestNode n5(5);
list.pushBack(&n1);
list.pushBack(&n3);
list.insertAfter(&n1, &n2);
list.pushFront(&n4);
list.insertBefore(nullptr, &n5);
assertOrder(list, {4, 1, 2, 3, 5});
assert(LabeledList<TestNode>::next(&n4) == &n1);
assert(LabeledList<TestNode>::previous(&n1) == &n4);
assert(LabeledList<TestNode>::next(&n5) == nullptr);
assert(list.comesBefore(&n1, &n3));
assert(list.getOrderLabel(&n1) < list.getOrderLabel(&n3));
list.moveBefore(&n5, &n2);
assertOrder(list, {4, 1, 5, 2, 3});
list.moveAfter(&n4, &n3);
assertOrder(list, {1, 5, 2, 3, 4});
list.remove(&n2);
assert(!n2.isLinked());
assertOrder(list, {1, 5, 3, 4});
list.clear();
assert(list.empty());
assert(!n1.isLinked());
assert(!n3.isLinked());
assert(!n4.isLinked());
assert(!n5.isLinked());
return 0;
}
struct ExpectedScheduledTask {
size_t nodeIndex;
int aest;
int alst;
int weight;
};
void assertScheduledTasks(GraphDCP& graph, CPU cpu, std::initializer_list<ExpectedScheduledTask> expectedTasks) {
auto actualTasks = graph.getScheduledTasks(cpu);
assert(actualTasks.size() == expectedTasks.size());
auto expectedIt = expectedTasks.begin();
for (const auto& actualTask : actualTasks) {
assert(expectedIt != expectedTasks.end());
if (actualTask.nodeIndex != expectedIt->nodeIndex || actualTask.aest != expectedIt->aest
|| actualTask.alst != expectedIt->alst || actualTask.weight != expectedIt->weight) {
std::cerr << "CPU " << cpu << " actual schedule:\n";
for (const auto& task : actualTasks) {
std::cerr << " " << task.nodeIndex << ") aest: " << task.aest << " alst: " << task.alst
<< " weight: " << task.weight << '\n';
}
}
assert(actualTask.nodeIndex == expectedIt->nodeIndex);
assert(actualTask.aest == expectedIt->aest);
assert(actualTask.alst == expectedIt->alst);
assert(actualTask.weight == expectedIt->weight);
++expectedIt;
}
}
int testDCPGraphFixture() {
std::cout << "testDCPGraphFixture:" << std::endl;
const std::vector<Weight_t> nodeWeights = {
80, 40, 40, 40, 40, 40, 60, 30, 30, 30,
30, 40, 20, 20, 20, 20, 10, 10,
};
GraphDCP graph(nodeWeights, {});
graph.makeEdge(0, 1, 3);
graph.makeEdge(0, 1, 120);
graph.makeEdge(0, 2, 120);
graph.makeEdge(0, 3, 120);
graph.makeEdge(0, 4, 120);
graph.makeEdge(0, 5, 120);
graph.makeEdge(0, 6, 120);
graph.makeEdge(2, 6, 80);
graph.makeEdge(2, 7, 80);
graph.makeEdge(3, 8, 80);
graph.makeEdge(4, 9, 80);
graph.makeEdge(5, 10, 80);
graph.makeEdge(6, 7, 120);
graph.makeEdge(6, 8, 120);
graph.makeEdge(6, 9, 120);
graph.makeEdge(6, 10, 120);
graph.makeEdge(6, 11, 120);
graph.makeEdge(8, 11, 80);
graph.makeEdge(8, 12, 80);
graph.makeEdge(9, 13, 80);
graph.makeEdge(10, 14, 80);
graph.makeEdge(11, 12, 120);
graph.makeEdge(11, 13, 120);
graph.makeEdge(11, 14, 120);
graph.makeEdge(11, 15, 120);
graph.makeEdge(13, 15, 80);
graph.makeEdge(13, 16, 80);
graph.makeEdge(14, 17, 80);
graph.makeEdge(15, 16, 120);
graph.makeEdge(15, 17, 120);
graph.DCP();
for (CPU cpu = 0; cpu < graph.cpuCount(); ++cpu) {
auto scheduledTasks = graph.getScheduledTasks(cpu);
std::cerr << "CPU " << cpu << " computed schedule:\n";
for (const auto& task : scheduledTasks) {
std::cerr << " " << task.nodeIndex << ") aest: " << task.aest << " alst: " << task.alst
<< " weight: " << task.weight << '\n';
}
}
assert(graph.cpuCount() == 4);
assertScheduledTasks(graph, 3, {
{1, 200, 370, 40},
});
assertScheduledTasks(graph, 2, {
{5, 200, 260, 40},
{10, 300, 300, 30},
});
assertScheduledTasks(graph, 1, {
{4, 200, 210, 40},
{7, 300, 380, 30},
});
assertScheduledTasks(graph, 0, {
{0, 0, 0, 80},
{2, 80, 80, 40},
{6, 120, 120, 60},
{3, 180, 200, 40},
{8, 220, 240, 30},
{11, 250, 270, 40},
{12, 290, 310, 20},
{9, 320, 330, 30},
{13, 350, 360, 20},
{15, 370, 380, 20},
{16, 390, 400, 10},
{14, 410, 410, 20},
{17, 430, 430, 10},
});
return 0;
}
} // namespace
int main(int argc, char* argv[]) {
(void) argc;
(void) argv;
int failures = 0;
failures += testLabeledList();
failures += testDCPGraphFixture();
if (failures != 0) {
std::cerr << failures << " test failures\n";
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}

View File

@@ -6,6 +6,7 @@ from subprocess_utils import run_command_with_reporter
PIM_PASS_LABELS = ( PIM_PASS_LABELS = (
("ONNXToSpatialPass", "ONNX to Spatial"), ("ONNXToSpatialPass", "ONNX to Spatial"),
("MergeComputeNodesPass", "Merge Compute Nodes"),
("SpatialToPimPass", "Spatial to PIM"), ("SpatialToPimPass", "Spatial to PIM"),
("PimBufferizationPass", "Bufferize PIM"), ("PimBufferizationPass", "Bufferize PIM"),
("HostConstantFoldingPass", "Fold Host Constants"), ("HostConstantFoldingPass", "Fold Host Constants"),