diff --git a/learning/train_jax_ppo.py b/learning/train_jax_ppo.py index 0e04191..8ac0260 100644 --- a/learning/train_jax_ppo.py +++ b/learning/train_jax_ppo.py @@ -244,18 +244,16 @@ def main(argv): # Handle checkpoint loading if _LOAD_CHECKPOINT_PATH.value is not None: # Convert to absolute path - _LOAD_CHECKPOINT_PATH.value = epath.Path( - _LOAD_CHECKPOINT_PATH.value - ).resolve() - if _LOAD_CHECKPOINT_PATH.value.is_dir(): - latest_ckpts = list(_LOAD_CHECKPOINT_PATH.value.glob("*")) + ckpt_path = epath.Path(_LOAD_CHECKPOINT_PATH.value).resolve() + if ckpt_path.is_dir(): + latest_ckpts = list(ckpt_path.glob("*")) latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()] latest_ckpts.sort(key=lambda x: int(x.name)) latest_ckpt = latest_ckpts[-1] restore_checkpoint_path = latest_ckpt print(f"Restoring from: {restore_checkpoint_path}") else: - restore_checkpoint_path = _LOAD_CHECKPOINT_PATH.value + restore_checkpoint_path = ckpt_path print(f"Restoring from checkpoint: {restore_checkpoint_path}") else: print("No checkpoint path provided, not restoring from checkpoint")