Skip to content

Commit

Permalink
Parallel sharding (#21)
Browse files Browse the repository at this point in the history
* chore: update transformers dependency

* feat: import transformer's gemma modeling code

It will be used to adapt it for sharding. Only imports have been
adapted, and only code relevant for GemmaForCausalLM has been added.

* chore: rename model Gemma -> TpuGemma to prepare for changes

* feat(DistributedModel): added config property

* chore: rename test_parallel_proxy.py -> test_distributed_model.py

* fix: use AutoModelForCausalLM instead of TpuModelForCausalLM

* feat: AutoModelForCausalLM will choose TpuGemmaForCausalLM if possible

* fix(TpuGemma): avoid using device_map when loading model

It seems that device_map parameter triggers a chain of calls that will
try to use accelerate to load the model using less memory. The problem
is that it skips the load state pre-hooks, making the weights loading
impossible.

* feat(gemma): sharding o_proj

It will now be running in parallel. More changes to come.

* feat(gemma): sharding on q_proj

* feat(gemma): sharding on k and v proj

* feat(gemma): sharding on mlp gate and up proj

* feat(gemma): sharding on mlp down proj

* feat: model il loaded using pytorch_dtype from config

This will lead to loading the model in bfloat16 when specified in the
config.

* fix: remove useless import

* feat(tests): added test showing gemma7b sharding and prefill works

* chore: config_name_to_class uses config.model_type now

* fix: get_generation_mode is now a method of generation_config

API change when transformers was updated.

* fix(TGI server): fix slot.stopped changed after transformers update

* fix(generator): fix sample generation again

I wrongly chose the model's generation config instead of the one to the
token selector.

* fix: better handle torch_dtype

bfloat16 will be set by default in gemma models, other models will still
load in float32 by default.

* fix: remove unused import
  • Loading branch information
tengomucho authored Apr 10, 2024
1 parent d5b921e commit 8e12733
Show file tree
Hide file tree
Showing 13 changed files with 1,438 additions and 29 deletions.
4 changes: 2 additions & 2 deletions examples/text-generation/generation_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import platform
from typing import List
import torch_xla.core.xla_model as xm
from optimum.tpu.modeling import TpuModelForCausalLM
from optimum.tpu.modeling import AutoModelForCausalLM
from transformers import AutoTokenizer, StaticCache


Expand Down Expand Up @@ -56,7 +56,7 @@ def main():
model_id = "google/gemma-2b"
torch_dtype = torch.bfloat16

model = TpuModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype)
device = model.device
model = model.eval()

Expand Down
24 changes: 20 additions & 4 deletions optimum/tpu/distributed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
import torch
import os
from enum import Enum
from typing import Dict
from loguru import logger

os.environ["PJRT_DEVICE"] = "TPU"

import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch.multiprocessing as mp

from optimum.tpu.modeling import TpuModelForCausalLM
from typing import Dict
from loguru import logger
from optimum.tpu.modeling import AutoModelForCausalLM
from transformers import PretrainedConfig


class ModelCommand(Enum):
Expand All @@ -26,6 +27,14 @@ def __init__(self, manager: mp.Manager):
self.root_command = manager.list()
self.model_ready = manager.Event()
self.output_data = manager.Value(torch.Tensor, torch.tensor([]))
self.model_config = manager.Value(PretrainedConfig, None)

@property
def config(self):
while True:
config = self.model_config.get()
if config is not None:
return config

def send(self, command: ModelCommand, data: Dict = None):
# First wait until model is ready to receive commands
Expand All @@ -49,6 +58,7 @@ def __init__(self, root_mailbox: RootMailbox):
self.root_command = root_mailbox.root_command
self.model_ready = root_mailbox.model_ready
self.output_data = root_mailbox.output_data
self.model_config = root_mailbox.model_config

def receive(self):
self.root_bell.wait()
Expand Down Expand Up @@ -80,9 +90,11 @@ def _mp_fn(rank, model_id, root_mailbox: RootMailbox, sample_fn: callable):
)

# Model loading and sharding should happen here
model = TpuModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
model = AutoModelForCausalLM.from_pretrained(model_id)
model = model.eval()
model.to(device)
if rank == 0:
mailbox.model_config.set(model.config)

def get_next_token(inputs):
# move inputs to device in a new dict to avoid conflicts
Expand Down Expand Up @@ -152,5 +164,9 @@ def leave(self):
logger.debug("Model loop finished")
self.mailbox = None

@property
def config(self):
return self.mailbox.config

def __del__(self):
self.leave()
2 changes: 1 addition & 1 deletion optimum/tpu/generation/token_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def create(
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id

generation_mode = model._get_generation_mode(generation_config, None)
generation_mode = generation_config.get_generation_mode()
if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]:
raise ValueError("Unsupported generation mode")

Expand Down
24 changes: 13 additions & 11 deletions optimum/tpu/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@
from typing import Any

from loguru import logger
from transformers import AutoModelForCausalLM as BaseAutoModelForCausalLM
from transformers.utils import is_accelerate_available
from transformers import AutoModelForCausalLM as BaseAutoModelForCausalLM, AutoConfig

from optimum.tpu.modeling_gemma import TpuGemmaForCausalLM


def config_name_to_class(pretrained_model_name_or_path: str):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
if config.model_type == "gemma":
return TpuGemmaForCausalLM
return BaseAutoModelForCausalLM


# TODO: For now TpuModelForCausalLM is just a shallow wrapper of
# AutoModelForCausalLM, later this could be replaced by a custom class.
class AutoModelForCausalLM(BaseAutoModelForCausalLM):

@classmethod
Expand All @@ -45,13 +51,9 @@ def from_pretrained(
logger.debug(f"Device set to: {device}")
else:
device = "xla"
if is_accelerate_available():
model = BaseAutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, device_map=device, *model_args, **kwargs
)
else:
model = BaseAutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
model.to(device)
cls = config_name_to_class(pretrained_model_name_or_path)
model = cls.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
model.to(device)
# Update config with specific data)
if task is not None or getattr(model.config, "task", None) is None:
model.config.task = task
Expand Down
Loading

0 comments on commit 8e12733

Please sign in to comment.