Skip to content

Commit

Permalink
Xla parallel proxy (#12)
Browse files Browse the repository at this point in the history
* fix(build): add loguru dependency for optimum-tpu

* feat: added generation with gemma example

* feat: import xla_model_parallel.py

Originally imported from gemma repository

https://github.com/google/gemma_pytorch.git

At commit version: cf8658c.

* feat: add implementation for parallel model proxy

This will allow to execute models in a parallel way and interact with
them from the caller thread.
To see how it works in output mode, you can launch the test with debug
enabled this way:

DEBUG=1 pytest -s tests/test_parallel_proxy.py

* feat(CI): added optimum tests to CI

* chore: adapt xla_model_parallel style to avoid CI complaints

* chore: Rename ModelProxy -> DistributedModel

* chore: replace custom debug logging with loguru

* chore(CI): remove accelerate installation from tests workflow

* chore: remove unused function
  • Loading branch information
tengomucho authored Apr 8, 2024
1 parent ab6b6a5 commit 7b48145
Show file tree
Hide file tree
Showing 8 changed files with 1,098 additions and 1 deletion.
33 changes: 33 additions & 0 deletions .github/workflows/test-pytorch-xla-tpu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Optimum TPU tests

on:
push:
branches: [ main ]
paths:
- "tests/**"
pull_request:
branches: [ main ]
paths:
- "tests/**"

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
do-the-job:
name: Run optimum tpu tests
runs-on: optimum-tpu
container:
# Use an image that works with TPU with Pytorch 2.3.0 (release was not working)
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla@sha256:8f1dcd5b03f993e4da5c20d17c77aff6a5f22d5455f8eb042d2e4b16ac460526
options: --shm-size "16gb" --ipc host --privileged
env:
PJRT_DEVICE: TPU
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Build and test optimum tpu
run: |
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tests
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ pypi_upload: ${PACKAGE_DIST} ${PACKAGE_WHEEL}
test_installs:
python -m pip install .[tpu,tests]

tests: test_installs
python -m pytest -sv tests

# Stand-alone TGI server for unit tests outside of TGI container
tgi_server:
python -m pip install -r text-generation-inference/server/build-requirements.txt
Expand Down
3 changes: 3 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# 🤗 Transformers example scripts on Google TPU with Optimum TPU

These examples show how to run an inference on Google TPU using Optimum TPU.
130 changes: 130 additions & 0 deletions examples/text-generation/generation_gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#!/usr/bin/python

import torch
import time
import datetime
import os
import platform
from typing import List
import torch_xla.core.xla_model as xm
from optimum.tpu.modeling import TpuModelForCausalLM
from transformers import AutoTokenizer, StaticCache


os.environ["PJRT_DEVICE"] = "TPU"


def sample_greedy(logits):
next_logits = logits[:, -1]
next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int()
return next_token_id


def decode_one_tokens(model, cur_token, input_pos, cache_position, step):
logits = model(
cur_token,
position_ids=input_pos,
cache_position=cache_position,
return_dict=False,
use_cache=True,
)[0]
new_token = sample_greedy(logits)
return new_token


def conditional_compile(func):
if "DBG_COMPILE" in os.environ:
compiled = torch.compile(func, backend="openxla")
return compiled
return func


def summary(values: List[float]):
values.sort()
n = len(values)
if n % 2 == 0:
median = (values[n // 2 - 1] + values[n // 2]) / 2
else:
median = values[n // 2]
total = sum(values)
mean = sum(values) / n
print(f"Decode time: {total}, average: {mean}, median: {median}")


def main():
prg_start = time.time()
model_id = "google/gemma-2b"
torch_dtype = torch.bfloat16

model = TpuModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype)
device = model.device
model = model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_id)
prompts = ["Here's a funny thing:", "Once upon a time,"]
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
batch_size, sequence_length = inputs["input_ids"].shape
max_cache_length = 1024
max_new_tokens = 20

start = time.time()
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length)
end = time.time()
print(f"Model cache setup took {end - start} seconds.")
start = time.time()
cache_position = torch.arange(sequence_length, device=device)
generated_ids = torch.zeros(
(batch_size, sequence_length + max_new_tokens + 1),
dtype=torch.int,
device=device,
)
generated_ids[:, cache_position] = inputs["input_ids"].to(torch.int)

# prefill here
attention_mask = inputs["attention_mask"]
pos_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0)
logits = model(
**inputs,
cache_position=cache_position,
return_dict=False,
use_cache=True,
position_ids=pos_ids,
)[0]
next_token = sample_greedy(logits)
xm.mark_step()
generated_ids[:, sequence_length] = next_token[:, 0]
end = time.time()
print(f"Prefill took {end - start} seconds.")

pos_ids = pos_ids.max(axis=-1)[0].unsqueeze(1) + 1

model = conditional_compile(model)
cache_position = torch.tensor([sequence_length + 1], device=device)
decode_times = []
for i in range(max_new_tokens):
step_start = time.time()
next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position, i)
generated_ids[:, cache_position] = next_token

cache_position += 1
pos_ids += 1
xm.mark_step()
step_end = time.time()
step_time = step_end - step_start
decode_times.append(step_time)
print(f"Step {i} took {step_time} seconds.")
summary(decode_times)

print(f"Decoding start at {datetime.datetime.now()}")

decoded_texts = tokenizer.batch_decode(generated_ids)
for i, text in enumerate(decoded_texts):
print(i, text)

end = time.time()
print(f"Program run in {end - prg_start} seconds. Device: {device} System: {platform.system()}")


if __name__ == "__main__":
with torch.no_grad():
main()
156 changes: 156 additions & 0 deletions optimum/tpu/distributed_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# ruff: noqa: E402
import torch
import os
from enum import Enum

os.environ["PJRT_DEVICE"] = "TPU"

import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch.multiprocessing as mp

from optimum.tpu.modeling import TpuModelForCausalLM
from typing import Dict
from loguru import logger


class ModelCommand(Enum):
LEAVE = 0
PREFILL = 1
DECODE = 2


class RootMailbox:
def __init__(self, manager: mp.Manager):
self.root_bell = manager.Event()
self.root_command = manager.list()
self.model_ready = manager.Event()
self.output_data = manager.Value(torch.Tensor, torch.tensor([]))

def send(self, command: ModelCommand, data: Dict = None):
# First wait until model is ready to receive commands
logger.debug(f" MM Command {command} waiting for model to be ready")
self.model_ready.wait()
self.model_ready.clear()

self.root_command[:] = [command, data]
self.root_bell.set()
logger.debug(f" MM Command {command} sent")
# wait again until model is ready, meaning command has been processed
self.model_ready.wait()
ret = self.output_data.get()
logger.debug(f" MM Command {command} output shape {ret.shape}")
return ret


class AgentMailbox:
def __init__(self, root_mailbox: RootMailbox):
self.root_bell = root_mailbox.root_bell
self.root_command = root_mailbox.root_command
self.model_ready = root_mailbox.model_ready
self.output_data = root_mailbox.output_data

def receive(self):
self.root_bell.wait()
self.root_bell.clear()
return self.root_command

def send(self, data: torch.Tensor):
logger.debug(f" MM Enqueueing data {data.shape}")
# Data needs to be moved to CPU before setting it
self.output_data.set(data.cpu())
logger.debug(" MM Enqueueing data done")

@property
def command_data(self):
command = self.root_command[0]
data = self.root_command[1]
return command, data


def _mp_fn(rank, model_id, root_mailbox: RootMailbox, sample_fn: callable):
device = xm.xla_device()
world_size = xm.xrt_world_size()
# create agent mailbox out of root's one
mailbox = AgentMailbox(root_mailbox)

logger.debug(
f"Rank {rank} on {device} real device {xm.xla_real_devices([device])} ordinal {xm.get_ordinal()} "
+ f"world size {world_size}"
)

# Model loading and sharding should happen here
model = TpuModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
model = model.eval()
model.to(device)

def get_next_token(inputs):
# move inputs to device in a new dict to avoid conflicts
model_inputs = {}
for key, value in inputs.items():
model_inputs[key] = value.to(device)
outputs = model(**model_inputs, return_dict=False)[0]
xm.mark_step()
# consider adding a rendezvous here
if rank == 0:
logger.debug(f"Rank {rank} getting tokens")
next_token = sample_fn(outputs)
xm.mark_step()
logger.debug(f"Rank {rank} sending next_tokens {next_token.shape}")
mailbox.send(next_token)

while True:
if rank == 0:
mailbox.model_ready.set()
logger.debug(f"Rank {rank} waiting for commands")
mailbox.receive()
# Wait for rank 0 to receive command
xm.rendezvous("start")

logger.debug(f"Rank {rank} waiting for command at rendezvous")
command, inputs = mailbox.command_data
if command == ModelCommand.PREFILL:
logger.debug(f"Rank {rank} PREFILL")
get_next_token(inputs)
elif command == ModelCommand.DECODE:
logger.debug(f"Rank {rank} DECODE")
get_next_token(inputs)
elif command == ModelCommand.LEAVE:
logger.debug(f"Rank {rank} LEAVE")
# Set model to ready
mailbox.model_ready.set()
break


def model_loop_fn(*args):
"""Spawn processes in the TPUs forwarding arguments"""
xmp.spawn(_mp_fn, args=(args), join=True, daemon=False)


class DistributedModel:
def __init__(self, model_id: str, sample_fn: callable):
manager = mp.Manager()
self.mailbox = RootMailbox(manager)

self.model_loop = mp.Process(target=model_loop_fn, args=(model_id, self.mailbox, sample_fn))
self.model_loop.start()

def prefill(self, **model_args):
assert self.mailbox is not None, "DistributedModel is not initialized"
return self.mailbox.send(ModelCommand.PREFILL, model_args)

def decode(self, **model_args):
assert self.mailbox is not None, "DistributedModel is not initialized"
return self.mailbox.send(ModelCommand.PREFILL, model_args)

def leave(self):
if self.mailbox is None:
return
self.mailbox.send(ModelCommand.LEAVE)
logger.debug("Joining...")
self.model_loop.join()
logger.debug("Model loop finished")
self.mailbox = None

def __del__(self):
self.leave()
Loading

0 comments on commit 7b48145

Please sign in to comment.