-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* fix(build): add loguru dependency for optimum-tpu * feat: added generation with gemma example * feat: import xla_model_parallel.py Originally imported from gemma repository https://github.com/google/gemma_pytorch.git At commit version: cf8658c. * feat: add implementation for parallel model proxy This will allow to execute models in a parallel way and interact with them from the caller thread. To see how it works in output mode, you can launch the test with debug enabled this way: DEBUG=1 pytest -s tests/test_parallel_proxy.py * feat(CI): added optimum tests to CI * chore: adapt xla_model_parallel style to avoid CI complaints * chore: Rename ModelProxy -> DistributedModel * chore: replace custom debug logging with loguru * chore(CI): remove accelerate installation from tests workflow * chore: remove unused function
- Loading branch information
1 parent
ab6b6a5
commit 7b48145
Showing
8 changed files
with
1,098 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
name: Optimum TPU tests | ||
|
||
on: | ||
push: | ||
branches: [ main ] | ||
paths: | ||
- "tests/**" | ||
pull_request: | ||
branches: [ main ] | ||
paths: | ||
- "tests/**" | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
do-the-job: | ||
name: Run optimum tpu tests | ||
runs-on: optimum-tpu | ||
container: | ||
# Use an image that works with TPU with Pytorch 2.3.0 (release was not working) | ||
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla@sha256:8f1dcd5b03f993e4da5c20d17c77aff6a5f22d5455f8eb042d2e4b16ac460526 | ||
options: --shm-size "16gb" --ipc host --privileged | ||
env: | ||
PJRT_DEVICE: TPU | ||
steps: | ||
- name: Checkout | ||
uses: actions/checkout@v4 | ||
|
||
- name: Build and test optimum tpu | ||
run: | | ||
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# 🤗 Transformers example scripts on Google TPU with Optimum TPU | ||
|
||
These examples show how to run an inference on Google TPU using Optimum TPU. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
#!/usr/bin/python | ||
|
||
import torch | ||
import time | ||
import datetime | ||
import os | ||
import platform | ||
from typing import List | ||
import torch_xla.core.xla_model as xm | ||
from optimum.tpu.modeling import TpuModelForCausalLM | ||
from transformers import AutoTokenizer, StaticCache | ||
|
||
|
||
os.environ["PJRT_DEVICE"] = "TPU" | ||
|
||
|
||
def sample_greedy(logits): | ||
next_logits = logits[:, -1] | ||
next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int() | ||
return next_token_id | ||
|
||
|
||
def decode_one_tokens(model, cur_token, input_pos, cache_position, step): | ||
logits = model( | ||
cur_token, | ||
position_ids=input_pos, | ||
cache_position=cache_position, | ||
return_dict=False, | ||
use_cache=True, | ||
)[0] | ||
new_token = sample_greedy(logits) | ||
return new_token | ||
|
||
|
||
def conditional_compile(func): | ||
if "DBG_COMPILE" in os.environ: | ||
compiled = torch.compile(func, backend="openxla") | ||
return compiled | ||
return func | ||
|
||
|
||
def summary(values: List[float]): | ||
values.sort() | ||
n = len(values) | ||
if n % 2 == 0: | ||
median = (values[n // 2 - 1] + values[n // 2]) / 2 | ||
else: | ||
median = values[n // 2] | ||
total = sum(values) | ||
mean = sum(values) / n | ||
print(f"Decode time: {total}, average: {mean}, median: {median}") | ||
|
||
|
||
def main(): | ||
prg_start = time.time() | ||
model_id = "google/gemma-2b" | ||
torch_dtype = torch.bfloat16 | ||
|
||
model = TpuModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype) | ||
device = model.device | ||
model = model.eval() | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
prompts = ["Here's a funny thing:", "Once upon a time,"] | ||
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device) | ||
batch_size, sequence_length = inputs["input_ids"].shape | ||
max_cache_length = 1024 | ||
max_new_tokens = 20 | ||
|
||
start = time.time() | ||
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length) | ||
end = time.time() | ||
print(f"Model cache setup took {end - start} seconds.") | ||
start = time.time() | ||
cache_position = torch.arange(sequence_length, device=device) | ||
generated_ids = torch.zeros( | ||
(batch_size, sequence_length + max_new_tokens + 1), | ||
dtype=torch.int, | ||
device=device, | ||
) | ||
generated_ids[:, cache_position] = inputs["input_ids"].to(torch.int) | ||
|
||
# prefill here | ||
attention_mask = inputs["attention_mask"] | ||
pos_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0) | ||
logits = model( | ||
**inputs, | ||
cache_position=cache_position, | ||
return_dict=False, | ||
use_cache=True, | ||
position_ids=pos_ids, | ||
)[0] | ||
next_token = sample_greedy(logits) | ||
xm.mark_step() | ||
generated_ids[:, sequence_length] = next_token[:, 0] | ||
end = time.time() | ||
print(f"Prefill took {end - start} seconds.") | ||
|
||
pos_ids = pos_ids.max(axis=-1)[0].unsqueeze(1) + 1 | ||
|
||
model = conditional_compile(model) | ||
cache_position = torch.tensor([sequence_length + 1], device=device) | ||
decode_times = [] | ||
for i in range(max_new_tokens): | ||
step_start = time.time() | ||
next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position, i) | ||
generated_ids[:, cache_position] = next_token | ||
|
||
cache_position += 1 | ||
pos_ids += 1 | ||
xm.mark_step() | ||
step_end = time.time() | ||
step_time = step_end - step_start | ||
decode_times.append(step_time) | ||
print(f"Step {i} took {step_time} seconds.") | ||
summary(decode_times) | ||
|
||
print(f"Decoding start at {datetime.datetime.now()}") | ||
|
||
decoded_texts = tokenizer.batch_decode(generated_ids) | ||
for i, text in enumerate(decoded_texts): | ||
print(i, text) | ||
|
||
end = time.time() | ||
print(f"Program run in {end - prg_start} seconds. Device: {device} System: {platform.system()}") | ||
|
||
|
||
if __name__ == "__main__": | ||
with torch.no_grad(): | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
# ruff: noqa: E402 | ||
import torch | ||
import os | ||
from enum import Enum | ||
|
||
os.environ["PJRT_DEVICE"] = "TPU" | ||
|
||
import torch_xla.core.xla_model as xm | ||
import torch_xla.distributed.xla_multiprocessing as xmp | ||
import torch.multiprocessing as mp | ||
|
||
from optimum.tpu.modeling import TpuModelForCausalLM | ||
from typing import Dict | ||
from loguru import logger | ||
|
||
|
||
class ModelCommand(Enum): | ||
LEAVE = 0 | ||
PREFILL = 1 | ||
DECODE = 2 | ||
|
||
|
||
class RootMailbox: | ||
def __init__(self, manager: mp.Manager): | ||
self.root_bell = manager.Event() | ||
self.root_command = manager.list() | ||
self.model_ready = manager.Event() | ||
self.output_data = manager.Value(torch.Tensor, torch.tensor([])) | ||
|
||
def send(self, command: ModelCommand, data: Dict = None): | ||
# First wait until model is ready to receive commands | ||
logger.debug(f" MM Command {command} waiting for model to be ready") | ||
self.model_ready.wait() | ||
self.model_ready.clear() | ||
|
||
self.root_command[:] = [command, data] | ||
self.root_bell.set() | ||
logger.debug(f" MM Command {command} sent") | ||
# wait again until model is ready, meaning command has been processed | ||
self.model_ready.wait() | ||
ret = self.output_data.get() | ||
logger.debug(f" MM Command {command} output shape {ret.shape}") | ||
return ret | ||
|
||
|
||
class AgentMailbox: | ||
def __init__(self, root_mailbox: RootMailbox): | ||
self.root_bell = root_mailbox.root_bell | ||
self.root_command = root_mailbox.root_command | ||
self.model_ready = root_mailbox.model_ready | ||
self.output_data = root_mailbox.output_data | ||
|
||
def receive(self): | ||
self.root_bell.wait() | ||
self.root_bell.clear() | ||
return self.root_command | ||
|
||
def send(self, data: torch.Tensor): | ||
logger.debug(f" MM Enqueueing data {data.shape}") | ||
# Data needs to be moved to CPU before setting it | ||
self.output_data.set(data.cpu()) | ||
logger.debug(" MM Enqueueing data done") | ||
|
||
@property | ||
def command_data(self): | ||
command = self.root_command[0] | ||
data = self.root_command[1] | ||
return command, data | ||
|
||
|
||
def _mp_fn(rank, model_id, root_mailbox: RootMailbox, sample_fn: callable): | ||
device = xm.xla_device() | ||
world_size = xm.xrt_world_size() | ||
# create agent mailbox out of root's one | ||
mailbox = AgentMailbox(root_mailbox) | ||
|
||
logger.debug( | ||
f"Rank {rank} on {device} real device {xm.xla_real_devices([device])} ordinal {xm.get_ordinal()} " | ||
+ f"world size {world_size}" | ||
) | ||
|
||
# Model loading and sharding should happen here | ||
model = TpuModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) | ||
model = model.eval() | ||
model.to(device) | ||
|
||
def get_next_token(inputs): | ||
# move inputs to device in a new dict to avoid conflicts | ||
model_inputs = {} | ||
for key, value in inputs.items(): | ||
model_inputs[key] = value.to(device) | ||
outputs = model(**model_inputs, return_dict=False)[0] | ||
xm.mark_step() | ||
# consider adding a rendezvous here | ||
if rank == 0: | ||
logger.debug(f"Rank {rank} getting tokens") | ||
next_token = sample_fn(outputs) | ||
xm.mark_step() | ||
logger.debug(f"Rank {rank} sending next_tokens {next_token.shape}") | ||
mailbox.send(next_token) | ||
|
||
while True: | ||
if rank == 0: | ||
mailbox.model_ready.set() | ||
logger.debug(f"Rank {rank} waiting for commands") | ||
mailbox.receive() | ||
# Wait for rank 0 to receive command | ||
xm.rendezvous("start") | ||
|
||
logger.debug(f"Rank {rank} waiting for command at rendezvous") | ||
command, inputs = mailbox.command_data | ||
if command == ModelCommand.PREFILL: | ||
logger.debug(f"Rank {rank} PREFILL") | ||
get_next_token(inputs) | ||
elif command == ModelCommand.DECODE: | ||
logger.debug(f"Rank {rank} DECODE") | ||
get_next_token(inputs) | ||
elif command == ModelCommand.LEAVE: | ||
logger.debug(f"Rank {rank} LEAVE") | ||
# Set model to ready | ||
mailbox.model_ready.set() | ||
break | ||
|
||
|
||
def model_loop_fn(*args): | ||
"""Spawn processes in the TPUs forwarding arguments""" | ||
xmp.spawn(_mp_fn, args=(args), join=True, daemon=False) | ||
|
||
|
||
class DistributedModel: | ||
def __init__(self, model_id: str, sample_fn: callable): | ||
manager = mp.Manager() | ||
self.mailbox = RootMailbox(manager) | ||
|
||
self.model_loop = mp.Process(target=model_loop_fn, args=(model_id, self.mailbox, sample_fn)) | ||
self.model_loop.start() | ||
|
||
def prefill(self, **model_args): | ||
assert self.mailbox is not None, "DistributedModel is not initialized" | ||
return self.mailbox.send(ModelCommand.PREFILL, model_args) | ||
|
||
def decode(self, **model_args): | ||
assert self.mailbox is not None, "DistributedModel is not initialized" | ||
return self.mailbox.send(ModelCommand.PREFILL, model_args) | ||
|
||
def leave(self): | ||
if self.mailbox is None: | ||
return | ||
self.mailbox.send(ModelCommand.LEAVE) | ||
logger.debug("Joining...") | ||
self.model_loop.join() | ||
logger.debug("Model loop finished") | ||
self.mailbox = None | ||
|
||
def __del__(self): | ||
self.leave() |
Oops, something went wrong.