#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/Block.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Format.h" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" #define FORMAT_OPERATION(op) 'x' << llvm::format_hex_no_prefix(reinterpret_cast(op), 0) #define FORMAT_ARGUMENT(computeOpPointer, argumentNum) llvm::format("Arg_%p_%u", computeOpPointer, argumentNum) using namespace mlir; namespace onnx_mlir { namespace { struct SpatialToGraphvizPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToGraphvizPass) StringRef getArgument() const override { return "convert-spatial-to-graphviz"; } StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; } SpatialToGraphvizPass(raw_ostream& os = llvm::errs()) : os(os) {} SpatialToGraphvizPass(const SpatialToGraphvizPass& pass) : SpatialToGraphvizPass(pass.os) {} void runOnOperation() final; private: raw_ostream& os; /** * Draws the subgraph for a given spatial::SpatWeightedCompute, including: * 1. Input nodes (block arguments) * 2. Operations * 3. Edges between yield (output) and its users * * @param op The spatial::SpatWeightedCompute to draw the subgraph for. * @param computeNum The number of the compute operation. */ void drawComputeOpSubgraph(spatial::SpatWeightedCompute op, size_t computeNum) { os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" << computeNum << "\";\n" << "\t\tstyle=filled;\n" << "\t\tcolor=lightblue;\n"; Block& block = op.getBody().front(); // Inputs size_t inputNum = 0; for (BlockArgument& input : block.getArguments()) { auto fromOp = FORMAT_ARGUMENT(op.getOperation(), inputNum); os << "\t\t" << fromOp << " [label=\"Arg" << inputNum << "\",shape=box];\n"; for (auto userOp : input.getUsers()) os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n"; inputNum++; } // Iterate operations for (auto& childOp : block.getOperations()) { os << "\t\t" << FORMAT_OPERATION(&childOp) << " [label=\"" << childOp.getName() << "\"];\n"; drawEdgesFromOpToItsUsers(&childOp); } os << "\t}\n"; // Draw edges from the yield to the users of this computeOp Operation* yieldOp = block.getTerminator(); if (!isa(yieldOp)) { yieldOp->emitError("Terminator of block must be YieldOp ???"); signalPassFailure(); return; } for (auto computeOpResult : op->getResults()) { for (auto& computeOpUse : computeOpResult.getUses()) { auto toOp = FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber()); os << "\t" << FORMAT_OPERATION(yieldOp) << " -> " << toOp << ";\n"; } } } /** * @brief Draws the subgraph for a concatOp. * * This function draws a subgraph for a concatOp. The subgraph consists of a * node for each input of the concatOp, as well as an output node. Edges are * created from the output node to each user of the concatOp. * * @param concatOp The concatOp for which the subgraph is drawn. * @param concatOpNum The number of the concatOp. */ void drawConcatOpSubgraph(Operation* concatOp, size_t concatOpNum) { os << "\tsubgraph clusterconcat" << concatOpNum << " {\n\t\tlabel=\"ConcatOp" << concatOpNum << "\";\n" << "\t\tstyle=filled;\n" << "\t\tcolor=orange;\n"; // Inputs size_t inputNum = 0; for (Value input : concatOp->getOperands()) { auto fromOp = FORMAT_ARGUMENT(concatOp, inputNum); os << "\t\t" << fromOp << " [label=\"Input" << inputNum << "\"];\n"; for (auto userOp : input.getUsers()) os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n"; inputNum++; } // Output os << "\t\t" << FORMAT_OPERATION(concatOp) << " [label=Out];\n"; os << "\t}\n"; // Edges from output to users for (auto& computeOpUse : concatOp->getResult(0).getUses()) { os << "\t" << FORMAT_OPERATION(concatOp) << " -> " << FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber()) << ";\n"; } } /** * Draws the ExtractSliceOp in the graph visualization. * * This function takes a tensor::ExtractSliceOp and adds the corresponding * node and edges to the graph visualization. It creates a node with the * label as the static offsets attribute of the sliceOp, and connects it to * the compute operations that use the result of the sliceOp. * * @param sliceOp The tensor::ExtractSliceOp to be drawn in the graph * visualization. */ void drawExtractSliceOp(tensor::ExtractSliceOp sliceOp) { auto nodeId = FORMAT_ARGUMENT(sliceOp.getOperation(), 0); os << "\t" << nodeId << " [label=\"Slice: "; sliceOp.getStaticOffsetsAttr().print(os); os << "\",color=lawngreen];\n"; for (auto& computeOpUse : sliceOp.getResult().getUses()) { os << "\t" << nodeId << " -> " << FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber()) << ";\n"; } } void drawBiasTileOp(tensor::ExtractSliceOp sliceOp) { auto nodeId = FORMAT_ARGUMENT(sliceOp.getOperation(), 0); os << "\t" << nodeId << " [label=\"Bias: "; sliceOp.getStaticOffsetsAttr().print(os); os << "\",color=lightpink];\n"; for (auto user : sliceOp.getResult().getUsers()) os << "\t" << nodeId << " -> " << FORMAT_OPERATION(user) << ";\n"; } /** * Draws edges from the given operation to its users. * * @param fromOp The operation from which the edges are drawn. */ void drawEdgesFromOpToItsUsers(mlir::Operation* fromOp) { for (auto result : fromOp->getResults()) for (auto userOp : result.getUsers()) os << "\t\t" << FORMAT_OPERATION(fromOp) << " -> " << FORMAT_OPERATION(userOp) << ";\n"; } /** * Draws input node and edges for the given `funcOp`. * * @param funcOp The `funcOp` for which to draw input nodes and edges. */ void drawInputNodesAndEdges(func::FuncOp& funcOp) { os << "\tinput [label=\"Module Input\",color=green];\n"; size_t funcOpArgNum = 0; for (BlockArgument& arg : funcOp.getArguments()) { for (auto& useOp : arg.getUses()) { os << "\tinput -> " << FORMAT_ARGUMENT(useOp.getOwner(), useOp.getOperandNumber()) << "[label=" << funcOpArgNum << "];\n"; } funcOpArgNum++; } } }; void SpatialToGraphvizPass::runOnOperation() { ModuleOp module = getOperation(); auto entryFunc = getPimEntryFunc(module); if (failed(entryFunc)) { signalPassFailure(); return; } func::FuncOp func = *entryFunc; os << "digraph G {\n" << "\tnode [style=filled,color=white];\n"; size_t computeNum = 0; size_t concatNum = 0; // Iterate over the ComputeOps within FuncOp: // 1. Print their subgraph // 2. Print the edges from its inputs to its outputs for (Operation& op : func.getOps()) { if (auto computeOp = dyn_cast(op)) { drawComputeOpSubgraph(computeOp, computeNum++); } else if (auto concatOp = dyn_cast(op)) { drawConcatOpSubgraph(concatOp, concatNum++); } else if (auto imgConcatOp = dyn_cast(op)) { drawConcatOpSubgraph(imgConcatOp, concatNum++); } else if (auto extractSliceOp = dyn_cast(op)) { auto producerOp = extractSliceOp->getOperand(0).getDefiningOp(); if (producerOp) { // Skip extractSliceOp if producer is constant weights (ONNXConstantOp) if (llvm::isa(producerOp)) continue; // If produced by tosa::ReshapeOp (i.e. it is a bias tile) connect // directly to its user, which is not a ComputeOp argument. if (llvm::isa(producerOp)) { drawBiasTileOp(extractSliceOp); continue; } } drawExtractSliceOp(extractSliceOp); } } // Draw input node, and edges to it users drawInputNodesAndEdges(func); // Draw output node (use the return Operation - argument number=0 - as nodeId) auto returnOp = func.getBody().front().getTerminator(); os << '\t' << FORMAT_ARGUMENT(returnOp, 0) << " [label=\"Module Output\",color=green];\n"; os << "}\n"; } } // namespace std::unique_ptr createSpatialToGraphvizPass() { return std::make_unique(); } } // namespace onnx_mlir