diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index aafe08f..a844ffa 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -28,7 +28,7 @@ jobs:
- name: Generate coverage badge
run: |
- python src/badge_generator.py
+ python badge_generator.py
# - name: Commit coverage badge
# run: |
diff --git a/README.md b/README.md
index 04297e6..7b8f1e4 100644
--- a/README.md
+++ b/README.md
@@ -73,6 +73,11 @@ In this project we implement different alternatives to it.
## 🛠️ Install
+### Install using pip
+```bash
+pip install relaxit
+```
+
### Install from source
```bash
pip install git+https://github.com/intsystems/discrete-variables-relaxation
@@ -85,10 +90,11 @@ cd discrete-variables-relaxation
pip install -e .
```
-## 🚀 Quickstart
+## 🚀 Quickstart
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intsystems/discrete-variables-relaxation/blob/main/demo/quickstart.ipynb)
```python
import torch
-from relaxit.distributions.InvertibleGaussian import InvertibleGaussian
+from relaxit.distributions import InvertibleGaussian
# initialize distribution parameters
loc = torch.zeros(3, 4, 5, requires_grad=True)
@@ -100,6 +106,8 @@ distribution = InvertibleGaussian(loc, scale, temperature)
# sample with reparameterization
sample = distribution.rsample()
+print('sample.shape:', sample.shape)
+print('sample.requires_grad:', sample.requires_grad)
```
## 🎮 Demo
diff --git a/README.rst b/README.rst
deleted file mode 100644
index e0161f8..0000000
--- a/README.rst
+++ /dev/null
@@ -1,153 +0,0 @@
-
-
-
-
-
Just Relax It
-
-
Discrete Variables Relaxation
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-# 📬 Assets
-==========
-
-1. `Technical Meeting 1 - Presentation `_
-
-2. `Technical Meeting 2 - Jupyter Notebook `_
-
-3. `Technical Meeting 3 — Jupyter Notebook `_
-
-4. `Blog Post `_
-
-5. `Documentation `_
-
-6. `Tests `_
-
-# 💡 Motivation
-==============
-
-For lots of mathematical problems we need an ability to sample discrete random variables.
-
-The problem is that due to continuos nature of deep learning optimization, the usage of truely discrete random variables is infeasible.
-
-Thus we use different relaxation method.
-
-One of them, `Concrete distribution](https://arxiv.org/abs/1611.00712) or [Gumbel-Softmax `_ or `Gumbel-Softmax `_ (this is one distribution proposed in parallel by two research groups) is implemented in different DL packages.
-
-In this project we implement different alternatives to it.
-
-
-
-
-
-
-
-# 🗃 Algorithms
-==============
-- `x] [Relaxed Bernoulli `_
-- `x] [Correlated relaxed Bernoulli `_
-- `x] [Gumbel-softmax TOP-K `_
-- `x] [Straight-Through Bernoulli `_
-- `x] [Invertible Gaussian `_ with KL implemented
-- `x] [Hard Concrete `_
-- `x] [REINFORCE `_
-- `x] [Logit-Normal](https://en.wikipedia.org/wiki/Logit-normal_distribution) and [Laplace-form approximation of Dirichlet `_ and `Laplace-form approximation of Dirichlet `_
-
-# 🎮 Demo
-========
-
-For demonstration purposes, we have implemented a simple VAE with discrete latents. Our code is available `here `_.
-
-Each of the discussed relaxation techniques allowed us to learn the latent space with the corresponding distribution.
-
-# 📚 Stack
-=========
-
-Some of the alternatives for GS were implemented in `pyro `_, so we base our library on their codebase.
-
-# 🧩 Some details
-================
-
-To make to library consistent, we integrate imports of distributions from `pyro` and `torch` into the library, so that all the categorical distributions can be imported from one entrypoint.
-
-# 👥 Contributors
-================
-- `Daniil Dorin `_ (Basic code writing, Final demo, Algorithms)
-- `Igor Ignashin `_ (Project wrapping, Documentation writing, Algorithms)
-- `Nikita Kiselev `_ (Project planning, Blog post, Algorithms)
-- `Andrey Veprikov `_ (Tests writing, Documentation writing, Algorithms)
-- You are welcome to contribute to our project!
-
-# 🔗 Useful links
-================
-- `About top-k GS `_
-- `VAE implementation with different latent distributions `_
-- `KL divergence between Dirichlet and Logistic-Normal implemented in R `_
-- `About score function (SF) and pathwise derivate (PD) estimators, VAE and REINFORCE `_ and pathwise derivate (PD) estimators, VAE and REINFORCE](https://arxiv.org/abs/1506.05254)
-
diff --git a/src/badge_generator.py b/badge_generator.py
similarity index 100%
rename from src/badge_generator.py
rename to badge_generator.py
diff --git a/basic/basic_code.ipynb b/basic/basic_code.ipynb
index 7ffcdc5..2ff11ed 100644
--- a/basic/basic_code.ipynb
+++ b/basic/basic_code.ipynb
@@ -365,12 +365,10 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "current_dir = os.getcwd()\n",
- "sys.path.append(os.path.abspath(os.path.join(current_dir, \"..\", \"src\")))\n",
"from relaxit.distributions import GaussianRelaxedBernoulli"
]
},
@@ -511,12 +509,10 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "current_dir = os.getcwd()\n",
- "sys.path.append(os.path.abspath(os.path.join(current_dir, \"..\", \"src\")))\n",
"from relaxit.distributions import CorrelatedRelaxedBernoulli"
]
},
@@ -581,13 +577,10 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "# Igor code\n",
- "current_dir = os.getcwd()\n",
- "sys.path.append(os.path.abspath(os.path.join(current_dir, \"..\", \"src\")))\n",
"from relaxit.distributions import StraightThroughBernoulli"
]
},
@@ -692,7 +685,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -710,8 +703,6 @@
"import torch\n",
"import matplotlib.pyplot as plt\n",
"\n",
- "current_dir = os.getcwd()\n",
- "sys.path.append(os.path.abspath(os.path.join(current_dir, \"..\", \"src\")))\n",
"from relaxit.distributions import HardConcrete\n",
"\n",
"%load_ext autoreload\n",
diff --git a/demo/demo.ipynb b/demo/demo.ipynb
index 0b31aef..d66d34a 100644
--- a/demo/demo.ipynb
+++ b/demo/demo.ipynb
@@ -44,7 +44,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -70,8 +70,6 @@
"from torch.nn import functional as F\n",
"from torchvision import datasets, transforms\n",
"\n",
- "current_dir = os.getcwd()\n",
- "sys.path.append(os.path.abspath(os.path.join(current_dir, '..', 'src')))\n",
"from relaxit.distributions import CorrelatedRelaxedBernoulli\n",
"\n",
"%load_ext autoreload\n",
diff --git a/demo/laplace-bridge.ipynb b/demo/laplace-bridge.ipynb
index 98599a3..b6f918f 100644
--- a/demo/laplace-bridge.ipynb
+++ b/demo/laplace-bridge.ipynb
@@ -12,15 +12,14 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import sys\n",
"\n",
- "sys.path.append(\"../src\")\n",
- "from relaxit.distributions.LogisticNormalSoftmax import LogisticNormalSoftmax\n",
+ "from relaxit.distributions import LogisticNormalSoftmax\n",
"from relaxit.distributions.approx import (\n",
" lognorm_approximation_fn,\n",
" dirichlet_approximation_fn,\n",
diff --git a/demo/quickstart.ipynb b/demo/quickstart.ipynb
new file mode 100644
index 0000000..c36a1f4
--- /dev/null
+++ b/demo/quickstart.ipynb
@@ -0,0 +1,55 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Quickstart into \"Just Relax It\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from relaxit.distributions import InvertibleGaussian\n",
+ "\n",
+ "# initialize distribution parameters\n",
+ "loc = torch.zeros(3, 4, 5, requires_grad=True)\n",
+ "scale = torch.ones(3, 4, 5, requires_grad=True)\n",
+ "temperature = torch.tensor([1e-0])\n",
+ "\n",
+ "# initialize distribution\n",
+ "distribution = InvertibleGaussian(loc, scale, temperature)\n",
+ "\n",
+ "# sample with reparameterization\n",
+ "sample = distribution.rsample()\n",
+ "print('sample.shape:', sample.shape)\n",
+ "print('sample.requires_grad:', sample.requires_grad)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "nkiselev_relaxit",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.15"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/demo/requirements.txt b/demo/requirements.txt
index c38d5a8..546ee61 100644
--- a/demo/requirements.txt
+++ b/demo/requirements.txt
@@ -7,4 +7,5 @@ torchvision==0.20.1
matplotlib==3.9.2
networkx==3.3
tqdm==4.66.5
-pillow==10.4.0
\ No newline at end of file
+pillow==10.4.0
+relaxit==0.1.2
\ No newline at end of file
diff --git a/demo/vae_correlated_bernoulli.py b/demo/vae_correlated_bernoulli.py
index fd32061..e7a2885 100644
--- a/demo/vae_correlated_bernoulli.py
+++ b/demo/vae_correlated_bernoulli.py
@@ -9,7 +9,6 @@
from torchvision import datasets, transforms
from torchvision.utils import save_image
-sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
from relaxit.distributions import CorrelatedRelaxedBernoulli
diff --git a/demo/vae_gaussian_bernoulli.py b/demo/vae_gaussian_bernoulli.py
index 63ce3ee..39638a1 100644
--- a/demo/vae_gaussian_bernoulli.py
+++ b/demo/vae_gaussian_bernoulli.py
@@ -9,7 +9,6 @@
from torchvision import datasets, transforms
from torchvision.utils import save_image
-sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
from relaxit.distributions import GaussianRelaxedBernoulli
diff --git a/demo/vae_gumbel_softmax_topk.py b/demo/vae_gumbel_softmax_topk.py
index 1c81285..a557d57 100644
--- a/demo/vae_gumbel_softmax_topk.py
+++ b/demo/vae_gumbel_softmax_topk.py
@@ -9,7 +9,6 @@
from torchvision import datasets, transforms
from torchvision.utils import save_image
-sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
from relaxit.distributions import GumbelSoftmaxTopK
parser = argparse.ArgumentParser(description="VAE MNIST Example")
diff --git a/demo/vae_hard_concrete.py b/demo/vae_hard_concrete.py
index 039cf5b..acae950 100644
--- a/demo/vae_hard_concrete.py
+++ b/demo/vae_hard_concrete.py
@@ -9,7 +9,6 @@
from torchvision import datasets, transforms
from torchvision.utils import save_image
-sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
from relaxit.distributions import HardConcrete
diff --git a/demo/vae_invertible_gaussian.py b/demo/vae_invertible_gaussian.py
index 107d4ad..912b36e 100644
--- a/demo/vae_invertible_gaussian.py
+++ b/demo/vae_invertible_gaussian.py
@@ -9,7 +9,6 @@
from torchvision import datasets, transforms
from torchvision.utils import save_image
-sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
from relaxit.distributions import InvertibleGaussian
from relaxit.distributions.kl import kl_divergence
diff --git a/demo/vae_straight_through_bernoulli.py b/demo/vae_straight_through_bernoulli.py
index 8475812..6d5cf69 100644
--- a/demo/vae_straight_through_bernoulli.py
+++ b/demo/vae_straight_through_bernoulli.py
@@ -9,7 +9,6 @@
from torchvision import datasets, transforms
from torchvision.utils import save_image
-sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
from relaxit.distributions import StraightThroughBernoulli
diff --git a/docs/index.rst b/docs/index.rst
index 82bf45f..114f005 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -9,6 +9,19 @@ Just Relax It
"Just Relax It" is a cutting-edge Python library designed to streamline the optimization of discrete probability distributions in neural networks, offering a suite of advanced relaxation techniques compatible with PyTorch.
+Motivation
+----------
+
+For lots of mathematical problems we need an ability to sample discrete random variables.
+The problem is that due to continuos nature of deep learning optimization, the usage of truely discrete random variables is infeasible.
+Thus we use different relaxation method.
+One of them, `Concrete distribution `_ or `Gumbel-Softmax `_ (this is one distribution proposed in parallel by two research groups) is implemented in different DL packages.
+In this project we implement different alternatives to it.
+
+.. image:: ../assets/overview.png
+ :width: 600
+ :align: center
+
.. toctree::
:maxdepth: 1
:caption: Guidelines
diff --git a/docs/install.md b/docs/install.md
index 6640e16..1ba1d57 100644
--- a/docs/install.md
+++ b/docs/install.md
@@ -1,10 +1,9 @@
## Install
-
+pip install relaxit
+```
### Install from source
```bash
diff --git a/docs/quickstart.md b/docs/quickstart.md
index 5b0c2b2..74caf23 100644
--- a/docs/quickstart.md
+++ b/docs/quickstart.md
@@ -1,8 +1,8 @@
## Quickstart
-
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intsystems/discrete-variables-relaxation/blob/main/demo/quickstart.ipynb)
```python
import torch
-from relaxit.distributions.InvertibleGaussian import InvertibleGaussian
+from relaxit.distributions import InvertibleGaussian
# initialize distribution parameters
loc = torch.zeros(3, 4, 5, requires_grad=True)
@@ -14,4 +14,6 @@ distribution = InvertibleGaussian(loc, scale, temperature)
# sample with reparameterization
sample = distribution.rsample()
+print('sample.shape:', sample.shape)
+print('sample.requires_grad:', sample.requires_grad)
```
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 0454c6e..5531ab7 100644
--- a/setup.py
+++ b/setup.py
@@ -11,7 +11,13 @@ def read(file_path):
with io.open(file_path, "r", encoding="utf-8") as f:
return f.read()
-readme = read("README.rst")
+try:
+ long_description = open("README.md", encoding="utf-8").read()
+except Exception as e:
+ sys.stderr.write("Failed to read README.md: {}\n".format(e))
+ sys.stderr.flush()
+ long_description = ""
+
# # вычищаем локальные версии из файла requirements (согласно PEP440)
requirements = '\n'.join(
re.findall(r'^([^\s^+]+).*$',
@@ -26,8 +32,8 @@ def read(file_path):
license="MIT",
author="",
author_email="",
- description="relaxit, python package",
- long_description=readme,
+ description="A Python library for discrete variables relaxation",
+ long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/intsystems/discrete-variables-relaxation",
# options
diff --git a/src/README.rst b/src/README.rst
deleted file mode 100644
index 0a7b2f3..0000000
--- a/src/README.rst
+++ /dev/null
@@ -1,29 +0,0 @@
-************
-Installation
-************
-
-Requirements
-============
-
-- Python Test
-- pip
-
-Installing by using PyPi
-========================
-
-Install
--------
-To install the package, follow these steps:
-
-.. code-block:: bash
-
- git clone ttps://github.com/intsystems/discrete-variables-relaxation.git /tmp/discrete-variables-relaxation
- python3 -m pip install /tmp/discrete-variables-relaxation/src/
-
-Uninstall
----------
-To uninstall the package, run the following command:
-
-.. code-block:: bash
-
- python3 -m pip uninstall relaxit