Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⬆️ Update dependencies #120

Merged
merged 13 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
from transformers import GemmaConfig, GenerationConfig, GenerationMixin


class GemmaConfigHf(GemmaConfig, gemma_config.GemmaConfig):
"""This class is used to support both the HF GemmaConfig and the Jetstream Pytorch GemmaConfig at the same time.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = None


class GemmaModelHf(GemmaModel, GenerationMixin):
"""Transformer module that uses HF GemmaConfig instead of Jetstream Pytorch GemmaConfig + device.

Expand All @@ -16,24 +25,8 @@ def __init__(
device,
env,
):
self.config = config
self.generation_config = GenerationConfig.from_model_config(config)

args = gemma_config.GemmaConfig(
vocab_size=config.vocab_size,
max_position_embeddings=config.max_position_embeddings,
num_hidden_layers=config.num_hidden_layers,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
head_dim=config.head_dim,
rms_norm_eps=config.rms_norm_eps,
dtype="bfloat16",
quant=False, # No quantization support for now
tokenizer=None,
)

args = GemmaConfigHf(**config.to_dict())
args.device = device
super().__init__(args, env)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@
from transformers import GenerationConfig, GenerationMixin, MixtralConfig


class MixtralConfigHf(MixtralConfig, mixtral_config.ModelArgs):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit brittle, as setting one argument will not update its aliased value. What is the benefit of adding this ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Transformer class from Jetstream is the model class for Mistral, and it has a config variable that used to be an instance of ModelArgs. When I use the model in TGI and want to use the GenerationMixin methods for sampling, it ends up using the config, that it expects it to be a transformers' MixtralConfig instance, and it fails because it is not. This is why I came up with this solution. I could try doing composition instead of heritage to see if I can find a cleaner solution though.

"""This class is used to support both the HF MixtralConfig and the Jetstream Pytorch ModelArgs at the same time.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_size = self.max_position_embeddings
self.n_layer = self.num_hidden_layers
self.n_head = self.num_attention_heads
self.dim = self.hidden_size
self.n_local_heads = self.num_local_experts or self.num_attention_heads
self.num_activated_experts = self.num_experts_per_tok
self.__post_init__()

class MixtralModelHf(Transformer, GenerationMixin):
"""Transformer module that uses HF MixtralConfig instead of Jetstream Pytorch MixtralConfig + device.
"""
Expand All @@ -14,20 +28,9 @@ def __init__(
device,
env,
):
self.config = config
self.generation_config = GenerationConfig.from_model_config(config)

args = mixtral_config.ModelArgs(
block_size=config.max_position_embeddings,
vocab_size=config.vocab_size,
n_layer=config.num_hidden_layers,
n_head=config.num_attention_heads,
dim=config.hidden_size,
intermediate_size=config.intermediate_size,
n_local_heads=config.num_local_experts or config.num_attention_heads,
num_activated_experts=config.num_experts_per_tok,
device=device,
)
args = MixtralConfigHf(**config.to_dict())
args.device = device
super().__init__(args, env)


Expand Down