Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BOUNTY] Add Phi-2 #117

Merged
merged 16 commits into from
Sep 3, 2024
Merged
68 changes: 68 additions & 0 deletions model_demos/nlp_demos/phi2/pytorch_phi2_text_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
# SPDX-License-Identifier: Apache-2.0

# Phi2 Demo - Text Generation

import os
import pybuda

from transformers import PhiForCausalLM, AutoTokenizer, PhiConfig
from pybuda.transformers.pipeline import pipeline as pybuda_pipeline

def run_phi2_causal_lm(batch_size=1):
os.environ["TT_BACKEND_TIMEOUT"] = '0'

# Set PyBuda configurations
compiler_cfg = pybuda.config._get_global_compiler_config()
compiler_cfg.default_df_override = pybuda.DataFormat.Float16_b
compiler_cfg.enable_auto_fusing = True
compiler_cfg.balancer_policy = "Ribbon"

# Setup model configuration
config = PhiConfig.from_pretrained("microsoft/phi-2")
config.use_cache = False
config.return_dict = False

# Load model and tokenizer with config
model = PhiForCausalLM.from_pretrained("microsoft/phi-2", config=config)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
tokenizer.pad_token, tokenizer.pad_token_id = (tokenizer.eos_token, tokenizer.eos_token_id)

# Disable DynamicCache
# See: https://github.com/tenstorrent/tt-buda/issues/42
model._supports_cache_class = False

# Example usage
prompt = ["My name is Jim Keller and"] * batch_size

# Initialize pipeline
text_generator = pybuda_pipeline(
"text-generation",
model=model,
tokenizer=tokenizer
)

# Inference on TT device
response = text_generator(
prompt,
temperature=0.7,
top_k=50,
top_p=0.9,
max_new_tokens=512,
num_beams=1,
do_sample=True,
no_repeat_ngram_size=5,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
early_stopping=True
)

# Display Responses
for batch_id in range(batch_size):
print(f"Batch: {batch_id}")
print(f"Response: {response[batch_id][0]['generated_text']}")
print()


if __name__ == "__main__":
run_phi2_causal_lm()
1 change: 1 addition & 0 deletions model_demos/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,6 @@ markers = [
"yolov6: tests that involve yolov6",
"segformer: tests that involve SegFormer",
"monodle: tests that involve Monodle",
"phi2: tests that involve Phi2",
"yolox: tests that involve YOLOX",
]
10 changes: 10 additions & 0 deletions model_demos/tests/test_pytorch_phi2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
# SPDX-License-Identifier: Apache-2.0

import pytest

from nlp_demos.phi2.pytorch_phi2_text_generation import run_phi2_causal_lm

@pytest.mark.phi2
def test_phi2_causal_lm_pytorch(clear_pybuda, test_device, batch_size):
run_phi2_causal_lm(batch_size=batch_size)