From 299768a259d3ec2aeb290103bc33ad2708658f16 Mon Sep 17 00:00:00 2001 From: csaybar Date: Thu, 24 Oct 2024 14:45:46 +0000 Subject: [PATCH] deploy: 2362b473f3db51c9f6ad25891d7f4ce426c75dd5 --- sitemap.xml.gz | Bin 268 -> 268 bytes supers2/main.py | 226 ++-- supers2/models/diffusion.py | 216 +++ supers2/models/opensr_diffusion/__init__.py | 0 .../opensr_diffusion/autoencoder/__init__.py | 0 .../autoencoder/autoencoder.py | 602 +++++++++ .../opensr_diffusion/autoencoder/utils.py | 460 +++++++ .../opensr_diffusion/denoiser/__init__.py | 0 .../models/opensr_diffusion/denoiser/unet.py | 832 ++++++++++++ .../models/opensr_diffusion/denoiser/utils.py | 1185 +++++++++++++++++ .../opensr_diffusion/diffusion/__init__.py | 0 .../diffusion/latentdiffusion.py | 932 +++++++++++++ .../opensr_diffusion/diffusion/utils.py | 901 +++++++++++++ supers2/models/utils.py | 170 +++ supers2/setup.py | 37 +- 15 files changed, 5432 insertions(+), 129 deletions(-) create mode 100644 supers2/models/diffusion.py create mode 100644 supers2/models/opensr_diffusion/__init__.py create mode 100644 supers2/models/opensr_diffusion/autoencoder/__init__.py create mode 100644 supers2/models/opensr_diffusion/autoencoder/autoencoder.py create mode 100644 supers2/models/opensr_diffusion/autoencoder/utils.py create mode 100644 supers2/models/opensr_diffusion/denoiser/__init__.py create mode 100644 supers2/models/opensr_diffusion/denoiser/unet.py create mode 100644 supers2/models/opensr_diffusion/denoiser/utils.py create mode 100644 supers2/models/opensr_diffusion/diffusion/__init__.py create mode 100644 supers2/models/opensr_diffusion/diffusion/latentdiffusion.py create mode 100644 supers2/models/opensr_diffusion/diffusion/utils.py create mode 100644 supers2/models/utils.py diff --git a/sitemap.xml.gz b/sitemap.xml.gz index dba3ab5ad23f5d29e89643e0e03151ee22b4811e..d5fc83ca2dfe51969ecf7ab30e1bafe8b9b8cb35 100644 GIT binary patch delta 15 WcmeBS>S1D&@8;l`8M~2S1D&@8;m(7v9Lm#s~l!0s_4N diff --git a/supers2/main.py b/supers2/main.py index cbc14ce..9fca29a 100644 --- a/supers2/main.py +++ b/supers2/main.py @@ -9,57 +9,55 @@ def setmodel( resolution: Literal["2.5m", "5m", "10m"] = "2.5m", - SR_model_name: Literal["cnn", "swin", "mamba"] = "cnn", - SR_model_size: Literal[ - "lightweight", "small", "medium", "expanded", "large" - ] = "small", + SR_model_name: Literal["cnn", "swin", "mamba", "diffusion"] = "cnn", + SR_model_size: Literal["lightweight", "small", "medium", "expanded", "large"] = "small", SR_model_loss: Literal["l1", "superloss", "adversarial"] = "l1", Fusionx2_model_name: Literal["cnn", "swin", "mamba"] = "cnn", - Fusionx2_model_size: Literal[ - "lightweight", "small", "medium", "expanded", "large" - ] = "lightweight", + Fusionx2_model_size: Literal["lightweight", "small", "medium", "expanded", "large"] = "lightweight", Fusionx4_model_name: Literal["cnn", "swin", "mamba"] = "cnn", - Fusionx4_model_size: Literal[ - "lightweight", "small", "medium", "expanded", "large" - ] = "lightweight", + Fusionx4_model_size: Literal["lightweight", "small", "medium", "expanded", "large"] = "lightweight", weights_path: Union[str, pathlib.Path, None] = None, + device: str = "cpu", + **kwargs ) -> dict: """ Sets up models for super-resolution and fusion tasks based on the specified parameters. Args: - resolution (Literal["2.5m", "5m", "10m"], optional): - Target spatial resolution. Determines which models to load. + resolution (Literal["2.5m", "5m", "10m"], optional): + Target spatial resolution. Determines which models to load. Defaults to "2.5m". - SR_model_name (Literal["cnn", "swin", "mamba"], optional): - The super-resolution model to use. + SR_model_name (Literal["cnn", "swin", "mamba"], optional): + The super-resolution model to use. Options: "cnn", "swin", "mamba". Defaults to "cnn". - SR_model_size (Literal["lightweight", "small", "medium", "expanded", "large"], optional): - Size of the super-resolution model. - Options: "lightweight", "small", "medium", "expanded", "large". + SR_model_size (Literal["lightweight", "small", "medium", "expanded", "large"], optional): + Size of the super-resolution model. + Options: "lightweight", "small", "medium", "expanded", "large". Defaults to "small". - SR_model_loss (Literal["l1", "superloss", "adversarial"], optional): - Loss function used in training the super-resolution model. + SR_model_loss (Literal["l1", "superloss", "adversarial"], optional): + Loss function used in training the super-resolution model. Options: "l1", "superloss", "adversarial". Defaults to "l1". - Fusionx2_model_name (Literal["cnn", "swin", "mamba"], optional): + Fusionx2_model_name (Literal["cnn", "swin", "mamba"], optional): Model for Fusion X2 (e.g., 20m -> 10m resolution). Options: "cnn", "swin", "mamba". Defaults to "cnn". - Fusionx2_model_size (Literal["lightweight", "small", "medium", "expanded", "large"], optional): - Size of the Fusion X2 model. - Options: "lightweight", "small", "medium", "expanded", "large". + Fusionx2_model_size (Literal["lightweight", "small", "medium", "expanded", "large"], optional): + Size of the Fusion X2 model. + Options: "lightweight", "small", "medium", "expanded", "large". Defaults to "lightweight". - Fusionx4_model_name (Literal["cnn", "swin", "mamba"], optional): + Fusionx4_model_name (Literal["cnn", "swin", "mamba"], optional): Model for Fusion X4 (e.g., 10m -> 2.5m resolution). Options: "cnn", "swin", "mamba". Defaults to "cnn". - Fusionx4_model_size (Literal["lightweight", "small", "medium", "expanded", "large"], optional): - Size of the Fusion X4 model. - Options: "lightweight", "small", "medium", "expanded", "large". + Fusionx4_model_size (Literal["lightweight", "small", "medium", "expanded", "large"], optional): + Size of the Fusion X4 model. + Options: "lightweight", "small", "medium", "expanded", "large". Defaults to "lightweight". - weights_path (Union[str, pathlib.Path, None], optional): - Path to the pre-trained model weights. + weights_path (Union[str, pathlib.Path, None], optional): + Path to the pre-trained model weights. Can be a string or pathlib.Path object. Defaults to None. If None, the code will try to retrieve the weights from the official repository. + device (str, optional): Device to use for the models. Defaults to "cpu". + **kwargs: Additional keyword arguments to pass to the models. Returns: dict: A dictionary containing the loaded models for super-resolution and fusion tasks. @@ -72,7 +70,7 @@ def setmodel( weights_path = pathlib.Path.home() / ".config" / "supers2" weights_path.mkdir(parents=True, exist_ok=True) - # If the resolution is 10m we only run the FusionX2 model that + # If the resolution is 10m we only run the FusionX2 model that # converts 20m bands to 10m if resolution == 10: return { @@ -80,7 +78,7 @@ def setmodel( model_name=Fusionx2_model_name, model_size=Fusionx2_model_size, model_loss="l1", - weights_path=weights_path, + weights_path=weights_path ), "FusionX4": None, "SR": None, @@ -92,33 +90,35 @@ def setmodel( model_name=Fusionx2_model_name, model_size=Fusionx2_model_size, model_loss="l1", - weights_path=weights_path, + weights_path=weights_path ), "FusionX4": load_fusionx4_model( model_name=Fusionx4_model_name, model_size=Fusionx4_model_size, model_loss="l1", - weights_path=weights_path, + weights_path=weights_path ), "SR": load_srx4_model( model_name=SR_model_name, model_size=SR_model_size, model_loss=SR_model_loss, weights_path=weights_path, - ), + device=device, + **kwargs + ) } def predict( X: torch.Tensor, - resolution: Literal["2.5m", "5m", "10m"] = "5m", - models: Optional[dict] = None, + resolution: Literal["2.5m", "5m", "10m"] = "2.5m", + models: Optional[dict] = None ) -> torch.Tensor: - """Generate a new S2 tensor with all the bands on the same resolution + """ Generate a new S2 tensor with all the bands on the same resolution Args: X (torch.Tensor): The input tensor with the S2 bands - resolution (Literal["2.5m", "5m", "10m"], optional): The final resolution of the + resolution (Literal["2.5m", "5m", "10m"], optional): The final resolution of the tensor. Defaults to "2.5m". device (str, optional): The device to use. Defaults to "cpu". @@ -141,10 +141,7 @@ def predict( raise ValueError("Invalid resolution. Please select 2.5m, 5m, or 10m.") -def fusionx2( - X: torch.Tensor, - models: dict -) -> torch.Tensor: +def fusionx2(X: torch.Tensor, models: dict) -> torch.Tensor: """Converts 20m bands to 10m resolution Args: @@ -152,61 +149,60 @@ def fusionx2( models (dict): The dictionary with the loaded models Returns: - torch.Tensor: The tensor with the same resolution for all the bands + torch.Tensor: The tensor with the same resolution for all the bands """ # Obtain the device of X device = X.device # Band Selection - index20 = [3, 4, 5, 7, 8, 9] - index10 = [0, 1, 2, 6] - + bands_20m = [3, 4, 5, 7, 8, 9] + bands_10m = [0, 1, 2, 6] + # Set the model fusionmodelx2 = models["FusionX2"].to(device) # Select the 20m bands - bands20_as_10 = X[index20] - - bands20 = torch.nn.functional.interpolate( - bands20_as_10[None], scale_factor=0.5, mode="nearest" + bands_20m_data = X[bands_20m] + + bands_20m_data_real = torch.nn.functional.interpolate( + bands_20m_data[None], + scale_factor=0.5, + mode="nearest" ).squeeze(0) - bands20_in_10 = torch.nn.functional.interpolate( - bands20[None], scale_factor=2, mode="bilinear", antialias=True + bands_20m_data = torch.nn.functional.interpolate( + bands_20m_data_real[None], + scale_factor=2, + mode="bilinear", + antialias=True ).squeeze(0) - + # Select the 10m bands - bands10 = X[index10] - + bands_10m_data = X[bands_10m] + # Concatenate the 20m and 10m bands - input_data = torch.cat([bands20_in_10, bands10], dim=0) - bands20_to_10 = fusionmodelx2(input_data[None]).squeeze(0) - - # Order the channels back - results = torch.stack( - [ - bands10[0], - bands10[1], - bands10[2], - bands20_to_10[0], - bands20_to_10[1], - bands20_to_10[2], - bands10[3], - bands20_to_10[3], - bands20_to_10[4], - bands20_to_10[5], - ], - dim=0, - ) + input_data = torch.cat([bands_20m_data, bands_10m_data], dim=0) + bands_20m_data_to_10 = fusionmodelx2(input_data[None]).squeeze(0) + + # Order the channels back + results = torch.stack([ + bands_10m_data[0], + bands_10m_data[1], + bands_10m_data[2], + bands_20m_data_to_10[0], + bands_20m_data_to_10[1], + bands_20m_data_to_10[2], + bands_10m_data[3], + bands_20m_data_to_10[3], + bands_20m_data_to_10[4], + bands_20m_data_to_10[5], + ], dim=0) return results -def fusionx8( - X: torch.Tensor, - models: dict -) -> torch.Tensor: +def fusionx8(X: torch.Tensor, models: dict) -> torch.Tensor: """Converts 20m bands to 10m resolution Args: @@ -224,54 +220,51 @@ def fusionx8( superX: torch.Tensor = fusionx2(X, models) # Band Selection - index20 = [3, 4, 5, 7, 8, 9] - index10 = [2, 1, 0, 6] # WARNING: The SR model needs RGBNIR bands - + bands_20m = [3, 4, 5, 7, 8, 9] + bands_10m = [2, 1, 0, 6] # WARNING: The SR model needs RGBNIR bands + # Set the SR resolution and x4 fusion model fusionmodelx4 = models["FusionX4"].to(device) srmodelx4 = models["SR"].to(device) - + # Convert the SWIR bands to 2.5m - bands20_to_10 = superX[index20] - bands10_in_2dot5 = torch.nn.functional.interpolate( - bands20_to_10[None], scale_factor=4, mode="bilinear", antialias=True - ).squeeze(0) - + bands_20m_data = superX[bands_20m] + bands_20m_data_up = torch.nn.functional.interpolate( + bands_20m_data[None], + scale_factor=4, + mode="bilinear", + antialias=True + ).squeeze(0) + # Run super-resolution on the 10m bands - bands10 = superX[index10] - bands10_to_2dot5 = srmodelx4(bands10[None]).squeeze(0) - + rgbn_bands_10m_data = superX[bands_10m] + tensor_x4_rgbnir = srmodelx4(rgbn_bands_10m_data[None]).squeeze(0) + # Reorder the bands from RGBNIR to BGRNIR - bands10_to_2dot5 = bands10_to_2dot5[[2, 1, 0, 3]] + tensor_x4_rgbnir = tensor_x4_rgbnir[[2, 1, 0, 3]] # Run the fusion x4 model in the SWIR bands (10m to 2.5m) - input_data = torch.cat([bands10_in_2dot5, bands10_to_2dot5], dim=0) - allbands_to_2dot5 = fusionmodelx4(input_data[None]).squeeze(0) - + input_data = torch.cat([bands_20m_data_up, tensor_x4_rgbnir], dim=0) + bands_20m_data_to_25m = fusionmodelx4(input_data[None]).squeeze(0) + # Order the channels back - results = torch.stack( - [ - bands10_to_2dot5[0], - bands10_to_2dot5[1], - bands10_to_2dot5[2], - allbands_to_2dot5[0], - allbands_to_2dot5[1], - allbands_to_2dot5[2], - bands10_to_2dot5[3], - allbands_to_2dot5[3], - allbands_to_2dot5[4], - allbands_to_2dot5[5], - ], - dim=0, - ) + results = torch.stack([ + tensor_x4_rgbnir[0], + tensor_x4_rgbnir[1], + tensor_x4_rgbnir[2], + bands_20m_data_to_25m[0], + bands_20m_data_to_25m[1], + bands_20m_data_to_25m[2], + tensor_x4_rgbnir[3], + bands_20m_data_to_25m[3], + bands_20m_data_to_25m[4], + bands_20m_data_to_25m[5], + ], dim=0) return results -def fusionx4( - X: torch.Tensor, - models: dict -) -> torch.Tensor: +def fusionx4(X: torch.Tensor, models: dict) -> torch.Tensor: """Converts 20m bands to 10m resolution Args: @@ -287,5 +280,8 @@ def fusionx4( # From 2.5m to 5m resolution return torch.nn.functional.interpolate( - superX[None], scale_factor=0.5, mode="bilinear", antialias=True - ).squeeze(0) + superX[None], + scale_factor=0.5, + mode="bilinear", + antialias=True + ).squeeze(0) \ No newline at end of file diff --git a/supers2/models/diffusion.py b/supers2/models/diffusion.py new file mode 100644 index 0000000..e7b716f --- /dev/null +++ b/supers2/models/diffusion.py @@ -0,0 +1,216 @@ +import torch +from supers2.models.utils import assert_tensor_validity +from supers2.models.utils import revert_padding +from supers2.models.opensr_diffusion.diffusion.utils import DDIMSampler +from supers2.models.opensr_diffusion.diffusion.latentdiffusion import LatentDiffusion + +from skimage.exposure import match_histograms + +from tqdm import tqdm +import numpy as np +from opensr_model.utils import linear_transform_4b + + +class SRLatentDiffusion(torch.nn.Module): + def __init__(self, device: str = "cpu"): + super().__init__() + + # Set up the model + first_stage_config, cond_stage_config = self.set_model_settings() + self.model = LatentDiffusion( + first_stage_config, + cond_stage_config, + timesteps=1000, + unet_config=cond_stage_config, + linear_start=0.0015, + linear_end=0.0155, + concat_mode=True, + cond_stage_trainable=False, + first_stage_key="image", + cond_stage_key="LR_image", + ) + self.model.eval() + + for param in self.model.parameters(): + param.requires_grad = False + + # Set up the model for inference + self.device = device # set self device + self.model.device = device # set model device as selected + self.model = self.model.to(device) # move model to device + self.model.eval() # set model state + self._X = None # placeholder for LR image + self.encode_conditioning = True # encode LR images before dif? + + + def set_model_settings(self): + # set up model settings + first_stage_config = { + "embed_dim":4, + "double_z": True, + "z_channels": 4, + "resolution": 256, + "in_channels": 4, + "out_ch": 4, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + } + cond_stage_config = { + "image_size": 64, + "in_channels": 8, + "model_channels": 160, + "out_channels": 4, + "num_res_blocks": 2, + "attention_resolutions": [16, 8], + "channel_mult": [1, 2, 2, 4], + "num_head_channels": 32, + } + self.linear_transform = linear_transform_4b + + return first_stage_config, cond_stage_config + + + def _tensor_encode(self,X: torch.Tensor): + # set copy to model + #X = torch.rand(1, 4, 32, 32) + self._X = X.clone() + # normalize image + X_enc = self.linear_transform(X, stage="norm") + + # encode LR images + X_int = torch.nn.functional.interpolate(X, size=X.shape[-1]*4, mode='bilinear', antialias=True) + + # encode conditioning + X_enc = self.model.first_stage_model.encode(X_int).sample() + + return X_enc + + def _tensor_decode(self, X_enc: torch.Tensor): + + # Decode + X_dec = self.model.decode_first_stage(X_enc) + X_dec = self.linear_transform(X_dec, stage="denorm") + + # Apply spectral correction + for i in range(X_dec.shape[1]): + X_dec[:, i] = self.hq_histogram_matching(X_dec[:, i], self._X[:, i]) + + # If the value is negative, set it to 0 + X_dec[X_dec < 0] = 0 + + return X_dec + + def _prepare_model( + self, + X: torch.Tensor, + eta: float = 1.0, + custom_steps: int = 100, + verbose: bool = False + ): + # Create the DDIM sampler + ddim = DDIMSampler(self.model) + + # make schedule to compute alphas and sigmas + ddim.make_schedule(ddim_num_steps=custom_steps, ddim_eta=eta, verbose=verbose) + + # Create the HR latent image + latent = torch.randn(X.shape, device=X.device) + + # Create the vector with the timesteps + timesteps = ddim.ddim_timesteps + time_range = np.flip(timesteps) + + return ddim, latent, time_range + + @torch.no_grad() + def forward( + self, + X: torch.Tensor, + eta: float = 1.0, + custom_steps: int = 100, + temperature: float = 1.0, + verbose: bool = True + ): + """Obtain the super resolution of the given image. + + Args: + X (torch.Tensor): If a Sentinel-2 L2A image with reflectance values + in the range [0, 1] and shape CxWxH, the super resolution of the image + is returned. If a batch of images with shape BxCxWxH is given, a batch + of super resolved images is returned. + custom_steps (int, optional): Number of steps to run the denoiser. Defaults + to 100. + temperature (float, optional): Temperature to use in the denoiser. + Defaults to 1.0. The higher the temperature, the more stochastic + the denoiser is (random noise gets multiplied by this). + spectral_correction (bool, optional): Apply spectral correction to the SR + image, using the LR image as reference. Defaults to True. + + Returns: + torch.Tensor: The super resolved image or batch of images with a shape of + Cx(Wx4)x(Hx4) or BxCx(Wx4)x(Hx4). + """ + + # Assert shape, size, dimensionality + X, padding = assert_tensor_validity(X) + + # Normalize the image + X = X.clone() + Xnorm = self._tensor_encode(X) + + # ddim, latent and time_range + ddim, latent, time_range = self._prepare_model( + X=Xnorm, eta=eta, custom_steps=custom_steps, verbose=False + ) + iterator = tqdm(time_range, desc="DDIM Sampler", total=custom_steps,disable=not verbose) + + # Iterate over the timesteps + for i, step in enumerate(iterator): + outs = ddim.p_sample_ddim( + x=latent, + c=Xnorm, + t=step, + index=custom_steps - i - 1, + use_original_steps=False, + temperature=temperature + ) + latent, _ = outs + + sr = self._tensor_decode(latent) + sr = revert_padding(sr,padding) + return sr + + + def hq_histogram_matching( + self, image1: torch.Tensor, image2: torch.Tensor + ) -> torch.Tensor: + """Lazy implementation of histogram matching + + Args: + image1 (torch.Tensor): The low-resolution image (C, H, W). + image2 (torch.Tensor): The super-resolved image (C, H, W). + + Returns: + torch.Tensor: The super-resolved image with the histogram of + the target image. + """ + + # Go to numpy + np_image1 = image1.detach().cpu().numpy() + np_image2 = image2.detach().cpu().numpy() + + if np_image1.ndim == 3: + np_image1_hat = match_histograms(np_image1, np_image2, channel_axis=0) + elif np_image1.ndim == 2: + np_image1_hat = match_histograms(np_image1, np_image2, channel_axis=None) + else: + raise ValueError("The input image must have 2 or 3 dimensions.") + + # Go back to torch + image1_hat = torch.from_numpy(np_image1_hat).to(image1.device) + + return image1_hat + \ No newline at end of file diff --git a/supers2/models/opensr_diffusion/__init__.py b/supers2/models/opensr_diffusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/supers2/models/opensr_diffusion/autoencoder/__init__.py b/supers2/models/opensr_diffusion/autoencoder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/supers2/models/opensr_diffusion/autoencoder/autoencoder.py b/supers2/models/opensr_diffusion/autoencoder/autoencoder.py new file mode 100644 index 0000000..1f8b253 --- /dev/null +++ b/supers2/models/opensr_diffusion/autoencoder/autoencoder.py @@ -0,0 +1,602 @@ +from typing import Tuple + +import numpy as np +import torch +from supers2.models.opensr_diffusion.autoencoder.utils import (Downsample, Normalize, ResnetBlock, + Upsample, make_attn, nonlinearity) +from torch import nn + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch: int, + ch_mult: Tuple[int, int, int, int] = (1, 2, 4, 8), + num_res_blocks: int, + attn_resolutions: Tuple[int, ...], + dropout: float = 0.0, + resamp_with_conv: bool = True, + in_channels: int, + resolution: int, + z_channels: int, + double_z: bool = True, + use_linear_attn: bool = False, + attn_type: str = "vanilla", + **ignorekwargs: dict, + ): + """ + Encoder module responsible for downsampling and transforming an input image tensor. + + Args: + ch (int): Base number of channels in the model. + num_res_blocks (int): Number of residual blocks per resolution. + attn_resolutions (tuple of int): Resolutions at which attention should be applied. + in_channels (int): Number of channels in the input data. + resolution (int): The resolution of the input data. + z_channels (int): Number of channels for the latent variable 'z'. + ch_mult (tuple of int, optional): Multipliers for the channels in different blocks. Defaults to (1, 2, 4, 8). + dropout (float, optional): Dropout rate to use in ResNet blocks. Defaults to 0.0. + resamp_with_conv (bool, optional): Whether to use convolution for downsampling. Defaults to True. + double_z (bool, optional): If True, output channels will be doubled for 'z'. Defaults to True. + use_linear_attn (bool, optional): If True, linear attention will be used. Overrides 'attn_type'. Defaults to False. + attn_type (str, optional): Type of attention mechanism. Options are "vanilla" or "linear". Defaults to "vanilla". + ignorekwargs (dict): Ignore extra keyword arguments. + + Examples: + >>> encoder = Encoder(in_channels=3, z_channels=64, ch=32, resolution=64, num_res_blocks=2, attn_resolutions=(16, 8)) + >>> x = torch.randn(1, 3, 64, 64) + >>> z = encoder(x) + """ + super().__init__() + + # If linear attention is used, override the attention type. + if use_linear_attn: + attn_type = "linear" + + # Setting global attributes to create the encoder. + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # Initial convolution for spectral reduction. + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + # Downsampling with residual blocks and optionally attention + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # Upsampling with residual blocks and optionally attention + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # Final convolution to get the latent variable 'z' + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the Encoder. + + Args: + x: Input tensor. + + Returns: + Transformed tensor after passing through the Encoder. + """ + + # timestep embedding (if needed in the next Diffusion runs!) + temb = None + + # Initial downsampling + hs = [self.conv_in(x)] + + # Downsampling through the layers + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # Middle processing with blocks and attention + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # Final transformation to produce the output + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch: int, + out_ch: int, + ch_mult: Tuple[int, int, int, int] = (1, 2, 4, 8), + num_res_blocks: int, + attn_resolutions: Tuple[int, ...], + dropout: float = 0.0, + resamp_with_conv: bool = True, + in_channels: int, + resolution: int, + z_channels: int, + give_pre_end: bool = False, + tanh_out: bool = False, + use_linear_attn: bool = False, + attn_type: str = "vanilla", + **ignorekwargs: dict, + ): + """ + A Decoder class that converts a given encoded data 'z' back to its original state. + + Args: + ch (int): Number of channels in the input data. + out_ch (int): Number of channels in the output data. + num_res_blocks (int): Number of residual blocks in the network. + attn_resolutions (Tuple[int, ...]): Resolutions at which attention mechanisms are applied. + in_channels (int): Number of channels in the encoded data 'z'. + resolution (int): The resolution of the output image. + z_channels (int): Number of channels in the latent space representation. + ch_mult (Tuple[int, int, int, int], optional): Multiplier for channels at different resolution + levels. Defaults to (1, 2, 4, 8). + dropout (float, optional): Dropout rate for regularization. Defaults to 0.0. + resamp_with_conv (bool, optional): Whether to use convolutional layers for upsampling. Defaults to True. + give_pre_end (bool, optional): If set to True, returns the output before the last layer. Useful for further + processing. Defaults to False. + tanh_out (bool, optional): If set to True, applies tanh activation function to the output. Defaults to False. + use_linear_attn (bool, optional): If set to True, uses linear attention mechanism. Defaults to False. + attn_type (str, optional): Type of attention mechanism used ("vanilla" or "linear"). Defaults to "vanilla". + ignorekwargs (dict): Ignore extra keyword arguments. + + Examples: + >>> decoder = Decoder( + ch=32, out_ch=3, z_channels=64, resolution=64, + in_channels=64, num_res_blocks=2, + attn_resolutions=(16, 8) + ) + >>> z = torch.randn(1, 64, 8, 8) + >>> x_reconstructed = decoder(z) + """ + + super().__init__() + + # If linear attention is required, set attention type as 'linear' + if use_linear_attn: + attn_type = "linear" + + # Initialize basic attributes for Decoding + self.ch = ch + self.temb_ch = 0 # Temporal embedding channel + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end # Controls the final output + self.tanh_out = tanh_out # Apply tanh activation at the end + + # Compute input channel multiplier, initial block input channel and current resolution + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # Display z-shape details + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + # Conversion layer: From z dimension to block input channels + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # Middle processing blocks + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # Upsampling layers + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + + # Apply ResNet blocks and attention at each resolution + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + + up = nn.Module() + up.block = block + up.attn = attn + + # Upsampling operations + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + + # Keep the order consistent with original resolutions + self.up.insert(0, up) + + # Final normalization and conversion layers + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + """Forward pass of the Decoder. + + Args: + z (torch.Tensor): The latent variable 'z' to be decoded. + + Returns: + torch.Tensor: Transformed tensor after passing through the Decoder. + """ + + self.last_z_shape = z.shape + + # Time-step embedding (not used, in the Decoder part) + temb = None + + # Convert z to block input + h = self.conv_in(z) + + # Middle processing blocks + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # Upsampling steps + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # Final output steps + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + + # Apply tanh activation if required + if self.tanh_out: + h = torch.tanh(h) + + return h + + +class DiagonalGaussianDistribution(object): + """ + Represents a multi-dimensional diagonal Gaussian distribution. + + The distribution is parameterized by means and diagonal variances + (or standard deviations) for each dimension. This means that the + covariance matrix of this Gaussian distribution is diagonal + (i.e., non-diagonal elements are zero). + + Attributes: + parameters (torch.Tensor): A tensor containing concatenated means and log-variances. + mean (torch.Tensor): The mean of the Gaussian distribution. + logvar (torch.Tensor): The logarithm of variances of the Gaussian distribution. + deterministic (bool): If true, the variance is set to zero, making the distribution + deterministic. + std (torch.Tensor): The standard deviation of the Gaussian distribution. + var (torch.Tensor): The variance of the Gaussian distribution. + + Examples: + >>> params = torch.randn((1, 10)) # Assuming 5 for mean and 5 for log variance + >>> dist = DiagonalGaussianDistribution(params) + >>> sample = dist.sample() # Sample from the distribution + """ + + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + """ + Initializes the DiagonalGaussianDistribution. + + Args: + parameters (torch.Tensor): A tensor containing concatenated means and log-variances. + deterministic (bool, optional): If set to true, this distribution becomes + deterministic (i.e., has zero variance). + """ + self.parameters = parameters + self.deterministic = deterministic + + # Split the parameters into means and log-variances + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + + # Limit the log variance values + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + + # Calculate standard deviation & variance from log variance + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + + # If deterministic, set variance and standard deviation to zero + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self) -> torch.Tensor: + """ + Sample from the Gaussian distribution. + + Returns: + torch.Tensor: Sampled tensor. + """ + + # Sample from a standard Gaussian distribution + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + + return x + + def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: + """ + Compute the KL divergence between this Gaussian distribution and another. + + Args: + other (DiagonalGaussianDistribution, optional): The other Gaussian + distribution. If None, computes the KL divergence with a standard + Gaussian (mean 0, variance 1). + + Returns: + torch.Tensor: KL divergence values. + """ + if self.deterministic: + return torch.Tensor([0.0]).to(device=self.parameters.device) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample: torch.Tensor, dims: list = [1, 2, 3]) -> torch.Tensor: + """ + Compute the negative log likelihood of a sample under this Gaussian distribution. + + Args: + sample (torch.Tensor): The input sample tensor. + dims (list, optional): The dimensions over which the sum is performed. Defaults + to [1, 2, 3]. + + Returns: + torch.Tensor: Negative log likelihood values. + """ + if self.deterministic: + return torch.Tensor([0.0]).to(device=self.parameters.device) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self) -> torch.Tensor: + """ + Get the mode of the Gaussian distribution (which is equal to its mean). + + Returns: + torch.Tensor: The mode (mean) of the Gaussian distribution. + """ + return self.mean + + +class AutoencoderKL(nn.Module): + """ + Autoencoder with KL divergence regularization. + + This class implements an autoencoder model where the encoder outputs parameters of a + Gaussian distribution, from which the latent representation can be sampled or its + mode can be taken. The decoder then reconstructs the input from the latent + representation. + + Attributes: + encoder (Encoder): Encoder module. + decoder (Decoder): Decoder module. + quant_conv (torch.nn.Conv2d): Convolutional layer used to process encoder outputs + into Gaussian parameters. + post_quant_conv (torch.nn.Conv2d): Convolutional layer used after sampling/mode + from the Gaussian distribution. + embed_dim (int): Embedding dimension of the latent space. + + Examples: + + >>> ddconfig = { + "z_channels": 16, "ch": 32, + "out_ch": 3, "ch_mult": (1, 2, 4, 8), + "resolution": 64, "in_channels": 3, + "double_z": True, "num_res_blocks": 2, + "attn_resolutions": (16, 8) + } + >>> embed_dim = 8 + >>> ae_model = AutoencoderKL(ddconfig, embed_dim) + >>> data = torch.randn((1, 3, 64, 64)) + >>> recon_data, posterior = ae_model(data) + """ + + def __init__(self, ddconfig: dict, embed_dim: int): + """ + Initialize the AutoencoderKL. + + Args: + ddconfig (dict): Configuration dictionary for the encoder and decoder. + embed_dim (int): Embedding dimension of the latent space. + """ + super().__init__() + + # Initialize the encoder and decoder with provided configurations + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + + # Check if the configuration expects double the z_channels + assert ddconfig["double_z"], "ddconfig must have 'double_z' set to True." + + # Define convolutional layers to transform between the latent space and Gaussian parameters + self.quant_conv = nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) + self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + + self.embed_dim = embed_dim + + def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution: + """ + Pass the input through the encoder and return the posterior Gaussian + distribution. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + DiagonalGaussianDistribution: Gaussian distribution parameters from the + encoded input. + """ + # Encoder's output + h = self.encoder(x) + + # Convert encoder's output to Gaussian parameters + moments = self.quant_conv(h) + + # Create a DiagonalGaussianDistribution using the moments + posterior = DiagonalGaussianDistribution(moments) + + return posterior + + def decode(self, z: torch.Tensor) -> torch.Tensor: + """ + Decode the latent representation to reconstruct the input. + + Args: + z (torch.Tensor): Latent representation. + + Returns: + torch.Tensor: Reconstructed tensor. + """ + # Process latent representation through a convolutional layer + z = self.post_quant_conv(z) + + # Decoder's output + dec = self.decoder(z) + + return dec + + def forward(self, input: torch.Tensor, sample_posterior: bool = True) -> tuple: + """ + Forward pass of the autoencoder. + + Encodes the input, samples/modes from the resulting Gaussian distribution, + and then decodes to get the reconstructed input. + + Args: + input (torch.Tensor): Input tensor. + sample_posterior (bool, optional): If True, sample from the posterior Gaussian + distribution. If False, use its mode. Defaults to True. + + Returns: + tuple: Reconstructed tensor and the posterior Gaussian distribution. + """ + + # Encode the input to get the Gaussian distribution parameters + posterior = self.encode(input) + + # Sample from the Gaussian distribution or take its mode + z = posterior.sample() if sample_posterior else posterior.mode() + + # Decode the sampled/mode latent representation + dec = self.decode(z) + + return dec, posterior diff --git a/supers2/models/opensr_diffusion/autoencoder/utils.py b/supers2/models/opensr_diffusion/autoencoder/utils.py new file mode 100644 index 0000000..5b03266 --- /dev/null +++ b/supers2/models/opensr_diffusion/autoencoder/utils.py @@ -0,0 +1,460 @@ +import torch +from einops import rearrange +from torch import nn + + +def Normalize(in_channels: int, num_groups: int = 32) -> torch.nn.GroupNorm: + """ + Returns a GroupNorm layer that normalizes the input tensor along the channel dimension. + + Args: + in_channels (int): Number of channels in the input tensor. + num_groups (int): Number of groups to separate the channels into. Default is 32. + + Returns: + torch.nn.GroupNorm: A GroupNorm layer that normalizes the input tensor along the + channel dimension. + + Example: + >>> input_tensor = torch.randn(1, 64, 32, 32) + >>> norm_layer = Normalize(in_channels=64, num_groups=16) + >>> output_tensor = norm_layer(input_tensor) + """ + # Create a GroupNorm layer with the specified number of groups and input channels + # Set eps to a small value to avoid division by zero + # Set affine to True to learn scaling and shifting parameters + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +def nonlinearity(x: torch.Tensor) -> torch.Tensor: + """ + Applies a non-linear activation function to the input tensor x. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor with the same shape as the input tensor. + + Example: + >>> input_tensor = torch.randn(10, 20) + >>> output_tensor = nonlinearity(input_tensor) + """ + # Apply the sigmoid function to the input tensor + sigmoid_x = torch.sigmoid(x) + + # Multiply the input tensor by the sigmoid of the input tensor + output_tensor = x * sigmoid_x + + return output_tensor + + +class Downsample(nn.Module): + def __init__(self, in_channels: int, with_conv: bool): + """ + Initializes a Downsample module that reduces the spatial dimensions + of the input tensor. + + Args: + in_channels (int): Number of channels in the input tensor. + with_conv (bool): Whether to use a convolutional layer for downsampling. + + Attributes: + conv (torch.nn.Conv2d): Convolutional layer for downsampling. Only used + if with_conv is True. + + Example: + >>> input_tensor = torch.randn(1, 64, 32, 32) + >>> downsample_module = Downsample(in_channels=64, with_conv=True) + >>> output_tensor = downsample_module(input_tensor) + """ + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Create a convolutional layer for downsampling + # Use kernel size 3, stride 2, and no padding + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Applies the Downsample module to the input tensor x. + + Args: + x (torch.Tensor): Input tensor with shape (batch_size, in_channels, height, width). + + Returns: + torch.Tensor: Output tensor with shape (batch_size, in_channels, height/2, width/2) + if with_conv is False, or (batch_size, in_channels, (height+1)/2, (width+1)/2) if + with_conv is True. + """ + if self.with_conv: + # Apply asymmetric padding to the input tensor + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + + # Apply the convolutional layer to the padded input tensor + x = self.conv(x) + else: + # Apply average pooling to the input tensor + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + + # Return the output tensor + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int, with_conv: bool): + """ + Initializes an Upsample module that increases the spatial dimensions of + the input tensor. + + Args: + in_channels (int): Number of channels in the input tensor. + with_conv (bool): Whether to use a convolutional layer for upsampling. + + Attributes: + with_conv (bool): Whether to use a convolutional layer for upsampling. + conv (torch.nn.Conv2d): Convolutional layer for upsampling. Only used + if with_conv is True. + + Example: + >>> input_tensor = torch.randn(1, 64, 32, 32) + >>> upsample_module = Upsample(in_channels=64, with_conv=True) + >>> output_tensor = upsample_module(input_tensor) + """ + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Create a convolutional layer for upsampling + # Use kernel size 3, stride 1, and padding 1 + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Applies the Upsample module to the input tensor x. + + Args: + x (torch.Tensor): Input tensor with shape (batch_size, in_channels, + height, width). + + Returns: + torch.Tensor: Output tensor with shape (batch_size, in_channels, height*2, width*2) + if with_conv is False, or (batch_size, in_channels, height*2-1, width*2-1) if + with_conv is True. + """ + # Apply nearest interpolation to the input tensor to double its spatial dimensions + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + + if self.with_conv: + # Apply the convolutional layer to the upsampled input tensor + x = self.conv(x) + + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int = None, + conv_shortcut: bool = False, + dropout: float, + temb_channels: int = 512, + ): + """ + Initializes a ResnetBlock module that consists of two convolutional layers with batch + normalization and a residual connection. + + Args: + in_channels (int): Number of channels in the input tensor. + out_channels (int, optional): Number of channels in the output tensor. If None, + defaults to in_channels. + conv_shortcut (bool): Whether to use a convolutional layer for the residual connection. + If False, uses a 1x1 convolution. + dropout (float): Dropout probability. + temb_channels (int): Number of channels in the conditioning tensor. If 0, no conditioning + is used. + + Attributes: + in_channels (int): Number of channels in the input tensor. + out_channels (int): Number of channels in the output tensor. + use_conv_shortcut (bool): Whether to use a convolutional layer for the residual connection. + norm1 (utils.Normalize): Batch normalization layer for the first convolutional layer. + conv1 (torch.nn.Conv2d): First convolutional layer. + temb_proj (torch.nn.Linear): Linear projection layer for the conditioning tensor. Only used + if temb_channels > 0. + norm2 (utils.Normalize): Batch normalization layer for the second convolutional layer. + dropout (torch.nn.Dropout): Dropout layer. + conv2 (torch.nn.Conv2d): Second convolutional layer. + conv_shortcut (torch.nn.Conv2d): Convolutional layer for the residual connection. Only + used if use_conv_shortcut is True. + nin_shortcut (torch.nn.Conv2d): 1x1 convolutional layer for the residual connection. Only + used if use_conv_shortcut is False. + """ + super().__init__() + + # Set the number of input and output channels + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + # Batch normalization layer for the first convolutional layer + self.norm1 = Normalize(in_channels) + + # First convolutional layer + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + # Linear projection layer for the conditioning tensor + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + + # BN+Dropout+Conv layer for the last convolutional layer + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + # 3x3 conv for the residual connection + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + # 1x1 conv for the residual connection + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + """ + Applies the ResnetBlock module to the input tensor x. + + Args: + x (torch.Tensor): Input tensor with shape (batch_size, in_channels, height, width). + temb (torch.Tensor): Conditioning tensor with shape (batch_size, temb_channels). + + Returns: + torch.Tensor: Output tensor with the same shape as the input tensor. + + Example: + >>> input_tensor = torch.randn(1, 64, 32, 32) + >>> resnet_block = ResnetBlock(in_channels=64, out_channels=128, dropout=0.5) + >>> output_tensor = resnet_block(input_tensor, temb=None) + """ + + # BN+Sigmoid+Conv + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + # Linear projection layer for the conditioning tensor + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + # BN+Sigmoid+Dropout+Conv + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + # 3x3 conv for the residual connection + x = self.conv_shortcut(x) + else: + # 1x1 conv for the residual connection + x = self.nin_shortcut(x) + + # Add the residual connection to the output tensor + return x + h + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +def make_attn(in_channels: int, attn_type: str = "vanilla") -> nn.Module: + """ + Creates an attention module of the specified type. + + Args: + in_channels (int): Number of channels in the input tensor. + attn_type (str): Type of attention module to create. Must be one of "vanilla", + "linear", or "none". Defaults to "vanilla". + + Returns: + nn.Module: Attention module. + + Raises: + AssertionError: If attn_type is not one of "vanilla", "linear", or "none". + + Example: + >>> input_tensor = torch.randn(1, 64, 32, 32) + >>> attn_module = make_attn(in_channels=64, attn_type="vanilla") + >>> output_tensor = attn_module(input_tensor) + """ + assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + # Create a vanilla attention module + return AttnBlock(in_channels) + elif attn_type == "none": + # Create an identity module + return nn.Identity(in_channels) + else: + # Create a linear attention module + return LinAttnBlock(in_channels) + + +class AttnBlock(nn.Module): + """ + An attention module that computes attention weights for each spatial location in the input tensor. + + Args: + in_channels (int): Number of channels in the input tensor. + + Attributes: + in_channels (int): Number of channels in the input tensor. + norm (Normalize): Normalization layer for the input tensor. + q (torch.nn.Conv2d): Convolutional layer for computing the query tensor. + k (torch.nn.Conv2d): Convolutional layer for computing the key tensor. + v (torch.nn.Conv2d): Convolutional layer for computing the value tensor. + proj_out (torch.nn.Conv2d): Convolutional layer for projecting the attended tensor. + + Example: + >>> input_tensor = torch.randn(1, 64, 32, 32) + >>> attn_module = AttnBlock(in_channels=64) + >>> output_tensor = attn_module(input_tensor) + """ + + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Computes the output tensor of the attention module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor with the same shape as the input tensor. + """ + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # reshape q to b,hw,c and transpose to b,c,hw + k = k.reshape(b, c, h * w) # reshape k to b,c,hw + w_ = torch.bmm( + q, k + ) # compute attention weights w[b,i,j] = sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) # scale the attention weights + w_ = torch.nn.functional.softmax( + w_, dim=2 + ) # apply softmax to get the attention probabilities + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # transpose w to b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_ + ) # compute the attended values h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) # reshape h_ to b,c,h,w + h_ = self.proj_out(h_) # project the attended values to the output space + + return x + h_ + + +class LinearAttention(nn.Module): + """ + A linear attention module that computes attention weights for each spatial + location in the input tensor. + + Args: + dim (int): Number of channels in the input tensor. + heads (int): Number of attention heads. Defaults to 4. + dim_head (int): Number of channels per attention head. Defaults to 32. + + Example: + >>> input_tensor = torch.randn(1, 64, 32, 32) + >>> attn_module = LinearAttention(dim=64, heads=8, dim_head=16) + >>> output_tensor = attn_module(input_tensor) + """ + + def __init__(self, dim: int, heads: int = 4, dim_head: int = 32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Computes the output tensor of the attention module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor with the same shape as the input tensor. + """ + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) diff --git a/supers2/models/opensr_diffusion/denoiser/__init__.py b/supers2/models/opensr_diffusion/denoiser/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/supers2/models/opensr_diffusion/denoiser/unet.py b/supers2/models/opensr_diffusion/denoiser/unet.py new file mode 100644 index 0000000..1196499 --- /dev/null +++ b/supers2/models/opensr_diffusion/denoiser/unet.py @@ -0,0 +1,832 @@ +from typing import List, Optional, Set, Tuple, Union + +import torch +import torch as th +from einops import rearrange + +from supers2.models.opensr_diffusion.denoiser.utils import (BasicTransformerBlock, Downsample, + Normalize, QKVAttention, + QKVAttentionLegacy, TimestepBlock, + Upsample, checkpoint, conv_nd, + convert_module_to_f16, + convert_module_to_f32, linear, + normalization, timestep_embedding, + zero_module) +from torch import nn + + +class ResBlock(TimestepBlock): + def __init__( + self, + channels: int, + emb_channels: int, + dropout: float, + out_channels: Optional[int] = None, + use_conv: bool = False, + use_scale_shift_norm: bool = False, + dims: int = 2, + use_checkpoint: bool = False, + up: bool = False, + down: bool = False, + ): + """ + A residual block with optional timestep conditioning. + + Args: + channels (int): The number of input channels. + emb_channels (int): The number of timestep embedding channels. + dropout (float): The dropout probability. + out_channels (int, optional): The number of output channels. + Defaults to None (same as input channels). + use_conv (bool, optional): Whether to use a convolutional skip connection. + Defaults to False. + use_scale_shift_norm (bool, optional): Whether to use scale-shift normalization. + Defaults to False. + dims (int, optional): The number of dimensions in the input tensor. + Defaults to 2. + use_checkpoint (bool, optional): Whether to use checkpointing to save memory. + Defaults to False. + up (bool, optional): Whether to use upsampling in the skip connection. Defaults to + False. + down (bool, optional): Whether to use downsampling in the skip connection. Defaults to + False. + + Example: + >>> resblock = ResBlock(channels=64, emb_channels=32, dropout=0.1) + >>> x = torch.randn(1, 64, 32, 32) + >>> emb = torch.randn(1, 32) + >>> out = resblock(x, emb) + >>> print(out.shape) + """ + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + # input layers + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + # skip connection + self.updown = up or down + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + # timestep embedding layers + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + + # output layers + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + # Skip connection + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + Args: + x (torch.Tensor): An [N x C x ...] Tensor of features. + emb (torch.Tensor): An [N x emb_channels] Tensor of timestep embeddings. + + Returns: + torch.Tensor: An [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + if self.updown: + # up/downsampling in skip connection + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + + # timestep embedding + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + + # scale-shift normalization + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + + # skip connection + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Args: + channels (int): The number of input channels. + num_heads (int, optional): The number of attention heads. Defaults to 1. + num_head_channels (int, optional): The number of channels per attention head. + If not specified, the input channels will be divided equally among the heads. + Defaults to -1. + use_checkpoint (bool, optional): Whether to use checkpointing to save memory. + Defaults to False. + use_new_attention_order (bool, optional): Whether to split the qkv tensor before + splitting the heads. If False, the heads will be split before the qkv tensor. + Defaults to False. + + Example: + >>> attention_block = AttentionBlock(channels=64, num_heads=4) + >>> x = torch.randn(1, 64, 32, 32) + >>> out = attention_block(x) + """ + + def __init__( + self, + channels: int, + num_heads: Optional[int] = 1, + num_head_channels: Optional[int] = -1, + use_checkpoint: Optional[bool] = False, + use_new_attention_order: Optional[bool] = False, + ) -> None: + super().__init__() + + # Set the number of input channels and attention heads + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + + # Set whether to use checkpointing and create normalization layer + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + + # Create convolutional layer for qkv tensor and attention module + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + # Create convolutional layer for output projection + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply the attention block to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor. + """ + return checkpoint(self._forward, (x,), self.parameters(), False) + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply the attention block to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor. + """ + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + + # Apply normalization and convolutional layer to qkv tensor + qkv = self.qkv(self.norm(x)) + + # Apply attention module and convolutional layer to output + h = self.attention(qkv) + h = self.proj_out(h) + + # Add input tensor to output and reshape + return (x + h).reshape(b, c, *spatial) + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image. + + Args: + in_channels (int): The number of input channels. + n_heads (int): The number of attention heads. + d_head (int): The number of channels per attention head. + depth (int, optional): The number of transformer blocks. Defaults to 1. + dropout (float, optional): The dropout probability. Defaults to 0. + context_dim (int, optional): The dimension of the context tensor. + If not specified, cross-attention defaults to self-attention. + Defaults to None. + """ + + def __init__( + self, + in_channels: int, + n_heads: int, + d_head: int, + depth: Optional[int] = 1, + dropout: Optional[float] = 0.0, + context_dim: Optional[int] = None, + ) -> None: + super().__init__() + + # Set the number of input channels and attention heads + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + # Create convolutional layer for input projection + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + # Create list of transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim + ) + for d in range(depth) + ] + ) + + # Create convolutional layer for output projection + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + + def forward( + self, x: torch.Tensor, context: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Apply the spatial transformer block to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + context (torch.Tensor, optional): The context tensor. If not specified, + cross-attention defaults to self-attention. Defaults to None. + + Returns: + torch.Tensor: The output tensor. + """ + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + + # Apply input projection and reshape + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + + # Apply transformer blocks + for block in self.transformer_blocks: + x = block(x, context=context) + + # Reshape and apply output projection + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + + # Add input tensor to output + return x + x_in + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + + Args: + nn.Sequential: The sequential module. + TimestepBlock: The timestep block module. + + Example: + >>> model = TimestepEmbedSequential( + ResBlock(channels=64, emb_channels=32, dropout=0.1), + ResBlock(channels=64, emb_channels=32, dropout=0.1) + ) + >>> x = torch.randn(1, 64, 32, 32) + >>> emb = torch.randn(1, 32) + >>> out = model(x, emb) + >>> print(out.shape) + + """ + + def forward( + self, x: torch.Tensor, emb: torch.Tensor, context: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Apply the sequential module to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + emb (torch.Tensor): The timestep embedding tensor. + context (torch.Tensor, optional): The context tensor. Defaults to None. + + Returns: + torch.Tensor: The output tensor. + """ + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + + Args: + in_channels (int): The number of channels in the input tensor. + model_channels (int): The base channel count for the model. + out_channels (int): The number of channels in the output tensor. + num_res_blocks (int): The number of residual blocks per downsample. + attention_resolutions (Union[Set[int], List[int], Tuple[int]]): A collection + of downsample rates at which attention will take place. For example, if + this contains 4, then at 4x downsampling, attention will be used. + dropout (float, optional): The dropout probability. Defaults to 0. + channel_mult (Tuple[int], optional): The channel multiplier for each level + of the UNet. Defaults to (1, 2, 4, 8). + conv_resample (bool, optional): If True, use learned convolutions for upsampling + and downsampling. Defaults to True. + dims (int, optional): Determines if the signal is 1D, 2D, or 3D. Defaults to 2. + num_classes (int, optional): If specified, then this model will be class-conditional + with `num_classes` classes. Defaults to None. + use_checkpoint (bool, optional): Use gradient checkpointing to reduce memory usage. + Defaults to False. + use_fp16 (bool, optional): Use half-precision floating point. Defaults to False. + num_heads (int, optional): The number of attention heads in each attention layer. + Defaults to -1. + num_head_channels (int, optional): If specified, ignore num_heads and instead use + a fixed channel width per attention head. Defaults to -1. + num_heads_upsample (int, optional): Works with num_heads to set a different number + of heads for upsampling. Deprecated. Defaults to -1. + use_scale_shift_norm (bool, optional): Use a FiLM-like conditioning mechanism. Defaults + to False. + resblock_updown (bool, optional): Use residual blocks for up/downsampling. Defaults to False. + use_new_attention_order (bool, optional): Use a different attention pattern for + potentially increased efficiency. Defaults to False. + use_spatial_transformer (bool, optional): Use a custom transformer support. Defaults to + False. + transformer_depth (int, optional): The depth of the custom transformer support. Defaults + to 1. + context_dim (int, optional): The dimension of the context tensor. Defaults to None. + n_embed (int, optional): Custom support for prediction of discrete ids into codebook + of first stage vq model. Defaults to None. + legacy (bool, optional): Use legacy mode. Defaults to True. + ignorekwargs (dict, optional): Ignore extra keyword arguments. + Example: + >>> cond_stage_config = { + "image_size": 64, + "in_channels": 8, + "model_channels": 160, + "out_channels": 4, + "num_res_blocks": 2, + "attention_resolutions": [16, 8], + "channel_mult": [1, 2, 2, 4], + "num_head_channels": 32 + } + + >>> model = UNetModel(**cond_stage_config) + >>> x = torch.randn(2, 8, 128, 128) + >>> emb = torch.randn(2) + >>> out = model(x, emb) + >>> print(out.shape) + """ + + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + num_res_blocks: int, + attention_resolutions: Union[Set[int], List[int], Tuple[int]], + dropout: float = 0, + channel_mult: Tuple[int] = (1, 2, 4, 8), + conv_resample: bool = True, + dims: int = 2, + num_classes: Optional[int] = None, + use_checkpoint: bool = False, + use_fp16: bool = False, + num_heads: int = -1, + num_head_channels: int = -1, + num_heads_upsample: int = -1, + use_scale_shift_norm: bool = False, + resblock_updown: bool = False, + use_new_attention_order: bool = False, + use_spatial_transformer: bool = False, + transformer_depth: int = 1, + context_dim: Optional[int] = None, + n_embed: Optional[int] = None, + legacy: bool = True, + **ignorekwargs: dict, + ): + super().__init__() + + # If num_heads_upsample is not set, set it to num_heads + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + # If num_heads is not set, raise an error if num_head_channels is not set + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + # If num_head_channels is not set, raise an error if num_heads is not set + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + # Set the instance variables + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + # Set up the time embedding layers + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + # If num_classes is not None, set up the label embedding layer + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + # Set up the input blocks + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + + # parameters for the block attention + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + + # Set up the attention blocks + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + + # If the downsample rate is in the attention resolutions, add an attention block + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + ) + ) + + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + + # If the downsample rate is not the last one, add a downsample block + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + # Set up the middle block parameters + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + + # If use_spatial_transformer is True, set up the spatial transformer + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + # Set up the output blocks + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + # If the downsample rate is in the attention resolutions, add an attention block + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + + # If the downsample rate is in the attention resolutions, add an attention block + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + ) + ) + + # If the downsample rate is the last one, add an upsample block + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + # Set up the output layer + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + # Set up the codebook id predictor layer + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), conv_nd(dims, model_channels, n_embed, 1) + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + + Args: + x (torch.Tensor): An [N x C x ...] Tensor of inputs. + timesteps (torch.Tensor, optional): A 1-D batch of timesteps. + Defaults to None. + context (torch.Tensor, optional): Conditioning plugged in via crossattn. + Defaults to None. + y (torch.Tensor, optional): An [N] Tensor of labels, if class-conditional. + Defaults to None. + + Returns: + torch.Tensor: An [N x C x ...] Tensor of outputs. + """ + # print("aaa") + # print(x.shape) + # print(timesteps.shape) + # print("aaa") + # 1 + "a" + + # Check if y is specified only if the model is class-conditional + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + # Initialize a list to store the hidden states of the input blocks + hs = [] + + # Compute the timestep embeddings and time embeddings + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + # Add label embeddings if the model is class-conditional + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + # Convert the input tensor to the specified data type + h = x.type(self.dtype) + + # Pass the input tensor through the input blocks and store the hidden states + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + + # Pass the output of the input blocks through the middle block + h = self.middle_block(h, emb, context) + + # Pass the output of the middle block through the output blocks in reverse order + for module in self.output_blocks: + # Concatenate the output of the current output block with the corresponding + # hidden state from the input blocks + h = th.cat([h, hs.pop()], dim=1) + + # Pass the concatenated tensor through the current output block + h = module(h, emb, context) + + # Convert the output tensor to the same data type as the input tensor + h = h.type(x.dtype) + + # Return the output tensor or the codebook ID predictions if specified + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) diff --git a/supers2/models/opensr_diffusion/denoiser/utils.py b/supers2/models/opensr_diffusion/denoiser/utils.py new file mode 100644 index 0000000..30792b7 --- /dev/null +++ b/supers2/models/opensr_diffusion/denoiser/utils.py @@ -0,0 +1,1185 @@ +import math +from abc import abstractmethod +from typing import Any, Callable, Optional, Tuple, Union + +import numpy as np +import torch +from einops import einsum, rearrange, repeat +from torch import nn +from torch.nn import functional as F + + +def Normalize(in_channels: int, num_groups: int = 32) -> torch.nn.GroupNorm: + """ + Returns a GroupNorm layer that normalizes the input tensor along the channel dimension. + + Args: + in_channels (int): Number of channels in the input tensor. + num_groups (int): Number of groups to separate the channels into. Default is 32. + + Returns: + torch.nn.GroupNorm: A GroupNorm layer that normalizes the input tensor along the + channel dimension. + + Example: + >>> input_tensor = torch.randn(1, 64, 32, 32) + >>> norm_layer = Normalize(in_channels=64, num_groups=16) + >>> output_tensor = norm_layer(input_tensor) + """ + # Create a GroupNorm layer with the specified number of groups and input channels + # Set eps to a small value to avoid division by zero + # Set affine to True to learn scaling and shifting parameters + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +def conv_nd(dims: int, *args: Union[int, float], **kwargs) -> nn.Module: + """ + Create a 1D, 2D, or 3D convolution module. + + Args: + dims (int): The number of dimensions for the convolution. Must be 1, 2, or 3. + *args (Union[int, float]): Positional arguments to pass to the convolution module constructor. + **kwargs: Keyword arguments to pass to the convolution module constructor. + + Returns: + nn.Module: A convolution module with the specified number of dimensions. + + Raises: + ValueError: If the number of dimensions is not 1, 2, or 3. + + Example: + >>> conv = conv_nd(2, 16, 32, kernel_size=3) + >>> x = torch.randn(1, 16, 32, 32) + >>> out = conv(x) + >>> out.shape + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args: Union[int, float], **kwargs) -> nn.Module: + """ + Create a linear module. + + Args: + *args (Union[int, float]): Positional arguments to pass + to the linear module constructor. + **kwargs: Keyword arguments to pass to the linear module constructor. + + Returns: + nn.Module: A linear module. + + Example: + >>> linear = linear(16, 32) + >>> x = torch.randn(1, 16) + >>> out = linear(x) + >>> out.shape + """ + return nn.Linear(*args, **kwargs) + + +def zero_module(module: nn.Module) -> nn.Module: + """ + Zero out the parameters of a module and return it. + + Args: + module (nn.Module): The module to zero out. + + Returns: + nn.Module: The zeroed-out module. + + Example: + >>> conv = conv_nd(2, 16, 32, kernel_size=3) + >>> conv = zero_module(conv) + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def timestep_embedding( + timesteps: torch.Tensor, + dim: int, + max_period: int = 10000, + repeat_only: bool = False, +) -> torch.Tensor: + """ + Create sinusoidal timestep embeddings. + + Args: + timesteps (torch.Tensor): A 1-D tensor of N indices, one per batch element. These may be fractional. + dim (int): The dimension of the output. + max_period (int): Controls the minimum frequency of the embeddings. + repeat_only (bool): If True, repeat the timestep embeddings instead of computing new ones. + + Returns: + torch.Tensor: An [N x dim] tensor of positional embeddings. + + Example: + >>> timesteps = torch.arange(0, 10) + >>> embeddings = timestep_embedding(timesteps, dim=16) + >>> embeddings.shape + """ + if not repeat_only: + # Compute the frequencies of the sinusoidal embeddings + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + + # Compute the arguments to the sinusoidal functions + args = timesteps[:, None].float() * freqs[None] + + # Compute the sinusoidal embeddings + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + # If the output dimension is odd, add a zero column to the embeddings + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + # Repeat the timestep embeddings instead of computing new ones + embedding = repeat(timesteps, "b -> b d", d=dim) + + return embedding + + +def exists(val: Any) -> bool: + """ + Check if a value exists (i.e., is not None). + + Args: + val (Any): The value to check. + + Returns: + bool: True if the value exists, False otherwise. + """ + return val is not None + + +def default(val: Any, d: Any) -> Any: + """ + Return the value if it exists, otherwise return the default value. + + Args: + val (Any): The value to check. + d (Any or Callable): The default value to return if `val` does not exist. If `d` is a callable, it will be called + with no arguments to generate the default value. + + Returns: + Any: The value if it exists, otherwise the default value. + """ + if exists(val): + return val + return d() if isinstance(d, Callable) else d + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + + Args: + func (callable): The function to evaluate. + inputs (Sequence): The argument sequence to pass to `func`. + params (Sequence): A sequence of parameters `func` depends on but does not explicitly take as arguments. + flag (bool): If False, disable gradient checkpointing. + + Returns: + Any: The output of the function `func`. + + Example: + >>> def my_func(x, y, z): + return x * y + z + >>> x = torch.randn(32, 64) + >>> y = torch.randn(32, 64) + >>> z = torch.randn(32, 64) + >>> output = checkpoint(func=my_func, inputs=(x, y), params=(z,), flag=True) + >>> output.shape + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +def normalization(channels: int) -> nn.Module: + """ + Create a standard normalization layer using group normalization. + + Args: + channels (int): The number of input channels. + + Returns: + nn.Module: A group normalization layer with 32 groups and `channels` input channels. + + Example: + >>> norm = normalization(channels=64) + >>> x = torch.randn(32, 64, 128, 128) + >>> output = norm(x) + >>> output.shape + torch.Size([32, 64, 128, 128]) + """ + + # Create a group normalization layer with 32 groups + return GroupNorm32(32, channels) + + +def count_flops_attn( + model: torch.nn.Module, _x: Tuple[torch.Tensor], y: Tuple[torch.Tensor] +) -> None: + """ + A counter for the `thop` package to count the operations in an attention operation. + + Args: + model (torch.nn.Module): The PyTorch model to count the operations for. + _x (Tuple[torch.Tensor]): The input tensors to the model (not used in this function). + y (Tuple[torch.Tensor]): The output tensors from the model. + + Returns: + None + """ + # Get the batch size, number of channels, and spatial dimensions of the output tensor + b, c, *spatial = y[0].shape + + # Compute the total number of spatial dimensions + num_spatial = int(np.prod(spatial)) + + # We perform two matrix multiplications with the same number of operations. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + + # Add the number of operations to the model's total_ops attribute + model.total_ops += torch.DoubleTensor([matmul_ops]) + + +def convert_module_to_f16(x: nn.Module) -> nn.Module: + """ + Convert a PyTorch module to use 16-bit floating point precision. + + Args: + x (nn.Module): The PyTorch module to convert. + + Returns: + nn.Module: The converted PyTorch module. + """ + pass + + +def convert_module_to_f32(x: nn.Module) -> nn.Module: + """ + Convert a PyTorch module to use 32-bit floating point precision. + + Args: + x (nn.Module): The PyTorch module to convert. + + Returns: + nn.Module: The converted PyTorch module. + """ + pass + + +def avg_pool_nd(dims: int, *args: Union[int, tuple], **kwargs) -> nn.Module: + """ + Create a 1D, 2D, or 3D average pooling module. + + Args: + dims (int): The number of dimensions of the pooling module (1, 2, or 3). + *args (Union[int, tuple]): The positional arguments to pass to the pooling module. + **kwargs: Additional keyword arguments to pass to the pooling module. + + Returns: + nn.Module: A 1D, 2D, or 3D average pooling module. + + Raises: + ValueError: If the number of dimensions is not 1, 2, or 3. + + Example: + >>> pool = avg_pool_nd(2, kernel_size=3, stride=2) + >>> x = torch.randn(1, 3, 32, 32) + >>> y = pool(x) + """ + if dims == 1: + # Create a 1D average pooling module + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + # Create a 2D average pooling module + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + # Create a 3D average pooling module + return nn.AvgPool3d(*args, **kwargs) + else: + # Raise an error if the number of dimensions is not 1, 2, or 3 + raise ValueError(f"Unsupported number of dimensions: {dims}") + + +class AttnBlock(nn.Module): + """ + An attention module that computes attention weights for each spatial location in the input tensor. + + Args: + in_channels (int): Number of channels in the input tensor. + + Attributes: + in_channels (int): Number of channels in the input tensor. + norm (Normalize): Normalization layer for the input tensor. + q (torch.nn.Conv2d): Convolutional layer for computing the query tensor. + k (torch.nn.Conv2d): Convolutional layer for computing the key tensor. + v (torch.nn.Conv2d): Convolutional layer for computing the value tensor. + proj_out (torch.nn.Conv2d): Convolutional layer for projecting the attended tensor. + + Example: + >>> input_tensor = torch.randn(1, 64, 32, 32) + >>> attn_module = AttnBlock(in_channels=64) + >>> output_tensor = attn_module(input_tensor) + """ + + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Computes the output tensor of the attention module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor with the same shape as the input tensor. + """ + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # reshape q to b,hw,c and transpose to b,c,hw + k = k.reshape(b, c, h * w) # reshape k to b,c,hw + w_ = torch.bmm( + q, k + ) # compute attention weights w[b,i,j] = sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) # scale the attention weights + w_ = torch.nn.functional.softmax( + w_, dim=2 + ) # apply softmax to get the attention probabilities + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # transpose w to b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_ + ) # compute the attended values h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) # reshape h_ to b,c,h,w + h_ = self.proj_out(h_) # project the attended values to the output space + + return x + h_ + + +class LinearAttention(nn.Module): + """ + A linear attention module that computes attention weights for each spatial + location in the input tensor. + + Args: + dim (int): Number of channels in the input tensor. + heads (int): Number of attention heads. Defaults to 4. + dim_head (int): Number of channels per attention head. Defaults to 32. + + Example: + >>> input_tensor = torch.randn(1, 64, 32, 32) + >>> attn_module = LinearAttention(dim=64, heads=8, dim_head=16) + >>> output_tensor = attn_module(input_tensor) + >>> output_tensor.shape + """ + + def __init__(self, dim: int, heads: int = 4, dim_head: int = 32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Computes the output tensor of the attention module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor with the same shape as the input tensor. + """ + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class TimestepBlock(nn.Module): + """ + Abstract base class for any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward( + self, x: torch.Tensor, emb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply the module to `x` given `emb` timestep embeddings. + + Args: + x (torch.Tensor): The input tensor to the module. + emb (torch.Tensor): The timestep embeddings to apply to the + input tensor. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing the output + tensor and the updated timestep embeddings. + """ + + +class CrossAttention(nn.Module): + """ + Cross-attention module that computes attention weights between a query + tensor and a context tensor. + + Args: + query_dim (int): The dimension of the query tensor. + context_dim (int, optional): The dimension of the context tensor. If + None, defaults to `query_dim`. + heads (int, optional): The number of attention heads to use. Defaults + to 8. + dim_head (int, optional): The dimension of each attention head. Defaults + to 64. + dropout (float, optional): The dropout probability to use. Defaults to 0. + + Inputs: + - x (torch.Tensor): The query tensor of shape + `(batch_size, query_seq_len, query_dim)`. + - context (torch.Tensor, optional): The context tensor of shape + `(batch_size, context_seq_len, context_dim)`. If None, defaults to `x`. + - mask (torch.Tensor, optional): A boolean mask of shape + `(batch_size, query_seq_len)` indicating which query elements should + be masked out of the attention computation. + + Outputs: + - torch.Tensor: The output tensor of shape `(batch_size, query_seq_len, query_dim)`. + + Example: + >>> query = torch.randn(2, 10, 64) + >>> context = torch.randn(2, 20, 64) + >>> attn = CrossAttention(query_dim=64, context_dim=64, heads=8, dim_head=64, dropout=0.1) + >>> out = attn(x=query, context=context) + >>> out.shape + torch.Size([2, 10, 64]) + """ + + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + # default to self-attention + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + # Reshape queries, keys, and values for multi-head attention + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + + # Aggregated attention weights + sim = einsum("bid, bjd -> bij", q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, "b ... -> b (...)") + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, "b j -> (b h) () j", h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + + return self.to_out(out) + + +class GEGLU(nn.Module): + """ + Gated Exponential Linear Unit (GEGLU) activation function. + + Applies a linear projection to the input tensor, splits the result + into two halves, and applies the GELU function to one half while + leaving the other half unchanged. The output is the element-wise + product of the two halves. + + Args: + dim_in (int): The number of input features. + dim_out (int): The number of output features. + + Example: + >>> x = torch.randn(32, 64) + >>> gelu = GEGLU(64, 128) + >>> y = gelu(x) + >>> y.shape + """ + + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + """ + Apply the GEGLU activation function to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor. + """ + # Apply linear projection and split into two halves + x, gate = self.proj(x).chunk(2, dim=-1) + + # Apply GELU to one half and leave the other half unchanged + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + """ + Feedforward neural network with an optional Gated Exponential Linear Unit (GEGLU) activation function. + + Applies a linear projection to the input tensor, applies the GELU or GEGLU activation function, and applies a + linear projection to the output tensor. + + Args: + dim (int): The number of input features. + dim_out (int, optional): The number of output features. If not provided, defaults to `dim`. + mult (float, optional): The multiplier for the inner dimension of the linear projections. Defaults to 4. + glu (bool, optional): Whether to use the GEGLU activation function instead of GELU. Defaults to False. + dropout (float, optional): The dropout probability. Defaults to 0. + + Example: + >>> x = torch.randn(32, 64) + >>> ff = FeedForward(64, 128, mult=2, glu=True, dropout=0.1) + >>> y = ff(x) + >>> y.shape + """ + + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + + # Define input projection + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + # Define network layers + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Apply the feedforward neural network to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor. + """ + return self.net(x) + + +class CheckpointFunction(torch.autograd.Function): + """ + A PyTorch autograd function that enables gradient checkpointing. + + Gradient checkpointing is a technique for reducing memory usage during backpropagation by recomputing intermediate + activations on-the-fly instead of storing them in memory. This function implements the forward and backward passes + for gradient checkpointing. + + Args: + run_function (callable): The function to evaluate. + length (int): The number of input tensors to `run_function`. + *args: The input tensors and parameters to `run_function`. + + Returns: + Any: The output of `run_function`. + + Example: + >>> def my_func(x, y, z): + ... return x * y + z + >>> x = torch.randn(32, 64) + >>> y = torch.randn(32, 64) + >>> z = torch.randn(32, 64) + >>> output = CheckpointFunction.apply(my_func, 2, x, y, z) + >>> output.shape + torch.Size([32, 64]) + """ + + @staticmethod + def forward(ctx, run_function, length, *args): + """ + Compute the forward pass of the gradient checkpointing function. + + Args: + ctx (torch.autograd.function._ContextMethodMixin): The context object for the autograd function. + run_function (callable): The function to evaluate. + length (int): The number of input tensors to `run_function`. + *args: The input tensors and parameters to `run_function`. + + Returns: + Any: The output of `run_function`. + """ + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + """ + Compute the backward pass of the gradient checkpointing function. + + Args: + ctx (torch.autograd.function._ContextMethodMixin): The context object for + the autograd function. + *output_grads: The gradients of the output tensors. + + Returns: + tuple: The gradients of the input tensors and parameters. + """ + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +class BasicTransformerBlock(nn.Module): + """ + A basic transformer block consisting of a self-attention layer, a feedforward layer, and another self-attention layer. + + Args: + dim (int): The input and output dimension of the block. + n_heads (int): The number of attention heads to use. + d_head (int): The dimension of each attention head. + dropout (float, optional): The dropout probability to use. Default: 0.0. + context_dim (int, optional): The dimension of the context tensor for the second + self-attention layer. If None, the second layer is a self-attention layer. Default: None. + gated_ff (bool, optional): Whether to use a gated feedforward layer. Default: True. + checkpoint (bool, optional): Whether to use gradient checkpointing to reduce memory + usage. Default: True. + + Inputs: + x (torch.Tensor): The input tensor of shape `(batch_size, seq_len, dim)`. + context (torch.Tensor, optional): The context tensor of shape + `(batch_size, seq_len, context_dim)`. If None, the second self-attention + layer is a self-attention layer. Default: None. + + Outputs: + torch.Tensor: The output tensor of shape `(batch_size, seq_len, dim)`. + + Example: + >>> block = BasicTransformerBlock( + dim=512, n_heads=8, d_head=64, + dropout=0.1, context_dim=256, gated_ff=True, checkpoint=True + ) + >>> x = torch.randn(32, 128, 512) + >>> context = torch.randn(32, 128, 256) + >>> output = block(x, context) + >>> output.shape + torch.Size([32, 128, 512]) + """ + + def __init__( + self, + dim: int, + n_heads: int, + d_head: int, + dropout: float = 0.0, + context_dim: Optional[int] = None, + gated_ff: bool = True, + checkpoint: bool = False, + ) -> None: + """ + Initialize the basic transformer block. + + Args: + dim (int): The input and output dimension of the block. + n_heads (int): The number of attention heads to use. + d_head (int): The dimension of each attention head. + dropout (float, optional): The dropout probability to use. Default: 0.0. + context_dim (int, optional): The dimension of the context tensor for the second self-attention layer. If None, the second layer is a self-attention layer. Default: None. + gated_ff (bool, optional): Whether to use a gated feedforward layer. Default: True. + checkpoint (bool, optional): Whether to use gradient checkpointing to reduce memory usage. Default: True. + """ + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward( + self, x: torch.Tensor, context: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Compute the forward pass of the basic transformer block. + + Args: + x (torch.Tensor): The input tensor of shape `(batch_size, seq_len, dim)`. + context (torch.Tensor, optional): The context tensor of shape + `(batch_size, seq_len, context_dim)`. If None, the second self-attention + layer is a self-attention layer. Default: None. + + Returns: + torch.Tensor: The output tensor of shape `(batch_size, seq_len, dim)`. + """ + return checkpoint( + self._forward, (x, context), self.parameters(), False#self.checkpoint + ) + + def _forward( + self, x: torch.Tensor, context: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Compute the forward pass of the basic transformer block. + + Args: + x (torch.Tensor): The input tensor of shape `(batch_size, seq_len, dim)`. + context (torch.Tensor, optional): The context tensor of shape + `(batch_size, seq_len, context_dim)`. If None, the second self-attention + layer is a self-attention layer. Default: None. + + Returns: + torch.Tensor: The output tensor of shape `(batch_size, seq_len, dim)`. + """ + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class GroupNorm32(nn.GroupNorm): + """ + A subclass of `nn.GroupNorm` that casts the input tensor to float32 before + passing it to the parent class's `forward` method, and then casts the output + back to the original data type of the input tensor. + + Args: + num_groups (int): The number of groups to divide the channels into. + num_channels (int): The number of channels in the input tensor. + eps (float, optional): A value added to the denominator for numerical stability. Default: 1e-5. + affine (bool, optional): Whether to apply learnable affine transformations to + the output. Default: True. + + Inputs: + x (torch.Tensor): The input tensor of shape `(batch_size, num_channels, *)`. + + Outputs: + torch.Tensor: The output tensor of the same shape as the input tensor. + + Example: + >>> norm = GroupNorm32(num_groups=32, num_channels=64, eps=1e-5, affine=True) + >>> x = torch.randn(32, 64, 128, 128) + >>> output = norm(x) + >>> output.shape + torch.Size([32, 64, 128, 128]) + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Compute the forward pass of the group normalization layer. + + Args: + x (torch.Tensor): The input tensor of shape `(batch_size, num_channels, *)`. + + Returns: + torch.Tensor: The output tensor of the same shape as the input tensor. + """ + return super().forward(x.float()).type(x.dtype) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + + Args: + n_heads (int): The number of attention heads. + + Inputs: + qkv (torch.Tensor): An [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + + Outputs: + torch.Tensor: An [N x (H * C) x T] tensor after attention. + + Example: + >>> attn = QKVAttention(n_heads=8) + >>> x = torch.randn(32, 24 * 8, 128) + >>> output = attn(x) + >>> output.shape + torch.Size([32, 192, 128]) + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + Args: + qkv (torch.Tensor): An [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + + Returns: + torch.Tensor: An [N x (H * C) x T] tensor after attention. + """ + # Get the batch size, width, and length of the input tensor + bs, width, length = qkv.shape + + # Ensure that the width is divisible by 3 * n_heads + assert width % (3 * self.n_heads) == 0 + + # Compute the number of channels per head + ch = width // (3 * self.n_heads) + + # Split the input tensor into Q, K, and V tensors + q, k, v = qkv.chunk(3, dim=1) + + # Compute the scaling factor for the dot product + scale = 1 / math.sqrt(math.sqrt(ch)) + + # Compute the dot product of Q and K, and apply the scaling factor + weight = torch.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) + + # Apply softmax to the dot product to get the attention weights + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + + # Compute the weighted sum of V using the attention weights + a = torch.einsum( + "bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length) + ) + + # Reshape the output tensor to the original shape + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/output + heads shaping. + """ + + def __init__(self, n_heads: int): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv: torch.Tensor) -> torch.Tensor: + """ + Apply QKV attention. + + Args: + qkv (torch.Tensor): An [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + + Returns: + torch.Tensor: An [N x (H * C) x T] tensor after attention. + + Example: + >>> attn = QKVAttentionLegacy(n_heads=8) + >>> x = torch.randn(32, 24 * 8, 128) + >>> output = attn(x) + >>> output.shape + torch.Size([32, 192, 128]) + """ + # Get the batch size, width, and length of the input tensor + bs, width, length = qkv.shape + + # Ensure that the width is divisible by 3 * n_heads + assert width % (3 * self.n_heads) == 0 + + # Compute the number of channels per head + ch = width // (3 * self.n_heads) + + # Split the input tensor into Q, K, and V tensors + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + + # Compute the scaling factor for the dot product + scale = 1 / math.sqrt(math.sqrt(ch)) + + # Compute the dot product of Q and K, and apply the scaling factor + weight = torch.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + + # Apply softmax to the dot product to get the attention weights + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + + # Compute the weighted sum of V using the attention weights + a = torch.einsum("bts,bcs->bct", weight, v) + + # Reshape the output tensor to the original shape + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops( + model: nn.Module, _x: Tuple[torch.Tensor], y: Tuple[torch.Tensor] + ) -> None: + """ + A counter for the `thop` package to count the operations in an attention operation. + + Args: + model (nn.Module): The PyTorch model to count the operations for. + _x (Tuple[torch.Tensor]): The input tensors to the model (not used in this function). + y (Tuple[torch.Tensor]): The output tensors from the model. + + Returns: + None + + Example: + >>> macs, params = thop.profile( + ... model, + ... inputs=(inputs, timestamps), + ... custom_ops={QKVAttention: QKVAttention.count_flops}, + ... ) + """ + count_flops_attn(model, _x, y) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + Args: + channels (int): The number of input and output channels. + use_conv (bool): Whether to apply a convolution after upsampling. + dims (int, optional): The number of dimensions of the input tensor (1, 2, or 3). + Defaults to 2. + out_channels (int, optional): The number of output channels. + Defaults to None (same as input channels). + padding (int, optional): The amount of padding to apply to the convolution. + Defaults to 1. + + Example: + >>> upsample = Upsample(channels=64, use_conv=True, dims=2, out_channels=128, padding=1) + >>> x = torch.randn(32, 64, 128, 128) + >>> output = upsample(x) + >>> output.shape + torch.Size([32, 128, 128, 128]) + """ + + def __init__( + self, + channels: int, + use_conv: bool, + dims: int = 2, + out_channels: Optional[int] = None, + padding: int = 1, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + # Create a convolutional layer with the specified number of dimensions + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + """ + Apply upsampling to the input tensor. + + Args: + x (torch.Tensor): The input tensor to upsample. + + Returns: + torch.Tensor: The upsampled tensor. + """ + # Ensure that the input tensor has the correct number of channels + assert x.shape[1] == self.channels + # Upsample the input tensor using nearest-neighbor interpolation + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + # Apply the convolutional layer if necessary + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + Args: + channels (int): The number of input and output channels. + use_conv (bool): Whether to apply a convolution after downsampling. + dims (int, optional): The number of dimensions of the input tensor (1, 2, or 3). Defaults to 2. + out_channels (int, optional): The number of output channels. Defaults to None + (same as input channels). + padding (int, optional): The amount of padding to apply to the convolution. Defaults to 1. + + Raises: + AssertionError: If the input tensor does not have the correct number of channels. + + Example: + >>> downsample = Downsample(64, use_conv=True, dims=2, out_channels=128, padding=1) + >>> x = torch.randn(1, 64, 32, 32) + >>> y = downsample(x) + """ + + def __init__( + self, + channels: int, + use_conv: bool, + dims: int = 2, + out_channels: Optional[int] = None, + padding: int = 1, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + # Create a convolutional layer with the specified number of dimensions + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + # Create an average pooling layer with the specified number of dimensions + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + """ + Apply downsampling to the input tensor. + + Args: + x (torch.Tensor): The input tensor to downsample. + + Returns: + torch.Tensor: The downsampled tensor. + + Raises: + AssertionError: If the input tensor does not have the correct number of channels. + """ + + # Ensure that the input tensor has the correct number of channels + assert x.shape[1] == self.channels + + # Apply the convolutional or pooling layer + return self.op(x) diff --git a/supers2/models/opensr_diffusion/diffusion/__init__.py b/supers2/models/opensr_diffusion/diffusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/supers2/models/opensr_diffusion/diffusion/latentdiffusion.py b/supers2/models/opensr_diffusion/diffusion/latentdiffusion.py new file mode 100644 index 0000000..635170e --- /dev/null +++ b/supers2/models/opensr_diffusion/diffusion/latentdiffusion.py @@ -0,0 +1,932 @@ +from contextlib import contextmanager +from functools import partial +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn + + +from supers2.models.opensr_diffusion.autoencoder.autoencoder import (AutoencoderKL, + DiagonalGaussianDistribution) +from supers2.models.opensr_diffusion.denoiser.unet import UNetModel +from supers2.models.opensr_diffusion.diffusion.utils import (LitEma, count_params, default, + disabled_train, exists, + extract_into_tensor, + make_beta_schedule, + make_convolutional_sample) + +__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} + + +class DiffusionWrapper(nn.Module): + """ + A wrapper around a UNetModel that supports different types of conditioning. + + Args: + diff_model_config (dict): A dictionary of configuration options for the UNetModel. + conditioning_key (str, optional): The type of conditioning to use + (None, 'concat', 'crossattn', 'hybrid', or 'adm'). Defaults to None. + + Raises: + AssertionError: If the conditioning key is not one of the supported values. + + Example: + >>> diff_model_config = {'in_channels': 3, 'out_channels': 3, 'num_filters': 32} + >>> wrapper = DiffusionWrapper(diff_model_config, conditioning_key='concat') + >>> x = torch.randn(1, 3, 256, 256) + >>> t = torch.randn(1) + >>> c_concat = [torch.randn(1, 32, 256, 256)] + >>> y = wrapper(x, t, c_concat=c_concat) + """ + + def __init__(self, diff_model_config: dict, conditioning_key: Optional[str] = None): + super().__init__() + self.diffusion_model = UNetModel(**diff_model_config) + self.conditioning_key = conditioning_key + + ckey_options = [None, "concat", "crossattn", "hybrid", "adm"] + assert ( + self.conditioning_key in ckey_options + ), f"Unsupported conditioning key: {self.conditioning_key}" + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + c_concat: Optional[List[torch.Tensor]] = None, + c_crossattn: Optional[List[torch.Tensor]] = None, + ) -> torch.Tensor: + """ + Apply the diffusion model to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + t (torch.Tensor): The diffusion time. + c_concat (List[torch.Tensor], optional): A list of tensors to concatenate with the input tensor. + Used when conditioning_key is 'concat'. Defaults to None. + c_crossattn (List[torch.Tensor], optional): A list of tensors to use for cross-attention. + Used when conditioning_key is 'crossattn', 'hybrid', or 'adm'. Defaults to None. + + Returns: + torch.Tensor: The output tensor. + + Raises: + NotImplementedError: If the conditioning key is not one of the supported values. + """ + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + return out + + +class DDPM(nn.Module): + """This class implements the classic DDPM (Diffusion Models) with Gaussian diffusion + in image space. + + Args: + unet_config (dict): A dictionary of configuration options for the UNetModel. + timesteps (int): The number of diffusion timesteps to use. + beta_schedule (str): The type of beta schedule to use (linear, cosine, or fixed). + use_ema (bool): Whether to use exponential moving averages (EMAs) of the model weights during training. + first_stage_key (str): The key to use for the first stage of the model (either "image" or "noise"). + linear_start (float): The starting value for the linear beta schedule. + linear_end (float): The ending value for the linear beta schedule. + cosine_s (float): The scaling factor for the cosine beta schedule. + given_betas (list): A list of beta values to use for the fixed beta schedule. + v_posterior (float): The weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta. + conditioning_key (str): The type of conditioning to use (None, 'concat', 'crossattn', 'hybrid', or 'adm'). + parameterization (str): The type of parameterization to use for the diffusion process (either "eps" or "x0"). + use_positional_encodings (bool): Whether to use positional encodings for the input. + + Methods: + register_schedule: Registers the schedule for the betas and alphas. + get_input: Gets the input from the DataLoader and rearranges it. + decode_first_stage: Decodes the first stage of the model. + ema_scope: Switches to EMA weights during training. + + Attributes: + parameterization (str): The type of parameterization used for the diffusion process. + cond_stage_model (None): The conditioning stage model (not used in this implementation). + first_stage_key (str): The key used for the first stage of the model. + use_positional_encodings (bool): Whether positional encodings are used for the input. + model (DiffusionWrapper): The diffusion model. + use_ema (bool): Whether EMAs of the model weights are used during training. + model_ema (LitEma): The EMA of the model weights. + v_posterior (float): The weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta. + + Example: + >>> unet_config = { + 'in_channels': 3, + 'model_channels': 160, + 'num_res_blocks': 2, + 'attention_resolutions': [16, 8], + 'channel_mult': [1, 2, 2, 4], + 'num_head_channels': 32 + } + >>> model = DDPM( + unet_config, timesteps=1000, beta_schedule='linear', + use_ema=True, first_stage_key='image' + ) + """ + + def __init__( + self, + unet_config: Dict[str, Any], + timesteps: int = 1000, + beta_schedule: str = "linear", + use_ema: bool = True, + first_stage_key: str = "image", + linear_start: float = 1e-4, + linear_end: float = 2e-2, + cosine_s: float = 8e-3, + given_betas: Optional[List[float]] = None, + v_posterior: float = 0.0, + conditioning_key: Optional[str] = None, + parameterization: str = "eps", + use_positional_encodings: bool = False, + ) -> None: + super().__init__() + assert parameterization in [ + "eps", + "x0", + ], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + + print( + f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode" + ) + + self.cond_stage_model = None + self.first_stage_key = first_stage_key + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + + count_params(self.model, verbose=True) + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.v_posterior = v_posterior + + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + + def register_schedule( + self, + given_betas: Optional[List[float]] = None, + beta_schedule: str = "linear", + timesteps: int = 1000, + linear_start: float = 1e-4, + linear_end: float = 2e-2, + cosine_s: float = 8e-3, + ) -> None: + """ + Registers the schedule for the betas and alphas. + + Args: + given_betas (list, optional): A list of beta values to use for the fixed beta schedule. + Defaults to None. + beta_schedule (str, optional): The type of beta schedule to use (linear, cosine, or fixed). + Defaults to "linear". + timesteps (int, optional): The number of diffusion timesteps to use. Defaults to 1000. + linear_start (float, optional): The starting value for the linear beta schedule. Defaults to 1e-4. + linear_end (float, optional): The ending value for the linear beta schedule. Defaults to 2e-2. + cosine_s (float, optional): The scaling factor for the cosine beta schedule. Defaults to 8e-3. + """ + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped", + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1", + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + ) + self.register_buffer( + "posterior_mean_coef2", + to_torch( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + ) + + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) + elif self.parameterization == "x0": + lvlb_weights = ( + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + ) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + def get_input(self, batch: Dict[str, torch.Tensor], k: str) -> torch.Tensor: + """ + Gets the input from the DataLoader and rearranges it. + + Args: + batch (Dict[str, torch.Tensor]): The batch of data from the DataLoader. + k (str): The key for the input tensor in the batch. + + Returns: + torch.Tensor: The input tensor, rearranged and converted to float. + """ + + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + + x = x.to(memory_format=torch.contiguous_format).float() + + return x + + @contextmanager + def ema_scope(self, context: Optional[str] = None) -> Generator[None, None, None]: + """ + A context manager that switches to EMA weights during training. + + Args: + context (Optional[str]): A string to print when switching to and from EMA weights. + + Yields: + None + """ + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + + def decode_first_stage(self, z: torch.Tensor) -> torch.Tensor: + """ + Decodes the first stage of the model. + + Args: + z (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The decoded output tensor. + """ + + z = 1.0 / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold( + z, ks, stride, uf=uf + ) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view( + (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) + ) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + output_list = [ + self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1]) + ] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + + else: + return self.first_stage_model.decode(z) + + else: + return self.first_stage_model.decode(z) + + +class LatentDiffusion(DDPM): + """ + LatentDiffusion is a class that extends the DDPM class and implements a diffusion + model with a latent variable. The model consists of two stages: a first stage that + encodes the input tensor into a latent tensor, and a second stage that decodes the + latent tensor into the output tensor. The model also has a conditional stage that + takes a conditioning tensor as input and produces a learned conditioning tensor + that is used to condition the first and second stages of the model. The class + provides methods for encoding and decoding tensors, computing the output tensor + and loss, and sampling from the distribution at a given latent tensor and timestep. + The class also provides methods for registering and applying schedules, and for + getting and setting the scale factor and conditioning key. + + Methods: + register_schedule(self, schedule: Schedule) -> None: Registers the given schedule + with the model. + make_cond_schedule(self, schedule: Schedule) -> Schedule: Returns a new schedule + with the given schedule applied to the conditional stage of the model. + encode_first_stage(self, x: torch.Tensor, t: int) -> torch.Tensor: Encodes the given + input tensor with the first stage of the model for the given timestep. + get_first_stage_encoding(self, x: torch.Tensor, t: int) -> torch.Tensor: Returns the + encoding of the given input tensor with the first stage of the model for the + given timestep. + get_learned_conditioning(self, x: torch.Tensor, t: int, y: Optional[torch.Tensor] = None) -> torch.Tensor: + Returns the learned conditioning tensor for the given input + tensor, timestep, and conditioning tensor. + get_input(self, x: torch.Tensor, t: int, y: Optional[torch.Tensor] = None) -> torch.Tensor: + Returns the input tensor for the given input tensor, timestep, and + conditioning tensor. + compute(self, x: torch.Tensor, t: int, y: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + Computes the output tensor and loss for the given input tensor, + timestep, and conditioning tensor. + apply_model(self, x: torch.Tensor, t: int, y: Optional[torch.Tensor] = None) -> torch.Tensor: Applies + the model to the given input tensor, timestep, and conditioning tensor. + get_fold_unfold(self, ks: int, stride: int, vqf: int) -> Tuple[Callable, Callable]: Returns the fold + and unfold functions for the given kernel size, stride, and vector quantization factor. + forward(self, x: torch.Tensor, t: int, y: Optional[torch.Tensor] = None) -> torch.Tensor: Computes the + output tensor for the given input tensor, timestep, and conditioning tensor. + q_sample(self, z: torch.Tensor, t: int, eps: Optional[torch.Tensor] = None) -> torch.Tensor: Samples + from the distribution at the given latent tensor and timestep. + """ + + def __init__( + self, + first_stage_config: Dict[str, Any], + cond_stage_config: Union[str, Dict[str, Any]], + num_timesteps_cond: Optional[int] = None, + cond_stage_key: str = "image", + cond_stage_trainable: bool = False, + concat_mode: bool = True, + cond_stage_forward: Optional[Callable] = None, + conditioning_key: Optional[str] = None, + scale_factor: float = 1.0, + scale_by_std: bool = False, + *args: Any, + **kwargs: Any, + ): + """ + Initializes the LatentDiffusion model. + + Args: + first_stage_config (Dict[str, Any]): The configuration for the first stage of the model. + cond_stage_config (Union[str, Dict[str, Any]]): The configuration for the conditional stage of the model. + num_timesteps_cond (Optional[int]): The number of timesteps for the conditional stage of the model. + cond_stage_key (str): The key for the conditional stage of the model. + cond_stage_trainable (bool): Whether the conditional stage of the model is trainable. + concat_mode (bool): Whether to use concatenation or cross-attention for the conditioning. + cond_stage_forward (Optional[Callable]): A function to apply to the output of the conditional stage of the model. + conditioning_key (Optional[str]): The key for the conditioning. + scale_factor (float): The scale factor for the input tensor. + scale_by_std (bool): Whether to scale the input tensor by its standard deviation. + *args (Any): Additional arguments. + **kwargs (Any): Additional keyword arguments. + """ + + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs["timesteps"] + + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = "concat" if concat_mode else "crossattn" + if cond_stage_config == "__is_unconditional__": + conditioning_key = None + + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer("scale_factor", torch.tensor(scale_factor)) + + self.cond_stage_forward = cond_stage_forward + + # Set Fusion parameters (SIMON) + # TODO: We only have SISR parameters + self.sr_type = "SISR" + + # Setup the AutoencoderKL model + embed_dim = first_stage_config["embed_dim"] # extract embedded dim fro first stage config + self.first_stage_model = AutoencoderKL(first_stage_config, embed_dim=embed_dim) + self.first_stage_model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + # Setup the Unet model + self.cond_stage_model = torch.nn.Identity() # Unet + self.cond_stage_model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + + def register_schedule( + self, + given_betas: Optional[Union[float, torch.Tensor]] = None, + beta_schedule: str = "linear", + timesteps: int = 1000, + linear_start: float = 1e-4, + linear_end: float = 2e-2, + cosine_s: float = 8e-3, + ) -> None: + """ + Registers the given schedule with the model. + + Args: + given_betas (Optional[Union[float, torch.Tensor]]): The betas for the schedule. + beta_schedule (str): The type of beta schedule to use. + timesteps (int): The number of timesteps for the schedule. + linear_start (float): The start value for the linear schedule. + linear_end (float): The end value for the linear schedule. + cosine_s (float): The scale factor for the cosine schedule. + """ + super().register_schedule( + given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s + ) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def make_cond_schedule(self) -> None: + """ + Shortens the schedule for the conditional stage of the model. + """ + self.cond_ids = torch.full( + size=(self.num_timesteps,), + fill_value=self.num_timesteps - 1, + dtype=torch.long, + ) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) + ).long() + self.cond_ids[: self.num_timesteps_cond] = ids + + + def encode_first_stage(self, x: torch.Tensor) -> torch.Tensor: + """ + Encodes the given input tensor with the first stage of the model. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The encoded output tensor. + """ + return self.first_stage_model.encode(x) + + + def get_first_stage_encoding( + self, encoder_posterior: Union[DiagonalGaussianDistribution, torch.Tensor] + ) -> torch.Tensor: + """ + Returns the encoding of the given input tensor with the first stage of the + model for the given timestep. + + Args: + encoder_posterior (Union[DiagonalGaussianDistribution, torch.Tensor]): The + encoder posterior. + + Returns: + torch.Tensor: The encoding of the input tensor. + """ + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + def get_learned_conditioning(self, c: torch.Tensor) -> torch.Tensor: + """ + Returns the learned conditioning tensor for the given input tensor. + + Args: + c (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The learned conditioning tensor. + """ + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, "encode") and callable( + self.cond_stage_model.encode + ): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + # cond stage model is identity + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def get_input( + self, + batch: torch.Tensor, + k: int, + return_first_stage_outputs: bool = False, + force_c_encode: bool = False, + cond_key: Optional[str] = None, + return_original_cond: bool = False, + bs: Optional[int] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Returns the input tensor for the given batch and timestep. + + Args: + batch (torch.Tensor): The input batch tensor. + k (int): The timestep. + return_first_stage_outputs (bool): Whether to return the outputs of the first stage of the model. + force_c_encode (bool): Whether to force encoding of the conditioning tensor. + cond_key (Optional[str]): The key for the conditioning tensor. + return_original_cond (bool): Whether to return the original conditioning tensor. + bs (Optional[int]): The batch size. + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: The input tensor, the outputs of the + first stage of the model (if `return_first_stage_outputs` is `True`), and the encoded conditioning tensor + (if `force_c_encode` is `True` and `cond_key` is not `None`). + """ + + # k = first_stage_key on this SR example + x = super().get_input(batch, k) # line 333 + + if bs is not None: + x = x[:bs] + x = x.to(self.device) + + # perform always for HR and for HR only of SISR + if self.sr_type == "SISR" or k == "image": + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None: + # self.model.conditioning_key = "image" in SR example + + if cond_key is None: + cond_key = self.cond_stage_key + + if cond_key != self.first_stage_key: + if cond_key in ["caption", "coordinates_bbox"]: + xc = batch[cond_key] + elif cond_key == "class_label": + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + # import pudb; pudb.set_trace() + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + # BUG if use_positional_encodings is True + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, "pos_x": pos_x, "pos_y": pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {"pos_x": pos_x, "pos_y": pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + + """ + # overwrite LR original with encoded LR if wanted + self.encode_conditioning = True + if self.encode_conditioning==True and self.sr_type=="SISR": + #print("Encoding conditioning!") + # try to upsample->encode conditioning + c = torch.nn.functional.interpolate(out[1], size=(512,512), mode='bilinear', align_corners=False) + # encode c + c = self.encode_first_stage(c).sample() + out[1] = c + """ + + + return out + + def compute( + self, example: torch.Tensor, custom_steps: int = 200, temperature: float = 1.0 + ) -> torch.Tensor: + """ + Performs inference on the given example tensor. + + Args: + example (torch.Tensor): The example tensor. + custom_steps (int): The number of steps to perform. + temperature (float): The temperature to use. + + Returns: + torch.Tensor: The output tensor. + """ + guider = None + ckwargs = None + ddim_use_x0_pred = False + temperature = temperature + eta = 1.0 + custom_shape = None + + if hasattr(self, "split_input_params"): + delattr(self, "split_input_params") + + logs = make_convolutional_sample( + example, + self, + custom_steps=custom_steps, + eta=eta, + quantize_x0=False, + custom_shape=custom_shape, + temperature=temperature, + noise_dropout=0.0, + corrector=guider, + corrector_kwargs=ckwargs, + x_T=None, + ddim_use_x0_pred=ddim_use_x0_pred, + ) + + return logs["sample"] + + def apply_model( + self, + x_noisy: torch.Tensor, + t: int, + cond: Optional[torch.Tensor] = None, + return_ids: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Applies the model to the given noisy input tensor. + + Args: + x_noisy (torch.Tensor): The noisy input tensor. + t (int): The timestep. + cond (Optional[torch.Tensor]): The conditioning tensor. + return_ids (bool): Whether to return the IDs of the diffusion process. + + Returns: + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The output tensor, and optionally the IDs of the + diffusion process. + """ + + if isinstance(cond, dict): + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = ( + "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn" + ) + cond = {key: cond} + + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def get_fold_unfold( + self, x: torch.Tensor, kernel_size: int, stride: int, uf: int = 1, df: int = 1 + ) -> Tuple[nn.Conv2d, nn.ConvTranspose2d]: + """ + Returns the fold and unfold convolutional layers for the given input tensor. + + Args: + x (torch.Tensor): The input tensor. + kernel_size (int): The kernel size. + stride (int): The stride. + uf (int): The unfold factor. + df (int): The fold factor. + + Returns: + Tuple[nn.Conv2d, nn.ConvTranspose2d]: The fold and unfold convolutional layers. + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict( + kernel_size=kernel_size, dilation=1, padding=0, stride=stride + ) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting( + kernel_size[0], kernel_size[1], Ly, Lx, x.device + ).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict( + kernel_size=kernel_size, dilation=1, padding=0, stride=stride + ) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict( + kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, + padding=0, + stride=(stride[0] * uf, stride[1] * uf), + ) + fold = torch.nn.Fold( + output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2 + ) + + weighting = self.get_weighting( + kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device + ).to(x.dtype) + normalization = fold(weighting).view( + 1, 1, h * uf, w * uf + ) # normalizes the overlap + weighting = weighting.view( + (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx) + ) + + elif df > 1 and uf == 1: + fold_params = dict( + kernel_size=kernel_size, dilation=1, padding=0, stride=stride + ) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict( + kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, + padding=0, + stride=(stride[0] // df, stride[1] // df), + ) + fold = torch.nn.Fold( + output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2 + ) + + weighting = self.get_weighting( + kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device + ).to(x.dtype) + normalization = fold(weighting).view( + 1, 1, h // df, w // df + ) # normalizes the overlap + weighting = weighting.view( + (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx) + ) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + def forward( + self, x: torch.Tensor, c: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + """ + Computes the forward pass of the model. + + Args: + x (torch.Tensor): The input tensor. + c (torch.Tensor): The conditioning tensor. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + torch.Tensor: The output tensor. + """ + t = torch.randint( + 0, self.num_timesteps, (x.shape[0],), device=self.device + ).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: # This is FALSE in our case + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option # TRUE in our case + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + + return self.p_losses(x, c, t, *args, **kwargs) + + def q_sample( + self, x_start: torch.Tensor, t: int, noise: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Samples from the posterior distribution at the given timestep. + + Args: + x_start (torch.Tensor): The starting tensor. + t (int): The timestep. + noise (Optional[torch.Tensor]): The noise tensor. + + Returns: + torch.Tensor: The sampled tensor. + """ + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) diff --git a/supers2/models/opensr_diffusion/diffusion/utils.py b/supers2/models/opensr_diffusion/diffusion/utils.py new file mode 100644 index 0000000..879216b --- /dev/null +++ b/supers2/models/opensr_diffusion/diffusion/utils.py @@ -0,0 +1,901 @@ +import time +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor +from torch import nn as nn +from tqdm import tqdm + + +def exists(val: Any) -> bool: + """ + Returns whether the given value exists (i.e., is not None). + + Args: + val (Any): The value to check. + + Returns: + bool: Whether the value exists. + """ + return val is not None + + +def default(val: Any, d: Callable) -> Any: + """ + Returns the given value if it exists, otherwise returns the default value. + + Args: + val (Any): The value to check. + d (Callable): The default value or function to generate the default value. + + Returns: + Any: The given value or the default value. + """ + if exists(val): + return val + return d() if callable(d) else d + + +def count_params(model: nn.Module, verbose: bool = False) -> int: + """ + Returns the total number of parameters in the given model. + + Args: + model (nn.Module): The model. + verbose (bool): Whether to print the number of parameters. + + Returns: + int: The total number of parameters. + """ + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def disabled_train(self, mode: bool = True) -> nn.Module: + """ + Overwrites the `train` method of the model to disable changing the mode. + + Args: + mode (bool): Whether to enable or disable training mode. + + Returns: + nn.Module: The model. + """ + return self + + +def make_convolutional_sample( + batch: Tensor, + model: nn.Module, + custom_steps: Optional[Union[int, Tuple[int, int]]] = None, + eta: float = 1.0, + quantize_x0: bool = False, + custom_shape: Optional[Tuple[int, int]] = None, + temperature: float = 1.0, + noise_dropout: float = 0.0, + corrector: Optional[nn.Module] = None, + corrector_kwargs: Optional[dict] = None, + x_T: Optional[Tensor] = None, + ddim_use_x0_pred: bool = False, +) -> Tuple[Tensor, Optional[Tensor]]: + """ + Generates a convolutional sample using the given model. + + Args: + batch (Tensor): The input batch tensor. + model (nn.Module): The model to use for sampling. + custom_steps (Optional[Union[int, Tuple[int, int]]]): The custom number of steps. + eta (float): The eta value. + quantize_x0 (bool): Whether to quantize the initial sample. + custom_shape (Optional[Tuple[int, int]]): The custom shape. + temperature (float): The temperature value. + noise_dropout (float): The noise dropout value. + corrector (Optional[nn.Module]): The corrector module. + corrector_kwargs (Optional[dict]): The corrector module keyword arguments. + x_T (Optional[Tensor]): The target tensor. + ddim_use_x0_pred (bool): Whether to use x0 prediction for DDim. + + Returns: + Tuple[Tensor, Optional[Tensor]]: The generated sample tensor and the + target tensor (if provided). + """ + # create an empty dictionary to store the log + log = dict() + + # get the input data and conditioning from the model + z, c, x, xrec, xc = model.get_input( + batch, + model.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=not ( + hasattr(model, "split_input_params") + and model.cond_stage_key == "coordinates_bbox" + ), + return_original_cond=True, + ) + + # if custom_shape is not None, generate random noise of the specified shape + if custom_shape is not None: + z = torch.randn(custom_shape) + print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}") + + # store the input and reconstruction in the log + log["input"] = x + log["reconstruction"] = xrec + + # sample from the model using convsample_ddim + with model.ema_scope("Plotting"): + t0 = time.time() + sample, intermediates = convsample_ddim( + model=model, + cond=c, + steps=custom_steps, + shape=z.shape, + eta=eta, + quantize_x0=quantize_x0, + noise_dropout=noise_dropout, + mask=None, + x0=None, + temperature=temperature, + score_corrector=corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + ) + t1 = time.time() + + # if ddim_use_x0_pred is True, use the predicted x0 from the intermediates + if ddim_use_x0_pred: + sample = intermediates["pred_x0"][-1] + + # decode the sample to get the generated image + x_sample = model.decode_first_stage(sample) + + # try to decode the sample without quantization to get the unquantized image + try: + x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) + log["sample_noquant"] = x_sample_noquant + log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) + except: + pass + + # store the generated image, time taken, and other information in the log + log["sample"] = x_sample + log["time"] = t1 - t0 + + # return the log + return log + + +def disabled_train(self: nn.Module, mode: bool = True) -> nn.Module: + """ + Overwrites the `train` method of the model to disable changing the mode. + + Args: + mode (bool): Whether to enable or disable training mode. + + Returns: + nn.Module: The model. + """ + return self + + +def convsample_ddim( + model: nn.Module, + cond: Tensor, + steps: int, + shape: Tuple[int, int], + eta: float = 1.0, + callback: Optional[callable] = None, + noise_dropout: float = 0.0, + normals_sequence: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + x0: Optional[Tensor] = None, + quantize_x0: bool = False, + temperature: float = 1.0, + score_corrector: Optional[nn.Module] = None, + corrector_kwargs: Optional[dict] = None, + x_T: Optional[Tensor] = None, +) -> Tuple[Tensor, Optional[Tensor]]: + """ + Generates a convolutional sample using the given model and conditioning tensor. + + Args: + model (nn.Module): The model to use for sampling. + cond (Tensor): The conditioning tensor. + steps (int): The number of steps. + shape (Tuple[int, int]): The shape of the sample. + eta (float): The eta value. + callback (Optional[callable]): The callback function. + normals_sequence (Optional[Tensor]): The normals sequence tensor. + noise_dropout (float): The noise dropout value. + mask (Optional[Tensor]): The mask tensor. + x0 (Optional[Tensor]): The initial sample tensor. + quantize_x0 (bool): Whether to quantize the initial sample. + temperature (float): The temperature value. + score_corrector (Optional[nn.Module]): The score corrector module. + corrector_kwargs (Optional[dict]): The score corrector module keyword arguments. + x_T (Optional[Tensor]): The target tensor. + + Returns: + Tuple[Tensor, Optional[Tensor]]: The generated sample tensor and the target tensor (if provided). + """ + ddim = DDIMSampler(model) + bs = shape[0] # dont know where this comes from but wayne + shape = shape[1:] # cut batch dim + print(f"Sampling with eta = {eta}; steps: {steps}") + samples, intermediates = ddim.sample( + steps, + batch_size=bs, + shape=shape, + conditioning=cond, + callback=callback, + normals_sequence=normals_sequence, + quantize_x0=quantize_x0, + eta=eta, + mask=mask, + x0=x0, + temperature=temperature, + verbose=False, + score_corrector=score_corrector, + noise_dropout=noise_dropout, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + ) + + return samples, intermediates + + +def make_ddim_sampling_parameters( + alphacums: np.ndarray, ddim_timesteps: np.ndarray, eta: float, verbose: bool = True +) -> tuple: + """ + Computes the variance schedule for the ddim sampler, based on the given parameters. + + Args: + alphacums (np.ndarray): Array of cumulative alpha values. + ddim_timesteps (np.ndarray): Array of timesteps to use for computing alphas. + eta (float): Scaling factor for computing sigmas. + verbose (bool, optional): Whether to print out the selected alphas and sigmas. Defaults to True. + + Returns: + tuple: A tuple containing three arrays: sigmas, alphas, and alphas_prev. + sigmas (np.ndarray): Array of sigma values for each timestep. + alphas (np.ndarray): Array of alpha values for each timestep. + alphas_prev (np.ndarray): Array of alpha values for the previous timestep. + """ + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + +def make_ddim_timesteps( + ddim_discr_method: str, + num_ddim_timesteps: int, + num_ddpm_timesteps: int, + verbose: bool = True, +) -> np.ndarray: + """ + Computes the timesteps to use for computing alphas in the ddim sampler. + + Args: + ddim_discr_method (str): The method to use for discretizing the timesteps. + Must be either 'uniform' or 'quad'. + num_ddim_timesteps (int): The number of timesteps to use for computing alphas. + num_ddpm_timesteps (int): The total number of timesteps in the DDPM model. + verbose (bool, optional): Whether to print out the selected timesteps. Defaults to True. + + Returns: + np.ndarray: An array of timesteps to use for computing alphas in the ddim sampler. + """ + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def noise_like(shape: tuple, device: str, repeat: bool = False) -> torch.Tensor: + """ + Generates noise with the same shape as the given tensor. + + Args: + shape (tuple): The shape of the tensor to generate noise for. + device (str): The device to place the noise tensor on. + repeat (bool, optional): Whether to repeat the same noise across the batch dimension. Defaults to False. + + Returns: + torch.Tensor: A tensor of noise with the same shape as the input tensor. + """ + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +class DDIMSampler(object): + def __init__(self, model: object, schedule: str = "linear", **kwargs: dict) -> None: + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + self.device = model.device + + def register_buffer(self, name: str, attr: torch.Tensor) -> None: + if type(attr) == torch.Tensor: + if attr.device != torch.device(self.device): + attr = attr.to(torch.device(self.device)) + setattr(self, name, attr) + + def make_schedule( + self, + ddim_num_steps: int, + ddim_discretize: str = "uniform", + ddim_eta: float = 0.0, + verbose: bool = True, + ) -> None: + # make ddim timesteps. these are the timesteps at which we compute alphas + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + + # get alphas_cumprod from the model + alphas_cumprod = self.model.alphas_cumprod + + # check if alphas_cumprod is defined for each timestep + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + + # define a function to convert tensor to torch tensor + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + # register buffers for betas, alphas_cumprod, and alphas_cumprod_prev + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + + # calculate sigmas for original sampling steps + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + def sample( + self, + S: int, + batch_size: int, + shape: Tuple[int, int, int], + conditioning: Optional[torch.Tensor] = None, + callback: Optional[callable] = None, + img_callback: Optional[callable] = None, + quantize_x0: bool = False, + eta: float = 0.0, + mask: Optional[torch.Tensor] = None, + x0: Optional[torch.Tensor] = None, + temperature: float = 1.0, + noise_dropout: float = 0.0, + score_corrector: Optional[callable] = None, + corrector_kwargs: Optional[dict] = None, + verbose: bool = True, + x_T: Optional[torch.Tensor] = None, + log_every_t: int = 100, + unconditional_guidance_scale: float = 1.0, + unconditional_conditioning: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, dict]: + """ + Samples from the model using DDIM sampling. + + Args: + S (int): Number of DDIM steps. + batch_size (int): Batch size. + shape (Tuple[int, int, int]): Shape of the output tensor. + conditioning (Optional[torch.Tensor], optional): Conditioning tensor. Defaults to None. + callback (Optional[callable], optional): Callback function. Defaults to None. + img_callback (Optional[callable], optional): Image callback function. Defaults to None. + quantize_x0 (bool, optional): Whether to quantize the denoised image. Defaults to False. + eta (float, optional): Learning rate for DDIM. Defaults to 0.. + mask (Optional[torch.Tensor], optional): Mask tensor. Defaults to None. + x0 (Optional[torch.Tensor], optional): Initial tensor. Defaults to None. + temperature (float, optional): Sampling temperature. Defaults to 1.. + noise_dropout (float, optional): Noise dropout rate. Defaults to 0.. + score_corrector (Optional[callable], optional): Score corrector function. Defaults to None. + corrector_kwargs (Optional[dict], optional): Keyword arguments for the score corrector function. + Defaults to None. + verbose (bool, optional): Whether to print verbose output. Defaults to True. + x_T (Optional[torch.Tensor], optional): Target tensor. Defaults to None. + log_every_t (int, optional): Log every t steps. Defaults to 100. + unconditional_guidance_scale (float, optional): Scale for unconditional guidance. Defaults to 1.. + unconditional_conditioning (Optional[torch.Tensor], optional): Unconditional conditioning tensor. + Defaults to None. + + Returns: + Tuple[torch.Tensor, dict]: Tuple containing the generated samples and intermediate results. + """ + # check if conditioning is None + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + # make schedule to compute alphas and sigmas + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + + # parameters for sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f"Data shape for DDIM sampling is {size}, eta {eta}") + + # sample from the model using ddim_sampling + samples, intermediates = self.ddim_sampling( + cond=conditioning, + shape=size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + + return samples, intermediates + + def ddim_sampling( + self, + cond: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]], + shape: Tuple[int, int, int], + x_T: Optional[torch.Tensor] = None, + ddim_use_original_steps: bool = False, + callback: Optional[callable] = None, + timesteps: Optional[List[int]] = None, + quantize_denoised: bool = False, + mask: Optional[torch.Tensor] = None, + x0: Optional[torch.Tensor] = None, + img_callback: Optional[callable] = None, + log_every_t: int = 100, + temperature: float = 1.0, + noise_dropout: float = 0.0, + score_corrector: Optional[callable] = None, + corrector_kwargs: Optional[Dict[str, Any]] = None, + unconditional_guidance_scale: float = 1.0, + unconditional_conditioning: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """ + Samples from the model using DDIM sampling. + + Args: + cond (Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]]): Conditioning + tensor. Defaults to None. + shape (Tuple[int, int, int]): Shape of the output tensor. + x_T (Optional[torch.Tensor], optional): Target tensor. Defaults to None. + ddim_use_original_steps (bool, optional): Whether to use original DDIM steps. Defaults to False. + callback (Optional[callable], optional): Callback function. Defaults to None. + timesteps (Optional[List[int]], optional): List of timesteps. Defaults to None. + quantize_denoised (bool, optional): Whether to quantize the denoised image. Defaults to False. + mask (Optional[torch.Tensor], optional): Mask tensor. Defaults to None. + x0 (Optional[torch.Tensor], optional): Initial tensor. Defaults to None. + img_callback (Optional[callable], optional): Image callback function. Defaults to None. + log_every_t (int, optional): Log every t steps. Defaults to 100. + temperature (float, optional): Sampling temperature. Defaults to 1.. + noise_dropout (float, optional): Noise dropout rate. Defaults to 0.. + score_corrector (Optional[callable], optional): Score corrector function. Defaults to None. + corrector_kwargs (Optional[Dict[str, Any]], optional): Keyword arguments for the score corrector + function. Defaults to None. + unconditional_guidance_scale (float, optional): Scale for unconditional guidance. Defaults to 1. + unconditional_conditioning (Optional[torch.Tensor], optional): Unconditional conditioning tensor. + Defaults to None. + + Returns: + Tuple[torch.Tensor, Dict[str, Any]]: Tuple containing the generated samples and intermediate results. + """ + # Get the device and batch size + device = self.model.betas.device + b = shape[0] + + # Initialize the image tensor + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + # Get the timesteps + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + # Initialize the intermediates dictionary + intermediates = {"x_inter": [img], "pred_x0": [img]} + + # Set the time range and total steps + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + # Initialize the progress bar iterator + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) + + # Loop over the timesteps + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + # Sample from the model using DDIM + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + temperature=temperature, + noise_dropout=noise_dropout, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + img, pred_x0 = outs + + + # Append the intermediate results to the intermediates dictionary + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + def p_sample_ddim( + self, + x: torch.Tensor, + c: torch.Tensor, + t: int, + index: int, + repeat_noise: bool = False, + use_original_steps: bool = False, + temperature: float = 1.0 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Samples from the model using DDIM sampling. + + Args: + x (torch.Tensor): Input tensor. + c (torch.Tensor): Conditioning tensor. + t (int): Current timestep. + index (int): Index of the current timestep. + repeat_noise (bool, optional): Whether to repeat noise. Defaults to False. + use_original_steps (bool, optional): Whether to use original DDIM steps. + Defaults to False. + quantize_denoised (bool, optional): Whether to quantize the denoised image. + Defaults to False. + temperature (float, optional): Sampling temperature. Defaults to 1.. + noise_dropout (float, optional): Noise dropout rate. Defaults to 0.. + score_corrector (Optional[callable], optional): Score corrector function. + Defaults to None. + corrector_kwargs (Optional[Dict[str, Any]], optional): Keyword arguments + for the score corrector function. Defaults to None. + unconditional_guidance_scale (float, optional): Scale for unconditional + guidance. Defaults to 1.. + unconditional_conditioning (Optional[torch.Tensor], optional): Unconditional + conditioning tensor. Defaults to None. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing the generated samples and intermediate results. + """ + t = torch.full((x.shape[0],), t, device=x.device, dtype=torch.long) + + # get batch size and device + b, *_, device = *x.shape, x.device + + # apply model with or without unconditional conditioning + e_t = self.model.apply_model(x, t, c) + + # get alphas, alphas_prev, sqrt_one_minus_alphas, and sigmas + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + + return x_prev, pred_x0 + + +def make_beta_schedule( + schedule: str, + n_timestep: int, + linear_start: float = 1e-4, + linear_end: float = 2e-2, + cosine_s: float = 8e-3, +) -> np.ndarray: + """ + Creates a beta schedule for the diffusion process. + + Args: + schedule (str): Type of schedule to use. Can be "linear", "cosine", "sqrt_linear", or "sqrt". + n_timestep (int): Number of timesteps. + linear_start (float, optional): Starting value for linear schedule. Defaults to 1e-4. + linear_end (float, optional): Ending value for linear schedule. Defaults to 2e-2. + cosine_s (float, optional): Scaling factor for cosine schedule. Defaults to 8e-3. + + Returns: + np.ndarray: Array of beta values. + """ + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def extract_into_tensor( + a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int, ...] +) -> torch.Tensor: + """ + Extracts values from a tensor into a new tensor based on indices. + + Args: + a (torch.Tensor): Input tensor. + t (torch.Tensor): Indices tensor. + x_shape (Tuple[int, ...]): Shape of the output tensor. + + Returns: + torch.Tensor: Output tensor. + """ + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +class LitEma(nn.Module): + def __init__( + self, model: nn.Module, decay: float = 0.9999, use_num_upates: bool = True + ) -> None: + """ + Initializes the LitEma class. + + Args: + model (nn.Module): The model to apply EMA to. + decay (float, optional): The decay rate for EMA. Must be between 0 and 1. Defaults to 0.9999. + use_num_upates (bool, optional): Whether to use the number of updates to adjust decay. Defaults to True. + """ + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def forward(self, model: nn.Module) -> None: + """ + Applies EMA to the model. + + Args: + model (nn.Module): The model to apply EMA to. + """ + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with True: + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model: nn.Module) -> None: + """ + Copies the EMA parameters to the model. + + Args: + model (nn.Module): The model to copy the EMA parameters to. + """ + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters: Iterable[nn.Parameter]) -> None: + """ + Saves the current parameters for restoring later. + + Args: + parameters (Iterable[nn.Parameter]): The parameters to be temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters: Iterable[nn.Parameter]) -> None: + """ + Restores the parameters stored with the `store` method. + + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + + Args: + parameters (Iterable[nn.Parameter]): The parameters to be updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/supers2/models/utils.py b/supers2/models/utils.py new file mode 100644 index 0000000..8e21ea7 --- /dev/null +++ b/supers2/models/utils.py @@ -0,0 +1,170 @@ +import torch +from einops import rearrange + + +def linear_transform_4b(t_input,stage="norm"): + assert stage in ["norm","denorm"] + # get the shape of the tensor + shape = t_input.shape + + # if 5 d tensor, norm/denorm individually + if len(shape)==5: + stack = [] + for batch in t_input: + stack2 = [] + for i in range(0, t_input.size(1), 4): + slice_tensor = batch[i:i+4, :, :, :] + slice_denorm = linear_transform_4b(slice_tensor,stage=stage) + stack2.append(slice_denorm) + stack2 = torch.stack(stack2) + stack2 = stack2.reshape(shape[1], shape[2], shape[3], shape[4]) + stack.append(stack2) + stack = torch.stack(stack) + return(stack) + + # here only if len(shape) == 4 + squeeze_needed = False + if len( shape ) == 3: + squeeze_needed = True + t_input = t_input.unsqueeze(0) + shape = t_input.shape + + assert len(shape)==4 or len(shape)==5,"Input tensor must have 4 dimensions (B,C,H,W) - or 5D for MISR" + transpose_needed = False + if shape[-1]>shape[1]: + transpose_needed = True + t_input = rearrange(t_input,"b c h w -> b w h c") + + # define constants + rgb_c = 3. + nir_c = 5. + + # iterate over batches + return_ls = [] + for t in t_input: + if stage == "norm": + # divide according to conventions + t[:,:,0] = t[:,:,0] * (10.0 / rgb_c) # R + t[:,:,1] = t[:,:,1] * (10.0 / rgb_c) # G + t[:,:,2] = t[:,:,2] * (10.0 / rgb_c) # B + t[:,:,3] = t[:,:,3] * (10.0 / nir_c) # NIR + # clamp to get rif of outlier pixels + t = t.clamp(0,1) + # bring to -1..+1 + t = (t*2)-1 + if stage == "denorm": + # bring to 0..1 + t = (t+1)/2 + # divide according to conventions + t[:,:,0] = t[:,:,0] * (rgb_c / 10.0) # R + t[:,:,1] = t[:,:,1] * (rgb_c / 10.0) # G + t[:,:,2] = t[:,:,2] * (rgb_c / 10.0) # B + t[:,:,3] = t[:,:,3] * (nir_c / 10.0) # NIR + # clamp to get rif of outlier pixels + t = t.clamp(0,1) + + # append result to list + return_ls.append(t) + + # after loop, stack image + t_output = torch.stack(return_ls) + #print("stacked",t_output.shape) + + if transpose_needed==True: + t_output = rearrange(t_output,"b w h c -> b c h w") + if squeeze_needed: + t_output = t_output.squeeze(0) + + return(t_output) + + +def linear_transform_6b(t_input,stage="norm"): + # iterate over batches + assert stage in ["norm","denorm"] + bands_c = 5. + return_ls = [] + clamp = False + for t in t_input: + if stage == "norm": + # divide according to conventions + t[:,:,0] = t[:,:,0] * (10.0 / bands_c) + t[:,:,1] = t[:,:,1] * (10.0 / bands_c) + t[:,:,2] = t[:,:,2] * (10.0 / bands_c) + t[:,:,3] = t[:,:,3] * (10.0 / bands_c) + t[:,:,4] = t[:,:,4] * (10.0 / bands_c) + t[:,:,5] = t[:,:,5] * (10.0 / bands_c) + # clamp to get rif of outlier pixels + if clamp: + t = t.clamp(0,1) + # bring to -1..+1 + t = (t*2)-1 + if stage == "denorm": + # bring to 0..1 + t = (t+1)/2 + # divide according to conventions + t[:,:,0] = t[:,:,0] * (bands_c / 10.0) + t[:,:,1] = t[:,:,1] * (bands_c / 10.0) + t[:,:,2] = t[:,:,2] * (bands_c / 10.0) + t[:,:,3] = t[:,:,3] * (bands_c / 10.0) + t[:,:,4] = t[:,:,4] * (bands_c / 10.0) + t[:,:,5] = t[:,:,5] * (bands_c / 10.0) + # clamp to get rif of outlier pixels + if clamp: + t = t.clamp(0,1) + + # append result to list + return_ls.append(t) + + # after loop, stack image + t_output = torch.stack(return_ls) + + return t_output + +def assert_tensor_validity(tensor): + + # ASSERT BATCH DIMENSION + # if unbatched, add batch dimension + if len(tensor.shape)==3: + tensor = tensor.unsqueeze(0) + + # ASSERT BxCxHxW ORDER + # Check the size of the input tensor + if tensor.shape[-1]<10: + tensor = rearrange(tensor,"b w h c -> b c h w") + + + height, width = tensor.shape[-2],tensor.shape[-1] + # Calculate how much padding is needed for height and width + if height < 128 or width < 128: + pad_height = max(0, 128 - height) # Amount to pad on height + pad_width = max(0, 128 - width) # Amount to pad on width + + # Padding for height and width needs to be added to both sides of the dimension + # The pad has the format (left, right, top, bottom) + padding = (pad_width // 2, pad_width - pad_width // 2, pad_height // 2, pad_height - pad_height // 2) + padding = padding + + # Apply symmetric padding + tensor = torch.nn.functional.pad(tensor, padding, mode='reflect') + + else: # save padding with 0s + padding = (0,0,0,0) + padding = padding + + return tensor,padding + + + +def revert_padding(tensor,padding): + left, right, top, bottom = padding + # account for 4x upsampling Factor + left, right, top, bottom = left*4, right*4, top*4, bottom*4 + # Calculate the indices to slice from the padded tensor + start_height = top + end_height = tensor.size(-2) - bottom + start_width = left + end_width = tensor.size(-1) - right + + # Slice the tensor to remove padding + unpadded_tensor = tensor[:,:, start_height:end_height, start_width:end_width] + return unpadded_tensor \ No newline at end of file diff --git a/supers2/setup.py b/supers2/setup.py index 8c3c0b9..4324df4 100644 --- a/supers2/setup.py +++ b/supers2/setup.py @@ -46,6 +46,9 @@ def get_model_name( def load_model_parameters(model_name: str, model_size: str): + if model_name == "diffusion": + return {} + # Dictionary mapping model names and sizes to corresponding functions/classes model_mapping = { "cnn": { @@ -98,29 +101,33 @@ def load_model_parameters(model_name: str, model_size: str): return model_params -def load_model(model_name: str, model_params: dict): - # Diccionario que mapea nombres de modelo a sus respectivos módulos y clases +def load_model(model_name: str, model_params: dict, device: str = "cpu", **kwargs): + # Dictionary mapping model names to corresponding modules and classes model_mapping = { "cnn_legacy": ("supers2.models.cnn_legacy", "CNNSR_legacy"), "cnn": ("supers2.models.cnn", "CNNSR"), "swin": ("supers2.models.swin", "Swin2SR"), "mamba": ("supers2.models.mamba", "MambaSR"), + "diffusion": ("supers2.models.diffusion", "SRLatentDiffusion"), } - # Verificar si el modelo existe en el mapeo + # Check if the model name is valid if model_name not in model_mapping: raise ValueError(f"Model '{model_name}' not found") - # Obtener el módulo y clase del modelo + # Get the module and class names module_name, class_name = model_mapping[model_name] - # Importar el módulo y obtener la clase + # Load the module and class model_module = importlib.import_module(module_name) model_module = importlib.reload(model_module) model_class = getattr(model_module, class_name) - # Instanciar el modelo con los parámetros dados - return model_class(**model_params) + # Instantiate the model + if model_name == "diffusion": + return model_class(device=device, **kwargs) + else: + return model_class(**model_params) def load_fusionx2_model( @@ -236,6 +243,8 @@ def load_srx4_model( model_size: str, model_loss: str, weights_path: Union[str, pathlib.Path], + device: str = "cpu", + **kwargs ): # Get the model snippet @@ -248,9 +257,9 @@ def load_srx4_model( ) # Load the weights - weights_data = torch.load(model_snippet, map_location=torch.device("cpu")) + weights_data = torch.load(model_snippet, map_location=device) - # remove hard_constraint + # remove hard_constraint (if exists) # TODO remove for key in list(weights_data.keys()): if "hard_constraint" in key: weights_data.pop(key) @@ -262,10 +271,10 @@ def load_srx4_model( model_params["upscale"] = 4 # Load the model - FusionX4 = load_model(model_name, model_params) - FusionX4.load_state_dict(weights_data) - FusionX4.eval() - for param in FusionX4.parameters(): + SRX4 = load_model(model_name, model_params, kwargs) + SRX4.load_state_dict(weights_data) + SRX4.eval() + for param in SRX4.parameters(): param.requires_grad = False # Define the Hard Constraint @@ -281,4 +290,4 @@ def load_srx4_model( param.requires_grad = False # Apply Model then hard constraint - return CustomModel(SRmodel=FusionX4, HardConstraint=hard_constraint) + return CustomModel(SRmodel=SRX4, HardConstraint=hard_constraint)