a question about how to increase batch size. #130
Replies: 7 comments
-
Hi. model = ANYMODEl
role_overs = 0
loss = 0
for data in data_loader:
pred = model(data['x'])
loss += loss_func(pred,data['y']) / gradient_accumulation_steps
role_overs += len(data['x'])
if role_overs // (gradient_accumulation_steps * batch_size) == 0:
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
loss = 0. but in EasyDeL things are way too different if you use Way to fix.use Here's how training loop performers in EasyDeL def casual_language_model_train_step(state, batch):
"""
The casual_language_model_train_step function is a training step function that takes in the current state
of the model and a batch of data. It then calculates the loss and accuracy for this batch
and returns an updated state with new parameters based on these gradients.
:param state: Store the model parameters
:param batch: Pass the data to the model, dict with
input_ids(bs, seq_len), labels(bs, seq_len-1), attention_mask(bs, seq_len)
:return: A tuple of (state, loss, accuracy)
"""
batch = with_sharding_constraint(batch, partition_spec)
def calculate_loss(params):
labels = batch.pop("labels") # already shifted left
logits = state.apply_fn(params=params, **batch, return_dict=True).logits
loss_normalizing_factor = (
SpecialLossNormalizingFactor.NUM_REAL_TARGET_TOKENS
)
# loss_weights is 1 unless the label is <= 0 or the attention mask is 0
loss_weights = jnp.where(
(batch["attention_mask"][:, 1:] != 0) & (labels > 0), 1, 0
)
lnf, weights = get_loss_normalizing_factor_and_weights(
loss_normalizing_factor,
{
"decoder_target_tokens": labels,
"decoder_loss_weights": loss_weights,
},
)
(
loss,
z_loss_computed,
weight_sum,
accuracy,
) = compute_weighted_cross_entropy_and_accuracy(
logits=logits[:, :-1, :],
targets=labels,
weights=weights,
label_smoothing=label_smoothing_factor,
z_loss=z_loss,
loss_normalizing_factor=lnf,
)
return loss, accuracy
grad_fn = jax.value_and_grad(calculate_loss, has_aux=True)
(loss__, accuracy__), grad = grad_fn(state.params)
state = state.apply_gradients(grads=grad)
return state, loss__, accuracy__
return casual_language_model_train_step |
Beta Was this translation helpful? Give feedback.
-
I understand now. Thank you for your detailed explaination. |
Beta Was this translation helpful? Give feedback.
-
Interesting, this same issue had me puzzled as well. trained_tokens = (
current_step *
self.arguments.total_batch_size *
self.arguments.gradient_accumulation_steps *
self.arguments.max_sequence_length
) |
Beta Was this translation helpful? Give feedback.
-
Related discussion: google-deepmind/optax#320 (comment) |
Beta Was this translation helpful? Give feedback.
-
I'm working on getting that fixed, but it seems like there are a lot of issues in the Trainer and EasyState backbone. |
Beta Was this translation helpful? Give feedback.
-
I have changed that trained_tokens = (
current_step *
self.arguments.total_batch_size *
self.arguments.max_sequence_length
) |
Beta Was this translation helpful? Give feedback.
-
@IvoryTower800 you can now try training your model with gradient accumulation. let me know if there's any other issue or question. |
Beta Was this translation helpful? Give feedback.
-
Describe the bug
Hi, I tried to finetune gemma-2b model with sharding_array=(1, 1, 1, -1) on Kaggle tpu vm v3-8.
there are two parameters about batch size in TrainArguments: total_batch_size, gradient_accumulation_steps.
If I set 1 for both of the two. it worked well. and the total tpu memory used was 14.4GB (reported by track_memory=True).
However, when I set the total_batch_size=1 and gradient_accumulation_steps=4. it says tpu memory exhausted.
I'm doubt about the batch size when training on TPU with easydel. if I finetune the model using transformers on GPU. I can set as
many as gradient_accumulation_steps I want, it won't increase the gpu vram usage. but on tpu, it can't.
Do I misunderstand the gradient_accumulation_steps in easydel? Could you please tell me how I can increase my actual batch size when finetuning?
for example, if I want my final batch size equal to 32.
To Reproduce
train_arguments = TrainArguments( model_class=type(model), model_name="gemma_2b_it", num_train_epochs=1, configs_to_initialize_model_class=configs_to_initialize_model_class, custom_rule=model.config.get_partition_rules(True), learning_rate=2e-5, learning_rate_end=2e-7, max_sequence_length=max_length, optimizer=EasyDelOptimizers.ADAMW, # "adamw", "lion", "adafactor" are supported scheduler=EasyDelSchedulers.LINEAR, # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear" are supported weight_decay=0.01, total_batch_size=1, max_training_steps=None, # None to let trainer Decide do_train=True, do_eval=False, # it's optional but supported backend="tpu", # default backed is set to cpu, so you must define you want to use tpu cpu or gpu max_length=max_length, # Note that you have to change this in the model config too gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE, sharding_array=(1, 1, 1, -1), # the way to shard model across gpu,cpu or TPUs using sharding array (1, 1, 1, -1) # everything training will be in sequence and model parallel automatic and share data between devices use_pjit_attention_force=False, remove_ckpt_after_load=True, init_input_shape=(1, max_length), gradient_accumulation_steps=4, loss_re_mat="", dtype=jnp.bfloat16, track_memory=True, use_wandb=True, # This disable WANB usage )
Beta Was this translation helpful? Give feedback.
All reactions