Skip to content

Commit

Permalink
Merge pull request #63 from vincentzhang:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 730385057
Change-Id: Ic50ede5f9e5242758348f1e070b23f08fa5c2bd8
  • Loading branch information
copybara-github committed Feb 24, 2025
2 parents 1fdd805 + 43f69dc commit 609168a
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions learning/train_jax_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,18 +244,16 @@ def main(argv):
# Handle checkpoint loading
if _LOAD_CHECKPOINT_PATH.value is not None:
# Convert to absolute path
_LOAD_CHECKPOINT_PATH.value = epath.Path(
_LOAD_CHECKPOINT_PATH.value
).resolve()
if _LOAD_CHECKPOINT_PATH.value.is_dir():
latest_ckpts = list(_LOAD_CHECKPOINT_PATH.value.glob("*"))
ckpt_path = epath.Path(_LOAD_CHECKPOINT_PATH.value).resolve()
if ckpt_path.is_dir():
latest_ckpts = list(ckpt_path.glob("*"))
latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]
latest_ckpts.sort(key=lambda x: int(x.name))
latest_ckpt = latest_ckpts[-1]
restore_checkpoint_path = latest_ckpt
print(f"Restoring from: {restore_checkpoint_path}")
else:
restore_checkpoint_path = _LOAD_CHECKPOINT_PATH.value
restore_checkpoint_path = ckpt_path
print(f"Restoring from checkpoint: {restore_checkpoint_path}")
else:
print("No checkpoint path provided, not restoring from checkpoint")
Expand Down

0 comments on commit 609168a

Please sign in to comment.