Skip to content

Commit

Permalink
Compare logits.grad instead
Browse files Browse the repository at this point in the history
  • Loading branch information
Tcc0403 committed Jan 29, 2025
1 parent 74f4ad8 commit be52cfd
Showing 1 changed file with 43 additions and 17 deletions.
60 changes: 43 additions & 17 deletions test/convergence/test_mini_models_with_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@
try:
# Qwen2-VL is only available in transformers>4.44.2
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLConfig
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLForConditionalGeneration,
)

QWEN2_VL_AVAILABLE = True
except ImportError:
Expand Down Expand Up @@ -434,7 +436,9 @@ def run_mini_model(

model = create_model(model_name).to(dtype).to(device)
train_dataset = load_from_disk(DEFAULT_DATASET_PATH)
loader = DataLoader(train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn)
loader = DataLoader(
train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn
)
loader_iter = iter(loader)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

Expand All @@ -444,6 +448,7 @@ def run_mini_model(
batch = next(loader_iter).to(model.device)
optimizer.zero_grad()
output = model(**batch)
output.logits.retain_grad() # For comparing logits.grad
output.loss.backward()
optimizer.step()
print(f"Step {i}, Loss: {output.loss.item()}")
Expand All @@ -468,7 +473,9 @@ def run_mini_model(
1e-2,
1e-2,
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
"mini_mllama",
Expand Down Expand Up @@ -498,7 +505,9 @@ def run_mini_model(
1e-2,
1e-2,
marks=[
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
pytest.mark.skipif(
not MLLAMA_AVAILABLE,
reason="Mllama not available in this version of transformers",
Expand All @@ -517,7 +526,9 @@ def run_mini_model(
1e-2,
1e-2,
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
"mini_qwen2_vl",
Expand Down Expand Up @@ -547,7 +558,9 @@ def run_mini_model(
1e-2,
1e-2,
marks=[
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
pytest.mark.skipif(
not QWEN2_VL_AVAILABLE,
reason="Qwen2-VL not available in this version of transformers",
Expand All @@ -566,7 +579,9 @@ def run_mini_model(
1e-2,
1e-2,
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
pytest.param(
Expand All @@ -580,7 +595,9 @@ def run_mini_model(
1e-2,
1e-2,
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
# TODO: mixtral is flaky so disable the test for now
# ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
Expand Down Expand Up @@ -612,7 +629,9 @@ def run_mini_model(
1e-2,
1e-2,
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
pytest.param(
Expand All @@ -626,7 +645,9 @@ def run_mini_model(
1e-2,
1e-2,
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
# TODO: Gemma2 test for bf16 is not passing within the tolerance range, might be casting issue, need to investigate
Expand Down Expand Up @@ -661,9 +682,13 @@ def test_mini_model(
):
# Non-liger models should be initialized and tested first to avoid the module being overridden

expected_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr)
expected_output = run_mini_model(
model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr
)

actual_output = run_mini_model(model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True)
actual_output = run_mini_model(
model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True
)

# Compare every step of the loss
assert_verbose_allclose(
Expand All @@ -673,12 +698,11 @@ def test_mini_model(
rtol=loss_rtol,
)

# No logits are materialized
# import pdb; pdb.set_trace()
# Compare the logits from the last step
# Compare the logits.grad from the last step instead of logits, liger implementation doesn't keep logits
assert_verbose_allclose(
expected_output["logits"],
actual_output["logits"],
expected_output["logits"].grad,
actual_output["logits"].grad,
atol=logits_atol,
rtol=logits_rtol,
)
Expand All @@ -690,4 +714,6 @@ def test_mini_model(
actual_output["model"].named_parameters(),
strict=False,
):
assert_verbose_allclose(expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol)
assert_verbose_allclose(
expected_param[1], actual_param[1], atol=param_atol, rtol=param_rtol
)

0 comments on commit be52cfd

Please sign in to comment.