Skip to content

Commit 34e3dc7

Browse files
committedOct 7, 2024
Fix: allow for additional kwargs to be passed to train(), remove unnecessary call of destructor of runtime, which is now handled by context manager and improve output formatting when checkpointing.
1 parent a828d26 commit 34e3dc7

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed
 

‎relexi/rl/ppo/train.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def train( config_file
8484
,mpi_launch_mpmd = False
8585
,strategy = None
8686
,debug = 0
87+
,**kwargs
8788
):
8889
"""
8990
Main training routine. Here, the (FLEXI) environment, the art. neural networks, the optimizer,...
@@ -343,9 +344,9 @@ def train( config_file
343344

344345
# Checkpoint the policy every ckpt_interval iterations
345346
if (i % ckpt_interval) == 0:
346-
rlxout.info('Saving checkpoint to: ' + ckpt_dir, newline=False)
347+
rlxout.info('Saving checkpoint to: ' + ckpt_dir)
347348
train_checkpointer.save(global_step)
348-
rlxout.info('Saving current model to: ' + save_dir)
349+
rlxout.info('Saving current model to: ' + save_dir, newline=False)
349350
actor_net.model.save(os.path.join(save_dir,f'model_{global_step.numpy():06d}'))
350351

351352
# Flush summary to TensorBoard
@@ -358,6 +359,3 @@ def train( config_file
358359
# Close all
359360
del my_env
360361
del my_eval_env
361-
362-
del runtime
363-
time.sleep(2.) # Wait for orchestrator to be properly closed

0 commit comments

Comments
 (0)
Failed to load comments.