-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path__init__.py
75 lines (56 loc) · 1.6 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import typing
import jax
import pydot
from .dot import draw_dot_graph
def draw(f, collapse_primitives=True, show_avals=True) -> typing.Callable:
"""
Visualise a JAX computation graph
Wraps a JAX jit compiled function, which when called
visualises the computation graph using
pydot.
Examples
--------
.. highlight:: python
.. code-block:: python
import jax
import jpviz
@jax.jit
def foo(x):
return 2 * x
@jax.jit
def bar(x):
x = foo(x)
return x - 1
g = jpviz.draw(bar)(jax.numpy.arange(10))
Parameters
----------
f:
JAX jit compiled function
collapse_primitives: bool
If `True` sub-functions that contain only JAX primitives
will be collapsed into a single node in the generated
graph
show_avals: bool
If `True` then type information will be
included on node labels
Returns
-------
Wrapped function that when called with concrete
values generated the corresponding visualisation
of the computation graph
"""
def _inner_draw(*args, **kwargs) -> pydot.Graph:
jaxpr = jax.make_jaxpr(f)(*args, **kwargs)
return draw_dot_graph(jaxpr, collapse_primitives, show_avals)
return _inner_draw
def view_pydot(dot_graph: pydot.Dot) -> None:
"""
Show a pydot graph in a jupyter notebook
Parameters
----------
dot_graph: Graph
Pydot graph as generated by `draw`
"""
from IPython.display import Image, display
plt = Image(dot_graph.create_png())
display(plt)