Skip to content

Commit

Permalink
Variable space dimensions (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
zombie-einstein authored Nov 19, 2024
1 parent 3d55d6b commit 1984da2
Show file tree
Hide file tree
Showing 4 changed files with 303 additions and 58 deletions.
145 changes: 104 additions & 41 deletions src/esquilax/transforms/_space.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial
from math import floor
from typing import Any, Callable, Optional, Tuple
from math import floor, isclose, prod
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import chex
import jax
Expand All @@ -11,11 +11,11 @@


def _sort_agents(
n_bins: int, width: float, pos: chex.Array, agents: chex.Array
n_bins: Tuple[int, int], width: float, pos: chex.Array, agents: chex.Array
) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array, chex.Array, chex.ArrayTree]:
co_ords, idxs = utils.space.get_bins(pos, n_bins, width)
sort_idxs = jnp.argsort(idxs)
_, bins = utils.graph.index_bins(idxs, n_bins**2)
_, bins = utils.graph.index_bins(idxs, prod(n_bins))
sorted_co_ords = co_ords[sort_idxs]
sorted_idxs = idxs[sort_idxs]
sorted_pos = pos[sort_idxs]
Expand All @@ -24,7 +24,64 @@ def _sort_agents(
return sorted_co_ords, sort_idxs, bins, sorted_idxs, sorted_pos, sorted_agents


def _argument_checks(
def _process_parameters(
i_range: float,
dims: Union[float, Sequence[float]],
n_bins: Optional[int | Sequence[int]],
) -> Tuple[Tuple[int, int], int, chex.Array]:
if isinstance(dims, Sequence):
assert (
len(dims) == 2
), f"2 spatial dimensions should be provided got {len(dims)}"

if n_bins is not None:
assert isinstance(
n_bins, Sequence
), f"n_bins should be a sequence if dims is a sequence, got {type(n_bins)}"
assert (
len(n_bins) == 2
), f"Number of bins should be provided for 2 dimensions, got {len(n_bins)}"
assert (
n_bins[0] > 0 and n_bins[1] > 0
), f"n_bins should all be greater than 0, got {n_bins}"
w1, w2 = dims[0] / n_bins[0], dims[1] / n_bins[1]
assert w1 == w2, (
"Dimensions of cells should be equal in "
f"both dimensions got {w1} and {w2}"
)
n_bins = (n_bins[0], n_bins[1])
else:
assert (
i_range is not None
), "If n_bins is not provided, i_range should be provided"
n0 = dims[0] / i_range
n1 = dims[1] / i_range
assert isclose(round(n0), n0) and isclose(
round(n1), n1
), "Dimensions should be a multiple of i_range"
n_bins = (round(n0), round(n1))

width = dims[0] / n_bins[0]
dims = jnp.array(dims)

else:
if n_bins is not None:
assert isinstance(
n_bins, int
), "n_bins should be an integer value if dims is a float"
assert n_bins > 0, f"n_bins should be greater than 0, got {n_bins}"
n_bins = (n_bins, n_bins)
else:
n_bins = floor(dims / i_range)
n_bins = (n_bins, n_bins)

width = dims / n_bins[0]
dims = jnp.array([dims, dims])

return n_bins, width, dims


def _check_arguments(
pos: chex.Array,
pos_b: Optional[chex.Array],
agent_a: chex.ArrayTree,
Expand Down Expand Up @@ -59,8 +116,9 @@ def spatial(
default: Default,
include_self: bool = False,
topology: str = "moore",
n_bins: Optional[int] = None,
n_bins: Optional[int | Sequence[int]] = None,
i_range: Optional[float] = None,
dims: Union[float, Sequence[float]] = 1.0,
) -> Callable:
"""
Apply a function between agents based on spatial proximity
Expand All @@ -74,7 +132,9 @@ def spatial(
This implementation currently assumes a 2-dimensional
space with continues boundary conditions (i.e. wrapped
on a torus).
on a torus). The shape/dimensions of the space
can be controlled with the `dims` parameter, by default
it is a unit square region.
.. note::
Expand Down Expand Up @@ -127,7 +187,7 @@ def foo(_k, p, a, b):
so in this case agent ``0`` does not have any neighbours in
its cell, but agents ``1`` and ``2`` observe each other.
The transform can also be used as a decoratot using
The transform can also be used as a decorator using
:py:meth:`functools.partial`. Arguments and return values
can be PyTrees or multidimensional arrays. Arguments can
also be ``None`` if not used
Expand Down Expand Up @@ -239,21 +299,20 @@ def f(
same number of cells. Each cell can only interact
with adjacent cells, so this value also consequently
also controls the number of interactions. If not provided
the minimum number of bins if derived from ``i_range``.
the minimum number of bins if derived from ``i_range``. For a square
space ``n_bins`` can be a single intiger, or a pair of integers
for the number of bins along each cell. The number of cells for
each dimension should result in square cells.
dims
Dimensions of the space, either a float edge length for a
square space, or a pait (tuple or list) of dimension.
Default value is a square space of size 1.0.
"""
if n_bins is None:
assert (
i_range is not None
), "If n_bins is not provided, i_range should be provided"
n_bins = floor(1.0 / i_range)
else:
assert n_bins > 0, f"n_bins should be greater than 0, got {f}"

width = 1.0 / n_bins
n_bins, width, dims = _process_parameters(i_range, dims, n_bins)
i_range = width if i_range is None else i_range
i_range = i_range**2

assert n_bins > 0, f"n_bins should be greater than 0, got {f}"
chex.assert_trees_all_equal_structs(
reduction, default
), "Reduction and default PyTrees should have the same structure"
Expand All @@ -272,7 +331,7 @@ def _spatial(
pos_b: Optional[chex.Array] = None,
**static_kwargs,
) -> Any:
_argument_checks(pos, pos_b, agents_a, agents_b)
_check_arguments(pos, pos_b, agents_a, agents_b)

same_types = pos_b is None

Expand Down Expand Up @@ -312,9 +371,11 @@ def inner(
) -> Tuple[chex.PRNGKey, Any]:
_k, _r = carry
_pos_b = sorted_pos_b[j]
d = utils.space.shortest_distance(pos_a, _pos_b, 1.0, norm=False)
d = utils.space.shortest_distance(
pos_a, _pos_b, length=dims, norm=False
)
return jax.lax.cond(
d < i_range, interact, lambda _, _x, _z: (_x, _z), j, _k, _r
d <= i_range, interact, lambda _, _x, _z: (_x, _z), j, _k, _r
)

if (not same_types) or include_self:
Expand Down Expand Up @@ -367,8 +428,9 @@ def nearest_neighbour(
*,
default: Default,
topology: str = "moore",
n_bins: Optional[int] = None,
n_bins: Optional[int | Sequence[int]] = None,
i_range: Optional[float] = None,
dims: Union[float, Sequence[float]] = 1.0,
) -> Callable:
"""
Apply a function between an agent and its closest neighbour
Expand All @@ -382,7 +444,9 @@ def nearest_neighbour(
This implementation currently assumes a 2-dimensional
space with continues boundary conditions (i.e. wrapped
on a torus).
on a torus). The shape/dimensions of the space
can be controlled with the `dims` parameter, by default
it is a unit square region.
.. note::
Expand Down Expand Up @@ -539,16 +603,12 @@ def f(
with adjacent cells, so this value also consequently
also controls the number of interactions. If not provided
the minimum number of bins if derived from ``i_range``.
dims
Dimensions of the space, either a float edge length for a
square space, or a pait (tuple or list) of dimension.
Default value is a square space of size 1.0.
"""
if n_bins is None:
assert (
i_range is not None
), "If n_bins is not provided, i_range should be provided"
n_bins = floor(1.0 / i_range)
else:
assert n_bins > 0, f"n_bins should be greater than 0, got {f}"

width = 1.0 / n_bins
n_bins, width, dims = _process_parameters(i_range, dims, n_bins)
i_range = width if i_range is None else i_range
i_range = i_range**2

Expand All @@ -566,7 +626,7 @@ def _nearest_neighbour(
pos_b: Optional[chex.Array] = None,
**static_kwargs,
) -> Any:
_argument_checks(pos, pos_b, agents_a, agents_b)
_check_arguments(pos, pos_b, agents_a, agents_b)

same_types = pos_b is None

Expand All @@ -576,14 +636,15 @@ def _nearest_neighbour(
bins_a,
sorted_idxs_a,
sorted_pos_a,
_,
sorted_agents_a,
) = _sort_agents(n_bins, width, pos, agents_a)

if same_types:
bins_b = bins_a
sorted_pos_b = sorted_pos_a
sorted_agents_b = jax.tree_util.tree_map(lambda y: y[sort_idxs_a], agents_b)
else:
_, _, bins_b, _, sorted_pos_b, _ = _sort_agents(
_, _, bins_b, _, sorted_pos_b, sorted_agents_b = _sort_agents(
n_bins, width, pos_b, agents_b
)

Expand All @@ -593,23 +654,25 @@ def cell(i: int, bin_range: chex.Array) -> Tuple[int, float]:
def inner(j: int, carry: Tuple[int, float]) -> Tuple[int, float]:
_best_idx, _best_d = carry
_pos_b = sorted_pos_b[j]
_d = utils.space.shortest_distance(pos_a, _pos_b, 1.0, norm=False)
_d = utils.space.shortest_distance(
pos_a, _pos_b, length=dims, norm=False
)
return jax.lax.cond(
jnp.logical_and(_d < i_range, _d < _best_d),
jnp.logical_and(_d <= i_range, _d < _best_d),
lambda: (j, _d),
lambda: (_best_idx, _best_d),
)

if not same_types:
best_idx, best_d = jax.lax.fori_loop(
bin_range[0], bin_range[1], inner, (-1, 1.0)
bin_range[0], bin_range[1], inner, (-1, jnp.inf)
)
else:
best_idx, best_d = jax.lax.fori_loop(
bin_range[0],
jnp.minimum(i, bin_range[1]),
inner,
(-1, 1.0),
(-1, jnp.inf),
)
best_idx, best_d = jax.lax.fori_loop(
jnp.maximum(i + 1, bin_range[0]),
Expand All @@ -629,15 +692,14 @@ def agent_reduce(i: int, co_ords: chex.Array) -> int:
return min_idx

n_agents = pos.shape[0]
keys = jax.random.split(key, n_agents)
nearest_idxs = jax.vmap(agent_reduce, in_axes=(0, 0))(
jnp.arange(n_agents), co_ords_a
)
inv_sort = jnp.argsort(sort_idxs_a)
nearest_idxs = nearest_idxs[inv_sort]

def apply(k, a, idx_b):
b = jax.tree.map(lambda x: x.at[idx_b].get(), agents_b)
b = jax.tree.map(lambda x: x.at[idx_b].get(), sorted_agents_b)
return partial(f, **static_kwargs)(k, params, a, b)

def check(k, a, idx_b):
Expand All @@ -650,6 +712,7 @@ def check(k, a, idx_b):
idx_b,
)

keys = jax.random.split(key, n_agents)
results = jax.vmap(check)(keys, agents_a, nearest_idxs)

return results
Expand Down
17 changes: 9 additions & 8 deletions src/esquilax/utils/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def get_bins(
x: chex.Array, n_cells: int, width: float
x: chex.Array, n_cells: Tuple[int, int], width: float
) -> Tuple[chex.Array, chex.Array]:
"""
Assign co-ordinates to a grid-cell
Expand All @@ -25,21 +25,23 @@ def get_bins(
x
2d array of co-ordinates, in shape ``[n, 2]``.
n_cells
Number of cells along dimensions.
Number of cells along each dimensions.
width
Width of space.
Width of a cell.
Returns
-------
jax.numpy.ndarray
Array of cell indices for each position
"""
y = jnp.floor_divide(x, width).astype(jnp.int32)
i = y[:, 0] * n_cells + y[:, 1]
i = y[:, 0] * n_cells[1] + y[:, 1]
return y, i


def neighbour_indices(x: chex.Array, offsets: chex.Array, n_bins: int) -> chex.Array:
def neighbour_indices(
x: chex.Array, offsets: chex.Array, n_bins: Tuple[int, int]
) -> chex.Array:
"""
Apply offsets to co-ordinates to get neighbouring bin indices
Expand All @@ -58,8 +60,7 @@ def neighbour_indices(x: chex.Array, offsets: chex.Array, n_bins: int) -> chex.A
Bin indices of neighbouring cells.
"""
offset_x = x + offsets
offset_x = offset_x % n_bins
return offset_x[:, 0] * n_bins + offset_x[:, 1]
return (offset_x[:, 0] % n_bins[0]) * n_bins[1] + (offset_x[:, 1] % n_bins[1])


def get_neighbours_offsets(topology: str) -> chex.Array:
Expand Down Expand Up @@ -130,7 +131,7 @@ def shortest_vector(a: chex.Array, b: chex.Array, length: float = 1.0) -> chex.A
def shortest_distance(
a: Union[float, chex.Array],
b: Union[float, chex.Array],
length: float = 1.0,
length: Union[float, chex.Array] = 1.0,
norm: bool = True,
) -> Union[float, chex.Array]:
"""
Expand Down
Loading

0 comments on commit 1984da2

Please sign in to comment.