A Flower-based federated learning application for medical image analysis tasks (classification and segmentation) using PyTorch.
- Two Medical Use Cases:
- 🧠 Brain Tumor Segmentation (BRATS 3D to 2D slices)
- 📊 Pathological Classification (MedMNIST dataset)
- Supported Models:
- Classification: ResNet-18, TinyViT
- Segmentation: U-Net, SegFormer
- Federated Learning Features:
- TensorBoard integration for training monitoring
- Best model checkpointing
- Centralized evaluation
git clone https://github.com/MagnusS0/fl-for-AI-health.git
cd fl-for-AI-health
pip install -e .
bash
flwr run .
Make sure you have downloaded the BRATS dataset and set the paths in the .env
file.
- Update
pyproject.toml
:
[tool.flwr.app.components]
# serverapp = "fl_for_ai_health.classification.class_server:app"
# clientapp = "fl_for_ai_health.classification.class_client:app"
serverapp = "fl_for_ai_health.segmentation.seg_server:app"
clientapp = "fl_for_ai_health.segmentation.seg_client:app"
[tool.flwr.app.config]
in-channels = 1
num-classes = 4
model="u-net" or "segformer"
- Run simulation:
flwr run .
On the first run this will build the dataset from 3D to 2D axial slices this might take some time. On connecutive runs this will run much faster. Alternativly run the dataset script first.
Key configuration options in pyproject.toml
:
[tool.flwr.app.config]
num-server-rounds = 10 # Total federation rounds
fraction-fit = 0.5 # Fraction of clients used for training
local-epochs = 1 # Local client epochs
batch-size = 64 # Training batch size
learning-rate = 4e-3 # Initial learning rate
img-size = 64 # Input image size
model = "tiny-vit" # Model architecture
in-channels = 3 # Input channels
num-classes = 9 # Output classes
TensorBoard logs are saved in tb_logs/
. Launch TensorBoard with:
tensorboard --logdir tb_logs/
-
MedMNIST:
- Automatically downloaded via Hugging Face Datasets
- Preprocessed into train/test/val splits
-
BRATS:
- Requires manual download from Medical Decathlon
- Preprocessing handled by
data/brats.py