From c4dd28a6076cd10234876b6b5c6c6af31842a5f1 Mon Sep 17 00:00:00 2001 From: ilgeco Date: Thu, 2 Jul 2026 17:01:26 +0200 Subject: [PATCH] Export csv graph for gephi --- src/PIM/Common/Support/DebugDump.cpp | 16 +- src/PIM/Common/Support/DebugDump.hpp | 5 + src/PIM/Compiler/PimCompilerOptions.cpp | 12 + src/PIM/Compiler/PimCompilerOptions.hpp | 8 + .../Conversion/ONNXToSpatial/CMakeLists.txt | 1 + .../ONNXToSpatial/LowerSpatialPlansPass.cpp | 14 +- .../ONNXToSpatial/ONNXToSpatialPass.cpp | 3 +- src/PIM/Conversion/ONNXToSpatial/Patterns.cpp | 1 + src/PIM/Conversion/ONNXToSpatial/Patterns.hpp | 1 + .../ONNXToSpatial/Patterns/Tensor/Flatten.cpp | 112 +++ src/PIM/Dialect/Spatial/CMakeLists.txt | 1 + .../MergeComputeNodesPass.cpp | 9 +- .../SpatialDataflowCsvExporter.cpp | 728 ++++++++++++++++++ .../SpatialDataflowCsvExporter.hpp | 25 + 14 files changed, 926 insertions(+), 10 deletions(-) create mode 100644 src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Flatten.cpp create mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/SpatialDataflowCsvExporter.cpp create mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/SpatialDataflowCsvExporter.hpp diff --git a/src/PIM/Common/Support/DebugDump.cpp b/src/PIM/Common/Support/DebugDump.cpp index fb47f93..4b7026e 100644 --- a/src/PIM/Common/Support/DebugDump.cpp +++ b/src/PIM/Common/Support/DebugDump.cpp @@ -7,15 +7,21 @@ namespace onnx_mlir { -void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) { +std::fstream openDialectDumpFileWithExtension(const std::string& name, llvm::StringRef destination, llvm::StringRef extension) { std::string outputDir = getOutputDir(); if (outputDir.empty()) + return {}; + + std::string dialectsDir = (outputDir + destination).str(); + createDirectory(dialectsDir); + return std::fstream(dialectsDir + "/" + name + "." + extension.str(), std::ios::out); +} + +void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) { + std::fstream file = openDialectDumpFileWithExtension(name, "/dialects", "mlir"); + if (!file.is_open()) return; - std::string dialectsDir = outputDir + "/dialects"; - createDirectory(dialectsDir); - - std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out); llvm::raw_os_ostream os(file); mlir::OpPrintingFlags flags; flags.elideLargeElementsAttrs().enableDebugInfo(true, false); diff --git a/src/PIM/Common/Support/DebugDump.hpp b/src/PIM/Common/Support/DebugDump.hpp index 9f55182..b0e8a99 100644 --- a/src/PIM/Common/Support/DebugDump.hpp +++ b/src/PIM/Common/Support/DebugDump.hpp @@ -1,7 +1,9 @@ #pragma once #include "mlir/IR/BuiltinOps.h" +#include "llvm/ADT/StringRef.h" +#include #include namespace onnx_mlir { @@ -10,4 +12,7 @@ namespace onnx_mlir { /// directory for pass-level debugging. void dumpModule(mlir::ModuleOp moduleOp, const std::string& name); +/// Opens a file under the same dialect dump directory used by dumpModule. +std::fstream openDialectDumpFileWithExtension(const std::string& name,llvm::StringRef destination = "/dialects", llvm::StringRef extension = "mlir"); + } // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimCompilerOptions.cpp b/src/PIM/Compiler/PimCompilerOptions.cpp index 4fed7cb..5e73d58 100644 --- a/src/PIM/Compiler/PimCompilerOptions.cpp +++ b/src/PIM/Compiler/PimCompilerOptions.cpp @@ -57,6 +57,18 @@ llvm::cl::opt pimConvLowering( llvm::cl::init(PimConvLoweringAuto), llvm::cl::cat(OnnxMlirOptions)); +llvm::cl::opt pimExportSpatialDataflow( + "pim-export-spatial-dataflow", + llvm::cl::desc("Emit Gephi-importable CSV dataflow reports around MergeComputeNodes materialization"), + llvm::cl::values(clEnumValN(SpatialDataflowExportNone, "none", "Do not emit Spatial dataflow CSV reports")), + llvm::cl::values(clEnumValN(SpatialDataflowExportPre, "pre", "Emit pre-materialization Spatial dataflow CSV reports")), + llvm::cl::values( + clEnumValN(SpatialDataflowExportPost, "post", "Emit post-materialization Spatial dataflow CSV reports")), + llvm::cl::values( + clEnumValN(SpatialDataflowExportBoth, "both", "Emit both pre- and post-materialization Spatial dataflow CSV reports")), + llvm::cl::init(SpatialDataflowExportNone), + llvm::cl::cat(OnnxMlirOptions)); + llvm::cl::opt pimOnlyCodegen("pim-only-codegen", llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"), diff --git a/src/PIM/Compiler/PimCompilerOptions.hpp b/src/PIM/Compiler/PimCompilerOptions.hpp index 51aa469..b5d931d 100644 --- a/src/PIM/Compiler/PimCompilerOptions.hpp +++ b/src/PIM/Compiler/PimCompilerOptions.hpp @@ -42,11 +42,19 @@ typedef enum { PimConvLoweringTiled2D = 8, } PimConvLoweringType; +typedef enum { + SpatialDataflowExportNone = 0, + SpatialDataflowExportPre = 1, + SpatialDataflowExportPost = 2, + SpatialDataflowExportBoth = 3, +} PimSpatialDataflowExportType; + extern llvm::cl::OptionCategory OnnxMlirOptions; extern llvm::cl::opt pimEmissionTarget; extern llvm::cl::opt pimMergeScheduler; extern llvm::cl::opt pimMemoryReport; extern llvm::cl::opt pimConvLowering; +extern llvm::cl::opt pimExportSpatialDataflow; extern llvm::cl::opt pimOnlyCodegen; extern llvm::cl::opt pimDisableMemoryCoalescing; diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index 320ab70..ae4bc06 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -20,6 +20,7 @@ add_pim_library(OMONNXToSpatial Patterns/NN/Sigmoid.cpp Patterns/NN/Softmax.cpp Patterns/Tensor/Concat.cpp + Patterns/Tensor/Flatten.cpp Patterns/Tensor/Gather.cpp Patterns/Tensor/Resize.cpp Patterns/Tensor/Reshape.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/LowerSpatialPlansPass.cpp b/src/PIM/Conversion/ONNXToSpatial/LowerSpatialPlansPass.cpp index e933cb9..e674fa2 100644 --- a/src/PIM/Conversion/ONNXToSpatial/LowerSpatialPlansPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/LowerSpatialPlansPass.cpp @@ -16,6 +16,7 @@ #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/SpatialDataflowCsvExporter.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -392,10 +393,17 @@ struct LowerSpatialPlansPass final : PassWrapper(); - preTarget.addIllegalOp(); + preTarget.addIllegalOp(); RewritePatternSet prePatterns(ctx); populatePrePatterns(prePatterns, ctx); @@ -142,6 +142,7 @@ void ONNXToSpatialPass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp index 0a747e9..1abe958 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp @@ -19,6 +19,7 @@ void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { populateSigmoidPatterns(patterns, ctx); populateSoftmaxPatterns(patterns, ctx); populateConcatPatterns(patterns, ctx); + populateFlattenPatterns(patterns, ctx); populateGatherPatterns(patterns, ctx); populateResizePatterns(patterns, ctx); populateReshapePatterns(patterns, ctx); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp index e687e3d..da2a7a6 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp @@ -26,6 +26,7 @@ void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* void populateSigmoidPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateSoftmaxPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateFlattenPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Flatten.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Flatten.cpp new file mode 100644 index 0000000..621558d --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Flatten.cpp @@ -0,0 +1,112 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/SmallVector.h" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static FailureOr normalizeFlattenAxis(int64_t axis, int64_t rank) { + int64_t normalizedAxis = axis < 0 ? rank + axis : axis; + if (normalizedAxis < 0 || normalizedAxis > rank) + return failure(); + return normalizedAxis; +} + +static int64_t product(ArrayRef values) { + int64_t result = 1; + for (int64_t value : values) + result *= value; + return result; +} + +static SmallVector getCollapseTo1DReassociation(int64_t rank) { + SmallVector reassociation(1); + reassociation.front().reserve(rank); + for (int64_t dim = 0; dim < rank; ++dim) + reassociation.front().push_back(dim); + return reassociation; +} + +static SmallVector getExpandFrom1DReassociation(int64_t rank) { + SmallVector reassociation(1); + reassociation.front().reserve(rank); + for (int64_t dim = 0; dim < rank; ++dim) + reassociation.front().push_back(dim); + return reassociation; +} + +static Value buildFlatten(Value input, + RankedTensorType sourceType, + RankedTensorType resultType, + int64_t axis, + ConversionPatternRewriter& rewriter, + Location loc) { + if (sourceType == resultType) + return input; + + if (axis > 0 && axis < sourceType.getRank()) { + SmallVector reassociation(2); + for (int64_t dim = 0; dim < axis; ++dim) + reassociation[0].push_back(dim); + for (int64_t dim = axis; dim < sourceType.getRank(); ++dim) + reassociation[1].push_back(dim); + return tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation); + } + + Value flattened = input; + if (sourceType.getRank() != 1) { + auto flatType = RankedTensorType::get({sourceType.getNumElements()}, sourceType.getElementType()); + flattened = tensor::CollapseShapeOp::create( + rewriter, loc, flatType, flattened, getCollapseTo1DReassociation(sourceType.getRank())); + } + return tensor::ExpandShapeOp::create( + rewriter, loc, resultType, flattened, getExpandFrom1DReassociation(resultType.getRank())); +} + +struct Flatten : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXFlattenOp flattenOp, + ONNXFlattenOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto sourceType = dyn_cast(adaptor.getInput().getType()); + auto resultType = dyn_cast(flattenOp.getOperation()->getResult(0).getType()); + if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()) + return failure(); + if (!hasStaticPositiveShape(sourceType) || !hasStaticPositiveShape(resultType) || resultType.getRank() != 2) + return failure(); + + auto axis = normalizeFlattenAxis(flattenOp.getAxis(), sourceType.getRank()); + if (failed(axis)) + return failure(); + + int64_t outerDim = product(sourceType.getShape().take_front(*axis)); + int64_t innerDim = product(sourceType.getShape().drop_front(*axis)); + if (resultType.getShape()[0] != outerDim || resultType.getShape()[1] != innerDim) + return failure(); + + auto replaceWithFlatten = [&](auto build) -> LogicalResult { + Value flattened = materializeOrComputeUnary(adaptor.getInput(), resultType, rewriter, flattenOp.getLoc(), build); + rewriter.replaceOp(flattenOp, flattened); + return success(); + }; + + return replaceWithFlatten([&](Value input) { + return buildFlatten(input, sourceType, resultType, *axis, rewriter, flattenOp.getLoc()); + }); + } +}; + +} // namespace + +void populateFlattenPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } + +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/CMakeLists.txt b/src/PIM/Dialect/Spatial/CMakeLists.txt index 01625f1..4de8cce 100644 --- a/src/PIM/Dialect/Spatial/CMakeLists.txt +++ b/src/PIM/Dialect/Spatial/CMakeLists.txt @@ -11,6 +11,7 @@ add_pim_library(SpatialOps Transforms/MergeComputeNodes/HostOutputFinalization.cpp Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp Transforms/MergeComputeNodes/ProjectedFragments.cpp + Transforms/MergeComputeNodes/SpatialDataflowCsvExporter.cpp Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 4f50638..a9f3995 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -25,6 +25,7 @@ #include #include "MaterializeMergeSchedule.hpp" +#include "SpatialDataflowCsvExporter.hpp" #include "Scheduling/ComputeGraph.hpp" #include "Scheduling/ComputeInstanceUtils.hpp" #include "Scheduling/MergeSchedulingAnalysis.hpp" @@ -364,6 +365,7 @@ public: const spatial::MergeScheduleResult* analysisResult = nullptr; analysisResult = &getAnalysis().getResult(); + spatial::SpatialDataflowExportStage exportMode = spatial::getSpatialDataflowExportStage(); if (failed(spatial::MergeScheduleMaterializer().run(func, *analysisResult, nextChannelId))) { signalPassFailure(); return; @@ -379,7 +381,12 @@ public: signalPassFailure(); return; } - dumpModule(cast(func->getParentOp()), "spatial1_merged"); + if (spatial::shouldExportSpatialDataflowStage(exportMode, spatial::SpatialDataflowExportStage::Post) + && failed(spatial::exportSpatialDataflowCsvPost(func))) { + signalPassFailure(); + return; + } + dumpModule(cast(func->getParentOp()), "spatial2_merged"); generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size()); } }; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/SpatialDataflowCsvExporter.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/SpatialDataflowCsvExporter.cpp new file mode 100644 index 0000000..7286cec --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/SpatialDataflowCsvExporter.cpp @@ -0,0 +1,728 @@ +#include "SpatialDataflowCsvExporter.hpp" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include +#include + +#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" +#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace spatial { + +namespace { + +struct TopLevelOpInfo { + Operation* op = nullptr; + size_t opId = 0; + bool isPost = false; + std::optional scalarCore; +}; + +struct ExpandedNodeInfo { + std::string id; + std::optional core; + std::optional lane; +}; + +struct ChannelSendRecord { + std::string sourceId; + std::optional sourceLane; +}; + +enum class LogicalNodeSelector { + Scalar, + Lane, + RangeRepresentative, +}; + +struct ResolvedProducer { + Operation* op = nullptr; + size_t resultIndex = 0; + LogicalNodeSelector selector = LogicalNodeSelector::Scalar; + uint32_t lane = 0; + uint32_t laneStart = 0; + uint32_t laneCount = 1; +}; + +struct EdgeSource { + std::string id; + std::optional sourceLane; +}; + +std::string csvEscape(StringRef field) { + bool needsQuotes = field.contains(',') || field.contains('"') || field.contains('\n') || field.contains('\r'); + if (!needsQuotes) + return field.str(); + + std::string escaped; + escaped.reserve(field.size() + 2); + escaped.push_back('"'); + for (char ch : field) { + if (ch == '"') + escaped += "\"\""; + else + escaped.push_back(ch); + } + escaped.push_back('"'); + return escaped; +} + +void writeCsvRow(std::fstream& file, ArrayRef fields) { + for (size_t i = 0; i < fields.size(); ++i) { + if (i != 0) + file << ","; + file << csvEscape(fields[i]); + } + file << "\n"; +} + +template +std::string maybeNumber(std::optional value) { + if (!value) + return ""; + return std::to_string(*value); +} + +std::string stringifyType(Type type) { + std::string storage; + llvm::raw_string_ostream os(storage); + type.print(os); + return os.str(); +} + +std::string stringifyValueAsOperand(Value value, AsmState& asmState) { + std::string storage; + llvm::raw_string_ostream os(storage); + value.printAsOperand(os, asmState); + return os.str(); +} + +std::string stringifyResultSsaNames(Operation* op, AsmState* asmState) { + if (!asmState || op->getNumResults() == 0) + return ""; + + std::string storage; + llvm::raw_string_ostream os(storage); + llvm::interleave( + op->getResults(), + [&](Value result) { os << stringifyValueAsOperand(result, *asmState); }, + [&]() { os << ";"; }); + return os.str(); +} + +std::optional getTypeSizeBytes(Type type) { + if (auto shapedType = dyn_cast(type)) { + if (!shapedType.hasStaticShape() || !hasByteSizedElementType(shapedType.getElementType())) + return std::nullopt; + return static_cast(getShapedTypeSizeInBytes(shapedType)); + } + + if (isa(type)) + return static_cast(getElementTypeSizeInBytes(type)); + if (auto intType = dyn_cast(type)) { + if (intType.getWidth() <= 0 || intType.getWidth() % 8 != 0) + return std::nullopt; + return static_cast(getElementTypeSizeInBytes(type)); + } + if (auto floatType = dyn_cast(type)) { + if (floatType.getWidth() <= 0 || floatType.getWidth() % 8 != 0) + return std::nullopt; + return static_cast(getElementTypeSizeInBytes(type)); + } + return std::nullopt; +} + +std::string getScalarId(bool isPost, size_t opId) { + return (isPost ? "sc:" : "gc:") + std::to_string(opId); +} + +std::string getBatchLaneId(bool isPost, size_t opId, uint32_t lane) { + return (isPost ? "scb:" : "gcb:") + std::to_string(opId) + ":" + std::to_string(lane); +} + +template +bool isTopLevelRelevantCompute(Operation& op) { + return isa(&op); +} + +template +FailureOr buildTopLevelOpInfo(Operation& op, bool isPost, size_t opId) { + TopLevelOpInfo info; + info.op = &op; + info.opId = opId; + info.isPost = isPost; + + if constexpr (std::is_same_v) { + if (auto compute = dyn_cast(&op)) { + auto coreId = getOptionalScheduledCoreId(compute, "spatial dataflow export core id"); + if (failed(coreId)) + return failure(); + if (*coreId) + info.scalarCore = **coreId; + } + } + + return info; +} + +template +FailureOr> getBatchLaneCoreIds(BatchOpTy batch) { + if constexpr (std::is_same_v) { + auto coreIds = getOptionalScheduledBatchCoreIds(batch, "spatial dataflow export core ids"); + if (failed(coreIds)) + return failure(); + if (!*coreIds) + return SmallVector {}; + return SmallVector((**coreIds).begin(), (**coreIds).end()); + } + return SmallVector {}; +} + +std::string getExpandedNodeId(const DenseMap, ExpandedNodeInfo>& expandedNodes, + Operation* op, + uint32_t lane) { + auto it = expandedNodes.find({op, lane}); + if (it == expandedNodes.end()) + return ""; + return it->second.id; +} + +void addScalarNodeRow(std::fstream& nodesFile, + DenseMap, ExpandedNodeInfo>& expandedNodes, + const TopLevelOpInfo& info, + AsmState* asmState = nullptr) { + std::string id = getScalarId(info.isPost, info.opId); + SmallVector row {id, std::to_string(info.opId), "", maybeNumber(info.scalarCore)}; + if (asmState) + row.push_back(stringifyResultSsaNames(info.op, asmState)); + writeCsvRow(nodesFile, row); + expandedNodes[{info.op, 0}] = {id, info.scalarCore, std::nullopt}; +} + +template +void addBatchNodeRows(std::fstream& nodesFile, + DenseMap, ExpandedNodeInfo>& expandedNodes, + const TopLevelOpInfo& info, + BatchOpTy batch, + ArrayRef> laneCoreIds, + AsmState* asmState = nullptr) { + for (uint32_t lane = 0; lane < static_cast(batch.getLaneCount()); ++lane) { + std::string id = getBatchLaneId(info.isPost, info.opId, lane); + SmallVector row {id, + std::to_string(info.opId), + std::to_string(lane), + maybeNumber(laneCoreIds[lane])}; + if (asmState) + row.push_back(stringifyResultSsaNames(info.op, asmState)); + writeCsvRow(nodesFile, row); + expandedNodes[{info.op, lane}] = {id, laneCoreIds[lane], lane}; + } +} + +std::optional evaluateIndexLike(Value value, Value laneArg, uint32_t lane); + +std::optional evaluateIndexLike(Value value, Value laneArg, uint32_t lane) { + if (value == laneArg) + return static_cast(lane); + + if (std::optional constant = matchConstantIndexValue(value)) + return *constant; + + if (auto constant = value.getDefiningOp()) { + if (auto intAttr = dyn_cast(constant.getValue())) + return intAttr.getInt(); + } + + if (auto extract = value.getDefiningOp()) { + auto constant = extract.getTensor().getDefiningOp(); + auto elements = constant ? dyn_cast(constant.getValue()) : nullptr; + auto shapedType = elements ? dyn_cast(elements.getType()) : nullptr; + if (!elements || !shapedType || shapedType.getRank() != 1 || extract.getIndices().size() != 1) + return std::nullopt; + + std::optional index = evaluateIndexLike(extract.getIndices().front(), laneArg, lane); + if (!index || *index < 0 || *index >= static_cast(elements.getNumElements())) + return std::nullopt; + + if (auto denseInts = dyn_cast(elements)) + return (*(denseInts.value_begin() + *index)).getSExtValue(); + return std::nullopt; + } + + if (auto affineApply = value.getDefiningOp()) + if (FailureOr folded = evaluateAffineApply( + affineApply, + [&](Value operand) -> FailureOr { + if (std::optional resolved = evaluateIndexLike(operand, laneArg, lane)) + return *resolved; + return failure(); + }); + succeeded(folded)) { + return *folded; + } + + return std::nullopt; +} + +SmallVector collectPossibleIntValues(Value value, Value laneArg, uint32_t lane) { + if (std::optional exact = evaluateIndexLike(value, laneArg, lane)) + return {*exact}; + + auto extract = value.getDefiningOp(); + auto constant = extract ? extract.getTensor().getDefiningOp() : nullptr; + auto elements = constant ? dyn_cast(constant.getValue()) : nullptr; + if (!elements) + return {}; + + SmallVector values; + if (auto denseInts = dyn_cast(elements)) { + values.reserve(elements.getNumElements()); + for (APInt element : denseInts.getValues()) + if (!llvm::is_contained(values, element.getSExtValue())) + values.push_back(element.getSExtValue()); + } + return values; +} + +template +std::optional getBatchLaneInput(BatchOpTy batch, uint32_t lane, unsigned inputIndex) { + if (batch.getNumResults() != 0) + return batch.getInputs()[inputIndex]; + + size_t laneCount = static_cast(batch.getLaneCount()); + if (laneCount == 0 || batch.getInputs().size() % laneCount != 0) + return std::nullopt; + + size_t inputsPerLane = batch.getInputs().size() / laneCount; + size_t flatIndex = static_cast(lane) * inputsPerLane + inputIndex; + if (flatIndex >= batch.getInputs().size()) + return std::nullopt; + return batch.getInputs()[flatIndex]; +} + +template +unsigned getBatchLaneInputCount(BatchOpTy batch) { + if (batch.getNumResults() != 0) + return batch.getInputs().size(); + + size_t laneCount = static_cast(batch.getLaneCount()); + if (laneCount == 0 || batch.getInputs().size() % laneCount != 0) + return 0; + return static_cast(batch.getInputs().size() / laneCount); +} + +template +std::optional resolveProducerForValue(Value value, std::optional consumerLane) { + Operation* op = value.getDefiningOp(); + if (!op) + return std::nullopt; + + while (auto extract = dyn_cast(op)) { + Value source = extract.getSource(); + Operation* sourceOp = source.getDefiningOp(); + auto sourceBatch = dyn_cast_or_null(sourceOp); + if (sourceBatch && sourceBatch.getNumResults() != 0) { + auto staticOffsets = extract.getStaticOffsets(); + if (!staticOffsets.empty() && staticOffsets.front() != ShapedType::kDynamic) { + uint32_t lane = static_cast(staticOffsets.front()); + return ResolvedProducer {sourceOp, 0, LogicalNodeSelector::Lane, lane, lane, 1}; + } + if (consumerLane) + return ResolvedProducer {sourceOp, 0, LogicalNodeSelector::Lane, *consumerLane, *consumerLane, 1}; + return ResolvedProducer { + sourceOp, 0, LogicalNodeSelector::RangeRepresentative, 0, 0, static_cast(sourceBatch.getLaneCount()) + }; + } + value = source; + op = sourceOp; + if (!op) + return std::nullopt; + } + + if (auto compute = dyn_cast(op)) + return ResolvedProducer { + compute.getOperation(), static_cast(cast(value).getResultNumber()), LogicalNodeSelector::Scalar, 0, 0, 1 + }; + + if (auto batch = dyn_cast(op)) { + if (batch.getNumResults() != 0) { + if (consumerLane) + return ResolvedProducer {op, 0, LogicalNodeSelector::Lane, *consumerLane, *consumerLane, 1}; + return ResolvedProducer { + op, 0, LogicalNodeSelector::RangeRepresentative, 0, 0, static_cast(batch.getLaneCount()) + }; + } + + uint32_t lane = static_cast(cast(value).getResultNumber()); + return ResolvedProducer {op, static_cast(lane), LogicalNodeSelector::Lane, lane, lane, 1}; + } + + return std::nullopt; +} + +SmallVector +resolveProducerSourcesForCsv(const ResolvedProducer& producer, + const DenseMap, ExpandedNodeInfo>& expandedNodes) { + SmallVector sources; + + if (producer.selector == LogicalNodeSelector::Scalar) { + std::string id = getExpandedNodeId(expandedNodes, producer.op, 0); + if (!id.empty()) + sources.push_back({id, std::nullopt}); + return sources; + } + + if (producer.selector == LogicalNodeSelector::Lane) { + std::string id = getExpandedNodeId(expandedNodes, producer.op, producer.lane); + if (!id.empty()) + sources.push_back({id, producer.lane}); + return sources; + } + + for (uint32_t lane = producer.laneStart; lane < producer.laneStart + producer.laneCount; ++lane) { + std::string id = getExpandedNodeId(expandedNodes, producer.op, lane); + if (!id.empty()) + sources.push_back({id, lane}); + } + return sources; +} + +void emitEdgeRow(std::fstream& edgesFile, + StringRef sourceId, + StringRef targetId, + std::optional byteSize, + Type propagatedType, + StringRef stage, + std::optional sourceLane, + std::optional targetLane, + std::optional channelId) { + writeCsvRow(edgesFile, + {sourceId.str(), + targetId.str(), + maybeNumber(byteSize), + stringifyType(propagatedType), + stage.str(), + maybeNumber(sourceLane), + maybeNumber(targetLane), + maybeNumber(channelId)}); +} + +template +LogicalResult emitDataEdges(std::fstream& edgesFile, + const DenseMap& topLevelInfo, + const DenseMap, ExpandedNodeInfo>& expandedNodes, + StringRef stage) { + for (const auto& entry : topLevelInfo) { + Operation* op = entry.first; + const TopLevelOpInfo& info = entry.second; + + if (auto compute = dyn_cast(op)) { + for (Value input : compute.getInputs()) { + if (isa_and_nonnull(input.getDefiningOp())) + continue; + + auto producer = resolveProducerForValue(input, std::nullopt); + if (!producer) + continue; + + SmallVector sources = resolveProducerSourcesForCsv(*producer, expandedNodes); + std::optional byteSize = getTypeSizeBytes(input.getType()); + std::string targetId = getScalarId(info.isPost, info.opId); + for (const EdgeSource& source : sources) + emitEdgeRow(edgesFile, source.id, targetId, byteSize, input.getType(), stage, source.sourceLane, std::nullopt, std::nullopt); + } + continue; + } + + auto batch = dyn_cast(op); + if (!batch) + continue; + + unsigned inputCount = getBatchLaneInputCount(batch); + for (uint32_t lane = 0; lane < static_cast(batch.getLaneCount()); ++lane) { + std::string targetId = getBatchLaneId(info.isPost, info.opId, lane); + for (unsigned inputIndex = 0; inputIndex < inputCount; ++inputIndex) { + std::optional input = getBatchLaneInput(batch, lane, inputIndex); + if (!input || isa_and_nonnull((*input).getDefiningOp())) + continue; + + auto producer = resolveProducerForValue(*input, lane); + if (!producer) + continue; + + SmallVector sources = resolveProducerSourcesForCsv(*producer, expandedNodes); + std::optional byteSize = getTypeSizeBytes((*input).getType()); + for (const EdgeSource& source : sources) + emitEdgeRow(edgesFile, source.id, targetId, byteSize, (*input).getType(), stage, source.sourceLane, lane, std::nullopt); + } + } + } + + return success(); +} + +template +void collectChannelSends(DenseMap>& sendsByChannelId, + const DenseMap, ExpandedNodeInfo>& expandedNodes, + BatchOpTy batch) { + std::optional laneArg = batch.getLaneArgument(); + if (!laneArg) + return; + + for (uint32_t lane = 0; lane < static_cast(batch.getLaneCount()); ++lane) { + std::string sourceId = getExpandedNodeId(expandedNodes, batch.getOperation(), lane); + if (sourceId.empty()) + continue; + batch.getBody().walk([&](SpatChannelSendOp send) { + std::optional channelId = evaluateIndexLike(send.getChannelId(), *laneArg, lane); + if (!channelId) + return; + sendsByChannelId[*channelId].push_back({sourceId, lane}); + }); + } +} + +void collectChannelSends(DenseMap>& sendsByChannelId, + const DenseMap, ExpandedNodeInfo>& expandedNodes, + SpatScheduledCompute compute) { + std::string sourceId = getExpandedNodeId(expandedNodes, compute.getOperation(), 0); + if (sourceId.empty()) + return; + compute.getBody().walk([&](SpatChannelSendOp send) { + std::optional channelId = evaluateIndexLike(send.getChannelId(), Value(), 0); + if (!channelId) + return; + sendsByChannelId[*channelId].push_back({sourceId, std::nullopt}); + }); +} + +DenseMap> +buildNodesByCore(const DenseMap, ExpandedNodeInfo>& expandedNodes) { + DenseMap> nodesByCore; + for (const auto& entry : expandedNodes) { + const ExpandedNodeInfo& node = entry.second; + if (!node.core) + continue; + nodesByCore[*node.core].push_back({node.id, node.lane}); + } + return nodesByCore; +} + +template +LogicalResult emitExplicitChannelEdges(std::fstream& edgesFile, + const DenseMap& topLevelInfo, + ResolveChannelSourcesFn&& resolveChannelSources, + StringRef stage) { + for (const auto& entry : topLevelInfo) { + Operation* op = entry.first; + const TopLevelOpInfo& info = entry.second; + + if (auto compute = dyn_cast(op)) { + compute.getBody().walk([&](SpatChannelReceiveOp receive) { + SmallVector sources = resolveChannelSources(receive, 0); + if (sources.empty()) + return; + std::optional channelId = evaluateIndexLike(receive.getChannelId(), Value(), 0); + std::string targetId = getScalarId(info.isPost, info.opId); + std::optional byteSize = getTypeSizeBytes(receive.getType()); + for (const ChannelSendRecord& source : sources) + emitEdgeRow(edgesFile, source.sourceId, targetId, byteSize, receive.getType(), stage, source.sourceLane, std::nullopt, channelId); + }); + continue; + } + + auto batch = dyn_cast(op); + if (!batch) + continue; + auto laneArg = batch.getLaneArgument(); + if (!laneArg) + continue; + for (uint32_t lane = 0; lane < static_cast(batch.getLaneCount()); ++lane) { + std::string targetId = getBatchLaneId(info.isPost, info.opId, lane); + batch.getBody().walk([&](SpatChannelReceiveOp receive) { + SmallVector sources = resolveChannelSources(receive, lane); + if (sources.empty()) + return; + std::optional channelId = evaluateIndexLike(receive.getChannelId(), *laneArg, lane); + std::optional byteSize = getTypeSizeBytes(receive.getType()); + for (const ChannelSendRecord& source : sources) + emitEdgeRow(edgesFile, source.sourceId, targetId, byteSize, receive.getType(), stage, source.sourceLane, lane, channelId); + }); + } + } + + return success(); +} + +LogicalResult exportStagePre(func::FuncOp func) { + std::fstream nodesFile = openDialectDumpFileWithExtension("spatial1_graph.nodes", "/reports", "csv"); + std::fstream edgesFile = openDialectDumpFileWithExtension("spatial1_graph.edges","/reports", "csv"); + if (!nodesFile.is_open() || !edgesFile.is_open()) + return success(); + + writeCsvRow(nodesFile, {"Id", "op_id", "lane", "core", "ssa_name"}); + writeCsvRow(edgesFile, {"Source", "Target", "Weight", "Type", "stage", "source_lane", "target_lane", "channel_id"}); + + Operation* asmRoot = func.getOperation(); + if (auto moduleOp = func->getParentOfType()) + asmRoot = moduleOp.getOperation(); + OpPrintingFlags flags; + flags.elideLargeElementsAttrs().enableDebugInfo(true, false); + AsmState asmState(asmRoot, flags); + + DenseMap topLevelInfo; + DenseMap, ExpandedNodeInfo> expandedNodes; + + size_t opId = 0; + for (Operation& op : func.getBody().front()) { + if (!isTopLevelRelevantCompute(op)) + continue; + FailureOr info = buildTopLevelOpInfo(op, false, opId++); + if (failed(info)) + return failure(); + topLevelInfo[&op] = *info; + + if (auto compute = dyn_cast(&op)) { + addScalarNodeRow(nodesFile, expandedNodes, *info, &asmState); + continue; + } + + auto batch = cast(&op); + SmallVector, 8> laneCoreIds(batch.getLaneCount()); + addBatchNodeRows(nodesFile, expandedNodes, *info, batch, laneCoreIds, &asmState); + } + + return emitDataEdges(edgesFile, topLevelInfo, expandedNodes, "pre"); +} + +LogicalResult exportStagePost(func::FuncOp func) { + std::fstream nodesFile = openDialectDumpFileWithExtension("spatial2_merged.nodes", "/reports", "csv"); + std::fstream edgesFile = openDialectDumpFileWithExtension("spatial2_merged.edges", "/reports", "csv"); + if (!nodesFile.is_open() || !edgesFile.is_open()) + return success(); + + writeCsvRow(nodesFile, {"Id", "op_id", "lane", "core"}); + writeCsvRow(edgesFile, {"Source", "Target", "Weight", "Type", "stage", "source_lane", "target_lane", "channel_id"}); + + DenseMap topLevelInfo; + DenseMap, ExpandedNodeInfo> expandedNodes; + + size_t opId = 0; + for (Operation& op : func.getBody().front()) { + if (!isTopLevelRelevantCompute(op)) + continue; + FailureOr info = buildTopLevelOpInfo(op, true, opId++); + if (failed(info)) + return failure(); + topLevelInfo[&op] = *info; + + if (isa(&op)) { + addScalarNodeRow(nodesFile, expandedNodes, *info); + continue; + } + + auto batch = cast(&op); + auto coreIds = getBatchLaneCoreIds(batch); + if (failed(coreIds)) + return failure(); + SmallVector, 8> laneCoreIds(batch.getLaneCount()); + for (uint32_t lane = 0; lane < static_cast(batch.getLaneCount()); ++lane) + if (lane < coreIds->size()) + laneCoreIds[lane] = (*coreIds)[lane]; + addBatchNodeRows(nodesFile, expandedNodes, *info, batch, laneCoreIds); + } + + if (failed(emitDataEdges(edgesFile, topLevelInfo, expandedNodes, "post"))) + return failure(); + + DenseMap> sendsByChannelId; + for (const auto& entry : topLevelInfo) { + Operation* op = entry.first; + if (auto compute = dyn_cast(op)) + collectChannelSends(sendsByChannelId, expandedNodes, compute); + else if (auto batch = dyn_cast(op)) + collectChannelSends(sendsByChannelId, expandedNodes, batch); + } + + DenseMap> nodesByCore = buildNodesByCore(expandedNodes); + auto resolveChannelSources = [&](SpatChannelReceiveOp receive, uint32_t lane) { + SmallVector sources; + + Value laneArg; + if (auto owner = receive->getParentOfType()) + if (auto maybeLaneArg = owner.getLaneArgument()) + laneArg = *maybeLaneArg; + + if (std::optional channelId = evaluateIndexLike(receive.getChannelId(), laneArg, lane)) { + if (auto it = sendsByChannelId.find(*channelId); it != sendsByChannelId.end()) + return it->second; + } + + for (int64_t sourceCore : collectPossibleIntValues(receive.getSourceCoreId(), laneArg, lane)) { + auto it = nodesByCore.find(static_cast(sourceCore)); + if (it == nodesByCore.end()) + continue; + llvm::append_range(sources, it->second); + } + return sources; + }; + + return emitExplicitChannelEdges( + edgesFile, topLevelInfo, resolveChannelSources, "post"); +} + +} // namespace + +SpatialDataflowExportStage getSpatialDataflowExportStage() { + switch (pimExportSpatialDataflow.getValue()) { + case SpatialDataflowExportNone: return SpatialDataflowExportStage::None; + case SpatialDataflowExportPre: return SpatialDataflowExportStage::Pre; + case SpatialDataflowExportPost: return SpatialDataflowExportStage::Post; + case SpatialDataflowExportBoth: return SpatialDataflowExportStage::Both; + } + llvm_unreachable("unknown spatial dataflow export mode"); +} + +bool shouldExportSpatialDataflowStage(SpatialDataflowExportStage mode, SpatialDataflowExportStage stage) { + switch (mode) { + case SpatialDataflowExportStage::None: return false; + case SpatialDataflowExportStage::Pre: return stage == SpatialDataflowExportStage::Pre; + case SpatialDataflowExportStage::Post: return stage == SpatialDataflowExportStage::Post; + case SpatialDataflowExportStage::Both: + return stage == SpatialDataflowExportStage::Pre || stage == SpatialDataflowExportStage::Post; + } + return false; +} + +LogicalResult exportSpatialDataflowCsvPre(func::FuncOp func) { return exportStagePre(func); } + +LogicalResult exportSpatialDataflowCsvPost(func::FuncOp func) { return exportStagePost(func); } + +} // namespace spatial +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/SpatialDataflowCsvExporter.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/SpatialDataflowCsvExporter.hpp new file mode 100644 index 0000000..aef6efc --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/SpatialDataflowCsvExporter.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Support/LogicalResult.h" + + +namespace onnx_mlir { +namespace spatial { + +enum class SpatialDataflowExportStage { + None, + Pre, + Post, + Both, +}; + +SpatialDataflowExportStage getSpatialDataflowExportStage(); + +mlir::LogicalResult exportSpatialDataflowCsvPre(mlir::func::FuncOp func); +mlir::LogicalResult exportSpatialDataflowCsvPost(mlir::func::FuncOp func); + +bool shouldExportSpatialDataflowStage(SpatialDataflowExportStage mode, SpatialDataflowExportStage stage); + +} // namespace spatial +} // namespace onnx_mlir