262 lines
8.7 KiB
C++
262 lines
8.7 KiB
C++
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/IR/Block.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
#include "llvm/Support/Casting.h"
|
|
#include "llvm/Support/Format.h"
|
|
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
#define FORMAT_OPERATION(op) 'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0)
|
|
#define FORMAT_ARGUMENT(computeOpPointer, argumentNum) llvm::format("Arg_%p_%u", computeOpPointer, argumentNum)
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
namespace {
|
|
|
|
struct SpatialToGraphvizPass : public PassWrapper<SpatialToGraphvizPass, OperationPass<ModuleOp>> {
|
|
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToGraphvizPass)
|
|
|
|
StringRef getArgument() const override { return "convert-spatial-to-graphviz"; }
|
|
|
|
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
|
|
|
|
SpatialToGraphvizPass(raw_ostream& os = llvm::errs())
|
|
: os(os) {}
|
|
SpatialToGraphvizPass(const SpatialToGraphvizPass& pass)
|
|
: SpatialToGraphvizPass(pass.os) {}
|
|
void runOnOperation() final;
|
|
|
|
private:
|
|
raw_ostream& os;
|
|
|
|
/**
|
|
* Draws the subgraph for a given spatial::SpatWeightedCompute, including:
|
|
* 1. Input nodes (block arguments)
|
|
* 2. Operations
|
|
* 3. Edges between yield (output) and its users
|
|
*
|
|
* @param op The spatial::SpatWeightedCompute to draw the subgraph for.
|
|
* @param computeNum The number of the compute operation.
|
|
*/
|
|
void drawComputeOpSubgraph(spatial::SpatWeightedCompute op, size_t computeNum) {
|
|
os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" << computeNum << "\";\n"
|
|
<< "\t\tstyle=filled;\n"
|
|
<< "\t\tcolor=lightblue;\n";
|
|
|
|
Block& block = op.getBody().front();
|
|
|
|
// Inputs
|
|
size_t inputNum = 0;
|
|
for (BlockArgument& input : block.getArguments()) {
|
|
|
|
auto fromOp = FORMAT_ARGUMENT(op.getOperation(), inputNum);
|
|
|
|
os << "\t\t" << fromOp << " [label=\"Arg" << inputNum << "\",shape=box];\n";
|
|
for (auto userOp : input.getUsers())
|
|
os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n";
|
|
inputNum++;
|
|
}
|
|
|
|
// Iterate operations
|
|
for (auto& childOp : block.getOperations()) {
|
|
os << "\t\t" << FORMAT_OPERATION(&childOp) << " [label=\"" << childOp.getName() << "\"];\n";
|
|
|
|
drawEdgesFromOpToItsUsers(&childOp);
|
|
}
|
|
|
|
os << "\t}\n";
|
|
|
|
// Draw edges from the yield to the users of this computeOp
|
|
Operation* yieldOp = block.getTerminator();
|
|
if (!isa<spatial::SpatYieldOp>(yieldOp)) {
|
|
yieldOp->emitError("Terminator of block must be YieldOp ???");
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
|
|
for (auto computeOpResult : op->getResults()) {
|
|
for (auto& computeOpUse : computeOpResult.getUses()) {
|
|
auto toOp = FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber());
|
|
os << "\t" << FORMAT_OPERATION(yieldOp) << " -> " << toOp << ";\n";
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief Draws the subgraph for a concatOp.
|
|
*
|
|
* This function draws a subgraph for a concatOp. The subgraph consists of a
|
|
* node for each input of the concatOp, as well as an output node. Edges are
|
|
* created from the output node to each user of the concatOp.
|
|
*
|
|
* @param concatOp The concatOp for which the subgraph is drawn.
|
|
* @param concatOpNum The number of the concatOp.
|
|
*/
|
|
void drawConcatOpSubgraph(Operation* concatOp, size_t concatOpNum) {
|
|
os << "\tsubgraph clusterconcat" << concatOpNum << " {\n\t\tlabel=\"ConcatOp" << concatOpNum << "\";\n"
|
|
<< "\t\tstyle=filled;\n"
|
|
<< "\t\tcolor=orange;\n";
|
|
|
|
// Inputs
|
|
size_t inputNum = 0;
|
|
for (Value input : concatOp->getOperands()) {
|
|
auto fromOp = FORMAT_ARGUMENT(concatOp, inputNum);
|
|
|
|
os << "\t\t" << fromOp << " [label=\"Input" << inputNum << "\"];\n";
|
|
for (auto userOp : input.getUsers())
|
|
os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n";
|
|
inputNum++;
|
|
}
|
|
|
|
// Output
|
|
os << "\t\t" << FORMAT_OPERATION(concatOp) << " [label=Out];\n";
|
|
|
|
os << "\t}\n";
|
|
|
|
// Edges from output to users
|
|
|
|
for (auto& computeOpUse : concatOp->getResult(0).getUses()) {
|
|
os << "\t" << FORMAT_OPERATION(concatOp) << " -> "
|
|
<< FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber()) << ";\n";
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Draws the ExtractSliceOp in the graph visualization.
|
|
*
|
|
* This function takes a tensor::ExtractSliceOp and adds the corresponding
|
|
* node and edges to the graph visualization. It creates a node with the
|
|
* label as the static offsets attribute of the sliceOp, and connects it to
|
|
* the compute operations that use the result of the sliceOp.
|
|
*
|
|
* @param sliceOp The tensor::ExtractSliceOp to be drawn in the graph
|
|
* visualization.
|
|
*/
|
|
void drawExtractSliceOp(tensor::ExtractSliceOp sliceOp) {
|
|
auto nodeId = FORMAT_ARGUMENT(sliceOp.getOperation(), 0);
|
|
os << "\t" << nodeId << " [label=\"Slice: ";
|
|
sliceOp.getStaticOffsetsAttr().print(os);
|
|
os << "\",color=lawngreen];\n";
|
|
|
|
for (auto& computeOpUse : sliceOp.getResult().getUses()) {
|
|
os << "\t" << nodeId << " -> " << FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber())
|
|
<< ";\n";
|
|
}
|
|
}
|
|
|
|
void drawBiasTileOp(tensor::ExtractSliceOp sliceOp) {
|
|
auto nodeId = FORMAT_ARGUMENT(sliceOp.getOperation(), 0);
|
|
os << "\t" << nodeId << " [label=\"Bias: ";
|
|
sliceOp.getStaticOffsetsAttr().print(os);
|
|
os << "\",color=lightpink];\n";
|
|
|
|
for (auto user : sliceOp.getResult().getUsers())
|
|
os << "\t" << nodeId << " -> " << FORMAT_OPERATION(user) << ";\n";
|
|
}
|
|
|
|
/**
|
|
* Draws edges from the given operation to its users.
|
|
*
|
|
* @param fromOp The operation from which the edges are drawn.
|
|
*/
|
|
void drawEdgesFromOpToItsUsers(mlir::Operation* fromOp) {
|
|
for (auto result : fromOp->getResults())
|
|
for (auto userOp : result.getUsers())
|
|
os << "\t\t" << FORMAT_OPERATION(fromOp) << " -> " << FORMAT_OPERATION(userOp) << ";\n";
|
|
}
|
|
|
|
/**
|
|
* Draws input node and edges for the given `funcOp`.
|
|
*
|
|
* @param funcOp The `funcOp` for which to draw input nodes and edges.
|
|
*/
|
|
void drawInputNodesAndEdges(func::FuncOp& funcOp) {
|
|
os << "\tinput [label=\"Module Input\",color=green];\n";
|
|
|
|
size_t funcOpArgNum = 0;
|
|
for (BlockArgument& arg : funcOp.getArguments()) {
|
|
|
|
for (auto& useOp : arg.getUses()) {
|
|
os << "\tinput -> " << FORMAT_ARGUMENT(useOp.getOwner(), useOp.getOperandNumber()) << "[label=" << funcOpArgNum
|
|
<< "];\n";
|
|
}
|
|
funcOpArgNum++;
|
|
}
|
|
}
|
|
};
|
|
|
|
void SpatialToGraphvizPass::runOnOperation() {
|
|
ModuleOp module = getOperation();
|
|
|
|
auto entryFunc = getPimEntryFunc(module);
|
|
if (failed(entryFunc)) {
|
|
signalPassFailure();
|
|
return;
|
|
}
|
|
func::FuncOp func = *entryFunc;
|
|
|
|
os << "digraph G {\n"
|
|
<< "\tnode [style=filled,color=white];\n";
|
|
|
|
size_t computeNum = 0;
|
|
size_t concatNum = 0;
|
|
|
|
// Iterate over the ComputeOps within FuncOp:
|
|
// 1. Print their subgraph
|
|
// 2. Print the edges from its inputs to its outputs
|
|
for (Operation& op : func.getOps()) {
|
|
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) {
|
|
drawComputeOpSubgraph(computeOp, computeNum++);
|
|
}
|
|
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
|
drawConcatOpSubgraph(concatOp, concatNum++);
|
|
}
|
|
else if (auto imgConcatOp = dyn_cast<spatial::SpatImgConcatOp>(op)) {
|
|
drawConcatOpSubgraph(imgConcatOp, concatNum++);
|
|
}
|
|
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
|
auto producerOp = extractSliceOp->getOperand(0).getDefiningOp();
|
|
if (producerOp) {
|
|
// Skip extractSliceOp if producer is constant weights (ONNXConstantOp)
|
|
if (llvm::isa<ONNXConstantOp>(producerOp))
|
|
continue;
|
|
// If produced by tosa::ReshapeOp (i.e. it is a bias tile) connect
|
|
// directly to its user, which is not a ComputeOp argument.
|
|
if (llvm::isa<tosa::ReshapeOp>(producerOp)) {
|
|
drawBiasTileOp(extractSliceOp);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
drawExtractSliceOp(extractSliceOp);
|
|
}
|
|
}
|
|
|
|
// Draw input node, and edges to it users
|
|
drawInputNodesAndEdges(func);
|
|
|
|
// Draw output node (use the return Operation - argument number=0 - as nodeId)
|
|
auto returnOp = func.getBody().front().getTerminator();
|
|
os << '\t' << FORMAT_ARGUMENT(returnOp, 0) << " [label=\"Module Output\",color=green];\n";
|
|
|
|
os << "}\n";
|
|
}
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> createSpatialToGraphvizPass() { return std::make_unique<SpatialToGraphvizPass>(); }
|
|
|
|
} // namespace onnx_mlir
|