diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 39e048a6039..2f662081207 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -1678,6 +1678,16 @@ def __post_init__(self): ) self.sync_module_states = True + if self.cpu_ram_efficient_loading != bool( + str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) + ): + env_var = env_prefix + "CPU_RAM_EFFICIENT_LOADING" + warnings.warn( + f"The `cpu_ram_efficient_loading` flag for `FullyShardedDataParallelPlugin` does not match the environment variable {env_var}. " + "Setting environment variable to match `cpu_ram_efficient_loading`." + ) + os.environ[env_var] = str(self.cpu_ram_efficient_loading) + if isinstance(self.mixed_precision_policy, dict): self.set_mixed_precision(self.mixed_precision_policy)