Skip to content

Commit

Permalink
upd imports and others
Browse files Browse the repository at this point in the history
  • Loading branch information
kisnikser committed Nov 30, 2024
1 parent 7be292f commit 82c1d48
Show file tree
Hide file tree
Showing 20 changed files with 104 additions and 220 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- name: Generate coverage badge
run: |
python src/badge_generator.py
python badge_generator.py
# - name: Commit coverage badge
# run: |
Expand Down
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ In this project we implement different alternatives to it.

## 🛠️ Install

### Install using pip
```bash
pip install relaxit
```

### Install from source
```bash
pip install git+https://github.com/intsystems/discrete-variables-relaxation
Expand All @@ -85,10 +90,11 @@ cd discrete-variables-relaxation
pip install -e .
```

## 🚀 Quickstart
## 🚀 Quickstart
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/intsystems/discrete-variables-relaxation/blob/main/demo/quickstart.ipynb)
```python
import torch
from relaxit.distributions.InvertibleGaussian import InvertibleGaussian
from relaxit.distributions import InvertibleGaussian

# initialize distribution parameters
loc = torch.zeros(3, 4, 5, requires_grad=True)
Expand All @@ -100,6 +106,8 @@ distribution = InvertibleGaussian(loc, scale, temperature)

# sample with reparameterization
sample = distribution.rsample()
print('sample.shape:', sample.shape)
print('sample.requires_grad:', sample.requires_grad)
```

## 🎮 Demo
Expand Down
153 changes: 0 additions & 153 deletions README.rst

This file was deleted.

File renamed without changes.
17 changes: 4 additions & 13 deletions basic/basic_code.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,10 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"current_dir = os.getcwd()\n",
"sys.path.append(os.path.abspath(os.path.join(current_dir, \"..\", \"src\")))\n",
"from relaxit.distributions import GaussianRelaxedBernoulli"
]
},
Expand Down Expand Up @@ -511,12 +509,10 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"current_dir = os.getcwd()\n",
"sys.path.append(os.path.abspath(os.path.join(current_dir, \"..\", \"src\")))\n",
"from relaxit.distributions import CorrelatedRelaxedBernoulli"
]
},
Expand Down Expand Up @@ -581,13 +577,10 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Igor code\n",
"current_dir = os.getcwd()\n",
"sys.path.append(os.path.abspath(os.path.join(current_dir, \"..\", \"src\")))\n",
"from relaxit.distributions import StraightThroughBernoulli"
]
},
Expand Down Expand Up @@ -692,7 +685,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -710,8 +703,6 @@
"import torch\n",
"import matplotlib.pyplot as plt\n",
"\n",
"current_dir = os.getcwd()\n",
"sys.path.append(os.path.abspath(os.path.join(current_dir, \"..\", \"src\")))\n",
"from relaxit.distributions import HardConcrete\n",
"\n",
"%load_ext autoreload\n",
Expand Down
4 changes: 1 addition & 3 deletions demo/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -70,8 +70,6 @@
"from torch.nn import functional as F\n",
"from torchvision import datasets, transforms\n",
"\n",
"current_dir = os.getcwd()\n",
"sys.path.append(os.path.abspath(os.path.join(current_dir, '..', 'src')))\n",
"from relaxit.distributions import CorrelatedRelaxedBernoulli\n",
"\n",
"%load_ext autoreload\n",
Expand Down
5 changes: 2 additions & 3 deletions demo/laplace-bridge.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import sys\n",
"\n",
"sys.path.append(\"../src\")\n",
"from relaxit.distributions.LogisticNormalSoftmax import LogisticNormalSoftmax\n",
"from relaxit.distributions import LogisticNormalSoftmax\n",
"from relaxit.distributions.approx import (\n",
" lognorm_approximation_fn,\n",
" dirichlet_approximation_fn,\n",
Expand Down
55 changes: 55 additions & 0 deletions demo/quickstart.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Quickstart into \"Just Relax It\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from relaxit.distributions import InvertibleGaussian\n",
"\n",
"# initialize distribution parameters\n",
"loc = torch.zeros(3, 4, 5, requires_grad=True)\n",
"scale = torch.ones(3, 4, 5, requires_grad=True)\n",
"temperature = torch.tensor([1e-0])\n",
"\n",
"# initialize distribution\n",
"distribution = InvertibleGaussian(loc, scale, temperature)\n",
"\n",
"# sample with reparameterization\n",
"sample = distribution.rsample()\n",
"print('sample.shape:', sample.shape)\n",
"print('sample.requires_grad:', sample.requires_grad)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "nkiselev_relaxit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 2 additions & 1 deletion demo/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ torchvision==0.20.1
matplotlib==3.9.2
networkx==3.3
tqdm==4.66.5
pillow==10.4.0
pillow==10.4.0
relaxit==0.1.2
1 change: 0 additions & 1 deletion demo/vae_correlated_bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torchvision import datasets, transforms
from torchvision.utils import save_image

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
from relaxit.distributions import CorrelatedRelaxedBernoulli


Expand Down
1 change: 0 additions & 1 deletion demo/vae_gaussian_bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torchvision import datasets, transforms
from torchvision.utils import save_image

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
from relaxit.distributions import GaussianRelaxedBernoulli


Expand Down
1 change: 0 additions & 1 deletion demo/vae_gumbel_softmax_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torchvision import datasets, transforms
from torchvision.utils import save_image

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
from relaxit.distributions import GumbelSoftmaxTopK

parser = argparse.ArgumentParser(description="VAE MNIST Example")
Expand Down
1 change: 0 additions & 1 deletion demo/vae_hard_concrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torchvision import datasets, transforms
from torchvision.utils import save_image

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
from relaxit.distributions import HardConcrete


Expand Down
1 change: 0 additions & 1 deletion demo/vae_invertible_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torchvision import datasets, transforms
from torchvision.utils import save_image

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))
from relaxit.distributions import InvertibleGaussian
from relaxit.distributions.kl import kl_divergence

Expand Down
Loading

0 comments on commit 82c1d48

Please sign in to comment.