Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resetting environment state in PPO training_step function #580

Open
nic-barbara opened this issue Feb 14, 2025 · 0 comments
Open

Resetting environment state in PPO training_step function #580

nic-barbara opened this issue Feb 14, 2025 · 0 comments

Comments

@nic-barbara
Copy link
Contributor

nic-barbara commented Feb 14, 2025

Hi there,

I'm currently working with a recurrent PPO version of the PPO training algorithm in brax. For recurrent PPO, it's best to evaluate each training step (within a training epoch) over a complete trajectory of the environment - i.e: starting from just after it was last terminated/reset up until it terminates again.

To test things out, I'm looking to reset the environment in the PPO training_step function. However, I'm a little confused as to how to do this efficiently and correctly. I've tried replacing these lines:

training_state, state, key = carry
key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3)

with this:

training_state, _, key = carry
key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3)
state = reset_fn(key_envs)

where reset_fn = jax.jit(jax.vmap(env.reset)). Doing this means I end up with an error regarding array sizes (the 256 comes from me using num_envs = 256 for training).

  File "/home/nbar5346/code/Robust-RL/robustrl/experiment.py", line 219, in train
    _, params, metrics = train_fn(environment=self.env,
  File "/home/nbar5346/code/Robust-RL/robustrl/recurrentppo/train.py", line 575, in train
    training_epoch_with_timing(training_state, env_state, epoch_keys)
  File "/home/nbar5346/code/Robust-RL/robustrl/recurrentppo/train.py", line 487, in training_epoch_with_timing
    result = training_epoch(training_state, env_state, key)
  File "/home/nbar5346/code/Robust-RL/robustrl/recurrentppo/train.py", line 466, in training_epoch
    (training_state, state, _), loss_metrics = jax.lax.scan(
  File "/home/nbar5346/code/Robust-RL/robustrl/recurrentppo/train.py", line 428, in training_step
    (state, _), data = jax.lax.scan(f, init_carry, (),
  File "/home/nbar5346/code/Robust-RL/robustrl/recurrentppo/train.py", line 416, in f
    next_state, data = acting.generate_unroll(
  File "/home/nbar5346/code/Robust-RL/robustrl/recurrentppo/acting.py", line 86, in generate_unroll
    (final_state, _), data = jax.lax.scan(
  File "/home/nbar5346/code/Robust-RL/robustrl/recurrentppo/acting.py", line 82, in f
    nstate, transition = actor_step(
  File "/home/nbar5346/code/Robust-RL/robustrl/recurrentppo/acting.py", line 50, in actor_step
    nstate = env.step(env_state, actions)
  File "/home/nbar5346/code/Robust-RL/venv/lib/python3.10/site-packages/brax/envs/wrappers/training.py", line 122, in step
    state = self.env.step(state, action)
  File "/home/nbar5346/code/Robust-RL/venv/lib/python3.10/site-packages/brax/envs/wrappers/training.py", line 228, in step
    res = jax.vmap(step, in_axes=[self._in_axes, 0, 0])(
ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (154 of them) had size 1, e.g. axis 0 of argument s.pipeline_state.solver_niter of type int32[1,256];
  * some axes (7 of them) had size 256, e.g. axis 0 of argument sys.body_ipos of type float32[256,3,3]

I'm using my own versions of the train.py and acting.py files which are slightly modified versions from brax v0.10.5. I have made only minimal changes to those files so I doubt they are the cause of this issue. I can train a policy as normal without introducing the additional reset.

Any help would be greatly appreciated, and happy to provide more info as required!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant