Skip to content

Commit

Permalink
Restructure transforms (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
zombie-einstein authored Oct 13, 2024
1 parent 0e725b4 commit 4afd821
Show file tree
Hide file tree
Showing 13 changed files with 637 additions and 516 deletions.
18 changes: 12 additions & 6 deletions docs/source/examples/evo_boids.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ First we import JAX, Chex, Flax, and Esquilax
.. testsetup:: evo_boids

from typing import Callable, Tuple
from functools import partial

.. testcode:: evo_boids

Expand Down Expand Up @@ -85,10 +86,11 @@ add aggregates position and velocity data from neighbours in-range

.. testcode:: evo_boids

@esquilax.transforms.spatial(
10,
(jnp.add, jnp.add, jnp.add, jnp.add),
(0, jnp.zeros(2), 0.0, 0.0),
@partial(
esquilax.transforms.spatial,
n_bins=10,
reduction=(jnp.add, jnp.add, jnp.add, jnp.add),
default=(0, jnp.zeros(2), 0.0, 0.0),
include_self=False,
)
def observe(_k: chex.PRNGKey, _params: Params, _a: Boid, b: Boid):
Expand Down Expand Up @@ -180,8 +182,12 @@ to calculate reward contributions

.. testcode:: evo_boids

@esquilax.transforms.spatial(
5, jnp.add, 0.0, include_self=False,
@partial(
esquilax.transforms.spatial,
n_bins=5,
reduction=jnp.add,
default=0.0,
include_self=False,
)
def reward(_k: chex.PRNGKey, params: Params, a: chex.Array, b: chex.Array):
d = esquilax.utils.shortest_distance(a, b, norm=True)
Expand Down
10 changes: 6 additions & 4 deletions docs/source/examples/hard_coded_boids.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ We first import JAX, `Chex <https://chex.readthedocs.io/en/latest/>`_, and Esqui

.. testcode:: hard_coded_boids

from functools import partial
import chex
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -54,10 +55,11 @@ Firstly agents observe the state of neighbours within a given range

.. testcode:: hard_coded_boids

@esquilax.transforms.spatial(
5,
(jnp.add, jnp.add, jnp.add, jnp.add),
(0, jnp.zeros(2), jnp.zeros(2), jnp.zeros(2)),
@partial(
esquilax.transforms.spatial,
n_bins=5,
reduction=(jnp.add, jnp.add, jnp.add, jnp.add),
default=(0, jnp.zeros(2), jnp.zeros(2), jnp.zeros(2)),
include_self=False,
)
def observe(_key: chex.PRNGKey, params: Params, a: Boid, b: Boid):
Expand Down
18 changes: 12 additions & 6 deletions docs/source/examples/rl_boids.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ example, but wrap them up in an environment class
.. testsetup:: rl_boids

from typing import Callable, Tuple, Any
from functools import partial
import chex
import evosax
import flax
Expand Down Expand Up @@ -48,10 +49,11 @@ example, but wrap them up in an environment class
x = flax.linen.tanh(x)
return x

@esquilax.transforms.spatial(
10,
(jnp.add, jnp.add, jnp.add, jnp.add),
(0, jnp.zeros(2), 0.0, 0.0),
@partial(
esquilax.transforms.spatial,
n_bins=10,
reduction=(jnp.add, jnp.add, jnp.add, jnp.add),
default=(0, jnp.zeros(2), 0.0, 0.0),
include_self=False,
)
def observe(_k: chex.PRNGKey, _params: Params, _a: Boid, b: Boid):
Expand Down Expand Up @@ -110,8 +112,12 @@ example, but wrap them up in an environment class
)
return (pos + d_pos) % 1.0

@esquilax.transforms.spatial(
5, jnp.add, 0.0, include_self=False,
@partial(
esquilax.transforms.spatial,
n_bins=5,
reduction=jnp.add,
default=0.0,
include_self=False,
)
def reward(_k: chex.PRNGKey, params: Params, a: chex.Array, b: chex.Array):
d = esquilax.utils.shortest_distance(a, b, norm=True)
Expand Down
11 changes: 7 additions & 4 deletions examples/boids/hard_coded.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import chex
import jax
import jax.numpy as jnp
Expand All @@ -21,10 +23,11 @@ class Params:
close_range: float


@esquilax.transforms.spatial(
5,
(jnp.add, jnp.add, jnp.add, jnp.add),
(0, jnp.zeros(2), jnp.zeros(2), jnp.zeros(2)),
@partial(
esquilax.transforms.spatial,
n_bins=5,
reduction=(jnp.add, jnp.add, jnp.add, jnp.add),
default=(0, jnp.zeros(2), jnp.zeros(2), jnp.zeros(2)),
include_self=False,
)
def observe(_key: chex.PRNGKey, params: Params, a: Boids, b: Boids):
Expand Down
19 changes: 11 additions & 8 deletions examples/boids/updates.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Tuple

import chex
Expand Down Expand Up @@ -25,10 +26,11 @@ class Params:
collision_penalty: float = 0.1


@esquilax.transforms.spatial(
10,
(jnp.add, jnp.add, jnp.add, jnp.add),
(0, jnp.zeros(2), 0.0, 0.0),
@partial(
esquilax.transforms.spatial,
n_bins=10,
reduction=(jnp.add, jnp.add, jnp.add, jnp.add),
default=(0, jnp.zeros(2), 0.0, 0.0),
include_self=False,
)
def observe(_k: chex.PRNGKey, _params: Params, a: Boid, b: Boid):
Expand Down Expand Up @@ -87,10 +89,11 @@ def move(_key: chex.PRNGKey, _params: Params, x):
return (pos + d_pos) % 1.0


@esquilax.transforms.spatial(
5,
jnp.add,
0.0,
@partial(
esquilax.transforms.spatial,
n_bins=5,
reduction=jnp.add,
default=0.0,
include_self=False,
)
def rewards(_k: chex.PRNGKey, params: Params, a: chex.Array, b: chex.Array):
Expand Down
6 changes: 5 additions & 1 deletion examples/opinion_dynamics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import chex
import jax
import jax.numpy as jnp
Expand All @@ -18,7 +20,9 @@ class SimState:
weights: chex.Array


@esquilax.transforms.graph_reduce((jnp.add, jnp.add), (0, 0.0))
@partial(
esquilax.transforms.graph_reduce, reduction=(jnp.add, jnp.add), default=(0, 0.0)
)
def collect_opinions(_, params: Params, my_opinion, your_opinion, weight):
d = jnp.abs(my_opinion - your_opinion)
w = params.strength * weight
Expand Down
Loading

0 comments on commit 4afd821

Please sign in to comment.