Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update colab examples #86

Merged
merged 7 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions examples/language-modeling/gemma_tuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@
"outputs": [],
"source": [
"from optimum.tpu import fsdp_v2\n",
"\n",
"\n",
"fsdp_v2.use_fsdp_v2()"
]
},
Expand All @@ -141,6 +143,8 @@
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"\n",
"\n",
"dataset = load_dataset(\"databricks/databricks-dolly-15k\", split=\"train\")"
]
},
Expand Down Expand Up @@ -199,6 +203,8 @@
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"\n",
"model_id = \"google/gemma-2b\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
Expand Down Expand Up @@ -249,8 +255,11 @@
"metadata": {},
"outputs": [],
"source": [
"from optimum.tpu import AutoModelForCausalLM\n",
"model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=False)"
"import torch\n",
"from transformers import AutoModelForCausalLM\n",
"\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=False, torch_dtype=torch.bfloat16)"
]
},
{
Expand All @@ -270,6 +279,7 @@
"source": [
"from peft import LoraConfig\n",
"\n",
"\n",
"# Set up PEFT LoRA for fine-tuning.\n",
"lora_config = LoraConfig(\n",
" r=8,\n",
Expand All @@ -293,8 +303,9 @@
"metadata": {},
"outputs": [],
"source": [
"from trl import SFTTrainer\n",
"from transformers import TrainingArguments\n",
"from trl import SFTTrainer\n",
"\n",
"\n",
"# Set up the FSDP arguments\n",
"fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)\n",
Expand Down
5 changes: 2 additions & 3 deletions examples/language-modeling/llama_tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,13 @@ Then, the tokenizer and model need to be loaded. We will choose [`meta-llama/Met

```python
import torch
from transformers import AutoTokenizer
from optimum.tpu import AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Add custom token for padding Llama
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
model = AutoModelForCausalLM.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
```

To tune the model with the [Abirate/english_quotes](https://huggingface.co/datasets/Abirate/english_quotes) dataset, you can load it and obtain the `quote` column:
Expand Down
8 changes: 6 additions & 2 deletions optimum/tpu/fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,11 @@ def get_fsdp_training_args(model: PreTrainedModel) -> Dict:
model_type = model.config.model_type
matched_model = False
if model_type == "gemma":
from transformers import GemmaForCausalLM as HFGemmaForCausalLLM

from .modeling_gemma import GemmaForCausalLM

if isinstance(model, GemmaForCausalLM):
if isinstance(model, GemmaForCausalLM) or isinstance(model, HFGemmaForCausalLLM):
logger = logging.get_logger(__name__)
from torch_xla import __version__ as xla_version
if xla_version == "2.3.0":
Expand All @@ -94,9 +96,11 @@ def get_fsdp_training_args(model: PreTrainedModel) -> Dict:
cls_to_wrap = "GemmaDecoderLayer"
matched_model = True
elif model_type == "llama":
from transformers import LlamaForCausalLM as HFLlamaForCausalLLM

from .modeling_llama import LlamaForCausalLM

if isinstance(model, LlamaForCausalLM):
if isinstance(model, LlamaForCausalLM) or isinstance(model, HFLlamaForCausalLLM):
cls_to_wrap = "LlamaDecoderLayer"
matched_model = True

Expand Down
Loading