diff --git a/examples/language-modeling/gemma_tuning.ipynb b/examples/language-modeling/gemma_tuning.ipynb index 9fbcaf79..fc7f4717 100644 --- a/examples/language-modeling/gemma_tuning.ipynb +++ b/examples/language-modeling/gemma_tuning.ipynb @@ -118,6 +118,8 @@ "outputs": [], "source": [ "from optimum.tpu import fsdp_v2\n", + "\n", + "\n", "fsdp_v2.use_fsdp_v2()" ] }, @@ -141,6 +143,8 @@ "outputs": [], "source": [ "from datasets import load_dataset\n", + "\n", + "\n", "dataset = load_dataset(\"databricks/databricks-dolly-15k\", split=\"train\")" ] }, @@ -199,6 +203,8 @@ "outputs": [], "source": [ "from transformers import AutoTokenizer\n", + "\n", + "\n", "model_id = \"google/gemma-2b\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", @@ -249,8 +255,10 @@ "metadata": {}, "outputs": [], "source": [ - "from transformers import AutoModelForCausalLM\n", "import torch\n", + "from transformers import AutoModelForCausalLM\n", + "\n", + "\n", "model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=False, torch_dtype=torch.bfloat16)" ] }, @@ -271,6 +279,7 @@ "source": [ "from peft import LoraConfig\n", "\n", + "\n", "# Set up PEFT LoRA for fine-tuning.\n", "lora_config = LoraConfig(\n", " r=8,\n", @@ -294,8 +303,9 @@ "metadata": {}, "outputs": [], "source": [ - "from trl import SFTTrainer\n", "from transformers import TrainingArguments\n", + "from trl import SFTTrainer\n", + "\n", "\n", "# Set up the FSDP arguments\n", "fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)\n", diff --git a/optimum/tpu/fsdp_v2.py b/optimum/tpu/fsdp_v2.py index 2b2bcbe9..8a138793 100644 --- a/optimum/tpu/fsdp_v2.py +++ b/optimum/tpu/fsdp_v2.py @@ -82,9 +82,10 @@ def get_fsdp_training_args(model: PreTrainedModel) -> Dict: model_type = model.config.model_type matched_model = False if model_type == "gemma": - from .modeling_gemma import GemmaForCausalLM from transformers import GemmaForCausalLM as HFGemmaForCausalLLM + from .modeling_gemma import GemmaForCausalLM + if isinstance(model, GemmaForCausalLM) or isinstance(model, HFGemmaForCausalLLM): logger = logging.get_logger(__name__) from torch_xla import __version__ as xla_version @@ -95,9 +96,10 @@ def get_fsdp_training_args(model: PreTrainedModel) -> Dict: cls_to_wrap = "GemmaDecoderLayer" matched_model = True elif model_type == "llama": - from .modeling_llama import LlamaForCausalLM from transformers import LlamaForCausalLM as HFLlamaForCausalLLM + from .modeling_llama import LlamaForCausalLM + if isinstance(model, LlamaForCausalLM) or isinstance(model, HFLlamaForCausalLLM): cls_to_wrap = "LlamaDecoderLayer" matched_model = True