Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential Overflow in TrainingState.env_steps on GPU due to jnp.int32 Default #578

Open
vincentzhang opened this issue Feb 13, 2025 · 2 comments

Comments

@vincentzhang
Copy link

vincentzhang commented Feb 13, 2025

In the Brax PPO training implementation, the TrainingState.env_steps field (see train.py#L60 and #L613) is jnp.ndarray that defaults to jnp.int32 on GPU. If the training run exceeds 2³¹ – 1 (2,147,483,647) steps, the counter may overflow, potentially leading to unexpected behavior in long-running training sessions.

The overflow problem has been reported in the Mujoco Playground repository (#48). To mitigate this, one option might be to initialize env_steps with a 64‑bit integer (e.g. jnp.array(0, dtype=jnp.int64)).

@erikfrey
Copy link
Collaborator

Unfortunately this is a pretty fundamental JAX limitation - you can't mix 32 and 64 bit precision. So if you want 64 bit step count, training is going to slow down a lot because everything else will become 64 bit too.

What we could do, if some intrepid soul would like to try it, is make the step count a custom big int. You will need to store two int32s, and when you want the step count you'd need to do:

step_count = num1 << 32 + num2

Then you'd have to add some logic to increment num1 and num2 appropriately.

@vincentzhang
Copy link
Author

Thanks a lot for the suggestion. I'll look into it sometime next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants