diff --git a/examples/language-modeling/gemma_tuning.ipynb b/examples/language-modeling/gemma_tuning.ipynb index 537c248f..9fbcaf79 100644 --- a/examples/language-modeling/gemma_tuning.ipynb +++ b/examples/language-modeling/gemma_tuning.ipynb @@ -250,7 +250,8 @@ "outputs": [], "source": [ "from transformers import AutoModelForCausalLM\n", - "model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=False)" + "import torch\n", + "model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=False, torch_dtype=torch.bfloat16)" ] }, { diff --git a/examples/language-modeling/llama_tuning.md b/examples/language-modeling/llama_tuning.md index 13bd484e..9d38130a 100644 --- a/examples/language-modeling/llama_tuning.md +++ b/examples/language-modeling/llama_tuning.md @@ -53,7 +53,7 @@ model_id = "meta-llama/Meta-Llama-3-8B" tokenizer = AutoTokenizer.from_pretrained(model_id) # Add custom token for padding Llama tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token}) -model = AutoModelForCausalLM.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) ``` To tune the model with the [Abirate/english_quotes](https://huggingface.co/datasets/Abirate/english_quotes) dataset, you can load it and obtain the `quote` column: