Skip to content

Commit

Permalink
Merge pull request #71 from Xiao-Chenguang/func-doc-string
Browse files Browse the repository at this point in the history
Update the doc strings
  • Loading branch information
Xiao-Chenguang authored Dec 3, 2024
2 parents cd13e88 + 069dca2 commit 5150f80
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
1 change: 1 addition & 0 deletions fedmind/algs/fedprox.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def _train_client(
criterion: The loss function to use.
epochs: The number of epochs to train the model.
logger: The logger object to log the training process.
config: The configuration dict.
Returns:
A dictionary containing the trained model parameters.
Expand Down
2 changes: 2 additions & 0 deletions fedmind/algs/mfl.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def _train_client(
criterion: The loss function to use.
epochs: The number of epochs to train the model.
logger: The logger object to log the training process.
config: The configuration dict.
momentum_buffer: The momentum buffer to use.
Returns:
A dictionary containing the trained model parameters.
Expand Down
14 changes: 10 additions & 4 deletions fedmind/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,10 @@ def _test_server(
test_loader: The DataLoader object that contains the test data.
criterion: The loss function to use.
logger: The logger object to log the testing process.
config: The configuration dictionary.
Returns:
The evaluation metrics.
The evaluation metrics dict.
"""

total_loss = 0
Expand Down Expand Up @@ -300,9 +301,10 @@ def _train_client(
criterion: The loss function to use.
epochs: The number of epochs to train the model.
logger: The logger object to log the training process.
config: The configuration dictionary.
Returns:
A dictionary containing the trained model parameters.
A dictionary containing the trained model parameters, training loss and more.
"""
# Train the model
model.load_state_dict(gm_params)
Expand Down Expand Up @@ -342,17 +344,21 @@ def _create_worker_process(
log_level: int,
config: EasyDict,
):
"""Train process for multi-process environment.
"""Train subprocess for multi-process environment.
Args:
worker_id: The worker process id.
task_queue: The task queue for task distribution.
result_queue: The result queue for result collection.
result_queue: The result queue for training.
test_queue: The result queue for testing.
client_func: The client function to train the model.
test_func: The function to test the model.
model: The model to train.
optim: dictionary containing the optimizer parameters.
criterion: The loss function to use.
epochs: The number of epochs to train the model.
log_level: The logging level.
config: The configuration dictionary.
"""
logging.basicConfig(
level=log_level,
Expand Down

0 comments on commit 5150f80

Please sign in to comment.