Skip to content

Commit

Permalink
Make graph flatteners use true key types
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Aug 28, 2024
1 parent 4939f9f commit 736fd5c
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 90 deletions.
4 changes: 2 additions & 2 deletions tests/test_flax_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from tjax import register_graph_as_nnx_node
from tjax import GraphEdgeKey, register_graph_as_nnx_node

try:
import networkx as nx
Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(self) -> None:

def test_flatten(graph: nx.DiGraph[Any]) -> None:
_, state, _ = nnx.graph.flatten(graph)
substate = state['a⟶b']
substate = state[GraphEdgeKey('a', 'b')]
assert isinstance(substate, nnx.State)
variable = substate['x']
assert isinstance(variable, nnx.VariableState)
Expand Down
8 changes: 7 additions & 1 deletion tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,10 @@ def test_flatten_flavors(graph: nx.DiGraph[Any]) -> None:
assert hash(tree_def_a) == hash(tree_def_b)
assert values_a == list(values_b)
assert values_a == [2.0, 3.0, 4.0, 5.0, 7.0]
assert key_paths[3][0] == 'a⟶b'
simplified = {"".join(str(x) for x in key_path)
for key_path in key_paths}
assert simplified == {".node['a']['y']",
".node['b']['z']",
".node['c']['w']",
".edge['a', 'b']['x']",
".edge['c', 'b']['x']"}
7 changes: 5 additions & 2 deletions tjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from ._src.display.print_generic import print_generic
from ._src.dtype_tools import cast_to_result_type, result_type
from ._src.dtypes import default_atol, default_rtol, default_tols
from ._src.graph import (graph_arrow, graph_edge_name, register_graph_as_jax_pytree,
register_graph_as_nnx_node)
from ._src.graph import (GraphEdgeKey, GraphNodeKey, UndirectedGraphEdgeKey, graph_arrow,
graph_edge_name, register_graph_as_jax_pytree, register_graph_as_nnx_node)
from ._src.leaky_integral import (diffused_leaky_integrate, leaky_covariance, leaky_data_weight,
leaky_integrate, leaky_integrate_time_series)
from ._src.math_tools import (abs_square, create_diagonal_array, divide_nonnegative, divide_where,
Expand All @@ -37,6 +37,8 @@
'Complex',
'ComplexArray',
'ComplexNumeric',
'GraphEdgeKey',
'GraphNodeKey',
'Integral',
'IntegralArray',
'IntegralNumeric',
Expand Down Expand Up @@ -66,6 +68,7 @@
'ShapeLike',
'SliceLike',
'TapFunctionTransforms',
'UndirectedGraphEdgeKey',
'abs_square',
'abstract_custom_jvp',
'abstract_jit',
Expand Down
204 changes: 119 additions & 85 deletions tjax/_src/graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from collections.abc import Generator, Iterable, MutableSet
from typing import Any, TypeVar
from collections.abc import Callable, Hashable, Iterable, MutableSet, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, cast, override

from jax.tree_util import register_pytree_with_keys
from rich.tree import Tree
Expand All @@ -20,6 +21,52 @@ def graph_edge_name(arrow: str, source: str, target: str) -> str:
return f"{source}{arrow}{target}"


@dataclass(frozen=True)
class GraphNodeKey:
"""Struct for use with :func:`jax.tree_util.register_pytree_with_keys`."""
key: Hashable

@override
def __str__(self) -> str:
return f'.node[{self.key!r}]'

def __lt__(self, value: GraphNodeKey | GraphEdgeKey, /) -> bool:
return (not isinstance(value, GraphNodeKey)
or self.key < value.key) # type: ignore # pyright: ignore


@dataclass(frozen=True)
class GraphEdgeKey:
"""Struct for use with :func:`jax.tree_util.register_pytree_with_keys`."""
source: Hashable
target: Hashable

@override
def __str__(self) -> str:
return f'.edge[{self.source!r}, {self.target!r}]'

def __lt__(self, value: GraphNodeKey | GraphEdgeKey, /) -> bool:
return isinstance(value, GraphEdgeKey) and (
(self.source, self.target) < (value.source, value.target))


@dataclass(frozen=True)
class UndirectedGraphEdgeKey(GraphEdgeKey):
@override
def __hash__(self) -> int:
return hash(self.source) ^ hash(self.target)

@override
def __lt__(self, value: GraphNodeKey | GraphEdgeKey, /) -> bool:
return isinstance(value, GraphEdgeKey) and (
sorted((self.source, self.target)) # type: ignore # pyright: ignore
< sorted((value.source, value.target))) # type: ignore # pyright: ignore


GraphAuxData: TypeAlias = list[GraphNodeKey | GraphEdgeKey] # The auxilliary data for Jax.
GraphData: TypeAlias = dict[str, Any] # What networkx stores in the graph.


try:
import networkx as nx
except ImportError:
Expand All @@ -29,71 +76,56 @@ def register_graph_as_jax_pytree(graph_type: type[Any]) -> None:
def register_graph_as_nnx_node(graph_type: type[Any]) -> None:
raise RuntimeError(msg)
else:
if TYPE_CHECKING:
Graph: TypeAlias = nx.Graph[Any]
else:
Graph: TypeAlias = nx.Graph
T = TypeVar('T', bound="nx.Graph[Any]")

def validate_node_name(name: Any,
arrow: str
) -> None:
if not isinstance(name, str):
raise TypeError
if arrow in name:
raise ValueError

def flatten_helper(graph: nx.Graph[Any],
arrow: str
) -> Generator[tuple[str, Any], None, None]:
for name, data in sorted(graph.nodes.data()):
validate_node_name(name, arrow)
yield name, data

def undirected_edge_key(source_target_data: tuple[str, str, Any]) -> tuple[str, ...]:
source, target, _ = source_target_data
return tuple(sorted([source, target]))
edge_data = graph.edges.data()
def nodes_helper(graph: nx.Graph[Any]) -> tuple[list[GraphNodeKey], list[PyTree]]:
node_data = graph.nodes.data()
node_keys = [GraphNodeKey(name) for name, _ in node_data]
node_values = [value for _, value in node_data]
return node_keys, node_values

def edges_helper(graph: nx.Graph[Any]) -> tuple[list[GraphEdgeKey], list[PyTree]]:
directed = isinstance(graph, nx.DiGraph)
if not directed:
edge_data = sorted(edge_data, key=undirected_edge_key)
for source, target, data in edge_data:
validate_node_name(source, arrow)
validate_node_name(target, arrow)
new_source, new_target = ((source, target)
if directed
else sorted([source, target]))
yield graph_edge_name(arrow, new_source, new_target), data

def init_graph(graph: nx.Graph[Any],
items: Iterable[tuple[str, Any]],
arrow: str
) -> None:
for key, value in items:
if not isinstance(value, dict):
raise TypeError
if arrow in key:
source, target = key.split(arrow, 2)
graph.add_edge(source, target, **value)
else:
graph.add_node(key, **value)
edge_data = graph.edges.data()
edge_key: Callable[[Hashable, Hashable], GraphEdgeKey] = (
GraphEdgeKey if directed else UndirectedGraphEdgeKey)
edge_keys = [edge_key(source, target) for source, target, _ in edge_data]
edge_data = [data for _, _, data in edge_data]
return edge_keys, edge_data

def register_graph_as_jax_pytree(graph_type: type[nx.Graph[Any]]) -> None: # pyright: ignore
arrow = graph_arrow(issubclass(graph_type, nx.DiGraph))
def flatten_tree(graph: nx.Graph[Any], /) -> tuple[list[GraphData], GraphAuxData]:
node_keys, node_values = nodes_helper(graph)
edge_keys, edge_values = edges_helper(graph)
return node_values + edge_values, node_keys + edge_keys

def unflatten_tree(names: tuple[str, ...], values: Iterable[Any], /) -> nx.Graph[Any]:
def register_graph_as_jax_pytree(graph_type: type[nx.Graph[Any]]) -> None: # pyright: ignore
def unflatten_tree(keys: GraphAuxData, values: Iterable[GraphData], /) -> nx.Graph[Any]:
graph = graph_type()
init_graph(graph, zip(names, values, strict=True), arrow)
for key, value in zip(keys, values, strict=True):
match key:
case GraphNodeKey(name):
graph.add_node(name, **value)
case GraphEdgeKey(source, target):
graph.add_edge(source, target, **value)
return graph

def flatten_with_keys(graph: nx.Graph[Any], /
) -> tuple[Iterable[tuple[str, Any]], tuple[str, ...]]:
names_and_values = tuple(flatten_helper(graph, arrow))
names = tuple(name for name, _ in names_and_values)
return (names_and_values, names)

def flatten_tree(graph: nx.Graph[Any], /) -> tuple[Iterable[PyTree], tuple[str, ...]]:
names_and_values = tuple(flatten_helper(graph, arrow))
names, values = zip(*names_and_values, strict=True)
return values, names

register_pytree_with_keys(graph_type, flatten_with_keys, unflatten_tree, flatten_tree)
) -> tuple[Iterable[tuple[GraphNodeKey | GraphEdgeKey, GraphData]],
GraphAuxData]:
values, keys = flatten_tree(graph)
return (zip(keys, values, strict=True), keys)

flatten_with_keys_ = cast(
Callable[[Graph], tuple[Iterable[tuple[Hashable, Any]], Hashable]],
flatten_with_keys)
unflatten_tree_ = cast(Callable[[Hashable, Any], Graph], unflatten_tree)
flatten_tree_ = cast(Callable[[Graph], tuple[Iterable[Any], Hashable]],
flatten_tree)
register_pytree_with_keys(graph_type, flatten_with_keys_, unflatten_tree_, flatten_tree_)

@display_generic.register(nx.Graph)
def _(value: nx.Graph[Any],
Expand Down Expand Up @@ -124,44 +156,46 @@ def register_graph_as_nnx_node(graph_type: type[Any]) -> None:
raise RuntimeError(msg)
else:
def register_graph_as_nnx_node(graph_type: type[T]) -> None: # pyright: ignore
arrow = graph_arrow(issubclass(graph_type, nx.DiGraph))

# flatten: Callable[[Node], tuple[Sequence[tuple[str, Leaf]], AuxData]],
def flatten_graph(graph: T, /) -> tuple[tuple[tuple[str, Any], ...], None]:
t = tuple(flatten_helper(graph, arrow))
return t, None
# flatten: Callable[[Node], tuple[Sequence[tuple[Key, Leaf]], AuxData]],
def flatten_graph(graph: T, /
) -> tuple[Sequence[tuple[Key, Any]], None]:
values, keys = flatten_tree(graph)
return list(zip(keys, values, strict=True)), None

# set_key: Callable[[Node, Key, Leaf], None],
def set_key_graph(graph: T, key: Key, value: Any, /) -> None:
if not isinstance(value, dict):
raise TypeError
assert isinstance(key, str)
d: dict[str, Any]
if arrow in key:
source, target = key.split(arrow, 2)
if not graph.has_edge(source, target):
graph.add_edge(source, target, **value)
return
d = graph.edges[source, target]
elif not graph.has_node(key):
graph.add_node(key, **value)
return
else:
d = graph.nodes[key]
match key:
case GraphNodeKey(name):
if not graph.has_node(name):
graph.add_node(name, **value)
return
d = graph.nodes[name]
case GraphEdgeKey(source, target):
if not graph.has_edge(source, target):
graph.add_edge(source, target, **value)
return
d = graph.edges[source, target]
case _:
raise TypeError
d.clear()
d.update(value)

# pop_key: Callable[[Node, Key], Leaf],
def pop_key_graph(graph: T, key: Key, /) -> Any:
assert isinstance(key, str)
if arrow in key:
source, target = key.split(arrow, 2)
retval = graph.edges[source, target]
graph.remove_edge(source, target)
return retval
retval = graph.nodes[key]
graph.remove_node(key)
return retval
match key:
case GraphNodeKey(name):
retval = graph.nodes[name]
graph.remove_node(name)
return retval
case GraphEdgeKey(source, target):
retval = graph.edges[source, target]
graph.remove_edge(source, target)
return retval
case _:
raise TypeError

# create_empty: Callable[[AuxData], Node],
def create_empty_graph(metadata: None, /) -> T:
Expand Down

0 comments on commit 736fd5c

Please sign in to comment.