Skip to content

Commit

Permalink
fix torch_dtype in estimate memory (#3383)
Browse files Browse the repository at this point in the history
* fix torch_dtype

* style

* add comments

* style
  • Loading branch information
SunMarc authored Feb 7, 2025
1 parent 81d8a03 commit f19b957
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/accelerate/commands/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from huggingface_hub import model_info
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError

Expand Down Expand Up @@ -62,7 +63,8 @@ def check_has_model(error):

def create_empty_model(model_name: str, library_name: str, trust_remote_code: bool = False, access_token: str = None):
"""
Creates an empty model from its parent library on the `Hub` to calculate the overall memory consumption.
Creates an empty model in full precision from its parent library on the `Hub` to calculate the overall memory
consumption.
Args:
model_name (`str`):
Expand Down Expand Up @@ -120,7 +122,8 @@ def create_empty_model(model_name: str, library_name: str, trust_remote_code: bo
break
if value is not None:
constructor = getattr(transformers, value)
model = constructor.from_config(config, trust_remote_code=trust_remote_code)
# we need to pass the dtype, otherwise it is going to use the torch_dtype that is saved in the config
model = constructor.from_config(config, torch_dtype=torch.float32, trust_remote_code=trust_remote_code)
elif library_name == "timm":
if not is_timm_available():
raise ImportError(
Expand Down

0 comments on commit f19b957

Please sign in to comment.