diff --git a/examples/fedavg_demo.py b/examples/fedavg_demo.py index 69dd3be..6990052 100644 --- a/examples/fedavg_demo.py +++ b/examples/fedavg_demo.py @@ -19,8 +19,8 @@ def test_fedavg(): org_ds = MNIST("dataset", train=True, download=True, transform=ToTensor()) test_ds = MNIST("dataset", train=False, download=True, transform=ToTensor()) - effective_size = len(org_ds) - len(org_ds) % args.NUM_CLIENT # type: ignore - idx_groups = torch.randperm(effective_size).reshape(args.NUM_CLIENT, -1) # type: ignore + effective_size = len(org_ds) - len(org_ds) % args.NUM_CLIENT + idx_groups = torch.randperm(effective_size).reshape(args.NUM_CLIENT, -1) fed_dss = [ClientDataset(org_ds, idx) for idx in idx_groups.tolist()] fed_loader = [DataLoader(ds, batch_size=32, shuffle=True) for ds in fed_dss] @@ -45,7 +45,7 @@ def test_fedavg(): test_loader=test_loader, criterion=criterion, args=args, - ).fit(args.NUM_CLIENT, args.ACTIVE_CLIENT, args.SERVER_EPOCHS) # type: ignore + ).fit(args.NUM_CLIENT, args.ACTIVE_CLIENT, args.SERVER_EPOCHS) if __name__ == "__main__": diff --git a/examples/fedprox_demo.py b/examples/fedprox_demo.py index 4622fe4..0c1c689 100644 --- a/examples/fedprox_demo.py +++ b/examples/fedprox_demo.py @@ -20,8 +20,8 @@ def test_fedavg(): org_ds = MNIST("dataset", train=True, download=True, transform=ToTensor()) test_ds = MNIST("dataset", train=False, download=True, transform=ToTensor()) - effective_size = len(org_ds) - len(org_ds) % args.NUM_CLIENT # type: ignore - idx_groups = torch.randperm(effective_size).reshape(args.NUM_CLIENT, -1) # type: ignore + effective_size = len(org_ds) - len(org_ds) % args.NUM_CLIENT + idx_groups = torch.randperm(effective_size).reshape(args.NUM_CLIENT, -1) fed_dss = [ClientDataset(org_ds, idx) for idx in idx_groups.tolist()] fed_loader = [DataLoader(ds, batch_size=32, shuffle=True) for ds in fed_dss] @@ -46,7 +46,7 @@ def test_fedavg(): test_loader=test_loader, criterion=criterion, args=args, - ).fit(args.NUM_CLIENT, args.ACTIVE_CLIENT, args.SERVER_EPOCHS) # type: ignore + ).fit(args.NUM_CLIENT, args.ACTIVE_CLIENT, args.SERVER_EPOCHS) if __name__ == "__main__": diff --git a/fedmind/algs/fedprox.py b/fedmind/algs/fedprox.py index ab9eca8..6d4b7a0 100644 --- a/fedmind/algs/fedprox.py +++ b/fedmind/algs/fedprox.py @@ -54,7 +54,7 @@ def _train_client( Returns: A dictionary containing the trained model parameters. """ - mu = args.PROX_MU # type: ignore + mu = args.PROX_MU # Train the model model.load_state_dict(gm_params) diff --git a/fedmind/server.py b/fedmind/server.py index 3307837..f25b8f1 100644 --- a/fedmind/server.py +++ b/fedmind/server.py @@ -42,7 +42,7 @@ def __init__( self.args = args self.gm_params = self.model.state_dict(destination=StateDict()) - optim: dict = self.args.OPTIM # type: ignore + optim: dict = self.args.OPTIM if optim["NAME"] == "SGD": self.optimizer = SGD(self.model.parameters(), lr=optim["LR"]) else: @@ -57,13 +57,13 @@ def __init__( ) logging.basicConfig( - level=args.LOG_LEVEL, # type: ignore + level=args.LOG_LEVEL, format="%(asctime)s %(levelname)s [%(processName)s] %(message)s", ) self.logger = logging.getLogger("Server") self.logger.info(f"Get following configs:\n{yaml.dump(args.to_dict())}") - if self.args.NUM_PROCESS > 0: # type: ignore + if self.args.NUM_PROCESS > 0: self.__init_mp__() def __init_mp__(self): @@ -78,17 +78,17 @@ def __init_mp__(self): # Start client processes self.processes = [] - for worker_id in range(self.args.NUM_PROCESS): # type: ignore + for worker_id in range(self.args.NUM_PROCESS): args = ( worker_id, self.task_queue, self.result_queue, self._train_client, self.model, - self.args.OPTIM, # type: ignore + self.args.OPTIM, self.criterion, - self.args.CLIENT_EPOCHS, # type: ignore - self.args.LOG_LEVEL, # type: ignore + self.args.CLIENT_EPOCHS, + self.args.LOG_LEVEL, self.args, ) p = mp.Process(target=self._create_worker_process, args=args) @@ -99,7 +99,7 @@ def __del_mp__(self): """Terminate multi-process environment.""" # Terminate all client processes - for _ in range(self.args.NUM_PROCESS): # type: ignore + for _ in range(self.args.NUM_PROCESS): self.task_queue.put("STOP") # Wait for all client processes to finish @@ -176,7 +176,7 @@ def fit(self, pool: int, num_clients: int, num_rounds: int): # 2. Synchornous clients training updates = [] - if self.args.NUM_PROCESS == 0: # type: ignore + if self.args.NUM_PROCESS == 0: # Serial simulation instead of parallel for cid in clients: updates.append( @@ -186,7 +186,7 @@ def fit(self, pool: int, num_clients: int, num_rounds: int): self.fed_loader[cid], self.optimizer, self.criterion, - self.args.CLIENT_EPOCHS, # type: ignore + self.args.CLIENT_EPOCHS, self.logger, self.args, ) @@ -208,7 +208,7 @@ def fit(self, pool: int, num_clients: int, num_rounds: int): self.wb_run.log(train_metrics | test_metrics) # Terminate multi-process environment - if self.args.NUM_PROCESS > 0: # type: ignore + if self.args.NUM_PROCESS > 0: self.__del_mp__() # Finish wandb run and sync diff --git a/fedmind/utils.py b/fedmind/utils.py index 8d005cb..da113a4 100644 --- a/fedmind/utils.py +++ b/fedmind/utils.py @@ -262,6 +262,13 @@ def __setattr__(self, name, value): __setitem__ = __setattr__ + def __getattr__(self, name): + if name in self: + return self[name] + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + def update(self, e=None, **f): d = e or dict() d.update(f) diff --git a/pyproject.toml b/pyproject.toml index a2b62b3..052d013 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "fedmind" -version = "0.1.4" +version = "0.1.5" description = "Federated Learning research framework in your mind" readme = "README.md" requires-python = ">=3.12" diff --git a/test/test_fedavg.py b/test/test_fedavg.py index bf8ac4b..dad938e 100644 --- a/test/test_fedavg.py +++ b/test/test_fedavg.py @@ -19,8 +19,8 @@ def test_fedavg(): org_ds = MNIST("dataset", train=True, download=True, transform=ToTensor()) test_ds = MNIST("dataset", train=False, download=True, transform=ToTensor()) - effective_size = len(org_ds) - len(org_ds) % args.NUM_CLIENT # type: ignore - idx_groups = torch.randperm(effective_size).reshape(args.NUM_CLIENT, -1) # type: ignore + effective_size = len(org_ds) - len(org_ds) % args.NUM_CLIENT + idx_groups = torch.randperm(effective_size).reshape(args.NUM_CLIENT, -1) fed_dss = [ClientDataset(org_ds, idx) for idx in idx_groups.tolist()] fed_loader = [DataLoader(ds, batch_size=32, shuffle=True) for ds in fed_dss] @@ -45,7 +45,7 @@ def test_fedavg(): test_loader=test_loader, criterion=criterion, args=args, - ).fit(args.NUM_CLIENT, args.ACTIVE_CLIENT, args.SERVER_EPOCHS) # type: ignore + ).fit(args.NUM_CLIENT, args.ACTIVE_CLIENT, args.SERVER_EPOCHS) assert True diff --git a/uv.lock b/uv.lock index 8c19044..45f4490 100644 --- a/uv.lock +++ b/uv.lock @@ -88,7 +88,7 @@ wheels = [ [[package]] name = "fedmind" -version = "0.1.4" +version = "0.1.5" source = { editable = "." } dependencies = [ { name = "numpy" },