Compare commits
3 Commits
933e138012
...
a903e30859
| Author | SHA1 | Date | |
|---|---|---|---|
| a903e30859 | |||
| 197c38f9ca | |||
| 831b7be4e7 |
@@ -31,9 +31,7 @@ public:
|
|||||||
bool isLinked() const { return owner_ != nullptr; }
|
bool isLinked() const { return owner_ != nullptr; }
|
||||||
Label getOrderLabel() const { return label; }
|
Label getOrderLabel() const { return label; }
|
||||||
|
|
||||||
friend bool operator<(const LabeledListNode& lft, const LabeledListNode& rgt){
|
friend bool operator<(const LabeledListNode& lft, const LabeledListNode& rgt) { return lft.label < rgt.label; }
|
||||||
return lft.label < rgt.label;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const void* owner_ = nullptr;
|
const void* owner_ = nullptr;
|
||||||
@@ -79,17 +77,17 @@ public:
|
|||||||
auto it = node->getIterator();
|
auto it = node->getIterator();
|
||||||
if (it == list->nodes_.begin())
|
if (it == list->nodes_.begin())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return *std::prev(it);
|
return &*std::prev(it);
|
||||||
}
|
}
|
||||||
|
|
||||||
static const NodeT* previous(const NodeT* node) {
|
static const NodeT* previous(const NodeT* node) {
|
||||||
if (!node || !owner(node))
|
if (!node || !owner(node))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
const auto* list = owner(node);
|
const auto* list = owner(node);
|
||||||
auto it = node->getIterator();
|
auto it = const_cast<NodeT*>(node)->getIterator();
|
||||||
if (it == list->nodes_.begin())
|
if (it == list->nodes_.begin())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return *std::prev(it);
|
return &*std::prev(it);
|
||||||
}
|
}
|
||||||
|
|
||||||
static NodeT* next(NodeT* node) {
|
static NodeT* next(NodeT* node) {
|
||||||
@@ -99,29 +97,29 @@ public:
|
|||||||
auto it = std::next(node->getIterator());
|
auto it = std::next(node->getIterator());
|
||||||
if (it == list->nodes_.end())
|
if (it == list->nodes_.end())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return *it;
|
return &*it;
|
||||||
}
|
}
|
||||||
|
|
||||||
static const NodeT* next(const NodeT* node) {
|
static const NodeT* next(const NodeT* node) {
|
||||||
if (!node || !owner(node))
|
if (!node || !owner(node))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
const auto* list = owner(node);
|
const auto* list = owner(node);
|
||||||
auto it = std::next(node->getIterator());
|
auto it = std::next(const_cast<NodeT*>(node)->getIterator());
|
||||||
if (it == list->nodes_.end())
|
if (it == list->nodes_.end())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return *it;
|
return &*it;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool contains(const NodeT* node) const { return node && node->owner_ == this; }
|
bool contains(const NodeT* node) const { return node && node->owner_ == this; }
|
||||||
|
|
||||||
Label getOrderLabel(const NodeT* node) const {
|
Label getOrderLabel(const NodeT* node) const {
|
||||||
assert(contains(node) && "node must belong to this list");
|
assert(contains(node) && "node must belong to this list");
|
||||||
return node->label_;
|
return node->label;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool comesBefore(const NodeT* lhs, const NodeT* rhs) const {
|
bool comesBefore(const NodeT* lhs, const NodeT* rhs) const {
|
||||||
assert(contains(lhs) && contains(rhs) && "nodes must belong to this list");
|
assert(contains(lhs) && contains(rhs) && "nodes must belong to this list");
|
||||||
return lhs->label_ < rhs->label_;
|
return lhs->label < rhs->label;
|
||||||
}
|
}
|
||||||
|
|
||||||
void pushFront(NodeT* node) { insertBefore(front(), node); }
|
void pushFront(NodeT* node) { insertBefore(front(), node); }
|
||||||
@@ -152,7 +150,7 @@ public:
|
|||||||
assert(contains(node) && "node must belong to this list");
|
assert(contains(node) && "node must belong to this list");
|
||||||
nodes_.remove(*node);
|
nodes_.remove(*node);
|
||||||
node->owner_ = nullptr;
|
node->owner_ = nullptr;
|
||||||
node->label_ = 0;
|
node->label = 0;
|
||||||
--size_;
|
--size_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,15 +188,14 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
Iterator begin() { return nodes_.begin(); }
|
Iterator begin() { return nodes_.begin(); }
|
||||||
|
|
||||||
Iterator end() { return nodes_.end(); }
|
Iterator end() { return nodes_.end(); }
|
||||||
|
|
||||||
RIterator rbegin() { return nodes_.rbegin(); }
|
RIterator rbegin() { return nodes_.rbegin(); }
|
||||||
RIterator rend() { return nodes_.rend(); }
|
RIterator rend() { return nodes_.rend(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static const LabeledList* owner(const NodeT* node) { return node->owner_; }
|
static const LabeledList* owner(const NodeT* node) { return static_cast<const LabeledList*>(node->owner_); }
|
||||||
static LabeledList* owner(NodeT* node) { return node->owner_; }
|
static LabeledList* owner(NodeT* node) { return static_cast<LabeledList*>(const_cast<void*>(node->owner_)); }
|
||||||
|
|
||||||
static Label lowerLabel(const NodeT* node) { return node ? node->label : kLowerSentinel; }
|
static Label lowerLabel(const NodeT* node) { return node ? node->label : kLowerSentinel; }
|
||||||
static Label upperLabel(const NodeT* node) { return node ? node->label : kUpperSentinel; }
|
static Label upperLabel(const NodeT* node) { return node ? node->label : kUpperSentinel; }
|
||||||
|
|||||||
@@ -28,12 +28,12 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
|
|
||||||
if (pimEmissionTarget >= EmitSpatial) {
|
if (pimEmissionTarget >= EmitSpatial) {
|
||||||
pm.addPass(createONNXToSpatialPass());
|
pm.addPass(createONNXToSpatialPass());
|
||||||
|
pm.addPass(createMergeComputeNodesPass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
pm.addPass(createMessagePass("Onnx lowered to Spatial"));
|
pm.addPass(createMessagePass("Onnx lowered to Spatial"));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (pimEmissionTarget >= EmitPim) {
|
if (pimEmissionTarget >= EmitPim) {
|
||||||
pm.addPass(createMergeComputeNodesPass());
|
|
||||||
pm.addPass(createSpatialToPimPass());
|
pm.addPass(createSpatialToPimPass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
||||||
|
|||||||
@@ -47,6 +47,11 @@ int getTranferCost(TaskDCP* parent, TaskDCP* child) {
|
|||||||
return child_position->second;
|
return child_position->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t GraphDCP::getNodeIndex(const TaskDCP* task) const {
|
||||||
|
assert(task >= nodes.data() && task < nodes.data() + nodes.size() && "task must belong to graph");
|
||||||
|
return static_cast<size_t>(task - nodes.data());
|
||||||
|
}
|
||||||
|
|
||||||
TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position) {
|
TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position) {
|
||||||
TaskInsertion ret;
|
TaskInsertion ret;
|
||||||
task->setCPU(cpu);
|
task->setCPU(cpu);
|
||||||
@@ -245,12 +250,12 @@ void GraphDCP::topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint) {
|
|||||||
auto moveChildAfterMe = [this](TaskDCP* origTask) -> void {
|
auto moveChildAfterMe = [this](TaskDCP* origTask) -> void {
|
||||||
auto cmp = [](Edge_t lft, Edge_t rgt) { return *rgt.first < *lft.first; };
|
auto cmp = [](Edge_t lft, Edge_t rgt) { return *rgt.first < *lft.first; };
|
||||||
TaskDCP* insertionPoint = origTask;
|
TaskDCP* insertionPoint = origTask;
|
||||||
std::vector<Edge_t>& childEdges = origTask->childs;
|
|
||||||
std::vector<TaskDCP*> worklist;
|
std::vector<TaskDCP*> worklist;
|
||||||
worklist.push_back(origTask);
|
worklist.push_back(origTask);
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
while (i < worklist.size()) {
|
while (i < worklist.size()) {
|
||||||
auto task = worklist[i];
|
auto task = worklist[i];
|
||||||
|
std::vector<Edge_t>& childEdges = task->childs;
|
||||||
// build min heap Complexity 3N
|
// build min heap Complexity 3N
|
||||||
std::make_heap(childEdges.begin(), childEdges.end(), cmp);
|
std::make_heap(childEdges.begin(), childEdges.end(), cmp);
|
||||||
auto lastPoppedIter = childEdges.end();
|
auto lastPoppedIter = childEdges.end();
|
||||||
@@ -279,12 +284,12 @@ void GraphDCP::topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if (!(*task < *pivotPoint)) {
|
if (!(*task < *pivotPoint))
|
||||||
return;
|
return;
|
||||||
topologicalOrder.moveAfter(task, pivotPoint);
|
|
||||||
if (task->hasChilds())
|
topologicalOrder.moveAfter(task, pivotPoint);
|
||||||
moveChildAfterMe(task);
|
if (task->hasChilds())
|
||||||
}
|
moveChildAfterMe(task);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphDCP::topologicalMoveBefore(TaskDCP* task, TaskDCP* pivotPoint) {
|
void GraphDCP::topologicalMoveBefore(TaskDCP* task, TaskDCP* pivotPoint) {
|
||||||
@@ -292,12 +297,12 @@ void GraphDCP::topologicalMoveBefore(TaskDCP* task, TaskDCP* pivotPoint) {
|
|||||||
auto moveParentBeforeMe = [this](TaskDCP* origTask) -> void {
|
auto moveParentBeforeMe = [this](TaskDCP* origTask) -> void {
|
||||||
auto cmp = [](Edge_t lft, Edge_t rgt) { return *lft.first < *rgt.first; };
|
auto cmp = [](Edge_t lft, Edge_t rgt) { return *lft.first < *rgt.first; };
|
||||||
TaskDCP* insertionPoint = origTask;
|
TaskDCP* insertionPoint = origTask;
|
||||||
std::vector<Edge_t>& parentEdges = origTask->parents;
|
|
||||||
std::vector<TaskDCP*> worklist;
|
std::vector<TaskDCP*> worklist;
|
||||||
worklist.push_back(origTask);
|
worklist.push_back(origTask);
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
while (i < worklist.size()) {
|
while (i < worklist.size()) {
|
||||||
auto task = worklist[i];
|
auto task = worklist[i];
|
||||||
|
std::vector<Edge_t>& parentEdges = task->parents;
|
||||||
// build max heap Complexity 3N
|
// build max heap Complexity 3N
|
||||||
std::make_heap(parentEdges.begin(), parentEdges.end(), cmp);
|
std::make_heap(parentEdges.begin(), parentEdges.end(), cmp);
|
||||||
auto lastPoppedIter = parentEdges.end();
|
auto lastPoppedIter = parentEdges.end();
|
||||||
@@ -326,14 +331,13 @@ void GraphDCP::topologicalMoveBefore(TaskDCP* task, TaskDCP* pivotPoint) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if (!(*task < *pivotPoint)) {
|
if (!(*pivotPoint < *task))
|
||||||
return;
|
return;
|
||||||
topologicalOrder.moveBefore(task, pivotPoint);
|
|
||||||
if (task->hasParents())
|
|
||||||
moveParentBeforeMe(task);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
topologicalOrder.moveBefore(task, pivotPoint);
|
||||||
|
if (task->hasParents())
|
||||||
|
moveParentBeforeMe(task);
|
||||||
|
}
|
||||||
|
|
||||||
GraphDCP::FindSlot GraphDCP::findSlot(TaskDCP* candidate, CPU cpu, bool push) {
|
GraphDCP::FindSlot GraphDCP::findSlot(TaskDCP* candidate, CPU cpu, bool push) {
|
||||||
int aest_on_cpu = computeAEST(candidate, cpu);
|
int aest_on_cpu = computeAEST(candidate, cpu);
|
||||||
@@ -407,7 +411,6 @@ GraphDCP::FindSlot GraphDCP::findSlot(TaskDCP* candidate, CPU cpu, bool push) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
|
void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
|
||||||
|
|
||||||
std::vector<CPU> processors;
|
std::vector<CPU> processors;
|
||||||
processors.reserve(lastCPU());
|
processors.reserve(lastCPU());
|
||||||
for (CPU c = push ? lastCPU() : lastCPU(); c >= 0; c--)
|
for (CPU c = push ? lastCPU() : lastCPU(); c >= 0; c--)
|
||||||
@@ -571,3 +574,15 @@ DCPAnalysisResult GraphDCP::getResult() {
|
|||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<GraphDCP::ScheduledTaskInfo> GraphDCP::getScheduledTasks(CPU cpu) const {
|
||||||
|
std::vector<ScheduledTaskInfo> scheduledTasks;
|
||||||
|
auto cpuIt = mapCPUTasks.find(cpu);
|
||||||
|
if (cpuIt == mapCPUTasks.end())
|
||||||
|
return scheduledTasks;
|
||||||
|
|
||||||
|
scheduledTasks.reserve(cpuIt->second.size());
|
||||||
|
for (auto* task : cpuIt->second)
|
||||||
|
scheduledTasks.push_back({getNodeIndex(task), task->getAEST(), task->getALST(), task->getWeight()});
|
||||||
|
return scheduledTasks;
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,7 +17,15 @@ void removeEdge(TaskDCP* parent, TaskDCP* child);
|
|||||||
int getTranferCost(TaskDCP* parent, TaskDCP* child);
|
int getTranferCost(TaskDCP* parent, TaskDCP* child);
|
||||||
|
|
||||||
class GraphDCP {
|
class GraphDCP {
|
||||||
|
public:
|
||||||
|
struct ScheduledTaskInfo {
|
||||||
|
size_t nodeIndex;
|
||||||
|
int aest;
|
||||||
|
int alst;
|
||||||
|
int weight;
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
struct FindSlot {
|
struct FindSlot {
|
||||||
int aest;
|
int aest;
|
||||||
int index;
|
int index;
|
||||||
@@ -35,21 +43,22 @@ class GraphDCP {
|
|||||||
|
|
||||||
std::vector<TaskDCP*> getRoots();
|
std::vector<TaskDCP*> getRoots();
|
||||||
|
|
||||||
long long getUniqueFlag() {return flag++;};
|
long long getUniqueFlag() { return flag++; }
|
||||||
|
|
||||||
void initAEST();
|
void initAEST();
|
||||||
int initDCPL();
|
int initDCPL();
|
||||||
void initALST();
|
void initALST();
|
||||||
|
|
||||||
int computeAEST(TaskDCP* task, CPU cpu);
|
int computeAEST(TaskDCP* task, CPU cpu);
|
||||||
int computeDCPL(TaskDCP* task, CPU cpu);
|
int computeDCPL(TaskDCP* task, CPU cpu);
|
||||||
int getDCPL() {return DCPL;};
|
int getDCPL() { return DCPL; }
|
||||||
|
|
||||||
void initTopological();
|
void initTopological();
|
||||||
void topologicalMoveAfter(TaskDCP* task, TaskDCP * pivotPoint);
|
void topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint);
|
||||||
void topologicalMoveBefore(TaskDCP* task, TaskDCP * pivotPoint);
|
void topologicalMoveBefore(TaskDCP* task, TaskDCP* pivotPoint);
|
||||||
|
|
||||||
llvm::DenseMap<TaskDCP*, int> computeALST(TaskDCP* task, CPU cpu);
|
llvm::DenseMap<TaskDCP*, int> computeALST(TaskDCP* task, CPU cpu);
|
||||||
|
size_t getNodeIndex(const TaskDCP* task) const;
|
||||||
|
|
||||||
TaskDCP* findCandidate(std::vector<TaskDCP*> nodes);
|
TaskDCP* findCandidate(std::vector<TaskDCP*> nodes);
|
||||||
void selectProcessor(TaskDCP* candidate, bool push);
|
void selectProcessor(TaskDCP* candidate, bool push);
|
||||||
@@ -60,21 +69,29 @@ class GraphDCP {
|
|||||||
|
|
||||||
friend TaskInsertion;
|
friend TaskInsertion;
|
||||||
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void DCP();
|
void DCP();
|
||||||
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatWeightedCompute> spatWeightedComputes,
|
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatWeightedCompute> spatWeightedComputes,
|
||||||
llvm::ArrayRef<EdgesIndex> edges)
|
llvm::ArrayRef<EdgesIndex> edges)
|
||||||
: nodes(), mapCPUTasks() {
|
: nodes(), mapCPUTasks() {
|
||||||
for (auto spatWeightedCompute : spatWeightedComputes){
|
for (auto spatWeightedCompute : spatWeightedComputes)
|
||||||
nodes.emplace_back(spatWeightedCompute);
|
nodes.emplace_back(spatWeightedCompute);
|
||||||
}
|
for (auto [start, end, weight] : edges)
|
||||||
|
makeEdge(start, end, weight);
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDCP(llvm::ArrayRef<Weight_t> nodeWeights, llvm::ArrayRef<EdgesIndex> edges)
|
||||||
|
: nodes(), mapCPUTasks() {
|
||||||
|
nodes.reserve(nodeWeights.size());
|
||||||
|
for (auto [index, weight] : llvm::enumerate(nodeWeights))
|
||||||
|
nodes.emplace_back(index, weight);
|
||||||
for (auto [start, end, weight] : edges)
|
for (auto [start, end, weight] : edges)
|
||||||
makeEdge(start, end, weight);
|
makeEdge(start, end, weight);
|
||||||
}
|
}
|
||||||
|
|
||||||
DCPAnalysisResult getResult();
|
DCPAnalysisResult getResult();
|
||||||
|
std::vector<ScheduledTaskInfo> getScheduledTasks(CPU cpu) const;
|
||||||
|
CPU cpuCount() const { return last_cpu; }
|
||||||
|
|
||||||
void makeEdge(size_t parent_index, size_t child_index, Weight_t weight) {
|
void makeEdge(size_t parent_index, size_t child_index, Weight_t weight) {
|
||||||
addEdge(&nodes[parent_index], &nodes[child_index], weight);
|
addEdge(&nodes[parent_index], &nodes[child_index], weight);
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ void TaskInsertion::rollBack() {
|
|||||||
addEdge(double_edge.first.first, double_edge.second.first, double_edge.first.second);
|
addEdge(double_edge.first.first, double_edge.second.first, double_edge.first.second);
|
||||||
}
|
}
|
||||||
if (afterNode.has_value()) {
|
if (afterNode.has_value()) {
|
||||||
auto double_edge = *beforeNode;
|
auto double_edge = *afterNode;
|
||||||
addEdge(double_edge.first.first, double_edge.second.first, double_edge.first.second);
|
addEdge(double_edge.first.first, double_edge.second.first, double_edge.first.second);
|
||||||
}
|
}
|
||||||
graph->topologicalOrder.moveBefore( taskInserted,&*oldTopologicalPosition );
|
graph->topologicalOrder.moveBefore( taskInserted,&*oldTopologicalPosition );
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstdint>
|
|
||||||
#include <iterator>
|
|
||||||
#include <list>
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@@ -21,6 +18,7 @@ class TaskDCP : public onnx_mlir::LabeledListNode<TaskDCP> {
|
|||||||
int weight;
|
int weight;
|
||||||
int origWeight;
|
int origWeight;
|
||||||
long long flag = 0;
|
long long flag = 0;
|
||||||
|
int64_t syntheticId = -1;
|
||||||
|
|
||||||
std::optional<Edge_t> addChild(TaskDCP* child, Weight_t weight);
|
std::optional<Edge_t> addChild(TaskDCP* child, Weight_t weight);
|
||||||
std::optional<Edge_t> addChild(TaskDCP& child, Weight_t weight) { return addChild(&child, weight); }
|
std::optional<Edge_t> addChild(TaskDCP& child, Weight_t weight) { return addChild(&child, weight); }
|
||||||
@@ -39,12 +37,27 @@ public:
|
|||||||
std::vector<Edge_t> childs;
|
std::vector<Edge_t> childs;
|
||||||
TaskDCP() = default;
|
TaskDCP() = default;
|
||||||
TaskDCP(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute)
|
TaskDCP(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute)
|
||||||
: onnx_mlir::LabeledListNode<TaskDCP>(), spatWeightedCompute(spatWeightedCompute),
|
: onnx_mlir::LabeledListNode<TaskDCP>(),
|
||||||
|
spatWeightedCompute(spatWeightedCompute),
|
||||||
aest(0),
|
aest(0),
|
||||||
alst(0),
|
alst(0),
|
||||||
scheduledCPU(),
|
scheduledCPU(),
|
||||||
weight(getSpatWeightCompute(spatWeightedCompute)),
|
weight(getSpatWeightCompute(spatWeightedCompute)),
|
||||||
origWeight(weight),
|
origWeight(weight),
|
||||||
|
syntheticId(-1),
|
||||||
|
parents(),
|
||||||
|
childs() {}
|
||||||
|
|
||||||
|
TaskDCP(int64_t id, int weight)
|
||||||
|
: onnx_mlir::LabeledListNode<TaskDCP>(),
|
||||||
|
spatWeightedCompute(),
|
||||||
|
aest(0),
|
||||||
|
alst(0),
|
||||||
|
scheduledCPU(),
|
||||||
|
weight(weight),
|
||||||
|
origWeight(weight),
|
||||||
|
flag(0),
|
||||||
|
syntheticId(id),
|
||||||
parents(),
|
parents(),
|
||||||
childs() {}
|
childs() {}
|
||||||
|
|
||||||
@@ -54,35 +67,35 @@ public:
|
|||||||
void setCPU(CPU cpu) { scheduledCPU = cpu; }
|
void setCPU(CPU cpu) { scheduledCPU = cpu; }
|
||||||
std::optional<CPU> getCPU() const { return scheduledCPU; }
|
std::optional<CPU> getCPU() const { return scheduledCPU; }
|
||||||
void resetCPU() { scheduledCPU = std::nullopt; }
|
void resetCPU() { scheduledCPU = std::nullopt; }
|
||||||
int getWeight() {
|
int getWeight() const {
|
||||||
if (isScheduled())
|
if (isScheduled())
|
||||||
return weight;
|
return weight;
|
||||||
else
|
return origWeight;
|
||||||
return origWeight;
|
|
||||||
}
|
}
|
||||||
void setWeight(int val) { weight = val; }
|
void setWeight(int val) { weight = val; }
|
||||||
void resetWeight() { weight = origWeight; }
|
void resetWeight() { weight = origWeight; }
|
||||||
int computeWeight(GraphDCP* graph, CPU cpu);
|
int computeWeight(GraphDCP* graph, CPU cpu);
|
||||||
|
|
||||||
bool hasParents() { return parents.size() != 0; }
|
bool hasParents() const { return parents.size() != 0; }
|
||||||
bool hasChilds() { return childs.size() != 0; }
|
bool hasChilds() const { return childs.size() != 0; }
|
||||||
|
|
||||||
int getAEST() { return aest; }
|
int getAEST() const { return aest; }
|
||||||
int getALST() { return alst; }
|
int getALST() const { return alst; }
|
||||||
void setAEST(int val) {
|
void setAEST(int val) {
|
||||||
assert(val >= 0);
|
assert(val >= 0);
|
||||||
aest = val;
|
aest = val;
|
||||||
}
|
}
|
||||||
void setALST(int val) {
|
void setALST(int val) { alst = val; }
|
||||||
assert(val >= 0 && val >= aest);
|
|
||||||
alst = val;
|
|
||||||
}
|
|
||||||
bool hasDescendent(TaskDCP* child);
|
bool hasDescendent(TaskDCP* child);
|
||||||
int64_t Id() const { return (int64_t) spatWeightedCompute.getAsOpaquePointer(); }
|
int64_t Id() const {
|
||||||
|
if (spatWeightedCompute)
|
||||||
|
return reinterpret_cast<int64_t>(spatWeightedCompute.getAsOpaquePointer());
|
||||||
|
return syntheticId;
|
||||||
|
}
|
||||||
|
|
||||||
bool isCP() const { return alst == aest; }
|
bool isCP() const { return alst == aest; }
|
||||||
bool isScheduled() const { return scheduledCPU.has_value(); }
|
bool isScheduled() const { return scheduledCPU.has_value(); }
|
||||||
onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute() { return spatWeightedCompute; }
|
onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute() const { return spatWeightedCompute; }
|
||||||
|
|
||||||
void setFlag(long long val) { flag = val; }
|
void setFlag(long long val) { flag = val; }
|
||||||
long long getFlag() const { return flag; }
|
long long getFlag() const { return flag; }
|
||||||
@@ -94,7 +107,6 @@ public:
|
|||||||
friend int getTranferCost(TaskDCP* parent, TaskDCP* child);
|
friend int getTranferCost(TaskDCP* parent, TaskDCP* child);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
struct TaskInsertion {
|
struct TaskInsertion {
|
||||||
std::optional<DoubleEdge> beforeNode;
|
std::optional<DoubleEdge> beforeNode;
|
||||||
std::optional<DoubleEdge> afterNode;
|
std::optional<DoubleEdge> afterNode;
|
||||||
@@ -103,5 +115,5 @@ struct TaskInsertion {
|
|||||||
TaskDCP* taskInserted;
|
TaskDCP* taskInserted;
|
||||||
GraphDCP* graph;
|
GraphDCP* graph;
|
||||||
|
|
||||||
void rollBack();
|
void rollBack();
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -0,0 +1,34 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
add_custom_target(pim-unittest)
|
||||||
|
set_target_properties(pim-unittest PROPERTIES FOLDER "Tests")
|
||||||
|
|
||||||
|
add_custom_target(check-pim-unittest
|
||||||
|
COMMENT "Running the PIM unit tests"
|
||||||
|
COMMAND "${CMAKE_CTEST_COMMAND}" -L pim-unittest --output-on-failure -C $<CONFIG> --force-new-ctest-process
|
||||||
|
USES_TERMINAL
|
||||||
|
DEPENDS pim-unittest
|
||||||
|
)
|
||||||
|
set_target_properties(check-pim-unittest PROPERTIES FOLDER "Tests")
|
||||||
|
set_target_properties(check-pim-unittest PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD ON)
|
||||||
|
|
||||||
|
function(add_pim_unittest test_name)
|
||||||
|
add_onnx_mlir_executable(${test_name} NO_INSTALL ${ARGN})
|
||||||
|
|
||||||
|
add_dependencies(pim-unittest ${test_name})
|
||||||
|
get_target_property(test_suite_folder pim-unittest FOLDER)
|
||||||
|
if (test_suite_folder)
|
||||||
|
set_property(TARGET ${test_name} PROPERTY FOLDER "${test_suite_folder}")
|
||||||
|
endif ()
|
||||||
|
|
||||||
|
add_test(NAME ${test_name} COMMAND ${test_name} WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
|
||||||
|
set_tests_properties(${test_name} PROPERTIES LABELS pim-unittest)
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
add_pim_unittest(TestPIM
|
||||||
|
TestPIM.cpp
|
||||||
|
|
||||||
|
LINK_LIBS PRIVATE
|
||||||
|
OMPimCommon
|
||||||
|
SpatialOps
|
||||||
|
)
|
||||||
|
|||||||
@@ -0,0 +1,202 @@
|
|||||||
|
/*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/LabeledList.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.hpp"
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <initializer_list>
|
||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
using onnx_mlir::LabeledList;
|
||||||
|
using onnx_mlir::LabeledListNode;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct TestNode : public LabeledListNode<TestNode> {
|
||||||
|
explicit TestNode(int id)
|
||||||
|
: id(id) {}
|
||||||
|
|
||||||
|
int id;
|
||||||
|
};
|
||||||
|
|
||||||
|
void assertOrder(LabeledList<TestNode>& list, std::initializer_list<int> expectedOrder) {
|
||||||
|
auto expectedIt = expectedOrder.begin();
|
||||||
|
for (auto& node : list) {
|
||||||
|
assert(expectedIt != expectedOrder.end());
|
||||||
|
assert(node.id == *expectedIt);
|
||||||
|
++expectedIt;
|
||||||
|
}
|
||||||
|
assert(expectedIt == expectedOrder.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
int testLabeledList() {
|
||||||
|
std::cout << "testLabeledList:" << std::endl;
|
||||||
|
|
||||||
|
LabeledList<TestNode> list;
|
||||||
|
TestNode n1(1);
|
||||||
|
TestNode n2(2);
|
||||||
|
TestNode n3(3);
|
||||||
|
TestNode n4(4);
|
||||||
|
TestNode n5(5);
|
||||||
|
|
||||||
|
list.pushBack(&n1);
|
||||||
|
list.pushBack(&n3);
|
||||||
|
list.insertAfter(&n1, &n2);
|
||||||
|
list.pushFront(&n4);
|
||||||
|
list.insertBefore(nullptr, &n5);
|
||||||
|
|
||||||
|
assertOrder(list, {4, 1, 2, 3, 5});
|
||||||
|
assert(LabeledList<TestNode>::next(&n4) == &n1);
|
||||||
|
assert(LabeledList<TestNode>::previous(&n1) == &n4);
|
||||||
|
assert(LabeledList<TestNode>::next(&n5) == nullptr);
|
||||||
|
assert(list.comesBefore(&n1, &n3));
|
||||||
|
assert(list.getOrderLabel(&n1) < list.getOrderLabel(&n3));
|
||||||
|
|
||||||
|
list.moveBefore(&n5, &n2);
|
||||||
|
assertOrder(list, {4, 1, 5, 2, 3});
|
||||||
|
|
||||||
|
list.moveAfter(&n4, &n3);
|
||||||
|
assertOrder(list, {1, 5, 2, 3, 4});
|
||||||
|
|
||||||
|
list.remove(&n2);
|
||||||
|
assert(!n2.isLinked());
|
||||||
|
assertOrder(list, {1, 5, 3, 4});
|
||||||
|
|
||||||
|
list.clear();
|
||||||
|
assert(list.empty());
|
||||||
|
assert(!n1.isLinked());
|
||||||
|
assert(!n3.isLinked());
|
||||||
|
assert(!n4.isLinked());
|
||||||
|
assert(!n5.isLinked());
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ExpectedScheduledTask {
|
||||||
|
size_t nodeIndex;
|
||||||
|
int aest;
|
||||||
|
int alst;
|
||||||
|
int weight;
|
||||||
|
};
|
||||||
|
|
||||||
|
void assertScheduledTasks(GraphDCP& graph, CPU cpu, std::initializer_list<ExpectedScheduledTask> expectedTasks) {
|
||||||
|
auto actualTasks = graph.getScheduledTasks(cpu);
|
||||||
|
assert(actualTasks.size() == expectedTasks.size());
|
||||||
|
|
||||||
|
auto expectedIt = expectedTasks.begin();
|
||||||
|
for (const auto& actualTask : actualTasks) {
|
||||||
|
assert(expectedIt != expectedTasks.end());
|
||||||
|
if (actualTask.nodeIndex != expectedIt->nodeIndex || actualTask.aest != expectedIt->aest
|
||||||
|
|| actualTask.alst != expectedIt->alst || actualTask.weight != expectedIt->weight) {
|
||||||
|
std::cerr << "CPU " << cpu << " actual schedule:\n";
|
||||||
|
for (const auto& task : actualTasks) {
|
||||||
|
std::cerr << " " << task.nodeIndex << ") aest: " << task.aest << " alst: " << task.alst
|
||||||
|
<< " weight: " << task.weight << '\n';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert(actualTask.nodeIndex == expectedIt->nodeIndex);
|
||||||
|
assert(actualTask.aest == expectedIt->aest);
|
||||||
|
assert(actualTask.alst == expectedIt->alst);
|
||||||
|
assert(actualTask.weight == expectedIt->weight);
|
||||||
|
++expectedIt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int testDCPGraphFixture() {
|
||||||
|
std::cout << "testDCPGraphFixture:" << std::endl;
|
||||||
|
|
||||||
|
const std::vector<Weight_t> nodeWeights = {
|
||||||
|
80, 40, 40, 40, 40, 40, 60, 30, 30, 30,
|
||||||
|
30, 40, 20, 20, 20, 20, 10, 10,
|
||||||
|
};
|
||||||
|
|
||||||
|
GraphDCP graph(nodeWeights, {});
|
||||||
|
graph.makeEdge(0, 1, 3);
|
||||||
|
graph.makeEdge(0, 1, 120);
|
||||||
|
graph.makeEdge(0, 2, 120);
|
||||||
|
graph.makeEdge(0, 3, 120);
|
||||||
|
graph.makeEdge(0, 4, 120);
|
||||||
|
graph.makeEdge(0, 5, 120);
|
||||||
|
graph.makeEdge(0, 6, 120);
|
||||||
|
graph.makeEdge(2, 6, 80);
|
||||||
|
graph.makeEdge(2, 7, 80);
|
||||||
|
graph.makeEdge(3, 8, 80);
|
||||||
|
graph.makeEdge(4, 9, 80);
|
||||||
|
graph.makeEdge(5, 10, 80);
|
||||||
|
graph.makeEdge(6, 7, 120);
|
||||||
|
graph.makeEdge(6, 8, 120);
|
||||||
|
graph.makeEdge(6, 9, 120);
|
||||||
|
graph.makeEdge(6, 10, 120);
|
||||||
|
graph.makeEdge(6, 11, 120);
|
||||||
|
graph.makeEdge(8, 11, 80);
|
||||||
|
graph.makeEdge(8, 12, 80);
|
||||||
|
graph.makeEdge(9, 13, 80);
|
||||||
|
graph.makeEdge(10, 14, 80);
|
||||||
|
graph.makeEdge(11, 12, 120);
|
||||||
|
graph.makeEdge(11, 13, 120);
|
||||||
|
graph.makeEdge(11, 14, 120);
|
||||||
|
graph.makeEdge(11, 15, 120);
|
||||||
|
graph.makeEdge(13, 15, 80);
|
||||||
|
graph.makeEdge(13, 16, 80);
|
||||||
|
graph.makeEdge(14, 17, 80);
|
||||||
|
graph.makeEdge(15, 16, 120);
|
||||||
|
graph.makeEdge(15, 17, 120);
|
||||||
|
|
||||||
|
graph.DCP();
|
||||||
|
for (CPU cpu = 0; cpu < graph.cpuCount(); ++cpu) {
|
||||||
|
auto scheduledTasks = graph.getScheduledTasks(cpu);
|
||||||
|
std::cerr << "CPU " << cpu << " computed schedule:\n";
|
||||||
|
for (const auto& task : scheduledTasks) {
|
||||||
|
std::cerr << " " << task.nodeIndex << ") aest: " << task.aest << " alst: " << task.alst
|
||||||
|
<< " weight: " << task.weight << '\n';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert(graph.cpuCount() == 4);
|
||||||
|
assertScheduledTasks(graph, 3, {
|
||||||
|
{1, 200, 370, 40},
|
||||||
|
});
|
||||||
|
assertScheduledTasks(graph, 2, {
|
||||||
|
{5, 200, 260, 40},
|
||||||
|
{10, 300, 300, 30},
|
||||||
|
});
|
||||||
|
assertScheduledTasks(graph, 1, {
|
||||||
|
{4, 200, 210, 40},
|
||||||
|
{7, 300, 380, 30},
|
||||||
|
});
|
||||||
|
assertScheduledTasks(graph, 0, {
|
||||||
|
{0, 0, 0, 80},
|
||||||
|
{2, 80, 80, 40},
|
||||||
|
{6, 120, 120, 60},
|
||||||
|
{3, 180, 200, 40},
|
||||||
|
{8, 220, 240, 30},
|
||||||
|
{11, 250, 270, 40},
|
||||||
|
{12, 290, 310, 20},
|
||||||
|
{9, 320, 330, 30},
|
||||||
|
{13, 350, 360, 20},
|
||||||
|
{15, 370, 380, 20},
|
||||||
|
{16, 390, 400, 10},
|
||||||
|
{14, 410, 410, 20},
|
||||||
|
{17, 430, 430, 10},
|
||||||
|
});
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
(void) argc;
|
||||||
|
(void) argv;
|
||||||
|
|
||||||
|
int failures = 0;
|
||||||
|
failures += testLabeledList();
|
||||||
|
failures += testDCPGraphFixture();
|
||||||
|
if (failures != 0) {
|
||||||
|
std::cerr << failures << " test failures\n";
|
||||||
|
return EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
return EXIT_SUCCESS;
|
||||||
|
}
|
||||||
+37
-1
@@ -1,8 +1,41 @@
|
|||||||
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
from subprocess_utils import run_command_with_reporter
|
from subprocess_utils import run_command_with_reporter
|
||||||
|
|
||||||
|
PIM_PASS_LABELS = (
|
||||||
|
("ONNXToSpatialPass", "ONNX to Spatial"),
|
||||||
|
("MergeComputeNodesPass", "Merge Compute Nodes"),
|
||||||
|
("SpatialToPimPass", "Spatial to PIM"),
|
||||||
|
("PimBufferizationPass", "Bufferize PIM"),
|
||||||
|
("HostConstantFoldingPass", "Fold Host Constants"),
|
||||||
|
("MaterializeHostConstantsPass", "Materialize Host Constants"),
|
||||||
|
("VerificationPass", "Verify PIM"),
|
||||||
|
("EmitPimJsonPass", "Emit PIM JSON"),
|
||||||
|
)
|
||||||
|
PIM_PASS_LABEL_BY_SUFFIX = dict(PIM_PASS_LABELS)
|
||||||
|
TIMING_LINE_RE = re.compile(r"^\s*([0-9]+\.[0-9]+)\s+\(\s*[0-9.]+%\)\s+(.+?)\s*$")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_pim_pass_timings(output_text):
|
||||||
|
pass_timings = {}
|
||||||
|
for line in output_text.splitlines():
|
||||||
|
match = TIMING_LINE_RE.match(line)
|
||||||
|
if not match:
|
||||||
|
continue
|
||||||
|
|
||||||
|
duration = float(match.group(1))
|
||||||
|
pass_name = match.group(2)
|
||||||
|
for suffix, label in PIM_PASS_LABEL_BY_SUFFIX.items():
|
||||||
|
if pass_name.endswith(suffix):
|
||||||
|
pass_timings[label] = pass_timings.get(label, 0.0) + duration
|
||||||
|
break
|
||||||
|
|
||||||
|
if not pass_timings:
|
||||||
|
raise RuntimeError("Raptor timing report did not contain any PIM pass timings.")
|
||||||
|
return pass_timings
|
||||||
|
|
||||||
|
|
||||||
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
||||||
crossbar_size, crossbar_count, cwd=None, reporter=None):
|
crossbar_size, crossbar_count, cwd=None, reporter=None):
|
||||||
@@ -16,16 +49,19 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path,
|
|||||||
# "--use-experimental-conv-impl=true",
|
# "--use-experimental-conv-impl=true",
|
||||||
f"--crossbar-size={crossbar_size}",
|
f"--crossbar-size={crossbar_size}",
|
||||||
f"--crossbar-count={crossbar_count}",
|
f"--crossbar-count={crossbar_count}",
|
||||||
|
"--enable-timing",
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
run_command_with_reporter(
|
output_text = run_command_with_reporter(
|
||||||
[str(raptor_onnx_path)] + [str(arg) for arg in args],
|
[str(raptor_onnx_path)] + [str(arg) for arg in args],
|
||||||
cwd=cwd,
|
cwd=cwd,
|
||||||
reporter=reporter,
|
reporter=reporter,
|
||||||
|
capture_output=True,
|
||||||
)
|
)
|
||||||
if reporter is None:
|
if reporter is None:
|
||||||
print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
|
print(Fore.GREEN + "Raptor execution successful" + Style.RESET_ALL)
|
||||||
|
return _parse_pim_pass_timings(output_text)
|
||||||
except subprocess.CalledProcessError:
|
except subprocess.CalledProcessError:
|
||||||
if reporter is None:
|
if reporter is None:
|
||||||
print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL)
|
print(Fore.RED + "Raptor execution failed" + Style.RESET_ALL)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ def _read_chunk(fd, treat_eio_as_eof=False):
|
|||||||
def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
||||||
selector = selectors.DefaultSelector()
|
selector = selectors.DefaultSelector()
|
||||||
recent_output = bytearray()
|
recent_output = bytearray()
|
||||||
|
captured_output = bytearray()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
selector.register(fd, selectors.EVENT_READ)
|
selector.register(fd, selectors.EVENT_READ)
|
||||||
@@ -34,6 +35,7 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
|||||||
reporter._clear()
|
reporter._clear()
|
||||||
os.write(1, data)
|
os.write(1, data)
|
||||||
reporter._render()
|
reporter._render()
|
||||||
|
captured_output.extend(data)
|
||||||
recent_output.extend(data)
|
recent_output.extend(data)
|
||||||
if len(recent_output) > MAX_ERROR_OUTPUT_BYTES:
|
if len(recent_output) > MAX_ERROR_OUTPUT_BYTES:
|
||||||
del recent_output[:-MAX_ERROR_OUTPUT_BYTES]
|
del recent_output[:-MAX_ERROR_OUTPUT_BYTES]
|
||||||
@@ -43,12 +45,22 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False):
|
|||||||
return_code = process.wait()
|
return_code = process.wait()
|
||||||
if return_code != 0:
|
if return_code != 0:
|
||||||
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output))
|
raise subprocess.CalledProcessError(return_code, process.args, output=bytes(recent_output))
|
||||||
|
return bytes(captured_output)
|
||||||
|
|
||||||
|
|
||||||
def run_command_with_reporter(cmd, cwd=None, reporter=None):
|
def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False):
|
||||||
if reporter is None:
|
if reporter is None:
|
||||||
|
if capture_output:
|
||||||
|
completed = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
cwd=cwd,
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
)
|
||||||
|
return completed.stdout.decode("utf-8", errors="replace")
|
||||||
subprocess.run(cmd, cwd=cwd, check=True)
|
subprocess.run(cmd, cwd=cwd, check=True)
|
||||||
return
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
master_fd, slave_fd = pty.openpty()
|
master_fd, slave_fd = pty.openpty()
|
||||||
@@ -60,8 +72,8 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None):
|
|||||||
stderr=subprocess.STDOUT,
|
stderr=subprocess.STDOUT,
|
||||||
)
|
)
|
||||||
assert process.stdout is not None
|
assert process.stdout is not None
|
||||||
_stream_output(process.stdout.fileno(), process, reporter)
|
output = _stream_output(process.stdout.fileno(), process, reporter)
|
||||||
return
|
return output.decode("utf-8", errors="replace") if capture_output else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
process = subprocess.Popen(
|
process = subprocess.Popen(
|
||||||
@@ -73,4 +85,5 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None):
|
|||||||
finally:
|
finally:
|
||||||
os.close(slave_fd)
|
os.close(slave_fd)
|
||||||
|
|
||||||
_stream_output(master_fd, process, reporter, treat_eio_as_eof=True)
|
output = _stream_output(master_fd, process, reporter, treat_eio_as_eof=True)
|
||||||
|
return output.decode("utf-8", errors="replace") if capture_output else None
|
||||||
|
|||||||
+34
-2
@@ -8,6 +8,7 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from colorama import Style, Fore
|
from colorama import Style, Fore
|
||||||
from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
|
from validate_one import ProgressReporter, clean_workspace_artifacts, validate_network
|
||||||
|
from raptor import PIM_PASS_LABELS
|
||||||
|
|
||||||
|
|
||||||
def format_command(cmd):
|
def format_command(cmd):
|
||||||
@@ -41,6 +42,19 @@ def print_validation_error(reporter, rel, exc):
|
|||||||
reporter.resume()
|
reporter.resume()
|
||||||
|
|
||||||
|
|
||||||
|
def print_average_pim_pass_timings(pass_timing_sums, pass_timing_counts, total_timing_sum, timed_benchmark_count):
|
||||||
|
if timed_benchmark_count == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\n" + Style.BRIGHT + Fore.CYAN + "Average PIM Pass Timings" + Style.RESET_ALL)
|
||||||
|
for _, label in PIM_PASS_LABELS:
|
||||||
|
count = pass_timing_counts[label]
|
||||||
|
if count == 0:
|
||||||
|
continue
|
||||||
|
print(f" {label.ljust(28)} {pass_timing_sums[label] / count:.4f}s")
|
||||||
|
print(f" {'Total'.ljust(28)} {total_timing_sum / timed_benchmark_count:.4f}s")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
|
ap = argparse.ArgumentParser(description="Validate all ONNX operations under the operations/ directory.")
|
||||||
ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
|
ap.add_argument("--raptor-path", help="Path to the Raptor compiler binary.")
|
||||||
@@ -90,11 +104,15 @@ def main():
|
|||||||
print("=" * 72)
|
print("=" * 72)
|
||||||
|
|
||||||
results = {} # relative_path -> passed
|
results = {} # relative_path -> passed
|
||||||
|
pass_timing_sums = {label: 0.0 for _, label in PIM_PASS_LABELS}
|
||||||
|
pass_timing_counts = {label: 0 for _, label in PIM_PASS_LABELS}
|
||||||
|
total_timing_sum = 0.0
|
||||||
|
timed_benchmark_count = 0
|
||||||
reporter = ProgressReporter(len(onnx_files))
|
reporter = ProgressReporter(len(onnx_files))
|
||||||
for index, onnx_path in enumerate(onnx_files, start=1):
|
for index, onnx_path in enumerate(onnx_files, start=1):
|
||||||
rel = onnx_path.relative_to(operations_dir)
|
rel = onnx_path.relative_to(operations_dir)
|
||||||
try:
|
try:
|
||||||
passed = validate_network(
|
result = validate_network(
|
||||||
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
|
onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir,
|
||||||
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
|
crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count,
|
||||||
threshold=a.threshold,
|
threshold=a.threshold,
|
||||||
@@ -102,7 +120,15 @@ def main():
|
|||||||
model_index=index,
|
model_index=index,
|
||||||
model_total=len(onnx_files),
|
model_total=len(onnx_files),
|
||||||
)
|
)
|
||||||
results[str(rel)] = passed
|
results[str(rel)] = result.passed
|
||||||
|
if result.pim_pass_timings:
|
||||||
|
benchmark_total = 0.0
|
||||||
|
for label, duration in result.pim_pass_timings.items():
|
||||||
|
pass_timing_sums[label] += duration
|
||||||
|
pass_timing_counts[label] += 1
|
||||||
|
benchmark_total += duration
|
||||||
|
total_timing_sum += benchmark_total
|
||||||
|
timed_benchmark_count += 1
|
||||||
except subprocess.CalledProcessError as exc:
|
except subprocess.CalledProcessError as exc:
|
||||||
results[str(rel)] = False
|
results[str(rel)] = False
|
||||||
print_validation_error(reporter, rel, exc)
|
print_validation_error(reporter, rel, exc)
|
||||||
@@ -131,6 +157,12 @@ def main():
|
|||||||
print(separator)
|
print(separator)
|
||||||
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
|
print(Style.BRIGHT + f"Passed: {n_passed}" + Style.RESET_ALL)
|
||||||
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
|
print(Style.BRIGHT + f"Failed: {n_total - n_passed}" + Style.RESET_ALL)
|
||||||
|
print_average_pim_pass_timings(
|
||||||
|
pass_timing_sums,
|
||||||
|
pass_timing_counts,
|
||||||
|
total_timing_sum,
|
||||||
|
timed_benchmark_count,
|
||||||
|
)
|
||||||
|
|
||||||
sys.exit(0 if n_passed == n_total else 1)
|
sys.exit(0 if n_passed == n_total else 1)
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import numpy as np
|
|||||||
import subprocess
|
import subprocess
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from colorama import Style, Fore
|
from colorama import Style, Fore
|
||||||
from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP
|
from onnx_utils import gen_random_inputs, save_inputs_to_files, onnx_io, write_inputs_to_memory_bin, _ONNX_TO_NP
|
||||||
@@ -25,6 +26,12 @@ STAGE_COUNT = len(STAGE_TITLES)
|
|||||||
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
|
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationResult:
|
||||||
|
passed: bool
|
||||||
|
pim_pass_timings: dict[str, float] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class ProgressReporter:
|
class ProgressReporter:
|
||||||
def __init__(self, total_models, stages_per_model=STAGE_COUNT):
|
def __init__(self, total_models, stages_per_model=STAGE_COUNT):
|
||||||
self.total_models = total_models
|
self.total_models = total_models
|
||||||
@@ -267,6 +274,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
|||||||
reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
|
reporter.log(Fore.CYAN + f"[{model_index}/{model_total}]" + Style.RESET_ALL +
|
||||||
f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
|
f" {Style.BRIGHT}Validating {network_onnx_path.name}{Style.RESET_ALL}")
|
||||||
failed_with_exception = False
|
failed_with_exception = False
|
||||||
|
pim_pass_timings = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
|
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile ONNX")
|
||||||
@@ -299,7 +307,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
|||||||
reporter.advance()
|
reporter.advance()
|
||||||
|
|
||||||
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
|
print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM")
|
||||||
compile_with_raptor(
|
pim_pass_timings = compile_with_raptor(
|
||||||
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
|
network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem,
|
||||||
crossbar_size, crossbar_count,
|
crossbar_size, crossbar_count,
|
||||||
cwd=raptor_dir, reporter=reporter)
|
cwd=raptor_dir, reporter=reporter)
|
||||||
@@ -326,7 +334,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir,
|
|||||||
reporter.record_result(passed)
|
reporter.record_result(passed)
|
||||||
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
|
status = Fore.GREEN + "PASS" + Style.RESET_ALL if passed else Fore.RED + "FAIL" + Style.RESET_ALL
|
||||||
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
|
reporter.log(Style.BRIGHT + f"Result: {status}" + Style.RESET_ALL)
|
||||||
return passed
|
return ValidationResult(passed=passed, pim_pass_timings=pim_pass_timings)
|
||||||
except Exception:
|
except Exception:
|
||||||
failed_with_exception = True
|
failed_with_exception = True
|
||||||
reporter.record_result(False)
|
reporter.record_result(False)
|
||||||
@@ -352,4 +360,4 @@ if __name__ == '__main__':
|
|||||||
passed = validate_network(
|
passed = validate_network(
|
||||||
a.network_onnx, a.raptor_path, a.onnx_include_dir, simulator_dir
|
a.network_onnx, a.raptor_path, a.onnx_include_dir, simulator_dir
|
||||||
)
|
)
|
||||||
raise SystemExit(0 if passed else 1)
|
raise SystemExit(0 if passed.passed else 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user