Skip to content

Commit

Permalink
Merge pull request #1343 from Hypothesis-Z/fix-transformers-4.48.0
Browse files Browse the repository at this point in the history
fix transformers 4.48.0
  • Loading branch information
hanhainebula authored Feb 7, 2025
2 parents 0082bcf + 99dfb3d commit 42a25b1
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
)
from .gemma_config import CostWiseGemmaConfig
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2RotaryEmbedding, rotate_half, apply_rotary_pos_emb
from transformers.models.gemma2.modeling_gemma2 import Gemma2MLP, repeat_kv, Gemma2Attention, Gemma2FlashAttention2, Gemma2SdpaAttention, GEMMA2_ATTENTION_CLASSES, Gemma2DecoderLayer, GEMMA2_START_DOCSTRING
from transformers.models.gemma2.modeling_gemma2 import Gemma2MLP, repeat_kv, Gemma2Attention, Gemma2DecoderLayer, GEMMA2_START_DOCSTRING
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING

if is_flash_attn_2_available():
Expand Down Expand Up @@ -105,12 +105,6 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()

GEMMA2_ATTENTION_CLASSES = {
"eager": Gemma2Attention,
"flash_attention_2": Gemma2FlashAttention2,
"sdpa": Gemma2SdpaAttention,
}


_CONFIG_FOR_DOC = "CostWiseGemmaConfig"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \
SequenceClassifierOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
Expand All @@ -63,6 +63,9 @@

# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
from packaging import version
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
if is_torch_fx_available():
if not is_torch_greater_or_equal_than_1_13:
import torch.fx
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
include_package_data=True,
install_requires=[
'torch>=1.6.0',
'transformers==4.44.2',
'transformers>=4.44.2',
'datasets==2.19.0',
'accelerate>=0.20.1',
'sentence_transformers',
Expand Down

0 comments on commit 42a25b1

Please sign in to comment.