You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In the Brax PPO training implementation, the TrainingState.env_steps field (see train.py#L60 and #L613) is jnp.ndarray that defaults to jnp.int32 on GPU. If the training run exceeds 2³¹ – 1 (2,147,483,647) steps, the counter may overflow, potentially leading to unexpected behavior in long-running training sessions.
The overflow problem has been reported in the Mujoco Playground repository (#48). To mitigate this, one option might be to initialize env_steps with a 64‑bit integer (e.g. jnp.array(0, dtype=jnp.int64)).
The text was updated successfully, but these errors were encountered:
Unfortunately this is a pretty fundamental JAX limitation - you can't mix 32 and 64 bit precision. So if you want 64 bit step count, training is going to slow down a lot because everything else will become 64 bit too.
What we could do, if some intrepid soul would like to try it, is make the step count a custom big int. You will need to store two int32s, and when you want the step count you'd need to do:
step_count = num1 << 32 + num2
Then you'd have to add some logic to increment num1 and num2 appropriately.
In the Brax PPO training implementation, the TrainingState.env_steps field (see train.py#L60 and #L613) is jnp.ndarray that defaults to jnp.int32 on GPU. If the training run exceeds 2³¹ – 1 (2,147,483,647) steps, the counter may overflow, potentially leading to unexpected behavior in long-running training sessions.
The overflow problem has been reported in the Mujoco Playground repository (#48). To mitigate this, one option might be to initialize env_steps with a 64‑bit integer (e.g. jnp.array(0, dtype=jnp.int64)).
The text was updated successfully, but these errors were encountered: