-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* chore: update transformers dependency * feat: import transformer's gemma modeling code It will be used to adapt it for sharding. Only imports have been adapted, and only code relevant for GemmaForCausalLM has been added. * chore: rename model Gemma -> TpuGemma to prepare for changes * feat(DistributedModel): added config property * chore: rename test_parallel_proxy.py -> test_distributed_model.py * fix: use AutoModelForCausalLM instead of TpuModelForCausalLM * feat: AutoModelForCausalLM will choose TpuGemmaForCausalLM if possible * fix(TpuGemma): avoid using device_map when loading model It seems that device_map parameter triggers a chain of calls that will try to use accelerate to load the model using less memory. The problem is that it skips the load state pre-hooks, making the weights loading impossible. * feat(gemma): sharding o_proj It will now be running in parallel. More changes to come. * feat(gemma): sharding on q_proj * feat(gemma): sharding on k and v proj * feat(gemma): sharding on mlp gate and up proj * feat(gemma): sharding on mlp down proj * feat: model il loaded using pytorch_dtype from config This will lead to loading the model in bfloat16 when specified in the config. * fix: remove useless import * feat(tests): added test showing gemma7b sharding and prefill works * chore: config_name_to_class uses config.model_type now * fix: get_generation_mode is now a method of generation_config API change when transformers was updated. * fix(TGI server): fix slot.stopped changed after transformers update * fix(generator): fix sample generation again I wrongly chose the model's generation config instead of the one to the token selector. * fix: better handle torch_dtype bfloat16 will be set by default in gemma models, other models will still load in float32 by default. * fix: remove unused import
- Loading branch information
1 parent
d5b921e
commit 8e12733
Showing
13 changed files
with
1,438 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.