diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 29c54f9..6289d14 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -14,12 +14,15 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_os_ostream.h" +#include #include #include #include #include #include #include +#include +#include #include #include "DCPGraph/DCPAnalysis.hpp" @@ -36,32 +39,86 @@ void generateReport(func::FuncOp funcOp, const std::string& name) { if (outputDir.empty()) return; - std::string dialectsDir = outputDir + "/dialects/stats"; + std::string dialectsDir = outputDir + "/dcp_graph"; createDirectory(dialectsDir); std::fstream file(dialectsDir + "/" + name + ".txt", std::ios::out); llvm::raw_os_ostream os(file); uint64_t numSpatCompute = 0; - std::vector numWeights; - std::vector numInstructions; + std::vector> collectedData; for (auto spatCompute : funcOp.getOps()) { - numSpatCompute++; - numWeights.push_back(spatCompute.getWeights().size()); - uint64_t numInst = 0; - for(auto& _ : spatCompute.getRegion().front() ){ + uint64_t numInst = 0; + for (auto& _ : spatCompute.getRegion().front()) numInst++; - } - numInstructions.push_back(numInst); + collectedData.push_back({numSpatCompute++, spatCompute.getWeights().size(), numInst}); } + std::stable_sort(collectedData.begin(), + collectedData.end(), + [](std::tuple lft, std::tuple rgt) { + auto [iLft, weightLft, numInstLft] = lft; + auto [iRgt, weightRgt, numInstRgt] = rgt; + + if (numInstLft < numInstRgt) + return false; + else if (numInstRgt < numInstLft) + return true; + + if (weightLft < weightRgt) + return false; + else if (weightRgt < weightLft) + return true; + + if (iLft < iRgt) + return true; + else if (iRgt < iLft) + return false; + + return true; + }); + for (uint64_t cI = 0; cI < numSpatCompute; ++cI) { - os << "Compute " << cI << ":\n"; - os << "\tNumber of instructions " << numInstructions[cI] << "\n"; - os << "\tNumber of used crossbars " << numWeights[cI] << "\n"; + uint64_t lastIndex = cI; + auto [currentComputeId, currentWeight, currentNumInst] = collectedData[cI]; + + for (uint64_t nI = cI + 1; nI < numSpatCompute; ++nI) { + auto [nextComputeId, nextWeight, nextNumInst] = collectedData[nI]; + if (currentWeight == nextWeight && currentNumInst == nextNumInst) + lastIndex = nI; + else + break; + } + + os << "Compute " << currentComputeId; + auto expectedPrintedValue = currentComputeId + 1; + bool rangePrinted = false; + cI++; + for (; cI < lastIndex; ++cI){ + auto candidateToPrint = std::get<0>(collectedData[cI]); + if (candidateToPrint == expectedPrintedValue){ + expectedPrintedValue = candidateToPrint + 1; + rangePrinted = true; + } else { + if (rangePrinted) { + os << " - " << expectedPrintedValue - 1; + } + os << " , " << candidateToPrint; + rangePrinted = false; + expectedPrintedValue = candidateToPrint + 1; + } + } + if (rangePrinted && currentComputeId != expectedPrintedValue - 1){ + os << " - " << expectedPrintedValue - 1; + } + + os << " :\n"; + os << "\tNumber of instructions " << currentNumInst << "\n"; + os << "\tNumber of used crossbars " << currentWeight << "\n"; + cI = lastIndex; } - + os.flush(); file.close(); }