Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove offline training, refactor train.py and logging/checkpointing #670

Open
wants to merge 27 commits into
base: main
Choose a base branch
from

Conversation

aliberts
Copy link
Collaborator

@aliberts aliberts commented Jan 31, 2025

What this does

  • ⚠️ Removes the offline training part from the train.py script: online training will be handled by the training scripts from [WIP] Fix SAC and port HIL SERL #644
  • In consequence, .offline and .online are removed from TrainPipelineConfig. To set the number of offline training step, simply use --steps:
python lerobot/scripts/train.py \
- --offline.steps=200000
+ --steps=200000
  • Adds wandb_utils.py and turns Logger into WandBLogger to remove responsibilities from this class so that it only manages wandb stuff.
  • Replaces training_state serialization with torch.save/load to safetensors.save_file/load_file. We shouldn't use torch.load() for this and in fact it breaks in which breaks in 2.6 due to weights_only=True by default.
/checkpoints/005000
  ├── pretrained_model
- └── training_state.pth
+ └── training_state
+     ├── optimizer_param_groups.json
+     ├── optimizer_state.safetensors
+     ├── rng_state.safetensors
+     ├── scheduler_state.json
+     └── training_step.json
  • Adds train_utils.py to handle training checkpoints logic (including training state).
  • Cleans up functions related to rng and groups them together in random_utils.py.
  • Save checkpoint before eval during training rather than after (safer in case eval crashes)
  • Fixes logging where displayed values would only be the last one measured instead of the average over the steps from previous logging step.
  • Changed the policies main forward() output format for clarity. It now returns a tuple[Tensor, dict | None] instead of just a dict, the first element being the loss:
- output_dict = policy.forward(batch)
- loss = output_dict["loss"]
+ loss, output_dict = policy.forward(batch)
loss.backward()

How it was tested

Adds the following tests:

  • tests/test_schedulers.py
  • tests/test_optimizers.py
  • tests/test_train_utils.py
  • tests/test_random_utils.py
  • tests/test_io_utils.py

How to checkout & try? (for the reviewer)

Examples:

pytest -v \
    tests/test_schedulers.py \
    tests/test_optimizers.py \
    tests/test_train_utils.py \
    tests/test_random_utils.py \
    tests/test_io_utils.py

@aliberts aliberts changed the title Update safetensors `training_state Update training_state serialization to safetensors Jan 31, 2025
@aliberts aliberts changed the title Update training_state serialization to safetensors Refactor Logger Feb 4, 2025
@aliberts aliberts changed the title Refactor Logger Refactor train.py and logging/checkpointing Feb 8, 2025
@aliberts aliberts changed the title Refactor train.py and logging/checkpointing Remove offline training, refactor train.py and logging/checkpointing Feb 8, 2025
@aliberts aliberts added the 🔄 Refactor Refactoring label Feb 8, 2025
@aliberts aliberts requested a review from Cadene February 8, 2025 21:48
@aliberts aliberts marked this pull request as ready for review February 8, 2025 21:48
Copy link
Collaborator

@Cadene Cadene left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful

Could you remove all appearance of ema?
There were added by default

@@ -153,7 +153,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
return loss, None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return loss, None
# no output_dict so returning None
return loss, None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🔄 Refactor Refactoring
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants