You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm currently working with a recurrent PPO version of the PPO training algorithm in brax. For recurrent PPO, it's best to evaluate each training step (within a training epoch) over a complete trajectory of the environment - i.e: starting from just after it was last terminated/reset up until it terminates again.
To test things out, I'm looking to reset the environment in the PPO training_step function. However, I'm a little confused as to how to do this efficiently and correctly. I've tried replacing these lines:
where reset_fn = jax.jit(jax.vmap(env.reset)). Doing this means I end up with an error regarding array sizes (the 256 comes from me using num_envs = 256 for training).
File "/home/nbar5346/code/Robust-RL/robustrl/experiment.py", line 219, in train
_, params, metrics = train_fn(environment=self.env,
File "/home/nbar5346/code/Robust-RL/robustrl/recurrentppo/train.py", line 575, in train
training_epoch_with_timing(training_state, env_state, epoch_keys)
File "/home/nbar5346/code/Robust-RL/robustrl/recurrentppo/train.py", line 487, in training_epoch_with_timing
result = training_epoch(training_state, env_state, key)
File "/home/nbar5346/code/Robust-RL/robustrl/recurrentppo/train.py", line 466, in training_epoch
(training_state, state, _), loss_metrics = jax.lax.scan(
File "/home/nbar5346/code/Robust-RL/robustrl/recurrentppo/train.py", line 428, in training_step
(state, _), data = jax.lax.scan(f, init_carry, (),
File "/home/nbar5346/code/Robust-RL/robustrl/recurrentppo/train.py", line 416, in f
next_state, data = acting.generate_unroll(
File "/home/nbar5346/code/Robust-RL/robustrl/recurrentppo/acting.py", line 86, in generate_unroll
(final_state, _), data = jax.lax.scan(
File "/home/nbar5346/code/Robust-RL/robustrl/recurrentppo/acting.py", line 82, in f
nstate, transition = actor_step(
File "/home/nbar5346/code/Robust-RL/robustrl/recurrentppo/acting.py", line 50, in actor_step
nstate = env.step(env_state, actions)
File "/home/nbar5346/code/Robust-RL/venv/lib/python3.10/site-packages/brax/envs/wrappers/training.py", line 122, in step
state = self.env.step(state, action)
File "/home/nbar5346/code/Robust-RL/venv/lib/python3.10/site-packages/brax/envs/wrappers/training.py", line 228, in step
res = jax.vmap(step, in_axes=[self._in_axes, 0, 0])(
ValueError: vmap got inconsistent sizes for array axes to be mapped:
* most axes (154 of them) had size 1, e.g. axis 0 of argument s.pipeline_state.solver_niter of type int32[1,256];
* some axes (7 of them) had size 256, e.g. axis 0 of argument sys.body_ipos of type float32[256,3,3]
I'm using my own versions of the train.py and acting.py files which are slightly modified versions from brax v0.10.5. I have made only minimal changes to those files so I doubt they are the cause of this issue. I can train a policy as normal without introducing the additional reset.
Any help would be greatly appreciated, and happy to provide more info as required!
The text was updated successfully, but these errors were encountered:
Hi there,
I'm currently working with a recurrent PPO version of the PPO training algorithm in brax. For recurrent PPO, it's best to evaluate each training step (within a training epoch) over a complete trajectory of the environment - i.e: starting from just after it was last terminated/reset up until it terminates again.
To test things out, I'm looking to reset the environment in the PPO
training_step
function. However, I'm a little confused as to how to do this efficiently and correctly. I've tried replacing these lines:with this:
where
reset_fn = jax.jit(jax.vmap(env.reset))
. Doing this means I end up with an error regarding array sizes (the256
comes from me usingnum_envs = 256
for training).I'm using my own versions of the
train.py
andacting.py
files which are slightly modified versions from braxv0.10.5
. I have made only minimal changes to those files so I doubt they are the cause of this issue. I can train a policy as normal without introducing the additional reset.Any help would be greatly appreciated, and happy to provide more info as required!
The text was updated successfully, but these errors were encountered: