On this submit, I discuss by the motivation, complexities and implementation particulars of constructing torchvista, an open-source bundle to interactively visualize the ahead cross of any Pytorch mannequin from inside web-based notebooks.
To get a way of the workings of torchvista whereas studying this submit, you possibly can take a look at:
- Github page if you wish to set up it through
pip
and use it from web-based notebooks (Jupyter, Colab, Kaggle, VSCode, and so on) - An interactive demo page with varied well-known fashions visualized
- A Google Colab tutorial
- A video demo:
Motivation
Pytorch fashions can get very massive and sophisticated, and making sense of 1 from the code alone generally is a tiresome and even intractable train. Having a graph-like visualization of it’s simply what we have to make this simpler.
Whereas there exist instruments like Netron, pytorchviz, and torchview that make this simpler, my motivation for constructing torchvista was that I discovered that they have been missing in some or all of those necessities:
- Interplay assist: The visualized graph ought to be interactive and never a static picture. It ought to be a construction you possibly can zoom, drag, broaden/collapse, and so on. Fashions can get very massive, and if all you might be see is a big static picture of the graph, how will you actually discover it?
- Modular exploration: Giant Pytorch fashions are modular in thought and implementation. For instance, consider a module which has a
Sequential
module which accommodates a couple ofConsideration
blocks, which in flip every has Totally related blocks which includeLinear
layers with activation capabilities and so forth. The software ought to mean you can faucet into this modular construction, and never simply current a low-level tensor hyperlink graph.

- Pocket book assist: We are likely to prototype and construct our fashions in notebooks. If a software have been supplied as a standalone utility that required you to construct your mannequin and cargo it to visualise it, it’s simply too lengthy a suggestions loop. So the software has to ideally work from inside notebooks.

- Error debugging assist: Whereas constructing fashions from scratch, we regularly run into many errors till the mannequin is ready to run a full ahead cross end-to-end. So the visualization software ought to be error tolerant and present you a partial visualization graph even when there are errors, so that you could debug the error.

torch.cat
failed attributable to mismatched tensor shapes- Ahead cross tracing: Pytorch natively exposes a backward cross graph by its autograd system, which the bundle pytorchviz exposes as a graph, however that is totally different from the ahead cross. After we construct, research and picture fashions, we predict extra concerning the ahead cross, and this may be very helpful to visualise.
Constructing torchvista
Fundamental API
The objective was to have a easy API that works with virtually any Pytorch mannequin.
import torch
from transformers import XLNetModel
from torchvista import trace_model
mannequin = XLNetModel.from_pretrained("xlnet-base-cased")
example_input = torch.randint(0, 32000, (1, 10))
# Hint it!
trace_model(mannequin, example_input)
With one line of code calling trace_model(<model_instance>, <enter>)
it ought to simply produce an interactive visualization of the ahead cross.
Steps concerned
Behind the scenes, torchvista, when referred to as, works in two phases:
- Tracing: That is the place torchvista extracts a graph knowledge construction from the ahead cross of the mannequin. Pytorch doesn’t inherently expose this graph construction (although it does expose a graph for the backward cross), so torchvista has to construct this knowledge construction by itself.
- Visualization: As soon as the graph is extracted, torchvista has to supply the precise visualization as an interactive graph. torchvista’s tracer does this by loading a template HTML file (with JS embedded inside it), and injecting serialized graph knowledge construction objects as strings into the template to be subsequently loaded by the browser engine.

Tracing
Tracing is actually completed by (quickly) wrapping all of the necessary and recognized tensor operations, and customary Pytorch modules. The objective of wrapping is to switch the capabilities in order that when referred to as, they moreover do the bookkeeping vital for tracing.
Construction of the graph
The graph we extract from the mannequin is a directed graph the place:
- The nodes are the assorted Tensor operations and the assorted inbuilt Pytorch modules that get referred to as through the ahead cross
- Moreover, enter and output tensors, and fixed valued tensors are additionally nodes within the graph.
- An edge exists from one node to the opposite for every tensor despatched from the previous to the latter.
- The sting label is the dimension of the related tensor.

However, the construction of our graph could be extra sophisticated as a result of most Pytorch modules name tensor operations and generally different modules’ ahead
methodology. This implies we’ve got to take care of a graph construction that holds data to visually discover it at any degree of depth.

Subsequently, the construction that torchvista extracts contains two major knowledge constructions:
- Adjacency record of the bottom degree operations/modules that get referred to as.
input_0 -> [ linear ]
linear -> [ __add__ ]
__getitem__ -> [ __add__ ]
__add__ -> [ multi_head_attention_forward ]
multi_head_attention_forward -> [ dropout ]
dropout -> [ __add__ ]
- Hierarchy map that maps every node to its dad or mum module container (if current)
linear -> Linear
multi_head_attention_forward -> MultiheadAttention
MultiheadAttention -> TransformerEncoderLayer
TransformerEncoderLayer -> TransformerEncoder
With each of those, we’re capable of assemble any desired views of the ahead cross within the visualization layer.
Wrapping operations and modules
The entire thought behind wrapping is to do some bookkeeping earlier than and after the precise operation, in order that when the operation known as, our wrapped operate as a substitute will get referred to as, and the bookkeeping is carried out. The targets of bookkeeping are:
- File connections between nodes primarily based on tensor references.
- File tensor dimensions to indicate as edge labels.
- File module hierarchy for modules within the case the place modules are nested inside each other
Here’s a simplified code snippet of how wrapping works:
original_operations = {}
def wrap_operation(module, operation):
original_operations[get_hashable_key(module, operation)] = operation
def wrapped_operation(*args, **kwargs):
# Do the mandatory pre-call bookkeeping
do_pre_call_bookkeeping()
# Name the unique operation
end result = operation(*args, **kwargs)
do_post_call_bookkeeping()
return end result
setattr(module, func_name, wrapped_operation)
for module, operation in LONG_LIST_OF_PYTORCH_OPS:
wrap_operation(module, operation)
And when trace_model is about to finish, we should reset all the pieces again to its unique state:
for module, operation in LONG_LIST_OF_PYTORCH_OPS:
setattr(module, func_name, original_operations[get_hashable_key(module,
operation)])
That is completed in the identical means for the ahead()
strategies of inbuilt Pytorch modules like Linear
, Conv2d
and so on.
Connections between nodes
As said beforehand, an edge exists between two nodes if a tensor was despatched from one to the opposite. This types the premise of making connections between nodes whereas constructing the graph.
Here’s a simplified code snippet of how this works:
adj_list = {}
def do_post_call_bookkeeping(module, operation, tensor_output):
# Set a "marker" on the output tensor in order that whoever consumes it
# is aware of which operation produced it
tensor_output._source_node = get_hashable_key(module, operation)
def do_pre_call_bookkeeping(module, operation, tensor_input):
source_node = tensor_input._source_node
# Add a hyperlink from the producer of the tensor to this node (the buyer)
adj_list[source_node].append(get_hashable_key(module, operation))

Module hierarchy map
After we wrap modules, issues need to be completed a bit in a different way to construct the module hierarchy map. The concept is to take care of a stack of modules at the moment being referred to as in order that the highest of the stack at all times represents within the speedy dad or mum within the hierarchy map.
Here’s a simplified code snippet of how this works:
hierarchy_map = {}
module_call_stack = []
def do_pre_call_bookkeeping_for_module(bundle, module, tensor_output):
# Add it to the stack
module_call_stack.append(get_hashable_key(bundle, module))
def do_post_call_bookkeeping_for_module(module, operation, tensor_input):
module_call_stack.pop()
# High of the stack now could be the dad or mum node
hierarchy_map[get_hashable_key(package, module)] = module_call_stack[-1]
Visualization
This half is totally dealt with in Javscript as a result of the visualization occurs in web-based notebooks. The important thing libraries which are used listed here are:
- graphviz: for producing the format for the graph (viz-js is the JS port)
- d3: for drawing the interactive graph on a canvas
- iPython: to render HTML contents inside a pocket book
Graph Format
Getting the format for the graph proper is a particularly advanced drawback. The principle objective is for the graph to have a top-to-bottom “circulation” of edges, and most significantly, for there to not be an overlap between the assorted nodes, edges, and edge labels.
That is made all of the extra advanced once we are working with a “hierarchical” graph the place there are “container” containers for modules inside which the underlying nodes and subcomponents are proven.

Fortunately, graphviz (viz-js) involves the rescue for us. graphviz makes use of a language referred to as “DOT language” by which we specify how we require the graph format to be constructed.
Here’s a pattern of the DOT syntax for the above graph:
# Edges and nodes
"input_0" [width=1.2, height=0.5];
"output_0" [width=1.2, height=0.5];
"input_0" -> "linear_1"[label="(1, 16)", fontsize="10", edge_data_id="5623840688" ];
"linear_1" -> "layer_norm_1"[label="(1, 32)", fontsize="10", edge_data_id="5801314448" ];
"linear_1" -> "layer_norm_2"[label="(1, 32)", fontsize="10", edge_data_id="5801314448" ];
...
# Module hierarchy specified utilizing clusters
subgraph cluster_FeatureEncoder_1 {
label="FeatureEncoder_1";
type=rounded;
subgraph cluster_MiddleBlock_1 {
label="MiddleBlock_1";
type=rounded;
subgraph cluster_InnerBlock_1 {
label="InnerBlock_1";
type=rounded;
subgraph cluster_LayerNorm_1 {
label="LayerNorm_1";
type=rounded;
"layer_norm_1";
}
subgraph cluster_TinyBranch_1 {
label="TinyBranch_1";
type=rounded;
subgraph cluster_MicroBranch_1 {
label="MicroBranch_1";
type=rounded;
subgraph cluster_Linear_2 {
label="Linear_2";
type=rounded;
"linear_2";
}
...
As soon as this DOT illustration is generated from our adjacency record and hierarchy map, graphviz produces a format with positions and sizes of all nodes and paths for edges.
Rendering
As soon as the format is generated, d3 is used to render the graph visually. All the pieces is drawn on a canvas (which is straightforward to make draggable and zoomable), and we set varied occasion handlers to detect person clicks.
When the person makes these two forms of broaden/collapse clicks on modules (utilizing the ‘+’ ‘-‘ buttons), torchvista data which node the motion was carried out on, and simply re-renders the graph as a result of the format must be reconstructed, after which routinely drags and zooms in to an acceptable degree primarily based on the recorded pre-click place.
Rendering a graph utilizing d3 is a really detailed matter and in any other case to not distinctive to torchvista, and therefore I pass over the small print from this submit.
[Bonus] Dealing with errors in Pytorch fashions
When customers hint their Pytorch fashions (particularly whereas creating the fashions), generally the fashions throw errors. It might have been simple for torchvista to only hand over when this occurs and let the person repair the error first earlier than they may use torchvista. However torchvista as a substitute lends a hand at debugging these errors by doing best-effort tracing of the mannequin. The concept is easy – simply hint the utmost it will possibly till the error occurs, after which render the graph with simply a lot (with visible indicators displaying the place the error occurred), after which simply increase the exception in order that the person also can see the stacktrace like they usually would.


Here’s a simplified code snippet of how this works:
def trace_model(...):
exception = None
strive:
# All of the tracing code
besides Exception as e:
exception = e
lastly:
# do all the mandatory cleanups (unwrapping all of the operations and modules)
if exception is just not None:
increase exception
Wrapping up
This submit shed some gentle on the journey of constructing a Pytorch visualization bundle. We first talked concerning the very particular motivation for constructing such a software by evaluating with different related instruments. Then, we mentioned the design and implementation of torchvista in two components. The primary half was concerning the strategy of tracing the ahead cross of a Pytorch mannequin utilizing (short-term) wrapping of operations and modules to extract detailed details about the mannequin’s ahead cross, together with not solely the connections between varied operations, but additionally the module hierarchy. Then, within the second half, we went over the visualization layer, and the complexities of format era, which have been solved utilizing the proper selection of libraries.
torchvista is open supply, and all contributions, together with suggestions, points and pull requests, are welcome. I hope torchvista helps folks of all ranges of experience in constructing and visualizing their fashions (no matter mannequin dimension), showcasing their work, and as a software for educating others about machine studying fashions.
Future instructions
Potential future enhancements to torchvista embody:
- Including assist for “rolling”, the place if the identical substructure of a mannequin is repeated a number of instances, it’s proven simply as soon as with a depend of what number of instances it repeats
- Systematic exploration of state-of-the-art fashions to make sure all their tensor operations are adequately coated
- Assist for exporting static pictures of fashions as png or pdf information
- Effectivity and pace enhancements
References
- Open supply libraries used:
- Dot language from graphviz
- Different related visualization instruments:
- torchvista:
All pictures until in any other case said are by the creator.