Skip to content

Commit

Permalink
Pass down TrainingArgs instance to iterate_batches function and Train…
Browse files Browse the repository at this point in the history
…ingCallback methods

Addresses ml-explore#1224
  • Loading branch information
chimezie committed Jan 28, 2025
1 parent 7a83077 commit a928bba
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions llms/mlx_lm/tuner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
from typing import Optional, Union

import mlx.core as mx
import mlx.nn as nn
Expand Down Expand Up @@ -76,7 +76,9 @@ def default_loss(model, inputs, targets, lengths):
return ce, ntoks


def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
def iterate_batches(
dataset, tokenizer, batch_size, max_seq_length, train=False, args=None
):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
if len(dataset) < batch_size:
Expand Down Expand Up @@ -167,11 +169,13 @@ def evaluate(

class TrainingCallback:

def on_train_loss_report(self, train_info: dict):
def on_train_loss_report(
self, train_info: dict, args: Optional[TrainingArgs] = None
):
"""Called to report training loss at specified intervals."""
pass

def on_val_loss_report(self, val_info: dict):
def on_val_loss_report(self, val_info: dict, args: Optional[TrainingArgs] = None):
"""Called to report validation loss at specified intervals or the beginning."""
pass

Expand Down Expand Up @@ -227,6 +231,7 @@ def step(batch):
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
args=args,
),
):
# Report validation loss if needed, the first validation loss
Expand Down Expand Up @@ -258,7 +263,7 @@ def step(batch):
"val_loss": val_loss,
"val_time": val_time,
}
training_callback.on_val_loss_report(val_info)
training_callback.on_val_loss_report(val_info, args=args)

start = time.perf_counter()

Expand Down Expand Up @@ -301,7 +306,7 @@ def step(batch):
"trained_tokens": trained_tokens,
"peak_memory": peak_mem,
}
training_callback.on_train_loss_report(train_info)
training_callback.on_train_loss_report(train_info, args=args)

losses = 0
n_tokens = 0
Expand Down

0 comments on commit a928bba

Please sign in to comment.