#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include #include "../SpatialOps.hpp" #include "DCPAnalysis.hpp" #include "Graph.hpp" #include "src/Support/TypeUtilities.hpp" namespace onnx_mlir { namespace spatial { using namespace mlir; SpatWeightedCompute getOriginalSpatWeightCompute(Operation* op) { if (!op) return {}; while (auto extract = llvm::dyn_cast(op)) { op = extract.getSource().getDefiningOp(); if (!op) return {}; } if (auto res = llvm::dyn_cast(op)) return res; return {}; } DCPAnalysisResult DCPAnalysis::runAnalysis() { using EdgesIndex = std::tuple; llvm::SmallVector spatWeightedComputes; llvm::SmallVector edges; for (auto& regions : entryOp->getRegions()) for (SpatWeightedCompute spatWeightedCompute : regions.getOps()) spatWeightedComputes.push_back(spatWeightedCompute); for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) { for (Value input : spatWeightedCompute.getInputs()) { if (auto spatWeightedComputeArgOp = getOriginalSpatWeightCompute(input.getDefiningOp())) { auto elemIter = llvm::find(spatWeightedComputes, spatWeightedComputeArgOp); assert(elemIter != spatWeightedComputes.end()); auto indexStartEdge = std::distance(spatWeightedComputes.begin(), elemIter); ResultRange outputs = spatWeightedComputeArgOp.getResults(); int64_t totalSize = 0; for (auto output : outputs) { ShapedType result = cast(output.getType()); totalSize += getSizeInBytes(result); } edges.push_back({indexStartEdge, indexEndEdge, totalSize}); } } } GraphDCP graphDCP(spatWeightedComputes, edges); graphDCP.DCP(); return graphDCP.getResult(); } } // namespace spatial } // namespace onnx_mlir