Skip to content

Commit

Permalink
Add Mistral support 💨 (#54)
Browse files Browse the repository at this point in the history
* feat(modeling): import Mistral support

Imported from transformers sha1: a2ede6667 (current main branch).
This allows to use recent static cache support.
The only changes are:

- fixed the import paths,
- added a workaround to avoid having to import SlidingWindowCache or
having to modify the file too much.

* feat(mistral): add inference sharding on Linear modules

* feat(examples): generalize text generation example to other models

This will allow using the same example for other models, such as mistralai/Mistral-7B-v0.3

* feat(inference): use Linear when world_size is 1

There is no point in using code to sync multiple TPUs when using only
one.

* refactor(tests): try to reduce repetition for decode tests

* test(tgi): added test for Mistral-7B-v0.3

* feat(tests): delete generator to prevent getting stuck when failing

* chore(doc): updated mention to Mistral

* feat(model): filter only for model safetensors

This will prevent downloading consolidated weights uselessly, as for
the Mistral repo.
  • Loading branch information
tengomucho authored Jun 17, 2024
1 parent a660f80 commit 3900bd7
Show file tree
Hide file tree
Showing 12 changed files with 1,793 additions and 330 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ working closely with Google and Google Cloud to make this a reality.
## Supported Model and Tasks

We currently support a few LLM models targeting text generation scenarios:
- Gemma (2b, 7b)
- Llama2 (7b) and Llama3 (8b)
- Mistral (soon)
- 💎 Gemma (2b, 7b)
- 🦙 Llama2 (7b) and Llama3 (8b)
- 💨 Mistral (7b)


## Installation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/python

import argparse
import datetime
import os
import platform
Expand Down Expand Up @@ -55,20 +56,33 @@ def summary(values: List[float]):


def main():
parser = argparse.ArgumentParser(description="Text generation example")
parser.add_argument("--model_id", type=str,
default="google/gemma-2b",
help="Model ID (e.g.: google/gemma-2b, mistralai/Mistral-7B-v0.3)")
parser.add_argument("--max_new_tokens", type=int, default=20, help="Number of tokens to generate")
parser.add_argument("--max_cache_length", type=int, default=256, help="Maximum cache length for the model")
args = parser.parse_args()

prg_start = time.time()
model_id = "google/gemma-2b"
print(f"⏳ Loading model {args.model_id}...")
model_id = args.model_id
torch_dtype = torch.bfloat16

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype)
device = model.device
model = model.eval()
print(f"✅ Model loaded in {time.time() - prg_start} seconds.")

tokenizer = AutoTokenizer.from_pretrained(model_id)
# Set pad token for cases where it is None, e.g. for Mistral
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
prompts = ["Here's a funny thing:", "Once upon a time,"]
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
batch_size, sequence_length = inputs["input_ids"].shape
max_cache_length = 1024
max_new_tokens = 20
max_new_tokens = args.max_new_tokens

# setup static cache
past_key_values = StaticCache(
Expand Down
2 changes: 1 addition & 1 deletion optimum/tpu/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def fetch_model(
local_path = snapshot_download(
repo_id=model_id,
revision=revision,
allow_patterns=["config.json", "*.safetensors", SAFE_WEIGHTS_INDEX_NAME, "tokenizer*"],
allow_patterns=["config.json", "model*.safetensors", SAFE_WEIGHTS_INDEX_NAME, "tokenizer*"],
)
end = time.time()
logger.info(f"Model successfully fetched in {end - start:.2f} s.")
Expand Down
4 changes: 4 additions & 0 deletions optimum/tpu/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def config_name_to_class(pretrained_model_name_or_path: str):
from .modeling_llama import LlamaForCausalLM

return LlamaForCausalLM
if config.model_type == "mistral":
from .modeling_mistral import MistralForCausalLM

return MistralForCausalLM
return BaseAutoModelForCausalLM


Expand Down
Loading

0 comments on commit 3900bd7

Please sign in to comment.