Skip to content

Commit

Permalink
Typing fixes & Compilation Checks (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
zombie-einstein authored Oct 13, 2024
1 parent 4afd821 commit 2cdffd1
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 34 deletions.
70 changes: 50 additions & 20 deletions src/esquilax/transforms/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
import jax.numpy as jnp

from esquilax import utils
from esquilax.typing import Default, Reduction


def _check_edges(edges: chex.ArrayTree, edge_idxs: chex.Array):
assert edge_idxs.ndim == 2 and edge_idxs.shape[0] == 2, (
"edge_idxs should be a 2d array of node " f"index pairs, got {edge_idxs.shape}"
)
chex.assert_tree_shape_prefix(edges, edge_idxs.shape[1:])


def edge_map(f: Callable) -> Callable:
Expand Down Expand Up @@ -110,13 +118,18 @@ def f(k, params, start, end, edge, **static_kwargs):
def _edge_map(
k: chex.PRNGKey,
params: Any,
starts: Any,
ends: Any,
edges: Any,
starts: chex.ArrayTree,
ends: chex.ArrayTree,
edges: chex.ArrayTree,
*,
edge_idxs: chex.Array,
**static_kwargs,
) -> Any:
) -> chex.ArrayTree:
chex.assert_tree_has_only_ndarrays(starts)
chex.assert_tree_has_only_ndarrays(ends)
chex.assert_tree_has_only_ndarrays(edge_idxs)
_check_edges(edges, edge_idxs)

n = edge_idxs.shape[1]
keys = jax.random.split(k, n)
starts = jax.tree_util.tree_map(
Expand All @@ -136,7 +149,7 @@ def _edge_map(


def graph_reduce(
f: Callable, *, reduction: chex.ArrayTree, default: chex.ArrayTree, n: int = -1
f: Callable, *, reduction: Reduction, default: Default, n: int = -1
) -> Callable:
"""
Map function over graph edges and reduce results to nodes
Expand Down Expand Up @@ -254,13 +267,16 @@ def f(k, params, start, end, edge, **static_kwargs):
def _graph_reduce(
k: chex.PRNGKey,
params: Any,
starts: Any,
ends: Any,
edges: Any,
starts: chex.ArrayTree,
ends: chex.ArrayTree,
edges: chex.ArrayTree,
*,
edge_idxs: chex.Array,
**static_kwargs,
) -> Any:
) -> chex.ArrayTree:
if starts is None:
assert n > 0, "If starts is not provided, n should be provided"

n_results = utils.functions.get_size(starts) if n < 0 else n

edge_results = _edge_map(
Expand All @@ -284,7 +300,7 @@ def reduce(r, x, d):
def random_edge(
f: Callable,
*,
default: Any,
default: Default,
n: int = -1,
) -> Callable:
"""
Expand Down Expand Up @@ -368,13 +384,20 @@ def f(k, params, start, end, edge, **static_kwargs):
def _random_edge(
k: chex.PRNGKey,
params: Any,
starts: Any,
ends: Any,
edges: Any,
starts: chex.ArrayTree,
ends: chex.ArrayTree,
edges: chex.ArrayTree,
*,
edge_idxs: chex.Array,
**static_kwargs,
):
) -> chex.ArrayTree:
if starts is None:
assert n > 0, "If starts is not provided, n should be provided"
chex.assert_tree_has_only_ndarrays(starts)
chex.assert_tree_has_only_ndarrays(ends)
chex.assert_tree_has_only_ndarrays(edges)
_check_edges(edges, edge_idxs)

n_results = utils.functions.get_size(starts) if n < 0 else n
bin_counts, bins = utils.graph.index_bins(edge_idxs[0], n_results)
keys = jax.random.split(k, n_results)
Expand All @@ -388,9 +411,9 @@ def sample(_k, i, a, b):
edge = jax.tree_util.tree_map(lambda x: x[j], edges)
return partial(f, **static_kwargs)(k2, params, start, end, edge)

def select(_k, i, count: int, bin: chex.Array) -> Any:
def select(_k, i, count: int, b: chex.Array) -> Any:
return jax.lax.cond(
count > 0, sample, lambda *_: default, _k, i, bin[0], bin[1]
count > 0, sample, lambda *_: default, _k, i, b[0], b[1]
)

return jax.vmap(select, in_axes=(0, 0, 0, 0))(
Expand All @@ -403,7 +426,7 @@ def select(_k, i, count: int, bin: chex.Array) -> Any:
return _random_edge


def highest_weight(f: Callable, *, default: chex.ArrayTree, n: int = -1) -> Callable:
def highest_weight(f: Callable, *, default: Default, n: int = -1) -> Callable:
"""
Map function over graph edges with the highest weights
Expand Down Expand Up @@ -530,14 +553,21 @@ def f(k, params, start, end, edge, **static_kwargs):
def _highest_weight(
key: chex.PRNGKey,
params: Any,
starts: Any,
ends: Any,
edges: Any,
starts: chex.ArrayTree,
ends: chex.ArrayTree,
edges: chex.ArrayTree,
*,
edge_idxs: chex.Array,
weights: chex.Array,
**static_kwargs,
) -> Any:
if starts is None:
assert n > 0, "If starts is not provided, n should be provided"
chex.assert_tree_has_only_ndarrays(starts)
chex.assert_tree_has_only_ndarrays(ends)
chex.assert_tree_has_only_ndarrays(edges)
_check_edges(edges, edge_idxs)

n_results = utils.functions.get_size(starts) if n < 0 else n
start_nodes = edge_idxs[0]
bin_counts, bins = utils.graph.index_bins(start_nodes, n_results)
Expand Down
5 changes: 4 additions & 1 deletion src/esquilax/transforms/_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def f(k, params, x, **static_kwargs):
keyword_args = utils.functions.get_keyword_args(f)

@partial(jax.jit, static_argnames=keyword_args)
def _self(k: chex.PRNGKey, params: Any, x: Any, **static_kwargs) -> Any:
def _self(
k: chex.PRNGKey, params: Any, x: chex.ArrayTree, **static_kwargs
) -> chex.ArrayTree:
chex.assert_tree_has_only_ndarrays(x)
n = utils.functions.get_size(x)
keys = jax.random.split(k, n)
results = jax.vmap(partial(f, **static_kwargs), in_axes=(0, None, 0))(
Expand Down
49 changes: 42 additions & 7 deletions src/esquilax/transforms/_space.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import typing
from functools import partial
from typing import Any, Callable, Optional, Tuple

Expand All @@ -7,6 +6,7 @@
import jax.numpy as jnp

from esquilax import utils
from esquilax.typing import Default, Reduction


def _sort_agents(
Expand All @@ -23,12 +23,40 @@ def _sort_agents(
return sorted_co_ords, sort_idxs, bins, sorted_idxs, sorted_pos, sorted_agents


def _argument_checks(
pos: chex.Array,
pos_b: Optional[chex.Array],
agent_a: chex.ArrayTree,
agent_b: chex.ArrayTree,
):
assert (
pos.ndim == 2 and pos.shape[1] == 2
), f"pos argument should be an array of 2d coordinates, got shape {pos.shape}"

if pos_b is not None:
assert pos.ndim == 2 and pos.shape[1] == 2, (
"pos_b argument should be an array of "
f"2d coordinates, got shape {pos.shape}"
)
n_b = pos_b.shape[:1]
else:
n_b = pos.shape[:1]

if agent_a is not None:
chex.assert_tree_has_only_ndarrays(agent_a)
chex.assert_tree_shape_prefix(agent_a, pos.shape[:1])

if agent_b is not None:
chex.assert_tree_has_only_ndarrays(agent_a)
chex.assert_tree_shape_prefix(agent_b, n_b)


def spatial(
f: typing.Callable,
f: Callable,
*,
n_bins: int,
reduction: chex.ArrayTree,
default: chex.ArrayTree,
reduction: Reduction,
default: Default,
include_self: bool = False,
topology: str = "moore",
i_range: Optional[float] = None,
Expand Down Expand Up @@ -215,6 +243,7 @@ def f(
i_range = width if i_range is None else i_range
i_range = i_range**2

assert n_bins > 0, f"n_bins should be greater than 0, got {f}"
chex.assert_trees_all_equal_structs(
reduction, default
), "Reduction and default PyTrees should have the same structure"
Expand All @@ -233,6 +262,8 @@ def _spatial(
pos_b: Optional[chex.Array] = None,
**static_kwargs,
) -> Any:
_argument_checks(pos, pos_b, agents_a, agents_b)

same_types = pos_b is None

(
Expand Down Expand Up @@ -270,8 +301,8 @@ def inner(
j: int, carry: Tuple[chex.PRNGKey, Any]
) -> Tuple[chex.PRNGKey, Any]:
_k, _r = carry
pos_b = sorted_pos_b[j]
d = utils.space.shortest_distance(pos_a, pos_b, 1.0, norm=False)
_pos_b = sorted_pos_b[j]
d = utils.space.shortest_distance(pos_a, _pos_b, 1.0, norm=False)
return jax.lax.cond(
d < i_range, interact, lambda _, _x, _z: (_x, _z), j, _k, _r
)
Expand Down Expand Up @@ -325,7 +356,7 @@ def nearest_neighbour(
f: Callable,
*,
n_bins: int,
default: chex.ArrayTree,
default: Default,
topology: str = "moore",
i_range: Optional[float] = None,
) -> Callable:
Expand Down Expand Up @@ -498,6 +529,8 @@ def f(
- ``**static_kwargs``: Any arguments required at compile
time by JAX can be passed as keyword arguments.
"""
assert n_bins > 0, f"n_bins should be greater than 0, got {f}"

width = 1.0 / n_bins
i_range = width if i_range is None else i_range
i_range = i_range**2
Expand All @@ -516,6 +549,8 @@ def _nearest_neighbour(
pos_b: Optional[chex.Array] = None,
**static_kwargs,
) -> Any:
_argument_checks(pos, pos_b, agents_a, agents_b)

same_types = pos_b is None

(
Expand Down
9 changes: 8 additions & 1 deletion src/esquilax/typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""
Generic types
"""
from typing import Collection, TypeVar
from typing import Callable, Collection, TypeVar, Union

import chex
from jax._src.numpy.ufuncs import ufunc

TSimState = TypeVar("TSimState")
"""Generic simulation state"""
Expand All @@ -14,3 +17,7 @@
T = TypeVar("T")
TypedPyTree = T | Collection[T]
"""PyTree with leaves of a single type"""
Reduction = TypedPyTree[Union[Callable | ufunc]]
"""Reduction function(s) type"""
Default = bool | int | float | chex.ArrayTree
"""Default reduction value types"""
5 changes: 4 additions & 1 deletion src/esquilax/utils/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def get_neighbours_offsets(topology: str) -> chex.Array:
offsets = jnp.array(offsets)
else:
raise ValueError(
"Topology should be one of 'same-cell', 'von-neumann' or 'moore'"
(
"Topology should be one of 'same-cell', "
f"'von-neumann' or 'moore' got {topology}"
)
)

return offsets
Expand Down
3 changes: 0 additions & 3 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
# SPDX-FileCopyrightText: 2023-present zombie-einstein <zombie-einstein@proton.me>
#
# SPDX-License-Identifier: MIT
2 changes: 1 addition & 1 deletion tests/test_transforms/test_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def foo(_, params, a, b):
pos=x,
)

expected = {"a": jnp.array(expected["a"]), "b": expected["b"]}
expected = {"a": jnp.array(expected["a"]), "b": jnp.array(expected["b"])}

assert jnp.array_equal(results["a"], expected["a"])
assert jnp.array_equal(results["b"], expected["b"])
Expand Down

0 comments on commit 2cdffd1

Please sign in to comment.