Add DCP alghoritm, partial working test
This commit is contained in:
60
src/PIM/Dialect/Spatial/DCPGraph/Utils.hpp
Normal file
60
src/PIM/Dialect/Spatial/DCPGraph/Utils.hpp
Normal file
@@ -0,0 +1,60 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../SpatialOps.hpp"
|
||||
#include "src/Support/TypeUtilities.hpp"
|
||||
|
||||
using CPU = int;
|
||||
using Weight_t = int;
|
||||
class TaskDCP;
|
||||
class GraphDCP;
|
||||
using Edge_t = std::pair<TaskDCP*, Weight_t>;
|
||||
using Edge_pair = std::pair<Edge_t, Edge_t>;
|
||||
using EdgesIndex = std::tuple<int64_t, int64_t, int64_t>;
|
||||
|
||||
template <typename T>
|
||||
void fastRemove(std::vector<std::pair<T*, Weight_t>>& vector, T* to_remove) {
|
||||
auto position =
|
||||
std::find_if(vector.begin(), vector.end(), [to_remove](Edge_t edge) { return edge.first == to_remove; });
|
||||
if (position != vector.end()) {
|
||||
std::swap(*(vector.end() - 1), *position);
|
||||
vector.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
inline void fastRemove(std::vector<TaskDCP*>& vector, TaskDCP* to_remove) {
|
||||
auto position =
|
||||
std::find_if(vector.begin(), vector.end(), [to_remove](TaskDCP* element) { return element == to_remove; });
|
||||
if (position != vector.end()) {
|
||||
std::swap(*(vector.end() - 1), *position);
|
||||
vector.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename P>
|
||||
void fastRemove(std::vector<std::pair<T*, Weight_t>>& vector, P position) {
|
||||
if (position != vector.end()) {
|
||||
std::swap(*(vector.end() - 1), *position);
|
||||
vector.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
// TODO Fare qualcosa di sensato
|
||||
inline int64_t getSpatWeightCompute(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) {
|
||||
int64_t tot = 0;
|
||||
for (auto& region : spatWeightedCompute.getBody()) {
|
||||
for (auto& inst : region) {
|
||||
for(auto result : inst.getResults()){
|
||||
if(auto element = llvm::dyn_cast<mlir::ShapedType>(result.getType()))
|
||||
tot += onnx_mlir::getSizeInBytes(element);
|
||||
}
|
||||
}
|
||||
}
|
||||
return tot;
|
||||
}
|
||||
Reference in New Issue
Block a user