From 43f69dcc2c201715eef043c83c8e1ccd686f3320 Mon Sep 17 00:00:00 2001 From: VincentZhang Date: Thu, 20 Feb 2025 20:27:37 -0700 Subject: [PATCH] Fix checkpoint flag assignment error --- learning/train_jax_ppo.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/learning/train_jax_ppo.py b/learning/train_jax_ppo.py index 0e04191..8ac0260 100644 --- a/learning/train_jax_ppo.py +++ b/learning/train_jax_ppo.py @@ -244,18 +244,16 @@ def main(argv): # Handle checkpoint loading if _LOAD_CHECKPOINT_PATH.value is not None: # Convert to absolute path - _LOAD_CHECKPOINT_PATH.value = epath.Path( - _LOAD_CHECKPOINT_PATH.value - ).resolve() - if _LOAD_CHECKPOINT_PATH.value.is_dir(): - latest_ckpts = list(_LOAD_CHECKPOINT_PATH.value.glob("*")) + ckpt_path = epath.Path(_LOAD_CHECKPOINT_PATH.value).resolve() + if ckpt_path.is_dir(): + latest_ckpts = list(ckpt_path.glob("*")) latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()] latest_ckpts.sort(key=lambda x: int(x.name)) latest_ckpt = latest_ckpts[-1] restore_checkpoint_path = latest_ckpt print(f"Restoring from: {restore_checkpoint_path}") else: - restore_checkpoint_path = _LOAD_CHECKPOINT_PATH.value + restore_checkpoint_path = ckpt_path print(f"Restoring from checkpoint: {restore_checkpoint_path}") else: print("No checkpoint path provided, not restoring from checkpoint")