Jaxpr Visualisation Tool
Project description
Jaxpr-Viz
JAX Computation Graph Visualisation Tool
Usage
Jaxpr-viz can be used to visualise jit compiled (and nested) functions, for example with the functions
import jax
import jax.numpy as jnp
@jax.jit
def foo(x):
return 2 * x
@jax.jit
def bar(x):
x = foo(x)
return x - 1
jaxpr-viz wraps functions, which can then be called with concrete arguments
import jpviz
jpviz.draw(bar)(jnp.arange(10))
produces
NOTE: For sub-functions to show as nodes/sub-graphs they need to be marked with
@jax.jit
Visualisation Options
Collapse Nodes
By default, functions that are composed of only primitive functions
are collapsed into a single node (like foo
in the above example).
The full computation graph can be rendered using the collapse_primitives
flag
import jpviz
jpviz.draw(bar, collapse_primitives=True)(jnp.arange(10))
produces
Show Types
By default, type information is included in the node labels, this
can be hidden using the show_avals
flag
import jpviz
jpviz.draw(bar, show_avals=False)(jnp.arange(10))
produces
Jupyter Notebook
To show the rendered graph in a jupyter notebook you can use the
helper function view_pydot
dot = jpviz.draw(bar, collapse_primitives=True)(jnp.arange(10))
jpviz.view_pydot(dot)
Developers
Dependencies can be installed with poetry by running
poetry install
Pre-Commit Hooks
Pre commit hooks can be installed by running
pre-commit install
Pre-commit checks can then be run using
task lint
Tests
Tests can be run with
task test
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.