Files
Raptor/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp
NiccoloN bb6dcd38a3 replace deprecated "rewriter.create()" pattern
refactor PIM to Pim everywhere except for the accelerator name
2026-03-20 13:30:53 +01:00

262 lines
8.7 KiB
C++

#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Format.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#define FORMAT_OPERATION(op) 'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0)
#define FORMAT_ARGUMENT(computeOpPointer, argumentNum) llvm::format("Arg_%p_%u", computeOpPointer, argumentNum)
using namespace mlir;
namespace onnx_mlir {
namespace {
struct SpatialToGraphvizPass : public PassWrapper<SpatialToGraphvizPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToGraphvizPass)
StringRef getArgument() const override { return "convert-spatial-to-graphviz"; }
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
SpatialToGraphvizPass(raw_ostream& os = llvm::errs())
: os(os) {}
SpatialToGraphvizPass(const SpatialToGraphvizPass& pass)
: SpatialToGraphvizPass(pass.os) {}
void runOnOperation() final;
private:
raw_ostream& os;
/**
* Draws the subgraph for a given spatial::SpatWeightedCompute, including:
* 1. Input nodes (block arguments)
* 2. Operations
* 3. Edges between yield (output) and its users
*
* @param op The spatial::SpatWeightedCompute to draw the subgraph for.
* @param computeNum The number of the compute operation.
*/
void drawComputeOpSubgraph(spatial::SpatWeightedCompute op, size_t computeNum) {
os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" << computeNum << "\";\n"
<< "\t\tstyle=filled;\n"
<< "\t\tcolor=lightblue;\n";
Block& block = op.getBody().front();
// Inputs
size_t inputNum = 0;
for (BlockArgument& input : block.getArguments()) {
auto fromOp = FORMAT_ARGUMENT(op.getOperation(), inputNum);
os << "\t\t" << fromOp << " [label=\"Arg" << inputNum << "\",shape=box];\n";
for (auto userOp : input.getUsers())
os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n";
inputNum++;
}
// Iterate operations
for (auto& childOp : block.getOperations()) {
os << "\t\t" << FORMAT_OPERATION(&childOp) << " [label=\"" << childOp.getName() << "\"];\n";
drawEdgesFromOpToItsUsers(&childOp);
}
os << "\t}\n";
// Draw edges from the yield to the users of this computeOp
Operation* yieldOp = block.getTerminator();
if (!isa<spatial::SpatYieldOp>(yieldOp)) {
yieldOp->emitError("Terminator of block must be YieldOp ???");
signalPassFailure();
return;
}
for (auto computeOpResult : op->getResults()) {
for (auto& computeOpUse : computeOpResult.getUses()) {
auto toOp = FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber());
os << "\t" << FORMAT_OPERATION(yieldOp) << " -> " << toOp << ";\n";
}
}
}
/**
* @brief Draws the subgraph for a concatOp.
*
* This function draws a subgraph for a concatOp. The subgraph consists of a
* node for each input of the concatOp, as well as an output node. Edges are
* created from the output node to each user of the concatOp.
*
* @param concatOp The concatOp for which the subgraph is drawn.
* @param concatOpNum The number of the concatOp.
*/
void drawConcatOpSubgraph(Operation* concatOp, size_t concatOpNum) {
os << "\tsubgraph clusterconcat" << concatOpNum << " {\n\t\tlabel=\"ConcatOp" << concatOpNum << "\";\n"
<< "\t\tstyle=filled;\n"
<< "\t\tcolor=orange;\n";
// Inputs
size_t inputNum = 0;
for (Value input : concatOp->getOperands()) {
auto fromOp = FORMAT_ARGUMENT(concatOp, inputNum);
os << "\t\t" << fromOp << " [label=\"Input" << inputNum << "\"];\n";
for (auto userOp : input.getUsers())
os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n";
inputNum++;
}
// Output
os << "\t\t" << FORMAT_OPERATION(concatOp) << " [label=Out];\n";
os << "\t}\n";
// Edges from output to users
for (auto& computeOpUse : concatOp->getResult(0).getUses()) {
os << "\t" << FORMAT_OPERATION(concatOp) << " -> "
<< FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber()) << ";\n";
}
}
/**
* Draws the ExtractSliceOp in the graph visualization.
*
* This function takes a tensor::ExtractSliceOp and adds the corresponding
* node and edges to the graph visualization. It creates a node with the
* label as the static offsets attribute of the sliceOp, and connects it to
* the compute operations that use the result of the sliceOp.
*
* @param sliceOp The tensor::ExtractSliceOp to be drawn in the graph
* visualization.
*/
void drawExtractSliceOp(tensor::ExtractSliceOp sliceOp) {
auto nodeId = FORMAT_ARGUMENT(sliceOp.getOperation(), 0);
os << "\t" << nodeId << " [label=\"Slice: ";
sliceOp.getStaticOffsetsAttr().print(os);
os << "\",color=lawngreen];\n";
for (auto& computeOpUse : sliceOp.getResult().getUses()) {
os << "\t" << nodeId << " -> " << FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber())
<< ";\n";
}
}
void drawBiasTileOp(tensor::ExtractSliceOp sliceOp) {
auto nodeId = FORMAT_ARGUMENT(sliceOp.getOperation(), 0);
os << "\t" << nodeId << " [label=\"Bias: ";
sliceOp.getStaticOffsetsAttr().print(os);
os << "\",color=lightpink];\n";
for (auto user : sliceOp.getResult().getUsers())
os << "\t" << nodeId << " -> " << FORMAT_OPERATION(user) << ";\n";
}
/**
* Draws edges from the given operation to its users.
*
* @param fromOp The operation from which the edges are drawn.
*/
void drawEdgesFromOpToItsUsers(mlir::Operation* fromOp) {
for (auto result : fromOp->getResults())
for (auto userOp : result.getUsers())
os << "\t\t" << FORMAT_OPERATION(fromOp) << " -> " << FORMAT_OPERATION(userOp) << ";\n";
}
/**
* Draws input node and edges for the given `funcOp`.
*
* @param funcOp The `funcOp` for which to draw input nodes and edges.
*/
void drawInputNodesAndEdges(func::FuncOp& funcOp) {
os << "\tinput [label=\"Module Input\",color=green];\n";
size_t funcOpArgNum = 0;
for (BlockArgument& arg : funcOp.getArguments()) {
for (auto& useOp : arg.getUses()) {
os << "\tinput -> " << FORMAT_ARGUMENT(useOp.getOwner(), useOp.getOperandNumber()) << "[label=" << funcOpArgNum
<< "];\n";
}
funcOpArgNum++;
}
}
};
void SpatialToGraphvizPass::runOnOperation() {
ModuleOp module = getOperation();
auto entryFunc = getPimEntryFunc(module);
if (failed(entryFunc)) {
signalPassFailure();
return;
}
func::FuncOp func = *entryFunc;
os << "digraph G {\n"
<< "\tnode [style=filled,color=white];\n";
size_t computeNum = 0;
size_t concatNum = 0;
// Iterate over the ComputeOps within FuncOp:
// 1. Print their subgraph
// 2. Print the edges from its inputs to its outputs
for (Operation& op : func.getOps()) {
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) {
drawComputeOpSubgraph(computeOp, computeNum++);
}
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
drawConcatOpSubgraph(concatOp, concatNum++);
}
else if (auto imgConcatOp = dyn_cast<spatial::SpatImgConcatOp>(op)) {
drawConcatOpSubgraph(imgConcatOp, concatNum++);
}
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
auto producerOp = extractSliceOp->getOperand(0).getDefiningOp();
if (producerOp) {
// Skip extractSliceOp if producer is constant weights (ONNXConstantOp)
if (llvm::isa<ONNXConstantOp>(producerOp))
continue;
// If produced by tosa::ReshapeOp (i.e. it is a bias tile) connect
// directly to its user, which is not a ComputeOp argument.
if (llvm::isa<tosa::ReshapeOp>(producerOp)) {
drawBiasTileOp(extractSliceOp);
continue;
}
}
drawExtractSliceOp(extractSliceOp);
}
}
// Draw input node, and edges to it users
drawInputNodesAndEdges(func);
// Draw output node (use the return Operation - argument number=0 - as nodeId)
auto returnOp = func.getBody().front().getTerminator();
os << '\t' << FORMAT_ARGUMENT(returnOp, 0) << " [label=\"Module Output\",color=green];\n";
os << "}\n";
}
} // namespace
std::unique_ptr<Pass> createSpatialToGraphvizPass() { return std::make_unique<SpatialToGraphvizPass>(); }
} // namespace onnx_mlir