-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add bf16/fp16 support for amp with mps device #3373
Draft
SunMarc
wants to merge
6
commits into
main
Choose a base branch
from
amp-mps
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
4 tasks
@SunMarc Hi Marc, I have tried to install this version of accelerate and pytorch2.6.0 to use trainer on mps device, but got the following error message, could you please help me check it out? ---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 30
20 trainer = Trainer(
21 model=model,
22 args=training_args,
(...)
27 compute_metrics=compute_metrics
28 )
29 logger.info("Start training")
---> 30 trainer.train()
File ~/miniforge3/envs/IOAI/lib/python3.12/site-packages/transformers/trainer.py:1885, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1883 hf_hub_utils.enable_progress_bars()
1884 else:
-> 1885 return inner_training_loop(
1886 args=args,
1887 resume_from_checkpoint=resume_from_checkpoint,
1888 trial=trial,
1889 ignore_keys_for_eval=ignore_keys_for_eval,
1890 )
File ~/miniforge3/envs/IOAI/lib/python3.12/site-packages/transformers/trainer.py:2216, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2213 self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
2215 with self.accelerator.accumulate(model):
-> 2216 tr_loss_step = self.training_step(model, inputs)
2218 if (
2219 args.logging_nan_inf_filter
2220 and not is_torch_xla_available()
2221 and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
2222 ):
2223 # if loss is nan or inf simply add the average of previous logged losses
2224 tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
File ~/miniforge3/envs/IOAI/lib/python3.12/site-packages/transformers/trainer.py:3250, in Trainer.training_step(***failed resolving arguments***)
3248 scaled_loss.backward()
3249 else:
-> 3250 self.accelerator.backward(loss)
3252 return loss.detach() / self.args.gradient_accumulation_steps
File ~/GitHub/accelerate/src/accelerate/accelerator.py:2250, in Accelerator.backward(self, loss, **kwargs)
2248 self.lomo_backward(loss, learning_rate)
2249 else:
-> 2250 loss.backward(**kwargs)
File ~/miniforge3/envs/IOAI/lib/python3.12/site-packages/torch/_tensor.py:626, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
616 if has_torch_function_unary(self):
617 return handle_torch_function(
618 Tensor.backward,
619 (self,),
(...)
624 inputs=inputs,
625 )
--> 626 torch.autograd.backward(
627 self, gradient, retain_graph, create_graph, inputs=inputs
628 )
File ~/miniforge3/envs/IOAI/lib/python3.12/site-packages/torch/autograd/__init__.py:347, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
342 retain_graph = create_graph
344 # The reason we repeat the same comment below is that
345 # some Python versions print out the first line of a multi-line function
346 # calls in the traceback and some print out the last line
--> 347 _engine_run_backward(
348 tensors,
349 grad_tensors_,
350 retain_graph,
351 create_graph,
352 inputs,
353 allow_unreachable=True,
354 accumulate_grad=True,
355 )
File ~/miniforge3/envs/IOAI/lib/python3.12/site-packages/torch/autograd/graph.py:823, in _engine_run_backward(t_outputs, *args, **kwargs)
821 unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
822 try:
--> 823 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
824 t_outputs, *args, **kwargs
825 ) # Calls into the C++ engine to run the backward pass
826 finally:
827 if attach_logging_hooks:
RuntimeError: Expected scalar_type == ScalarType::Float || inputTensor.scalar_type() == ScalarType::Int || scalar_type == ScalarType::Bool to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
This PR adds MPS mixed-precision autocast support.
Draft until we get support for GradScaler with autocast. Right now, support for bf16 ops with mps are still a bit limited but pytorch team is working on improving the coverage.
Feel free to test the PR to try bf16 for now