Skip to content

Commit

Permalink
fix torch_dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Feb 6, 2025
1 parent 81d8a03 commit 1069ab0
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/accelerate/commands/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
is_timm_available,
is_transformers_available,
)

import torch

if is_transformers_available():
import transformers
Expand Down Expand Up @@ -120,7 +120,7 @@ def create_empty_model(model_name: str, library_name: str, trust_remote_code: bo
break
if value is not None:
constructor = getattr(transformers, value)
model = constructor.from_config(config, trust_remote_code=trust_remote_code)
model = constructor.from_config(config, torch_dtype=torch.float32, trust_remote_code=trust_remote_code)
elif library_name == "timm":
if not is_timm_available():
raise ImportError(
Expand Down

0 comments on commit 1069ab0

Please sign in to comment.