Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xla parallel proxy #12

Merged
merged 10 commits into from
Apr 8, 2024
34 changes: 34 additions & 0 deletions .github/workflows/test-pytorch-xla-tpu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Optimum TPU tests

on:
push:
branches: [ main ]
paths:
- "tests/**"
pull_request:
branches: [ main ]
paths:
- "tests/**"

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
do-the-job:
name: Run optimum tpu tests
runs-on: optimum-tpu
container:
# Use an image that works with TPU with Pytorch 2.3.0 (release was not working)
image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla@sha256:8f1dcd5b03f993e4da5c20d17c77aff6a5f22d5455f8eb042d2e4b16ac460526
options: --shm-size "16gb" --ipc host --privileged
env:
PJRT_DEVICE: TPU
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Build and test optimum tpu
run: |
pip install accelerate==0.27.2
mfuntowicz marked this conversation as resolved.
Show resolved Hide resolved
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_TPU_CI }} make tests
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ pypi_upload: ${PACKAGE_DIST} ${PACKAGE_WHEEL}
test_installs:
python -m pip install .[tpu,tests]

tests: test_installs
python -m pytest -sv tests

# Stand-alone TGI server for unit tests outside of TGI container
tgi_server:
python -m pip install -r text-generation-inference/server/build-requirements.txt
Expand Down
3 changes: 3 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# 🤗 Transformers example scripts on Google TPU with Optimum TPU

These examples show how to run an inference on Google TPU using Optimum TPU.
130 changes: 130 additions & 0 deletions examples/text-generation/generation_gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#!/usr/bin/python

import torch
import time
import datetime
import os
import platform
from typing import List
import torch_xla.core.xla_model as xm
from optimum.tpu.modeling import TpuModelForCausalLM
from transformers import AutoTokenizer, StaticCache


os.environ["PJRT_DEVICE"] = "TPU"


def sample_greedy(logits):
next_logits = logits[:, -1]
next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int()
return next_token_id


def decode_one_tokens(model, cur_token, input_pos, cache_position, step):
logits = model(
cur_token,
position_ids=input_pos,
cache_position=cache_position,
return_dict=False,
use_cache=True,
)[0]
new_token = sample_greedy(logits)
return new_token


def conditional_compile(func):
if "DBG_COMPILE" in os.environ:
compiled = torch.compile(func, backend="openxla")
return compiled
return func


def summary(values: List[float]):
values.sort()
n = len(values)
if n % 2 == 0:
median = (values[n // 2 - 1] + values[n // 2]) / 2
else:
median = values[n // 2]
total = sum(values)
mean = sum(values) / n
print(f"Decode time: {total}, average: {mean}, median: {median}")


def main():
prg_start = time.time()
model_id = "google/gemma-2b"
torch_dtype = torch.bfloat16

model = TpuModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype)
device = model.device
model = model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_id)
prompts = ["Here's a funny thing:", "Once upon a time,"]
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
batch_size, sequence_length = inputs["input_ids"].shape
max_cache_length = 1024
max_new_tokens = 20

start = time.time()
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length)
end = time.time()
print(f"Model cache setup took {end - start} seconds.")
start = time.time()
cache_position = torch.arange(sequence_length, device=device)
generated_ids = torch.zeros(
(batch_size, sequence_length + max_new_tokens + 1),
dtype=torch.int,
device=device,
)
generated_ids[:, cache_position] = inputs["input_ids"].to(torch.int)

# prefill here
attention_mask = inputs["attention_mask"]
pos_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0)
logits = model(
**inputs,
cache_position=cache_position,
return_dict=False,
use_cache=True,
position_ids=pos_ids,
)[0]
next_token = sample_greedy(logits)
xm.mark_step()
generated_ids[:, sequence_length] = next_token[:, 0]
end = time.time()
print(f"Prefill took {end - start} seconds.")

pos_ids = pos_ids.max(axis=-1)[0].unsqueeze(1) + 1

model = conditional_compile(model)
cache_position = torch.tensor([sequence_length + 1], device=device)
decode_times = []
for i in range(max_new_tokens):
step_start = time.time()
next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position, i)
generated_ids[:, cache_position] = next_token

cache_position += 1
pos_ids += 1
xm.mark_step()
step_end = time.time()
step_time = step_end - step_start
decode_times.append(step_time)
print(f"Step {i} took {step_time} seconds.")
summary(decode_times)

print(f"Decoding start at {datetime.datetime.now()}")

decoded_texts = tokenizer.batch_decode(generated_ids)
for i, text in enumerate(decoded_texts):
print(i, text)

end = time.time()
print(f"Program run in {end - prg_start} seconds. Device: {device} System: {platform.system()}")


if __name__ == "__main__":
with torch.no_grad():
main()
Loading
Loading