-
Install dependencies
pip install -r requirements.txt
- Download the data and run training:
bash scripts/download_data.sh
python train.py --amp
This model was trained from scratch with 267 images of lungs (radiology) and there respective masks in bi-colored color channel. Managed to score a Dice coefficient of score:0.96035277 on over 5 epochs and 37 iterations.
Note : Use Python 3.6 or newer
> python train.py -h
usage: train.py [-h] [--epochs E] [--batch-size B] [--learning-rate LR]
[--load LOAD] [--scale SCALE] [--validation VAL] [--amp]
Train the UNet on images and target masks
optional arguments:
-h, --help show this help message and exit
--epochs E, -e E Number of epochs
--batch-size B, -b B Batch size
--load LOAD, -f LOAD Load model from a .pth file
--scale SCALE, -s SCALE
Downscaling factor of the images
--validation VAL, -v VAL
Percent of the data that is used as validation (0-100)
--amp Use mixed precision
By default, the scale
is 0.5, so if you wish to obtain better results (but use more memory), set it to 1.
After training your model and saving it to MODEL.pth
, you can easily test the output masks on your images via the CLI.
To predict a single image and save it:
python predict.py -i image.jpg -o output.jpg -m pathToModel
Choose highest epoch at model/data/checkpoints as your PTH model.
You can also download it using the helper script:
bash scripts/download_data.sh
&
bash scripts/img_channel_to_1.sh
To convert the images from 1 channel to 3. This would make the unnecessary need to change the UNet architecture,time is of the essence rn ;)
The input images and target masks should be in the data/imgs
and data/masks
folders respectively (note that the imgs
and masks
folder should not contain any sub-folder or any other files, due to the greedy data-loader).
i talked to Olaf (creator of UNet) to help me get in University of Freiburg for a masters program. He was not impressed. 😄
*spoiler alert he didn't help
Original paper by Olaf Ronneberger, Philipp Fischer, Thomas Brox:
U-Net: Convolutional Networks for Biomedical Image Segmentation