Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional states #458

Open
yongyanrao opened this issue Oct 5, 2023 · 4 comments
Open

Additional states #458

yongyanrao opened this issue Oct 5, 2023 · 4 comments
Assignees

Comments

@yongyanrao
Copy link

We noticed some additional states for each module, e.g.,

transformer.seq_layers.0.layer.self_attention.layernorm_qkv._extra_state
transformer.seq_layers.0.layer.self_attention.proj._extra_state
transformer.seq_layers.0.layer.layernorm_mlp._extra_state

And these states are empty binary strings b''. We are thinking these new states are related to fp8. How should we deal with them? Should we explicitly remove them? Or should we deal with them by some explicit methods?

@ptrendx
Copy link
Member

ptrendx commented Oct 5, 2023

Why do you want to remove them? Those states are handled internally by Transformer Engine if FP8 is used.

@Teng-xu
Copy link

Teng-xu commented Oct 5, 2023

I am observing the same behavior during training without FP8, and I believe that these states are causing problems when attempting to load checkpoints into the model, especially when there is no "_extra_state" present in the checkpoint. Is there a method to deactivate or exclude these fields during training without FP8, given that they are all empty?

@ksivaman
Copy link
Member

ksivaman commented Jan 8, 2024

@Teng-xu @yongyanrao These extra states are indeed a part of the additional information needed for FP8 training checkpoint. These can be explicitly removed but the simplest method would be to load the checkpoint using the strict=False flag when using PyTorch's load state dict method.

@zte-tcb
Copy link

zte-tcb commented Apr 7, 2024

You can read _extra_state with code like this instead of state.read(). this can show _extra_state.

if isinstance(state, io.BytesIO):
    state.seek(0)
    state = torch.load(state)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants