Skip to content

Won-Seong/simple-latent-diffusion-model

Repository files navigation

Welcome to the Simple Latent Diffusion Model

🌐 README in Korean: KR 한국어 버전

This repository provides a lightweight and modular implementation of a Latent Diffusion Model (LDM), which performs efficient image generation by operating in a lower-dimensional latent space instead of pixel space.

Dataset Generation Process of Latents Generated Data
Swiss-roll
CIFAR-10
CelebA

Text-to-image synthesis using CLIP-guided latent diffusion.

The table below showcases text-to-image generation using CLIP. The dataset used is the Asian Composite Dataset, with input text in Korean.

English Text Generated Image
A round face with voluminous, slightly long short hair, along with barely visible vocal cords, gives off a more feminine aura than a masculine one. The well-defined eyes and lips enhance the subject's delicate features, making them appear more refined and intellectual.
The hairstyle appears slightly unpolished, lacking a refined touch. The slightly upturned eyes give off a sharp and somewhat sensitive impression. Overall, they seem to have a slender physique and appear efficient in handling tasks, though their social interactions may not be particularly smooth.

Why doesn't the generated image match the input text?

This is because the sampling was performed without guidance. Using classifier-free guidance (CFG) during sampling can significantly improve sample quality and enhance the performance of conditional generation.

Tutorials

Usage

The following example demonstrates how to use the code in this repository.

import torch
import os

from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder
from helper.data_generator import DataGenerator
from helper.painter import Painter
from helper.trainer import Trainer
from helper.loader import Loader
from diffusion_model.models.latent_diffusion_model import LatentDiffusionModel
from diffusion_model.network.uncond_u_net import UnconditionalUnetwork
from diffusion_model.sampler.ddim import DDIM

# Path to the configuration file
CONFIG_PATH = './configs/cifar10_config.yaml'

# Instantiate helper classes
painter = Painter()
loader = Loader()
data_generator = DataGenerator()

# Load CIFAR-10 dataset
data_loader = data_generator.cifar10(batch_size=128)

# Train the Variational Autoencoder (VAE)
vae = VariationalAutoEncoder(CONFIG_PATH)  # Initialize the VAE model
trainer = Trainer(vae, vae.loss)  # Create a trainer for the VAE
trainer.train(dl=data_loader, epochs=1000, file_name='vae', no_label=True)  # Train the VAE

# Train the Latent Diffusion Model (LDM)
sampler = DDIM(CONFIG_PATH)  # Initialize the DDIM sampler
network = UnconditionalUnetwork(CONFIG_PATH)  # Initialize the U-Net network
ldm = LatentDiffusionModel(network, sampler, vae, image_shape=(3, 32, 32))  # Initialize the LDM
trainer = Trainer(ldm, ldm.loss)  # Create a trainer for the LDM
trainer.train(dl=data_loader, epochs=1000, file_name='ldm', no_label=True)  
# Train the LDM; set 'no_label=False' if the dataset includes labels

# Generate samples using the trained diffusion model
ldm = LatentDiffusionModel(network, sampler, vae, image_shape=(3, 32, 32))  # Re-initialize the LDM
loader.model_load('./diffusion_model/check_points/ldm_epoch1000', ldm, ema=True)  # Load the trained model
sample = ldm(n_samples=4)  # Generate 4 sample images
painter.show_images(sample)  # Display the generated images

References