Skip to content

Commit

Permalink
Upgrade tqdm (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
zombie-einstein authored Oct 7, 2024
1 parent b7c5d0e commit 7c14673
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 33 deletions.
19 changes: 10 additions & 9 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ classifiers = [
python = "^3.10"
jax = "^0.4.30"
chex = "^0.1.86"
jax-tqdm = "^0.2.1"
jax-tqdm = "^0.3.0"
evosax = "^0.1.6"
flax = "^0.8.5"

Expand Down
24 changes: 18 additions & 6 deletions src/esquilax/batch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jax

from .sim import Sim, TSimParams
from .utils.functions import get_size


def batch_sim_runner(
Expand Down Expand Up @@ -69,19 +70,30 @@ def batch_sim_runner(
if params is None and param_samples is None:
params = sim.default_params()

def inner(k, _params):
def inner(k, i, _params):
_, records = sim.init_and_run(
n_steps, k, show_progress=show_progress, params=_params, **step_kwargs
n_steps,
k,
show_progress=show_progress,
pbar_id=i,
params=_params,
**step_kwargs,
)
return records

def sample_params(k, _params):
def sample_params(k, _params, pbar_offset):
pbar_offset = pbar_offset * n_samples
keys = jax.random.split(k, n_samples)
return jax.vmap(inner, in_axes=(0, None))(keys, _params)
return jax.vmap(inner, in_axes=(0, 0, None))(
keys, pbar_offset + jax.numpy.arange(n_samples), _params
)

if params is not None:
batch_records = sample_params(key, params)
batch_records = sample_params(key, params, 0)
else:
batch_records = jax.vmap(sample_params, in_axes=(None, 0))(key, param_samples)
n_params = get_size(param_samples)
batch_records = jax.vmap(sample_params, in_axes=(None, 0, 0))(
key, param_samples, jax.numpy.arange(n_params)
)

return batch_records
27 changes: 17 additions & 10 deletions src/esquilax/ml/rl/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def generate_samples(
agent_states: TypedPyTree[Union[AgentState, BatchAgentState]],
greedy: bool = False,
show_progress: bool = False,
pbar_id: int = 0,
) -> Tuple[Trajectory, TEnvState]:
"""
Run the environment forward generating trajectory and state records
Expand All @@ -124,6 +125,9 @@ def generate_samples(
actions from a greedy policy.
show_progress
If ``True`` a progress bar will show execution progress.
pbar_id
Optional progress bar index. Can be used to print
to multiple progress bars.
Returns
-------
Expand All @@ -134,13 +138,16 @@ def generate_samples(
"""
step_fun = step(agents, env_params, env, greedy=greedy)

k_reset, k_run = jax.random.split(key, 2)
obs, env_state = env.reset(k_reset, env_params)
init = (k_run, env_state, obs, agent_states)

if show_progress:
step_fun = jax_tqdm.scan_tqdm(n_env_steps, desc="Step")(step_fun)
init = jax_tqdm.PBar(id=pbar_id, carry=init)

k_reset, k_run = jax.random.split(key, 2)
obs, env_state = env.reset(k_reset, env_params)
_, (trajectories, env_states) = jax.lax.scan(
step_fun, (k_run, env_state, obs, agent_states), jnp.arange(n_env_steps)
step_fun, init, jnp.arange(n_env_steps)
)
return trajectories, env_states

Expand Down Expand Up @@ -201,9 +208,9 @@ def batch_generate_samples(
show_progress=show_progress,
)
keys = jax.random.split(key, n_env)
trajectories, env_states = jax.vmap(sampling_func, in_axes=(0, None))(
keys, agent_states
)
trajectories, env_states = jax.vmap(
lambda k, a, i: sampling_func(k, a, pbar_id=i), in_axes=(0, None, 0)
)(keys, agent_states, jnp.arange(n_env))
return trajectories, env_states


Expand Down Expand Up @@ -363,8 +370,8 @@ def test(
show_progress=show_progress,
)

def sample_trajectories(_k, _agent_states):
_trajectories, _states = sampling_func(_k, _agent_states)
def sample_trajectories(_k, i, _agent_states):
_trajectories, _states = sampling_func(_k, _agent_states, pbar_id=i)
if return_trajectories:
return _states, _trajectories
else:
Expand All @@ -377,8 +384,8 @@ def sample_trajectories(_k, _agent_states):

k_sample = jax.random.split(key, n_env)

states, recorded = jax.vmap(sample_trajectories, in_axes=(0, None))(
k_sample, agent_states
states, recorded = jax.vmap(sample_trajectories, in_axes=(0, 0, None))(
k_sample, jnp.arange(n_env), agent_states
)

return states, recorded
Expand Down
16 changes: 13 additions & 3 deletions src/esquilax/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def sim_runner(
n_steps: int,
rng: Union[chex.PRNGKey, int],
show_progress: bool = True,
pbar_id: int = 0,
**static_kwargs,
) -> Tuple[Any, Any, chex.PRNGKey]:
"""
Expand Down Expand Up @@ -59,6 +60,9 @@ def step(i, k, params, state, **static_kwargs):
Either an integer random seed, or a JAX PRNGKey.
show_progress
If ``True`` a progress bar will be shown.
pbar_id
Optional progress bar index, can be used to print
multiple progress bars.
**static_kwargs
Any keyword static values passed to the step function.
These should be used for any values or functionality required
Expand All @@ -83,11 +87,17 @@ def step(carry, i):
new_state, records = step_fun(i, step_key, params, state)
return (k, new_state), records

init = (key, initial_state)

if show_progress:
step = jax_tqdm.scan_tqdm(n_steps, desc="Step")(step)
init = jax_tqdm.PBar(id=pbar_id, carry=init)

final_value, record_history = jax.lax.scan(step, init, jax.numpy.arange(n_steps))

(key, final_state), record_history = jax.lax.scan(
step, (key, initial_state), jax.numpy.arange(n_steps)
)
if show_progress:
(key, final_state) = final_value.carry
else:
(key, final_state) = final_value

return final_state, record_history, key
10 changes: 10 additions & 0 deletions src/esquilax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def run(
params: TSimParams,
initial_state: TSimState,
show_progress: bool = True,
pbar_id: int = 0,
**step_kwargs: Any
) -> Tuple[TSimState, chex.ArrayTree, chex.PRNGKey]:
"""
Expand All @@ -126,6 +127,9 @@ def run(
show_progress
If ``True`` a progress bar will be displayed.
Default ``True``
pbar_id
Optional progress bar index, can be used to print
multiple progress bars.
**step_kwargs
Any additional keyword arguments passed to the
step function. Arguments are static over the
Expand All @@ -147,6 +151,7 @@ def run(
n_steps,
key,
show_progress=show_progress,
pbar_id=pbar_id,
)

return final_state, records, k
Expand All @@ -158,6 +163,7 @@ def init_and_run(
key: chex.PRNGKey,
show_progress: bool = True,
params: Optional[TSimParams] = None,
pbar_id: int = 0,
**step_kwargs
) -> Tuple[TSimState, chex.ArrayTree]:
"""
Expand All @@ -175,6 +181,9 @@ def init_and_run(
params
Optional simulation parameters, if not provided
default sim parameters will be used.
pbar_id
Optional progress bar index, can be used to print
multiple progress bars.
**step_kwargs
Any additional keyword arguments passed to the
step function. Arguments are static over the
Expand All @@ -199,6 +208,7 @@ def init_and_run(
params,
initial_state,
show_progress=show_progress,
pbar_id=pbar_id,
**step_kwargs
)

Expand Down
15 changes: 11 additions & 4 deletions tests/test_batch_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def step(self, i, k, params, state):
return TestSim()


def test_single_params(sim):
@pytest.mark.parametrize("show_progress", [True, False])
def test_single_params(sim, show_progress):
n_samples = 2
n_steps = 10

Expand All @@ -29,23 +30,29 @@ def test_single_params(sim):
n_samples,
n_steps,
101,
show_progress=False,
show_progress=show_progress,
)

assert isinstance(results, tuple)
assert results[0].shape == (n_samples, n_steps)
assert results[1].shape == (n_samples, n_steps)


def test_param_set(sim):
@pytest.mark.parametrize("show_progress", [True, False])
def test_param_set(sim, show_progress):
n_params = 3
n_samples = 2
n_steps = 10

param_set = jnp.arange(n_params)

results = esquilax.batch_sim_runner(
sim, n_samples, n_steps, 101, show_progress=False, param_samples=param_set
sim,
n_samples,
n_steps,
101,
show_progress=show_progress,
param_samples=param_set,
)

assert isinstance(results, tuple)
Expand Down

0 comments on commit 7c14673

Please sign in to comment.