-
Notifications
You must be signed in to change notification settings - Fork 272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
acting.Evaluator.run_evaluation
compilation hangs
#569
Comments
acting.Evaluator.run_evaluation
compilation hangs
I just tried narrowing down the error to a specific version and here's what I observed:
It appears that it's Perhaps this is something that should be reported in Jax. Unfortunately, however, I'm also unable to reproduce this issue outside of Brax + MJX. Any leads would be much appreciated. |
I haven't had the chance to create a minimum repro yet but I've also had issues training mjx/brax with jax[cuda12_local] >0.4.35. I haven't seen the hanging evaluator compilation but for some environments it can crash training to a bad local minimum and never recover. |
Thanks @hartikainen for the detailed report. I also ran into this issue, and set |
Posted here jax-ml/jax#26162 |
Hey folks,
I tried updating our Brax version from
0.11.0
to0.12.1
but noticed issues while doing that. Our trainer, which uses PPO and has been working fine until now, now freezes due to theacting.Evaluator.run_evaluator
taking for ever to compile (literally, at least 12 hours). I've spent a couple of full working days chasing down this bug and there seems to be some pretty weird dynamics in play here, and thus I've not been able to narrow the issue down to a single underlying issue.Here's the model I've been using:
Example script
And here's what I've observed so far. First, when running this file, the compilation takes a long time. I have not run this particular script for more than 10 minutes, but the original training pipeline in our code base ran for >12 hours without the evaluation loop compilation finishing (it takes a couple of minutes with
brax==0.11.0
).Second, I believe, although am not sure, that the underlying issue might have something to do with the mocap bodies in the environment. If I comment out the mocap code above, the code runs fine:
The reason I'm not sure about this connection is that the use of
mocap_pos
is not a culprit in itself: I tried to reproduce this issue by swapping the above environment with theHumanoid
-environment from the MJX tutorial notebook, with similar mocap body added to it, and that compiled fine.jax-ml/jax#6823 and jax-ml/jax#9651 seemed like they could be related and I tried running the code with
JAX_ENABLE_MLIR=0 OMP_NUM_THREADS=1
(as suggested in the issues) but that did not fix the problem.I also tried the recommended
XLA_FLAGS=--xla_dump_to=/tmp/foo
. I don't know exactly how to interpret the results for those, however. What I notice, though, is that the working run outputs files up tomodule_0211.jit__...
, whereas the hanging run only outputs files untilmodule_0109.jit_generate_eval_unroll...
, which suggests that the compilation gets stuck atgenerate_eval_unroll
. I've manually verified that this is indeed the case.Does anyone have pointers to how to fix or further investigate this issue?
My setup (I've installed
jax
withjax[cuda12-local]
extras):The text was updated successfully, but these errors were encountered: