Skip to content

AI-Hypercomputer/torchprime

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

97 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TorchPrime

Note: this repository is under experimental status. We don't yet have correctness guarantees for the model implementations.

TorchPrime is a reference model implementation for PyTorch on TPU/GPU using torch_xla and torchax. It is designed to showcase best practices for high-performance model training with these frameworks.

Examples

Here is a simple example of training on a single TPU VM. It assumes that you have already installed torch_xla 1 and torchax 2 following their respective project READMEs.

Install torchprime:

git clone https://github.com/AI-Hypercomputer/torchprime.git
cd torchprime
pip install -e '.[dev]'

Local training

Train Llama 3 8B using torch_xla:

export HF_TOKEN='...your huggingface token...'
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 python3 torchprime/torch_xla_models/train.py

Train Llama 3 8B using torchax:

python3 torchprime/experimental/torchax_models/run.py global_batch_size=16

Refer to README.md in torchprime/torch_xla_models and torchprime/experimental/torchax_models for more details.

Configuring training

torchprime uses hydra to read configurations (e.g. model name, batch size) from the command line and .yaml files.

In both torch_xla_models and torchax_models directories, you'll find a configs/default.yaml. That specifies the default configuration for the trainer. You may override configs on the command line with a key=value syntax. For example, the following command will train Mixtral 8x7B with a global batch size of 256, and set the FSDP SPMD mesh axis length to 64:

python3 torchprime/torch_xla_models/train.py \
    model=mixtral-8x7b \
    global_batch_size=256 \
    mesh.fsdp=64

You may refer to the hydra docs for other ways to specify configs.

Distributed training

torchprime uses xpk as the standard path for iterating on distributed training code.

First teach torchprime about the XPK cluster it is using, the artifact storage location, etc. You only need to do this on first clone or when switching to a different topology or cluster. Example:

tp use \
    --cluster <XPK CLUSTER NAME> \
    --project my-gcp-project \
    --zone us-east5-b \
    --num-slices 1 \
    --tpu-type v6e-256 \
    --artifact-dir gs://bucket/dir

Then prepend tp run to a particular Python file you would like to run remotely, including arguments, e.g.

tp run torchprime/experimental/torchax_models/run.py global_batch_size=256

tp run will broadcast this command to all VMs in the XPK cluster, which is the convention for running SPMD distributed workloads.

Env var passed to the workload

tp run will pick up these environment variables locally and proxy them to the distributed workload, if found:

Structure

This repo will contain a set of reference models that we have optimized and runs well on TPU. The best performing scaling configuration (parallelism techniques, checkpointing, etc.) for a model on various hardwares will be provided for ease of reproducibility.

docs contains guides for optimizing performance and debugging issues.

torchprime/launcher contains scripts to train a model on a large TPU cluster.

torchprime/data contains dataset and data loading utilities.

torchprime/torch_xla_models contains model implementations using torch_xla.

torchprime/experimental/torchax_models contains model implementations using torchax.

Finally, each model may also provide a GPU "original" version that illustrates and attributes where this model code came from, if any. This also helps to show case what changes we have done to make it performant on TPU. The original version is not expected to be run.

Run huggingface transformer models

Torchprime supports run with huggingface models by taking advantage of tp run. To use huggingface models, you can clone huggingface/transformers under torchprime and name it as local_transformers. This allows you to pick any branch or make code modifications in transformers for experiment:

git clone https://github.com/huggingface/transformers.git local_transformers

If huggingface transformer doesn't exist, torchprime will automatically clone the repo and build the docker for experiment. To switch to huggingface models, add flag --use-hf to tp run command:

tp run --use-hf torchprime/hf_models/train.py

Contributing

Contributions are welcome! Please feel free to submit a pull request.

When developing, use pip install -e '.[dev]' to install dev dependencies such as linter and formatter.

How to run tests:

pytest

How to run some of the tests, and re-run them whenever you change a file:

tp -i test ... # replace with path to tests/directories

How to format:

ruff format

How to lint:

ruff check [--fix]

You can install a Ruff VSCode plugin to check errors and format files from the editor.

License

This project is licensed under the New BSD License - see the LICENSE file for details.

For more information on PyTorch/XLA, visit the official documentation.

About

TorchPrime is a reference model implementation for PyTorch on TPU/GPU.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published