From f9a659373233572c76bb59faa872bea407607bc4 Mon Sep 17 00:00:00 2001 From: Alvaro Moran <6949769+tengomucho@users.noreply.github.com> Date: Mon, 29 Apr 2024 09:53:38 +0200 Subject: [PATCH] Fix tests with do_sample=True (#30) * fix(tests): update expected result when do_sample=True * chore: try to set seed on xla to stabilize tests --- optimum/tpu/generation/token_selector.py | 2 ++ text-generation-inference/tests/test_gpt2.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/optimum/tpu/generation/token_selector.py b/optimum/tpu/generation/token_selector.py index 97469f47..2ebc2d4a 100644 --- a/optimum/tpu/generation/token_selector.py +++ b/optimum/tpu/generation/token_selector.py @@ -3,6 +3,7 @@ from typing import Optional import torch +import torch_xla.core.xla_model as xm from transformers.generation import ( GenerationConfig, GenerationMixin, @@ -52,6 +53,7 @@ def __init__( self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.logits_warper = logits_warper + xm.set_rng_state(seed) self.generator = torch.Generator() self.generator.manual_seed(seed) diff --git a/text-generation-inference/tests/test_gpt2.py b/text-generation-inference/tests/test_gpt2.py index 1f26b9cf..6e883da4 100644 --- a/text-generation-inference/tests/test_gpt2.py +++ b/text-generation-inference/tests/test_gpt2.py @@ -72,8 +72,8 @@ def create_request( ], [ "It was a bright cold day in April, and the clocks were striking thirteen.", - 775, - " We", + 1439, + " All", True, ], ], @@ -113,7 +113,7 @@ def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_ [ "It was a bright cold day in April, and the clocks were striking thirteen.", 20, - " We sat at the front door and watched the clock on a box of yellow paper and found it almost", + " All day the sun had set, as was well-known. The first thing I noticed when I", True, ], ],