diff --git a/FlagEmbedding/inference/reranker/decoder_only/models/gemma_model.py b/FlagEmbedding/inference/reranker/decoder_only/models/gemma_model.py index 6c2c450e..38c16efe 100644 --- a/FlagEmbedding/inference/reranker/decoder_only/models/gemma_model.py +++ b/FlagEmbedding/inference/reranker/decoder_only/models/gemma_model.py @@ -53,7 +53,7 @@ ) from .gemma_config import CostWiseGemmaConfig from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2RotaryEmbedding, rotate_half, apply_rotary_pos_emb -from transformers.models.gemma2.modeling_gemma2 import Gemma2MLP, repeat_kv, Gemma2Attention, Gemma2FlashAttention2, Gemma2SdpaAttention, GEMMA2_ATTENTION_CLASSES, Gemma2DecoderLayer, GEMMA2_START_DOCSTRING +from transformers.models.gemma2.modeling_gemma2 import Gemma2MLP, repeat_kv, Gemma2Attention, Gemma2DecoderLayer, GEMMA2_START_DOCSTRING from transformers.models.gemma2.modeling_gemma2 import GEMMA2_INPUTS_DOCSTRING if is_flash_attn_2_available(): @@ -105,12 +105,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() -GEMMA2_ATTENTION_CLASSES = { - "eager": Gemma2Attention, - "flash_attention_2": Gemma2FlashAttention2, - "sdpa": Gemma2SdpaAttention, -} - _CONFIG_FOR_DOC = "CostWiseGemmaConfig" diff --git a/FlagEmbedding/inference/reranker/decoder_only/models/modeling_minicpm_reranker.py b/FlagEmbedding/inference/reranker/decoder_only/models/modeling_minicpm_reranker.py index 116828db..6d5f9379 100644 --- a/FlagEmbedding/inference/reranker/decoder_only/models/modeling_minicpm_reranker.py +++ b/FlagEmbedding/inference/reranker/decoder_only/models/modeling_minicpm_reranker.py @@ -41,7 +41,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \ SequenceClassifierOutputWithPast from transformers.modeling_utils import PreTrainedModel -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -63,6 +63,9 @@ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. +from packaging import version +parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version) +is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13") if is_torch_fx_available(): if not is_torch_greater_or_equal_than_1_13: import torch.fx diff --git a/setup.py b/setup.py index 45f8e89d..ea1f1789 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ include_package_data=True, install_requires=[ 'torch>=1.6.0', - 'transformers==4.44.2', + 'transformers>=4.44.2', 'datasets==2.19.0', 'accelerate>=0.20.1', 'sentence_transformers',