diff --git a/.github/workflows/test-pytorch-xla-tpu-tgi.yml b/.github/workflows/test-pytorch-xla-tpu-tgi.yml index 387ac07d..538f2f02 100644 --- a/.github/workflows/test-pytorch-xla-tpu-tgi.yml +++ b/.github/workflows/test-pytorch-xla-tpu-tgi.yml @@ -29,6 +29,8 @@ jobs: #uses: actions/checkout@v4 - name: slepp run: | + apt update + apt install pip pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip install --upgrade pip python -c 'import jax; print("TPU cores:", jax.device_count())'