-
Notifications
You must be signed in to change notification settings - Fork 424
/
Copy pathgcp_job.yaml
60 lines (48 loc) · 1.82 KB
/
gcp_job.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Job config.
#
# Usage:
# oumi launch up --config configs/recipes/vision/smolvlm/sft/gcp_job.yaml --cluster smolvlm-vision
#
# See Also:
# - Documentation: https://oumi.ai/docs/en/latest/user_guides/launch/launch.html
# - Config class: oumi.core.configs.JobConfig
# - Config source: https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/job_config.py
# - Other job configs: configs/**/*job.yaml
name: smolvlm-sft-trl-train
resources:
cloud: gcp
accelerators: "A100:4"
use_spot: false
disk_size: 1000 # Disk size in GBs
num_nodes: 1 # Set it to N for multi-node training.
# Upload working directory to remote ~/sky_workdir.
working_dir: .
# Mount local files.
file_mounts:
~/.netrc: ~/.netrc # WandB credentials
# Mount HF token, which is needed to download locked-down models from HF Hub.
# This is created on your local machine by running `huggingface-cli login`.
~/.cache/huggingface/token: ~/.cache/huggingface/token
envs:
WANDB_PROJECT: oumi-train
OUMI_RUN_NAME: smolvlm.fft.oumi
setup: |
set -e
pip install uv && uv pip install oumi[gpu] hf_transfer
# Install model from HF Hub. This tool increases download speed compared to
# downloading the model during training.
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download HuggingFaceTB/SmolVLM-Instruct --exclude "onnx/*" "runs/*"
pip install -U flash-attn --no-build-isolation
run: |
set -e # Exit if any command failed.
source ./configs/examples/misc/sky_init.sh
set -x
oumi distributed torchrun \
-m oumi train \
-c configs/recipes/vision/smolvlm/sft/train.yaml \
--training.run_name "${OUMI_RUN_NAME}.${SKYPILOT_TASK_ID}" \
--training.trainer_type OUMI \
--training.max_steps 20 \
--training.save_steps 0 \
--training.save_final_model false
echo "Node ${SKYPILOT_NODE_RANK} is all done!"