Skip to content

Latest commit

 

History

History
34 lines (24 loc) · 1.22 KB

README.md

File metadata and controls

34 lines (24 loc) · 1.22 KB

Examples

PyTorch Examples

Please use torchrun to launch the pytorch examples. Take simple_function.py for example:

# for single-node environment (2 GPUs for examle)
torchrun --nproc_per_node 2 --master_port 9543 ./torch/test_simple.py

# for multi-node environment (2 nodes for example, 2 GPUs each node):
# Machine1: 
torchrun --nnodes 2 --node_rank 0 \
         --master_addr [Machine1 IP] --master_port 9543 \
         ./torch/test_simple.py
# Machine2: 
torchrun --nnodes 2 --node_rank 1 \
         --master_addr [Machine1 IP] --master_port 9543 \
         ./torch/test_simple.py

For more details of torchrun please refer Torch Distributed Elastic.

Jax Examples

Please use mpirun to launch the pytorch examples. Take simple_function.py for example:

# for single-node environment (2 GPUs for examle)
mpirun -np 2 python ./examples/jax/simple_function.py

For multi-node environments, you may need to read docs Multi Process in Jax. Also, the function easydist_setup_jax in easydist/jax/__init__.py may need to be modified to launch process in a clustered environment such as SLURM.