From 1051965b6adcefe5c4c7e4a5ed6b77a1a6fde903 Mon Sep 17 00:00:00 2001 From: Martin Nicolas Everaert <55559086+martin-ev@users.noreply.github.com> Date: Thu, 29 Feb 2024 15:28:50 +0100 Subject: [PATCH] Add files via upload --- colab.ipynb | 2308 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 2308 insertions(+) create mode 100644 colab.ipynb diff --git a/colab.ipynb b/colab.ipynb new file mode 100644 index 0000000..7249a93 --- /dev/null +++ b/colab.ipynb @@ -0,0 +1,2308 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ec8409df-1ed5-40a2-b594-268b6d331c53", + "metadata": { + "id": "ec8409df-1ed5-40a2-b594-268b6d331c53" + }, + "source": [ + "# Exploiting the Signal-Leak Bias in Diffusion Models\n", + "\n", + "[![arXiv](https://img.shields.io/badge/arXiv-2309.15842-red)](https://arxiv.org/abs/2309.15842)\n", + "[![Project Page](https://img.shields.io/badge/Project%20Page-IVRL-blue)](https://ivrl.github.io/signal-leak-bias/)\n", + "[![Proceedings](https://img.shields.io/badge/WACV%20Proceedings-CVF-blue)](https://openaccess.thecvf.com/content/WACV2024/html/Everaert_Exploiting_the_Signal-Leak_Bias_in_Diffusion_Models_WACV_2024_paper.html)\n", + "\n", + "\n", + "## Overview\n", + "\n", + "This is the Colab version of the official implementation for our paper titled \"[Exploiting the Signal-Leak Bias in Diffusion Models](https://ivrl.github.io/signal-leak-bias/)\", presented at [WACV 2024](https://openaccess.thecvf.com/content/WACV2024/html/Everaert_Exploiting_the_Signal-Leak_Bias_in_Diffusion_Models_WACV_2024_paper.html) 🔥\n", + "\n", + "Code is available at [https://github.com/IVRL/signal-leak-bias](https://github.com/IVRL/signal-leak-bias).\n", + "\n", + "### 🔎 Research Highlights\n", + "- In the training of most diffusion models, data are never completely noised, creating a signal leakage and leading to discrepancies between training and inference processes.\n", + "- As a consequence of this signal leakage, the low-frequency / large-scale content of the generated images is mostly unchanged from the initial latents we start the generation process from, generating greyish images or images that do not match the desired style.\n", + "- Our research proposed to exploit this signal-leak bias at inference time to gain more control over generated images.\n", + "- We model the distribution of the signal leak present during training, to include a signal leak at inference time in the initial latents.\n", + "- ✨✨ No training required! ✨✨\n", + "\n", + "### 📃 [Exploiting the Signal-Leak Bias in Diffusion Models](https://ivrl.github.io/signal-leak-bias/)\n", + "\n", + "[Martin Nicolas Everaert](https://martin-ev.github.io/) 1, [Athanasios Fitsios](https://www.linkedin.com/in/athanasiosfitsios/) 1,2, [Marco Bocchio](https://scholar.google.com/citations?user=KDiTxBQAAAAJ) 2, [Sami Arpa](https://scholar.google.com/citations?user=84FopNgAAAAJ) 2, [Sabine Süsstrunk](https://scholar.google.com/citations?user=EX3OYP4AAAAJ) 1, [Radhakrishna Achanta](https://scholar.google.com/citations?user=lc2HaZwAAAAJ) 1\n", + "\n", + "1[School of Computer and Communication Sciences, EPFL, Switzerland](https://www.epfl.ch/labs/ivrl/) ; 2[Largo.ai, Lausanne, Switzerland](https://home.largo.ai/)\n", + "\n", + "**Abstract**: There is a bias in the inference pipeline of most diffusion models. This bias arises from a signal leak whose distribution deviates from the noise distribution, creating a discrepancy between training and inference processes. We demonstrate that this signal-leak bias is particularly significant when models are tuned to a specific style, causing sub-optimal style matching. Recent research tries to avoid the signal leakage during training. We instead show how we can exploit this signal-leak bias in existing diffusion models to allow more control over the generated images. This enables us to generate images with more varied brightness, and images that better match a desired style or color. By modeling the distribution of the signal leak in the spatial frequency and pixel domains, and including a signal leak in the initial latent, we generate images that better match expected results without any additional training.\n", + "\n", + "\n", + "# License\n", + "\n", + "The implementation here is provided solely as part of the research publication \"[Exploiting the Signal-Leak Bias in Diffusion Models](https://ivrl.github.io/signal-leak-bias/)\", only for academic non-commercial usage. Details can be found in the [LICENSE file](https://github.com/ivrl/signal-leak-bias/blob/main/LICENSE). If the License is not suitable for your business or project, please contact Largo.ai (info@largo.ai) and EPFL-TTO (info.tto@epfl.ch) for a full commercial license.\n", + "\n", + "\n", + "# Citation\n", + "\n", + "Please cite the paper as follows:\n", + "\n", + "```\n", + "@InProceedings{Everaert_2024_WACV,\n", + " author = {Everaert, Martin Nicolas and Fitsios, Athanasios and Bocchio, Marco and Arpa, Sami and Süsstrunk, Sabine and Achanta, Radhakrishna},\n", + " title = {{E}xploiting the {S}ignal-{L}eak {B}ias in {D}iffusion {M}odels},\n", + " booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},\n", + " month = {January},\n", + " year = {2024},\n", + " pages = {4025-4034}\n", + "}\n", + "```\n", + "\n", + "\n", + "## Getting Started\n", + "\n", + "### Code and development environment\n", + "\n", + "Our code mainly builds on top of the code of the [🤗 Diffusers](https://huggingface.co/docs/diffusers/index) library.\n", + "\n", + "Clone this repository:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "deb44fb5-366d-40a7-ab21-1898ae139f09", + "metadata": { + "id": "deb44fb5-366d-40a7-ab21-1898ae139f09" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/IVRL/signal-leak-bias\n", + "%cd signal-leak-bias/src" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0BgeSbrWZ38", + "metadata": { + "id": "d0BgeSbrWZ38" + }, + "outputs": [], + "source": [ + "!rm -r examples #We will regenerate all the examples in this notebook :-)" + ] + }, + { + "cell_type": "markdown", + "id": "3c30d4be-88de-4296-b633-d4b983ad07a4", + "metadata": { + "id": "3c30d4be-88de-4296-b633-d4b983ad07a4" + }, + "source": [ + "Run the following command to install our dependencies:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20296877-e45e-4f61-b61d-495786484472", + "metadata": { + "id": "20296877-e45e-4f61-b61d-495786484472" + }, + "outputs": [], + "source": [ + "!pip install diffusers==0.25.1\n", + "!pip install accelerate==0.26.1" + ] + }, + { + "cell_type": "markdown", + "id": "b9c99c07-a896-4c94-8136-7799c347fb94", + "metadata": { + "id": "b9c99c07-a896-4c94-8136-7799c347fb94" + }, + "source": [ + "Run the following command to download some images for the examples:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3aaa9acf-8814-4679-9205-b520f8f49efa", + "metadata": { + "id": "3aaa9acf-8814-4679-9205-b520f8f49efa" + }, + "outputs": [], + "source": [ + "!GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/sd-dreambooth-library/nasa-space-v2-768\n", + "!git clone https://huggingface.co/sd-concepts-library/line-art\n", + "!wget https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip\n", + "!unzip -q coco128" + ] + }, + { + "cell_type": "markdown", + "id": "8a309f28-ab66-4b4c-b5fd-067157f5f223", + "metadata": { + "id": "8a309f28-ab66-4b4c-b5fd-067157f5f223" + }, + "source": [ + "### Computing statistics of the signal leak\n", + "\n", + "The provided Python file for computing statistics of the signal leak can be used, for example, as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23c08437-ae69-4a37-81f5-0ef81ad38358", + "metadata": { + "id": "23c08437-ae69-4a37-81f5-0ef81ad38358" + }, + "outputs": [], + "source": [ + "!python signal_leak.py \\\n", + " --pretrained_model_name_or_path=\"stabilityai/stable-diffusion-2-1\" \\\n", + " --data_dir=\"coco128/images/train2017\" \\\n", + " --output_dir=\"examples/C\" \\\n", + " --resolution=768 \\\n", + " --n_components=3 \\\n", + " --statistic_type=\"dct+pixel\" \\\n", + " --center_crop" + ] + }, + { + "cell_type": "markdown", + "id": "7TQ273KnT2tn", + "metadata": { + "id": "7TQ273KnT2tn" + }, + "source": [ + "### Inference\n", + "\n", + "Once the statistics have been computed, you can use them to sample a signal-leak at inference time too, for instance as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "GzDMxPxNT35P", + "metadata": { + "id": "GzDMxPxNT35P" + }, + "outputs": [], + "source": [ + "from signal_leak import sample_from_stats\n", + "\n", + "signal_leak = sample_from_stats(path=\"examples/C\")" + ] + }, + { + "cell_type": "markdown", + "id": "wg5j28Gigtk5", + "metadata": { + "id": "wg5j28Gigtk5" + }, + "source": [ + "Images can be generated with the sampled signal-leak in the initial latents, for instance as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1B-gCJNogv7L", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 849, + "referenced_widgets": [ + "edeb82954da6454487447c3cded28fb5", + "b36394c79ecd499ea6f3542c1d8400a2", + "b563ef9c5b254b1b85374852cc1ca6f1", + "cdca892715d34db8801220591ff3069c", + "cf25de4cb76841dcb6d8ed58ebcd05bc", + "b688152435e44628b3397b07bfc11096", + "20f9fc4c61c94b7396d4a386a3c27ddb", + "c775f4e631ac40338be6e0e2322868f4", + "a5bcc403252c4965af6634c6778f9855", + "b8a9580ba1004c9390d6cc3a2263a9ac", + "8d26c64943534f858d12d4b009a0a6ab", + "55da9af2bc2846148251ba18b8503138", + "a5bb486b87684aad814574ea9ca145d7", + "e90aef66217f495082645bbc67d386a6", + "1bfd35792f5343de86238de0e807ba01", + "dd1644de2241479b8b974375a9e098e2", + "9e216a39076047508ba3a00d5a427560", + "81c7387460a244fb8625863158ba0103", + "9cba2a19d1854d289f7b74c2d147cf05", + "4e41df9b1027494f81d5f0449c6d9027", + "9ec248e2db2b416fb0a9db0fca6228bb", + "b07b1a7b188d4aa584bcdcb6b293a3c2" + ] + }, + "id": "1B-gCJNogv7L", + "outputId": "fcbd08bf-182c-4005-dac3-604975fbcea7" + }, + "outputs": [], + "source": [ + "from diffusers import StableDiffusionPipeline\n", + "import torch\n", + "\n", + "pipeline = StableDiffusionPipeline.from_pretrained(\"stabilityai/stable-diffusion-2-1\").to(\"cuda\")\n", + "num_inference_steps = 50\n", + "\n", + "# Get the timestep T of the first reverse diffusion iteration\n", + "pipeline.scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n", + "first_inference_timestep = pipeline.scheduler.timesteps[0].item()\n", + "\n", + "# Get the values of sqrt(alpha_prod_T)\n", + "sqrt_alpha_prod = pipeline.scheduler.alphas_cumprod[first_inference_timestep] ** 0.5\n", + "sqrt_one_minus_alpha_prod = (1 - pipeline.scheduler.alphas_cumprod[first_inference_timestep]) ** 0.5\n", + "\n", + "# Generate the initial latents, without signal leak\n", + "latents = torch.randn([1, 4, 96, 96])\n", + "\n", + "# Add a signal leak in the initial latents\n", + "latents = sqrt_alpha_prod * signal_leak + sqrt_one_minus_alpha_prod * latents\n", + "\n", + "# Generate image\n", + "image = pipeline(\n", + " prompt = \"An astronaut riding a horse\",\n", + " num_inference_steps = num_inference_steps,\n", + " latents = latents,\n", + ").images[0]\n", + "display(image)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "LOwgNEAPhh2r", + "metadata": { + "id": "LOwgNEAPhh2r" + }, + "outputs": [], + "source": [ + "del image, pipeline, latents, signal_leak, first_inference_timestep, num_inference_steps, sqrt_alpha_prod, sqrt_one_minus_alpha_prod\n", + "import torch\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "markdown", + "id": "pWc8tPnW7jib", + "metadata": { + "id": "pWc8tPnW7jib" + }, + "source": [ + "## Examples" + ] + }, + { + "cell_type": "markdown", + "id": "wzulJYgvg0ft", + "metadata": { + "id": "wzulJYgvg0ft" + }, + "source": [ + "\n", + "\n", + "### Improving style-tuned models\n", + "\n", + "Models tuned on specific styles often produce results that do not match the styles well (see the second column of the next two tables). We argue that this is because of a discrepancy between training (contains a signal leak whose distribution differs from unit/standard multivariate Gaussian) and inference (no signal leak). We fix this discrepancy by modelling the signal leak present during training and including a signal leak (see third column) at inference time too. We use a \"pixel\" model, that is we estimate the mean and variance of each pixel (spatial elements of the latent encodings).\n", + "\n", + "In the 2 following examples, we show how to fix two models:\n", + "- [sd-dreambooth-library/nasa-space-v2-768](https://huggingface.co/sd-dreambooth-library/nasa-space-v2-768) is a model tuned with [DreamBooth](https://huggingface.co/docs/diffusers/en/training/dreambooth) (Ruiz et al., 2022) on 24 images of the sky.\n", + "- [sd-concepts-library/line-art](https://huggingface.co/sd-concepts-library/line-art) is an embedding for [Stable Diffusion v1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4) trained with [Textual Inversion](https://huggingface.co/docs/diffusers/training/text_inversion) (Gal et al, 2022) on 7 images with line-art style.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "iYNxBb-bg626", + "metadata": { + "id": "iYNxBb-bg626" + }, + "source": [ + "#### Example 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "QZWAMe6XJV6K", + "metadata": { + "id": "QZWAMe6XJV6K" + }, + "outputs": [], + "source": [ + "!python signal_leak.py \\\n", + " --pretrained_model_name_or_path=\"stabilityai/stable-diffusion-2-1\" \\\n", + " --data_dir=\"nasa-space-v2-768/concept_images\" \\\n", + " --output_dir=\"examples/A1/\" \\\n", + " --resolution=768 \\\n", + " --statistic_type=\"pixel\" \\\n", + " --center_crop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "gQcITNs7hD1x", + "metadata": { + "id": "gQcITNs7hD1x" + }, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "from diffusers import StableDiffusionPipeline\n", + "from signal_leak import sample_from_stats\n", + "\n", + "folder = \"examples/A1/imgs\"\n", + "path_stats = \"examples/A1\"\n", + "\n", + "os.makedirs(folder, exist_ok=True)\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "# Load the pipeline\n", + "pipeline = StableDiffusionPipeline.from_pretrained(\n", + " \"sd-dreambooth-library/nasa-space-v2-768\",\n", + ").to(device)\n", + "num_inference_steps = 50\n", + "\n", + "# Get the timestep T of the first reverse diffusion iteration\n", + "pipeline.scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n", + "first_inference_timestep = pipeline.scheduler.timesteps[0].item()\n", + "\n", + "# Get the values of sqrt(alpha_prod_T)\n", + "sqrt_alpha_prod = pipeline.scheduler.alphas_cumprod[first_inference_timestep] ** 0.5\n", + "sqrt_one_minus_alpha_prod = (1 - pipeline.scheduler.alphas_cumprod[first_inference_timestep]) ** 0.5\n", + "\n", + "# Dimensions of the latent space, with batch_size=1\n", + "shape_latents = [\n", + " 1,\n", + " pipeline.unet.config.in_channels,\n", + " pipeline.unet.config.sample_size,\n", + " pipeline.unet.config.sample_size,\n", + "]\n", + "\n", + "# Utility function to visualize initial latents / signal leak\n", + "def latents_to_pil(pipeline, latents, generator):\n", + " decoded = pipeline.vae.decode(\n", + " latents / pipeline.vae.config.scaling_factor,\n", + " return_dict=False,\n", + " generator=generator,\n", + " )[0]\n", + " image = pipeline.image_processor.postprocess(\n", + " decoded,\n", + " output_type=\"pil\",\n", + " do_denormalize=[True],\n", + " )[0]\n", + " return image\n", + "\n", + "# Random number generator\n", + "generator = torch.Generator(device=device)\n", + "generator = generator.manual_seed(12345)\n", + "\n", + "with torch.no_grad():\n", + " for n in range(5):\n", + "\n", + " # Generate the initial latents\n", + " initial_latents = torch.randn(\n", + " shape_latents, generator=generator, device=device, dtype=torch.float32\n", + " )\n", + " latents_to_pil(pipeline, initial_latents, generator).save(f\"{folder}/latents{n}.png\")\n", + "\n", + "\n", + " # Generate an image WITHOUT signal leak in the initial latents\n", + " image = pipeline(\n", + " prompt=\"A very dark picture of the sky, Nasa style\",\n", + " guidance_scale=1,\n", + " num_inference_steps=num_inference_steps,\n", + " latents=initial_latents,\n", + " ).images[0]\n", + " image.save(f\"{folder}/original{n}.png\")\n", + "\n", + " # Generate a signal leak from computed statistics\n", + " signal_leak = sample_from_stats(\n", + " path=path_stats,\n", + " dims=shape_latents,\n", + " generator_pt=generator,\n", + " generator_np=None,\n", + " device=device\n", + " )\n", + " latents_to_pil(pipeline, signal_leak, generator).save(f\"{folder}/signal_leak{n}.png\")\n", + "\n", + " # Add a signal leak in the initial latents\n", + " initial_latents = sqrt_alpha_prod * signal_leak + sqrt_one_minus_alpha_prod * initial_latents\n", + "\n", + " # Generate an image WITH signal leak in the initial latents\n", + " image_with_signalleak = pipeline(\n", + " prompt=\"A very dark picture of the sky, Nasa style\",\n", + " guidance_scale=1,\n", + " num_inference_steps=num_inference_steps,\n", + " latents=initial_latents,\n", + " ).images[0]\n", + " image_with_signalleak.save(f\"{folder}/ours{n}.png\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ekVdj9uqiWNh", + "metadata": { + "id": "ekVdj9uqiWNh" + }, + "outputs": [], + "source": [ + "del device, first_inference_timestep, generator, num_inference_steps, image, image_with_signalleak, initial_latents, pipeline, signal_leak, sqrt_alpha_prod, sqrt_one_minus_alpha_prod\n", + "import torch\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7sB9fzXVRJP0", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "7sB9fzXVRJP0", + "outputId": "55984400-7b8a-49d6-fd14-d28c904673c1" + }, + "outputs": [], + "source": [ + "from IPython.display import display, HTML\n", + "from PIL import Image\n", + "from io import BytesIO\n", + "import base64\n", + "\n", + "def load_images():\n", + " images = []\n", + " for i in range(5):\n", + " latents_path = f'examples/A1/imgs/latents{i}.png'\n", + " original_path = f'examples/A1/imgs/original{i}.png'\n", + " signal_leak_path = f'examples/A1/imgs/signal_leak{i}.png'\n", + " ours_path = f'examples/A1/imgs/ours{i}.png'\n", + "\n", + " latents_image = Image.open(latents_path)\n", + " original_image = Image.open(original_path)\n", + " signal_leak_image = Image.open(signal_leak_path)\n", + " ours_image = Image.open(ours_path)\n", + "\n", + " images.append([\n", + " latents_image, original_image, signal_leak_image, ours_image\n", + " ])\n", + "\n", + " return images\n", + "\n", + "# Load images\n", + "loaded_images = load_images()\n", + "\n", + "# Function to generate HTML code for the table\n", + "def generate_table(data, headers):\n", + " table_code = \"\"\n", + "\n", + " # Add headers\n", + " for header in headers:\n", + " table_code += \"\"\n", + " table_code += \"\"\n", + "\n", + " # Add data rows\n", + " for row in data:\n", + " table_code += \"\"\n", + " for cell in row:\n", + " img_data = BytesIO()\n", + " cell.save(img_data, format=\"PNG\")\n", + " img_data.seek(0)\n", + " img_data_uri = f\"data:image/png;base64,{base64.b64encode(img_data.read()).decode()}\"\n", + " table_code += f''\n", + " table_code += \"\"\n", + " table_code += \"
\" + header + \"
\"\n", + "\n", + " return table_code\n", + "\n", + "# Data for the table\n", + "table_headers = [\"Initial latents\", \"Generated image (original)\", \"+ Signal Leak\", \"Generated image (ours)\"]\n", + "\n", + "table_data = []\n", + "for i in range(5):\n", + " table_data.append([\n", + " loaded_images[i][0], loaded_images[i][1], loaded_images[i][2], loaded_images[i][3]\n", + " ])\n", + "\n", + "# Display the table\n", + "display(HTML(generate_table(table_data, table_headers)))" + ] + }, + { + "cell_type": "markdown", + "id": "puimEnuzhQZm", + "metadata": { + "id": "puimEnuzhQZm" + }, + "source": [ + "#### Example 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "O6JJRZNINfUY", + "metadata": { + "id": "O6JJRZNINfUY" + }, + "outputs": [], + "source": [ + "!python signal_leak.py \\\n", + " --pretrained_model_name_or_path=\"CompVis/stable-diffusion-v1-4\" \\\n", + " --data_dir=\"line-art/concept_images\" \\\n", + " --output_dir=\"examples/A2/\" \\\n", + " --resolution=512 \\\n", + " --statistic_type=\"pixel\" \\\n", + " --center_crop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "CxLf2ojfhgqV", + "metadata": { + "id": "CxLf2ojfhgqV" + }, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "from diffusers import StableDiffusionPipeline\n", + "from signal_leak import sample_from_stats\n", + "\n", + "folder = \"examples/A2/imgs\"\n", + "path_stats = \"examples/A2\"\n", + "\n", + "os.makedirs(folder, exist_ok=True)\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "# Load the pipeline\n", + "pipeline = StableDiffusionPipeline.from_pretrained(\n", + " \"CompVis/stable-diffusion-v1-4\",\n", + ").to(device)\n", + "pipeline.load_textual_inversion(\n", + " \"sd-concepts-library/line-art\",\n", + ")\n", + "num_inference_steps = 50\n", + "\n", + "# Get the timestep T of the first reverse diffusion iteration\n", + "pipeline.scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n", + "first_inference_timestep = pipeline.scheduler.timesteps[0].item()\n", + "\n", + "# Get the values of sqrt(alpha_prod_T)\n", + "sqrt_alpha_prod = pipeline.scheduler.alphas_cumprod[first_inference_timestep] ** 0.5\n", + "sqrt_one_minus_alpha_prod = (1 - pipeline.scheduler.alphas_cumprod[first_inference_timestep]) ** 0.5\n", + "\n", + "# Dimensions of the latent space, with batch_size=1\n", + "shape_latents = [\n", + " 1,\n", + " pipeline.unet.config.in_channels,\n", + " pipeline.unet.config.sample_size,\n", + " pipeline.unet.config.sample_size,\n", + "]\n", + "\n", + "# Utility function to visualize initial latents / signal leak\n", + "def latents_to_pil(pipeline, latents, generator):\n", + " decoded = pipeline.vae.decode(\n", + " latents / pipeline.vae.config.scaling_factor,\n", + " return_dict=False,\n", + " generator=generator,\n", + " )[0]\n", + " image = pipeline.image_processor.postprocess(\n", + " decoded,\n", + " output_type=\"pil\",\n", + " do_denormalize=[True],\n", + " )[0]\n", + " return image\n", + "\n", + "# Random number generator\n", + "generator = torch.Generator(device=device)\n", + "generator = generator.manual_seed(12345)\n", + "\n", + "with torch.no_grad():\n", + " for n in range(5):\n", + "\n", + " # Generate the initial latents\n", + " initial_latents = torch.randn(\n", + " shape_latents, generator=generator, device=device, dtype=torch.float32\n", + " )\n", + " latents_to_pil(pipeline, initial_latents, generator).save(f\"{folder}/latents{n}.png\")\n", + "\n", + "\n", + " # Generate an image WITHOUT signal leak in the initial latents\n", + " image = pipeline(\n", + " prompt=\"An astronaut riding a horse in the style of \",\n", + " num_inference_steps=num_inference_steps,\n", + " latents=initial_latents,\n", + " ).images[0]\n", + " image.save(f\"{folder}/original{n}.png\")\n", + "\n", + " # Generate a signal leak from computed statistics\n", + " signal_leak = sample_from_stats(\n", + " path=path_stats,\n", + " dims=shape_latents,\n", + " generator_pt=generator,\n", + " generator_np=None,\n", + " device=device\n", + " )\n", + " latents_to_pil(pipeline, signal_leak, generator).save(f\"{folder}/signal_leak{n}.png\")\n", + "\n", + " # Add a signal leak in the initial latents\n", + " initial_latents = sqrt_alpha_prod * signal_leak + sqrt_one_minus_alpha_prod * initial_latents\n", + "\n", + " # Generate an image WITH signal leak in the initial latents\n", + " image_with_signalleak = pipeline(\n", + " prompt=\"An astronaut riding a horse in the style of \",\n", + " num_inference_steps=num_inference_steps,\n", + " latents=initial_latents,\n", + " ).images[0]\n", + " image_with_signalleak.save(f\"{folder}/ours{n}.png\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "jK3nv76-lKJb", + "metadata": { + "id": "jK3nv76-lKJb" + }, + "outputs": [], + "source": [ + "del device, first_inference_timestep, generator, num_inference_steps, image, image_with_signalleak, initial_latents, pipeline, signal_leak, sqrt_alpha_prod, sqrt_one_minus_alpha_prod\n", + "import torch\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "MRH0TdAbS-vo", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "MRH0TdAbS-vo", + "outputId": "e2cc8e25-f007-4196-8ac3-c4c4eef09741" + }, + "outputs": [], + "source": [ + "from IPython.display import display, HTML\n", + "from PIL import Image\n", + "from io import BytesIO\n", + "import base64\n", + "\n", + "def load_images():\n", + " images = []\n", + " for i in range(5):\n", + " latents_path = f'examples/A2/imgs/latents{i}.png'\n", + " original_path = f'examples/A2/imgs/original{i}.png'\n", + " signal_leak_path = f'examples/A2/imgs/signal_leak{i}.png'\n", + " ours_path = f'examples/A2/imgs/ours{i}.png'\n", + "\n", + " latents_image = Image.open(latents_path)\n", + " original_image = Image.open(original_path)\n", + " signal_leak_image = Image.open(signal_leak_path)\n", + " ours_image = Image.open(ours_path)\n", + "\n", + " images.append([\n", + " latents_image, original_image, signal_leak_image, ours_image\n", + " ])\n", + "\n", + " return images\n", + "\n", + "# Load images\n", + "loaded_images = load_images()\n", + "\n", + "# Function to generate HTML code for the table\n", + "def generate_table(data, headers):\n", + " table_code = \"\"\n", + "\n", + " # Add headers\n", + " for header in headers:\n", + " table_code += \"\"\n", + " table_code += \"\"\n", + "\n", + " # Add data rows\n", + " for row in data:\n", + " table_code += \"\"\n", + " for cell in row:\n", + " img_data = BytesIO()\n", + " cell.save(img_data, format=\"PNG\")\n", + " img_data.seek(0)\n", + " img_data_uri = f\"data:image/png;base64,{base64.b64encode(img_data.read()).decode()}\"\n", + " table_code += f''\n", + " table_code += \"\"\n", + " table_code += \"
\" + header + \"
\"\n", + "\n", + " return table_code\n", + "\n", + "# Data for the table\n", + "table_headers = [\"Initial latents\", \"Generated image (original)\", \"+ Signal Leak\", \"Generated image (ours)\"]\n", + "\n", + "table_data = []\n", + "for i in range(5):\n", + " table_data.append([\n", + " loaded_images[i][0], loaded_images[i][1], loaded_images[i][2], loaded_images[i][3]\n", + " ])\n", + "\n", + "# Display the table\n", + "display(HTML(generate_table(table_data, table_headers)))" + ] + }, + { + "cell_type": "markdown", + "id": "lT0c1woYhpTP", + "metadata": { + "id": "lT0c1woYhpTP" + }, + "source": [ + "\n", + "### Training-free style adaptation of Stable Diffusion\n", + "\n", + "The same approach as the previous example can be used directly in the base diffusion model, instead of the model finetuned on a style. That is, we include a signal leak at inference time to bias the image generation towards the desired style.\n", + "\n", + "Without our approach (see second column of the next two tables), the prompt alone is not sufficient enough to generate picture of the desired style. Complementing it with a signal leak of the style (third column) generates images (last column) that better match the desired output.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "bbzhCr3niGIa", + "metadata": { + "id": "bbzhCr3niGIa" + }, + "source": [ + "#### Example 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "zxl1Tc2AhotZ", + "metadata": { + "id": "zxl1Tc2AhotZ" + }, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "from diffusers import StableDiffusionPipeline\n", + "from signal_leak import sample_from_stats\n", + "\n", + "folder = \"examples/B1/imgs\"\n", + "path_stats = \"examples/A1\"\n", + "\n", + "os.makedirs(folder, exist_ok=True)\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "# Load the pipeline\n", + "pipeline = StableDiffusionPipeline.from_pretrained(\n", + " \"stabilityai/stable-diffusion-2-1\",\n", + ").to(device)\n", + "num_inference_steps = 50\n", + "\n", + "# Get the timestep T of the first reverse diffusion iteration\n", + "pipeline.scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n", + "first_inference_timestep = pipeline.scheduler.timesteps[0].item()\n", + "\n", + "# Get the values of sqrt(alpha_prod_T)\n", + "sqrt_alpha_prod = pipeline.scheduler.alphas_cumprod[first_inference_timestep] ** 0.5\n", + "sqrt_one_minus_alpha_prod = (1 - pipeline.scheduler.alphas_cumprod[first_inference_timestep]) ** 0.5\n", + "\n", + "# Dimensions of the latent space, with batch_size=1\n", + "shape_latents = [\n", + " 1,\n", + " pipeline.unet.config.in_channels,\n", + " pipeline.unet.config.sample_size,\n", + " pipeline.unet.config.sample_size,\n", + "]\n", + "\n", + "# Utility function to visualize initial latents / signal leak\n", + "def latents_to_pil(pipeline, latents, generator):\n", + " decoded = pipeline.vae.decode(\n", + " latents / pipeline.vae.config.scaling_factor,\n", + " return_dict=False,\n", + " generator=generator,\n", + " )[0]\n", + " image = pipeline.image_processor.postprocess(\n", + " decoded,\n", + " output_type=\"pil\",\n", + " do_denormalize=[True],\n", + " )[0]\n", + " return image\n", + "\n", + "# Random number generator\n", + "generator = torch.Generator(device=device)\n", + "generator = generator.manual_seed(12345)\n", + "\n", + "with torch.no_grad():\n", + " for n in range(5):\n", + "\n", + " # Generate the initial latents\n", + " initial_latents = torch.randn(\n", + " shape_latents, generator=generator, device=device, dtype=torch.float32\n", + " )\n", + " latents_to_pil(pipeline, initial_latents, generator).save(f\"{folder}/latents{n}.png\")\n", + "\n", + "\n", + " # Generate an image WITHOUT signal leak in the initial latents\n", + " image = pipeline(\n", + " prompt=\"A very dark picture of the sky, taken by the Nasa.\",\n", + " guidance_scale=1,\n", + " num_inference_steps=num_inference_steps,\n", + " latents=initial_latents,\n", + " ).images[0]\n", + " image.save(f\"{folder}/original{n}.png\")\n", + "\n", + " # Generate a signal leak from computed statistics\n", + " signal_leak = sample_from_stats(\n", + " path=path_stats,\n", + " dims=shape_latents,\n", + " generator_pt=generator,\n", + " generator_np=None,\n", + " device=device\n", + " )\n", + " latents_to_pil(pipeline, signal_leak, generator).save(f\"{folder}/signal_leak{n}.png\")\n", + "\n", + " # Add a signal leak in the initial latents\n", + " initial_latents = sqrt_alpha_prod * signal_leak + sqrt_one_minus_alpha_prod * initial_latents\n", + "\n", + " # Generate an image WITH signal leak in the initial latents\n", + " image_with_signalleak = pipeline(\n", + " prompt=\"A very dark picture of the sky, taken by the Nasa.\",\n", + " guidance_scale=1,\n", + " num_inference_steps=num_inference_steps,\n", + " latents=initial_latents,\n", + " ).images[0]\n", + " image_with_signalleak.save(f\"{folder}/ours{n}.png\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "qfgNleCPnJ57", + "metadata": { + "id": "qfgNleCPnJ57" + }, + "outputs": [], + "source": [ + "del device, first_inference_timestep, generator, num_inference_steps, image, image_with_signalleak, initial_latents, pipeline, signal_leak, sqrt_alpha_prod, sqrt_one_minus_alpha_prod\n", + "import torch\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "nxYFoh1dTJN-", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "nxYFoh1dTJN-", + "outputId": "859627fa-7faa-4fc6-ae84-3c059f5fba75" + }, + "outputs": [], + "source": [ + "from IPython.display import display, HTML\n", + "from PIL import Image\n", + "from io import BytesIO\n", + "import base64\n", + "\n", + "def load_images():\n", + " images = []\n", + " for i in range(5):\n", + " latents_path = f'examples/B1/imgs/latents{i}.png'\n", + " original_path = f'examples/B1/imgs/original{i}.png'\n", + " signal_leak_path = f'examples/B1/imgs/signal_leak{i}.png'\n", + " ours_path = f'examples/B1/imgs/ours{i}.png'\n", + "\n", + " latents_image = Image.open(latents_path)\n", + " original_image = Image.open(original_path)\n", + " signal_leak_image = Image.open(signal_leak_path)\n", + " ours_image = Image.open(ours_path)\n", + "\n", + " images.append([\n", + " latents_image, original_image, signal_leak_image, ours_image\n", + " ])\n", + "\n", + " return images\n", + "\n", + "# Load images\n", + "loaded_images = load_images()\n", + "\n", + "# Function to generate HTML code for the table\n", + "def generate_table(data, headers):\n", + " table_code = \"\"\n", + "\n", + " # Add headers\n", + " for header in headers:\n", + " table_code += \"\"\n", + " table_code += \"\"\n", + "\n", + " # Add data rows\n", + " for row in data:\n", + " table_code += \"\"\n", + " for cell in row:\n", + " img_data = BytesIO()\n", + " cell.save(img_data, format=\"PNG\")\n", + " img_data.seek(0)\n", + " img_data_uri = f\"data:image/png;base64,{base64.b64encode(img_data.read()).decode()}\"\n", + " table_code += f''\n", + " table_code += \"\"\n", + " table_code += \"
\" + header + \"
\"\n", + "\n", + " return table_code\n", + "\n", + "# Data for the table\n", + "table_headers = [\"Initial latents\", \"Generated image (original)\", \"+ Signal Leak\", \"Generated image (ours)\"]\n", + "\n", + "table_data = []\n", + "for i in range(5):\n", + " table_data.append([\n", + " loaded_images[i][0], loaded_images[i][1], loaded_images[i][2], loaded_images[i][3]\n", + " ])\n", + "\n", + "# Display the table\n", + "display(HTML(generate_table(table_data, table_headers)))" + ] + }, + { + "cell_type": "markdown", + "id": "wXFXqRCaiJt3", + "metadata": { + "id": "wXFXqRCaiJt3" + }, + "source": [ + "\n", + "#### Example 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ykTY08UiMZK", + "metadata": { + "id": "2ykTY08UiMZK" + }, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "from diffusers import StableDiffusionPipeline\n", + "from signal_leak import sample_from_stats\n", + "\n", + "folder = \"examples/B2/imgs\"\n", + "path_stats = \"examples/A2\"\n", + "\n", + "os.makedirs(folder, exist_ok=True)\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "# Load the pipeline\n", + "pipeline = StableDiffusionPipeline.from_pretrained(\n", + " \"CompVis/stable-diffusion-v1-4\",\n", + ").to(device)\n", + "num_inference_steps = 50\n", + "\n", + "# Get the timestep T of the first reverse diffusion iteration\n", + "pipeline.scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n", + "first_inference_timestep = pipeline.scheduler.timesteps[0].item()\n", + "\n", + "# Get the values of sqrt(alpha_prod_T)\n", + "sqrt_alpha_prod = pipeline.scheduler.alphas_cumprod[first_inference_timestep] ** 0.5\n", + "sqrt_one_minus_alpha_prod = (1 - pipeline.scheduler.alphas_cumprod[first_inference_timestep]) ** 0.5\n", + "\n", + "# Dimensions of the latent space, with batch_size=1\n", + "shape_latents = [\n", + " 1,\n", + " pipeline.unet.config.in_channels,\n", + " pipeline.unet.config.sample_size,\n", + " pipeline.unet.config.sample_size,\n", + "]\n", + "\n", + "# Utility function fo visualize initial latents / signal leak\n", + "def latents_to_pil(pipeline, latents, generator):\n", + " decoded = pipeline.vae.decode(\n", + " latents / pipeline.vae.config.scaling_factor,\n", + " return_dict=False,\n", + " generator=generator,\n", + " )[0]\n", + " image = pipeline.image_processor.postprocess(\n", + " decoded,\n", + " output_type=\"pil\",\n", + " do_denormalize=[True],\n", + " )[0]\n", + " return image\n", + "\n", + "# Random number generator\n", + "generator = torch.Generator(device=device)\n", + "generator = generator.manual_seed(12345)\n", + "\n", + "with torch.no_grad():\n", + " for n in range(5):\n", + "\n", + " # Generate the initial latents\n", + " initial_latents = torch.randn(\n", + " shape_latents, generator=generator, device=device, dtype=torch.float32\n", + " )\n", + " latents_to_pil(pipeline, initial_latents, generator).save(f\"{folder}/latents{n}.png\")\n", + "\n", + "\n", + " # Generate an image WITHOUT signal leak in the initial latents\n", + " image = pipeline(\n", + " prompt=\"An astronaut riding a horse, in the style of line art, pastel colors.\",\n", + " num_inference_steps=num_inference_steps,\n", + " latents=initial_latents,\n", + " ).images[0]\n", + " image.save(f\"{folder}/original{n}.png\")\n", + "\n", + " # Generate a signal leak from computed statistics\n", + " signal_leak = sample_from_stats(\n", + " path=path_stats,\n", + " dims=shape_latents,\n", + " generator_pt=generator,\n", + " generator_np=None,\n", + " device=device\n", + " )\n", + " latents_to_pil(pipeline, signal_leak, generator).save(f\"{folder}/signal_leak{n}.png\")\n", + "\n", + " # Add a signal leak in the initial latents\n", + " initial_latents = sqrt_alpha_prod * signal_leak + sqrt_one_minus_alpha_prod * initial_latents\n", + "\n", + " # Generate an image WITH signal leak in the initial latents\n", + " image_with_signalleak = pipeline(\n", + " prompt=\"An astronaut riding a horse, in the style of line art, pastel colors.\",\n", + " num_inference_steps=num_inference_steps,\n", + " latents=initial_latents,\n", + " ).images[0]\n", + " image_with_signalleak.save(f\"{folder}/ours{n}.png\")\n", + "\n", + "del pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "oGWrcDJqTRoY", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "oGWrcDJqTRoY", + "outputId": "e17e879a-1777-4fe0-9100-74bc4a396a50" + }, + "outputs": [], + "source": [ + "from IPython.display import display, HTML\n", + "from PIL import Image\n", + "from io import BytesIO\n", + "import base64\n", + "\n", + "def load_images():\n", + " images = []\n", + " for i in range(5):\n", + " latents_path = f'examples/B2/imgs/latents{i}.png'\n", + " original_path = f'examples/B2/imgs/original{i}.png'\n", + " signal_leak_path = f'examples/B2/imgs/signal_leak{i}.png'\n", + " ours_path = f'examples/B2/imgs/ours{i}.png'\n", + "\n", + " latents_image = Image.open(latents_path)\n", + " original_image = Image.open(original_path)\n", + " signal_leak_image = Image.open(signal_leak_path)\n", + " ours_image = Image.open(ours_path)\n", + "\n", + " images.append([\n", + " latents_image, original_image, signal_leak_image, ours_image\n", + " ])\n", + "\n", + " return images\n", + "\n", + "# Load images\n", + "loaded_images = load_images()\n", + "\n", + "# Function to generate HTML code for the table\n", + "def generate_table(data, headers):\n", + " table_code = \"\"\n", + "\n", + " # Add headers\n", + " for header in headers:\n", + " table_code += \"\"\n", + " table_code += \"\"\n", + "\n", + " # Add data rows\n", + " for row in data:\n", + " table_code += \"\"\n", + " for cell in row:\n", + " img_data = BytesIO()\n", + " cell.save(img_data, format=\"PNG\")\n", + " img_data.seek(0)\n", + " img_data_uri = f\"data:image/png;base64,{base64.b64encode(img_data.read()).decode()}\"\n", + " table_code += f''\n", + " table_code += \"\"\n", + " table_code += \"
\" + header + \"
\"\n", + "\n", + " return table_code\n", + "\n", + "# Data for the table\n", + "table_headers = [\"Initial latents\", \"Generated image (original)\", \"+ Signal Leak\", \"Generated image (ours)\"]\n", + "\n", + "table_data = []\n", + "for i in range(5):\n", + " table_data.append([\n", + " loaded_images[i][0], loaded_images[i][1], loaded_images[i][2], loaded_images[i][3]\n", + " ])\n", + "\n", + "# Display the table\n", + "display(HTML(generate_table(table_data, table_headers)))" + ] + }, + { + "cell_type": "markdown", + "id": "p0qzDKXEiU4C", + "metadata": { + "id": "p0qzDKXEiU4C" + }, + "source": [ + " \n", + "### More diverse generated images\n", + "\n", + "In the previous examples, the signal leak is modelled with a \"pixel\" model, realigning the training and inference distributions for stylized images. For *natural* images, the disrepency between training and inference distribution mostly lies in the frequency components: noised images during training still retain the low-frequency contents (large-scale patterns, main colors) of the original images, while the initial latents during inference always have medium low-frequency contents (e.g. *greyish* average color). Compared to the examples above, we then additionnaly model the low-frequency content of the signal leak, using a small set of natural images.\n", + "\n", + "In the next examples, we will use [this set of 128 images](https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip) from [COCO](https://cocodataset.org/)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "_1x92MFiiUOB", + "metadata": { + "id": "_1x92MFiiUOB" + }, + "outputs": [], + "source": [ + "!rm -r examples/C\n", + "!python signal_leak.py \\\n", + " --pretrained_model_name_or_path=\"stabilityai/stable-diffusion-2-1\" \\\n", + " --data_dir=\"coco128/images/train2017\" \\\n", + " --output_dir=\"examples/C/\" \\\n", + " --resolution=768 \\\n", + " --n_components=3 \\\n", + " --statistic_type=\"dct+pixel\" \\\n", + " --center_crop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "p8U3jIuJclo-", + "metadata": { + "id": "p8U3jIuJclo-" + }, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "import numpy as np\n", + "from diffusers import StableDiffusionPipeline\n", + "from signal_leak import sample_from_stats\n", + "\n", + "folder = \"examples/C/imgs\"\n", + "path_stats = \"examples/C\"\n", + "\n", + "os.makedirs(folder, exist_ok=True)\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "# Load the pipeline\n", + "pipeline = StableDiffusionPipeline.from_pretrained(\n", + " \"stabilityai/stable-diffusion-2-1\",\n", + ").to(device)\n", + "num_inference_steps = 50\n", + "\n", + "# Get the timestep T of the first reverse diffusion iteration\n", + "pipeline.scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n", + "first_inference_timestep = pipeline.scheduler.timesteps[0].item()\n", + "\n", + "# Get the values of sqrt(alpha_prod_T)\n", + "sqrt_alpha_prod = pipeline.scheduler.alphas_cumprod[first_inference_timestep] ** 0.5\n", + "sqrt_one_minus_alpha_prod = (1 - pipeline.scheduler.alphas_cumprod[first_inference_timestep]) ** 0.5\n", + "\n", + "# Dimensions of the latent space, with batch_size=1\n", + "shape_latents = [\n", + " 1,\n", + " pipeline.unet.config.in_channels,\n", + " pipeline.unet.config.sample_size,\n", + " pipeline.unet.config.sample_size,\n", + "]\n", + "\n", + "# Utility function to visualize initial latents / signal leak\n", + "def latents_to_pil(pipeline, latents, generator):\n", + " decoded = pipeline.vae.decode(\n", + " latents / pipeline.vae.config.scaling_factor,\n", + " return_dict=False,\n", + " generator=generator,\n", + " )[0]\n", + " image = pipeline.image_processor.postprocess(\n", + " decoded,\n", + " output_type=\"pil\",\n", + " do_denormalize=[True],\n", + " )[0]\n", + " return image\n", + "\n", + "# Random number generator\n", + "generator = torch.Generator(device=device)\n", + "generator = generator.manual_seed(12345)\n", + "generator_np = np.random.default_rng(seed=654321)\n", + "\n", + "with torch.no_grad():\n", + " for n in range(5):\n", + "\n", + " # Generate the initial latents\n", + " initial_latents = torch.randn(\n", + " shape_latents, generator=generator, device=device, dtype=torch.float32\n", + " )\n", + " latents_to_pil(pipeline, initial_latents, generator).save(f\"{folder}/latents{n}.png\")\n", + "\n", + "\n", + " # Generate an image WITHOUT signal leak in the initial latents\n", + " image = pipeline(\n", + " prompt=\"An astronaut riding a horse\",\n", + " num_inference_steps=num_inference_steps,\n", + " latents=initial_latents,\n", + " ).images[0]\n", + " image.save(f\"{folder}/original{n}.png\")\n", + "\n", + " # Generate a signal leak from computed statistics\n", + " signal_leak = sample_from_stats(\n", + " path=path_stats,\n", + " dims=shape_latents,\n", + " generator_pt=generator,\n", + " generator_np=generator_np,\n", + " device=device\n", + " )\n", + " latents_to_pil(pipeline, signal_leak, generator).save(f\"{folder}/signal_leak{n}.png\")\n", + "\n", + " # Add a signal leak in the initial latents\n", + " initial_latents = sqrt_alpha_prod * signal_leak + sqrt_one_minus_alpha_prod * initial_latents\n", + "\n", + " # Generate an image WITH signal leak in the initial latents\n", + " image_with_signalleak = pipeline(\n", + " prompt=\"An astronaut riding a horse\",\n", + " num_inference_steps=num_inference_steps,\n", + " latents=initial_latents,\n", + " ).images[0]\n", + " image_with_signalleak.save(f\"{folder}/ours{n}.png\")\n", + "\n", + "\n", + "del pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6kHnAi-RP8NG", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "6kHnAi-RP8NG", + "outputId": "19e1ba82-b7f1-4460-c7cf-003d71339a78" + }, + "outputs": [], + "source": [ + "from IPython.display import display, HTML\n", + "from PIL import Image\n", + "from io import BytesIO\n", + "import base64\n", + "\n", + "def load_images():\n", + " images = []\n", + " for i in range(5):\n", + " latents_path = f'examples/C/imgs/latents{i}.png'\n", + " original_path = f'examples/C/imgs/original{i}.png'\n", + " signal_leak_path = f'examples/C/imgs/signal_leak{i}.png'\n", + " ours_path = f'examples/C/imgs/ours{i}.png'\n", + "\n", + " latents_image = Image.open(latents_path)\n", + " original_image = Image.open(original_path)\n", + " signal_leak_image = Image.open(signal_leak_path)\n", + " ours_image = Image.open(ours_path)\n", + "\n", + " images.append([\n", + " latents_image, original_image, signal_leak_image, ours_image\n", + " ])\n", + "\n", + " return images\n", + "\n", + "# Load images\n", + "loaded_images = load_images()\n", + "\n", + "# Function to generate HTML code for the table\n", + "def generate_table(data, headers):\n", + " table_code = \"\"\n", + "\n", + " # Add headers\n", + " for header in headers:\n", + " table_code += \"\"\n", + " table_code += \"\"\n", + "\n", + " # Add data rows\n", + " for row in data:\n", + " table_code += \"\"\n", + " for cell in row:\n", + " img_data = BytesIO()\n", + " cell.save(img_data, format=\"PNG\")\n", + " img_data.seek(0)\n", + " img_data_uri = f\"data:image/png;base64,{base64.b64encode(img_data.read()).decode()}\"\n", + " table_code += f''\n", + " table_code += \"\"\n", + " table_code += \"
\" + header + \"
\"\n", + "\n", + " return table_code\n", + "\n", + "# Data for the table\n", + "table_headers = [\"Initial latents\", \"Generated image (original)\", \"+ Signal Leak\", \"Generated image (ours)\"]\n", + "\n", + "table_data = []\n", + "for i in range(5):\n", + " table_data.append([\n", + " loaded_images[i][0], loaded_images[i][1], loaded_images[i][2], loaded_images[i][3]\n", + " ])\n", + "\n", + "# Display the table\n", + "display(HTML(generate_table(table_data, table_headers)))" + ] + }, + { + "cell_type": "markdown", + "id": "wFNlga2eihG-", + "metadata": { + "id": "wFNlga2eihG-" + }, + "source": [ + " \n", + "### Control on the average color\n", + "\n", + "In the previous example, the signal leak given at inference time is sampled randomly from the statistics of the signal leak present at training time. Instead, it is also possible to *manually* set its low-frequency components, providing control on the low-frequency content of the generated image, as we show in the following example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "gZpwtr60igZ0", + "metadata": { + "id": "gZpwtr60igZ0" + }, + "outputs": [], + "source": [ + "!python signal_leak.py \\\n", + " --pretrained_model_name_or_path=\"stabilityai/stable-diffusion-2-1\" \\\n", + " --data_dir=\"coco128/images/train2017\" \\\n", + " --output_dir=\"examples/D/\" \\\n", + " --resolution=768 \\\n", + " --n_components=1 \\\n", + " --statistic_type=\"dct+pixel\" \\\n", + " --center_crop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "wmha41uFQPJa", + "metadata": { + "id": "wmha41uFQPJa" + }, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "import numpy as np\n", + "from diffusers import StableDiffusionPipeline\n", + "from signal_leak import sample_from_stats\n", + "\n", + "folder = \"examples/D/imgs\"\n", + "path_stats = \"examples/D\"\n", + "\n", + "os.makedirs(folder, exist_ok=True)\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "# Load the pipeline\n", + "pipeline = StableDiffusionPipeline.from_pretrained(\n", + " \"stabilityai/stable-diffusion-2-1\",\n", + ").to(device)\n", + "num_inference_steps = 50\n", + "\n", + "# Get the timestep T of the first reverse diffusion iteration\n", + "pipeline.scheduler.set_timesteps(num_inference_steps, device=\"cuda\")\n", + "first_inference_timestep = pipeline.scheduler.timesteps[0].item()\n", + "\n", + "# Get the values of sqrt(alpha_prod_T)\n", + "sqrt_alpha_prod = pipeline.scheduler.alphas_cumprod[first_inference_timestep] ** 0.5\n", + "sqrt_one_minus_alpha_prod = (1 - pipeline.scheduler.alphas_cumprod[first_inference_timestep]) ** 0.5\n", + "\n", + "# Dimensions of the latent space, with batch_size=1\n", + "shape_latents = [\n", + " 1,\n", + " pipeline.unet.config.in_channels,\n", + " pipeline.unet.config.sample_size,\n", + " pipeline.unet.config.sample_size,\n", + "]\n", + "\n", + "# Utility function to visualize initial latents / signal leak\n", + "def latents_to_pil(pipeline, latents, generator):\n", + " decoded = pipeline.vae.decode(\n", + " latents / pipeline.vae.config.scaling_factor,\n", + " return_dict=False,\n", + " generator=generator,\n", + " )[0]\n", + " image = pipeline.image_processor.postprocess(\n", + " decoded,\n", + " output_type=\"pil\",\n", + " do_denormalize=[True],\n", + " )[0]\n", + " return image\n", + "\n", + "# Random number generator\n", + "generator = torch.Generator(device=device)\n", + "generator = generator.manual_seed(12345)\n", + "\n", + "\n", + "# Generate the initial latents WITHOUT signal-leak\n", + "shape_latents = [\n", + " 1,\n", + " pipeline.unet.config.in_channels,\n", + " pipeline.unet.config.sample_size,\n", + " pipeline.unet.config.sample_size,\n", + "]\n", + "initial_latents_without_signalleak = torch.randn(\n", + " shape_latents, generator=generator, device=device, dtype=torch.float32\n", + ")\n", + "\n", + "with torch.no_grad():\n", + " for channel in range(4):\n", + " for value in (-2, -1, 0, 1, 2):\n", + "\n", + " # Reset the seed, so that the only difference between the different initial latents is the LF components\n", + " generator = generator.manual_seed(123456)\n", + " generator_np = np.random.default_rng(seed=654321)\n", + "\n", + " # Generate the initial latents with signal leak\n", + " signal_leak = sample_from_stats(\n", + " path=path_stats,\n", + " dims=shape_latents,\n", + " generator_pt=generator,\n", + " generator_np=generator_np,\n", + " device=device,\n", + " only_hf=True\n", + " )\n", + " signal_leak[:, channel, :, :] += value\n", + "\n", + " initial_latents = (\n", + " sqrt_alpha_prod * signal_leak\n", + " + sqrt_one_minus_alpha_prod * initial_latents_without_signalleak\n", + " )\n", + " # Generate an image\n", + " image_with_signalleak = pipeline(\n", + " prompt=\"An astronaut riding a horse\",\n", + " num_inference_steps=num_inference_steps,\n", + " latents=initial_latents,\n", + " ).images[0]\n", + " image_with_signalleak.save(f\"{folder}/{channel}_{value}.png\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "Ape_LEgQMZKm", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 874 + }, + "id": "Ape_LEgQMZKm", + "outputId": "96d43a56-a20d-4c64-91ed-2c40ec5d2f60" + }, + "outputs": [], + "source": [ + "from IPython.display import display, HTML\n", + "from PIL import Image\n", + "from io import BytesIO\n", + "import base64\n", + "\n", + "def load_images():\n", + " images = []\n", + " for channel in range(4):\n", + "\n", + " images.append([\n", + " Image.open(f'examples/D/imgs/{channel}_{value}.png') for value in (-2, -1, 0, 1, 2)\n", + " ])\n", + "\n", + " return images\n", + "\n", + "# Load images\n", + "loaded_images = load_images()\n", + "\n", + "# Function to generate HTML code for the table\n", + "def generate_table(data, headers):\n", + " table_code = \"\"\n", + "\n", + " # Add headers\n", + " for header in headers:\n", + " table_code += \"\"\n", + " table_code += \"\"\n", + "\n", + " # Add data rows\n", + " for row in data:\n", + " table_code += \"\"\n", + " for cell in row:\n", + " img_data = BytesIO()\n", + " cell.save(img_data, format=\"PNG\")\n", + " img_data.seek(0)\n", + " img_data_uri = f\"data:image/png;base64,{base64.b64encode(img_data.read()).decode()}\"\n", + " table_code += f''\n", + " table_code += \"\"\n", + " table_code += \"
\" + header + \"
\"\n", + "\n", + " return table_code\n", + "\n", + "# Data for the table\n", + "table_headers = [\"-2\", \"-1\", \"0\", \"1\", \"2\"]\n", + "\n", + "# Display the table\n", + "display(HTML(generate_table(loaded_images, table_headers)))" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "1bfd35792f5343de86238de0e807ba01": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9ec248e2db2b416fb0a9db0fca6228bb", + "placeholder": "​", + "style": "IPY_MODEL_b07b1a7b188d4aa584bcdcb6b293a3c2", + "value": " 50/50 [01:02<00:00,  1.24s/it]" + } + }, + "20f9fc4c61c94b7396d4a386a3c27ddb": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "4e41df9b1027494f81d5f0449c6d9027": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "55da9af2bc2846148251ba18b8503138": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_a5bb486b87684aad814574ea9ca145d7", + "IPY_MODEL_e90aef66217f495082645bbc67d386a6", + "IPY_MODEL_1bfd35792f5343de86238de0e807ba01" + ], + "layout": "IPY_MODEL_dd1644de2241479b8b974375a9e098e2" + } + }, + "81c7387460a244fb8625863158ba0103": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "8d26c64943534f858d12d4b009a0a6ab": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "9cba2a19d1854d289f7b74c2d147cf05": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9e216a39076047508ba3a00d5a427560": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9ec248e2db2b416fb0a9db0fca6228bb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a5bb486b87684aad814574ea9ca145d7": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9e216a39076047508ba3a00d5a427560", + "placeholder": "​", + "style": "IPY_MODEL_81c7387460a244fb8625863158ba0103", + "value": "100%" + } + }, + "a5bcc403252c4965af6634c6778f9855": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "b07b1a7b188d4aa584bcdcb6b293a3c2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b36394c79ecd499ea6f3542c1d8400a2": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b688152435e44628b3397b07bfc11096", + "placeholder": "​", + "style": "IPY_MODEL_20f9fc4c61c94b7396d4a386a3c27ddb", + "value": "Loading pipeline components...: 100%" + } + }, + "b563ef9c5b254b1b85374852cc1ca6f1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_c775f4e631ac40338be6e0e2322868f4", + "max": 6, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_a5bcc403252c4965af6634c6778f9855", + "value": 6 + } + }, + "b688152435e44628b3397b07bfc11096": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "b8a9580ba1004c9390d6cc3a2263a9ac": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "c775f4e631ac40338be6e0e2322868f4": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "cdca892715d34db8801220591ff3069c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b8a9580ba1004c9390d6cc3a2263a9ac", + "placeholder": "​", + "style": "IPY_MODEL_8d26c64943534f858d12d4b009a0a6ab", + "value": " 6/6 [00:01<00:00,  5.48it/s]" + } + }, + "cf25de4cb76841dcb6d8ed58ebcd05bc": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "dd1644de2241479b8b974375a9e098e2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e90aef66217f495082645bbc67d386a6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_9cba2a19d1854d289f7b74c2d147cf05", + "max": 50, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_4e41df9b1027494f81d5f0449c6d9027", + "value": 50 + } + }, + "edeb82954da6454487447c3cded28fb5": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_b36394c79ecd499ea6f3542c1d8400a2", + "IPY_MODEL_b563ef9c5b254b1b85374852cc1ca6f1", + "IPY_MODEL_cdca892715d34db8801220591ff3069c" + ], + "layout": "IPY_MODEL_cf25de4cb76841dcb6d8ed58ebcd05bc" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}