All checks were successful
Validate Operations / validate-operations (push) Successful in 2h55m12s
66 lines
2.1 KiB
C++
66 lines
2.1 KiB
C++
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/OpDefinition.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/IR/ValueRange.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/Support/Casting.h"
|
|
|
|
#include <iterator>
|
|
|
|
#include "../SpatialOps.hpp"
|
|
#include "DCPAnalysis.hpp"
|
|
#include "Graph.hpp"
|
|
#include "src/Support/TypeUtilities.hpp"
|
|
|
|
namespace onnx_mlir {
|
|
namespace spatial {
|
|
|
|
using namespace mlir;
|
|
|
|
SpatWeightedCompute getOriginalSpatWeightCompute(Operation* op) {
|
|
if (!op)
|
|
return {};
|
|
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) {
|
|
op = extract.getSource().getDefiningOp();
|
|
if (!op)
|
|
return {};
|
|
}
|
|
if (auto res = llvm::dyn_cast<SpatWeightedCompute>(op))
|
|
return res;
|
|
return {};
|
|
}
|
|
|
|
DCPAnalysisResult DCPAnalysis::runAnalysis() {
|
|
using EdgesIndex = std::tuple<int64_t, int64_t, int64_t>;
|
|
llvm::SmallVector<SpatWeightedCompute, 10> spatWeightedComputes;
|
|
llvm::SmallVector<EdgesIndex, 10> edges;
|
|
for (auto& regions : entryOp->getRegions())
|
|
for (SpatWeightedCompute spatWeightedCompute : regions.getOps<SpatWeightedCompute>())
|
|
spatWeightedComputes.push_back(spatWeightedCompute);
|
|
|
|
for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) {
|
|
for (Value input : spatWeightedCompute.getInputs()) {
|
|
if (auto spatWeightedComputeArgOp = getOriginalSpatWeightCompute(input.getDefiningOp())) {
|
|
auto elemIter = llvm::find(spatWeightedComputes, spatWeightedComputeArgOp);
|
|
assert(elemIter != spatWeightedComputes.end());
|
|
auto indexStartEdge = std::distance(spatWeightedComputes.begin(), elemIter);
|
|
ResultRange outputs = spatWeightedComputeArgOp.getResults();
|
|
int64_t totalSize = 0;
|
|
for (auto output : outputs) {
|
|
ShapedType result = cast<ShapedType>(output.getType());
|
|
totalSize += getSizeInBytes(result);
|
|
}
|
|
edges.push_back({indexStartEdge, indexEndEdge, totalSize});
|
|
}
|
|
}
|
|
}
|
|
|
|
GraphDCP graphDCP(spatWeightedComputes, edges);
|
|
graphDCP.DCP();
|
|
return graphDCP.getResult();
|
|
}
|
|
|
|
} // namespace spatial
|
|
} // namespace onnx_mlir
|