create an offset neural work in the flow matching model as a deterministic offset to overcome the oscillation ———— branch offset
redistribute the noise matrices to the image that has the shortest distance when training, simulate the idea of OT-CFM ———— branch conv
create a convolution method to match the shortest distance, without calculating too high dimension distance matrix ———— branch conv
use importance sampling when we sample the time ———— branch logit
Finally, the advantages are mainly in the earlier convergence stage and the better qualitative outcome
(xiaweijiexox is the student author, and the main innovation is in the /model/train_flow_latent.py
The baseline is https://github.com/VinAIResearch/LFM)
Python 3.10
and Pytorch 1.13.1
/2.0.0
are used in this implementation.
Please install required libraries:
pip install -r requirements.txt
For CelebA HQ 256, FFHQ 256 and LSUN, please check NVAE's instructions out.
For higher resolution datasets (CelebA HQ 512 & 1024), please refer to WaveDiff's documents.
For ImageNet dataset, please download it directly from the official website.
All training scripts are wrapped in run.sh. Simply comment/uncomment the relevant commands and run bash run.sh
.
Run run_test.sh / run_test_cls.sh with corresponding argument's file.
bash run_test.sh <path_to_arg_file>
Only 1 gpu is required.
These arguments are specified as follows:
MODEL_TYPE=DiT-L/2
EPOCH_ID=475
DATASET=celeba_256
EXP=celeb_f8_dit
METHOD=dopri5
STEPS=0
USE_ORIGIN_ADM=False
IMG_SIZE=256
Argument's files and checkpoints are provided below:
Exp | Args | FID | Checkpoints |
---|---|---|---|
celeb_f8_dit | test_args/celeb256_dit.txt | 5.26 | model_475.pth |
ffhq_f8_dit | test_args/ffhq_dit.txt | 4.55 | model_475.pth |
bed_f8_dit | test_args/bed_dit.txt | 4.92 | model_550.pth |
church_f8_dit | test_args/church_dit.txt | 5.54 | model_575.pth |
imnet_f8_ditb2 | test_args/imnet_dit.txt | 4.46 | model_875.pth |
celeb512_f8_adm | test_args/celeb512_adm.txt | 6.35 | model_575.pth |
celeba_f8_adm | test_args/celeb256_adm.txt | 5.82 | --- |
ffhq_f8_adm | test_args/ffhq_adm.txt | 5.82 | --- |
bed_f8_adm | test_args/bed_adm.txt | 7.05 | --- |
church_f8_adm | test_args/church_adm.txt | 7.7 | --- |
imnet_f8_adm | test_args/imnet_adm.txt | 8.58 | --- |
Please put downloaded pre-trained models in saved_info/latent_flow/<DATASET>/<EXP>
directory where <DATASET>
is defined as in bash_scripts/run.sh.
Utilities
To measure time, please add --measure_time
in the script.
To compute the number of function evaluations of adaptive solver (default: dopri5
), please add --compute_nfe
in the script.
To use fixed-steps solver (e.g. euler
and heun
), please add --use_karras_samplers
and change two arguments as follow:
METHOD=heun
STEPS=50
To evaluate FID scores, please download pre-computed stats from here and put it to pytorch_fid
.
Then run bash run_test_ddp.sh
for unconditional generation and bash run_test_cls_ddp.sh
for conditional generation. By default, multi-gpu sampling with 8 GPUs is supported for faster compute.
Computing stats for new dataset
pytorch_fid/compute_dataset_stat.py
is provided for this purpose.
python pytorch_fid/compute_dataset_stat.py \
--dataset <dataset> --datadir <path_to_data> \
--image_size <image_size> --save_path <path_to_save>