Skip to content

Commit

Permalink
pep8
Browse files Browse the repository at this point in the history
  • Loading branch information
Vepricov committed Nov 26, 2024
1 parent f51e0b4 commit 7adc7c6
Show file tree
Hide file tree
Showing 36 changed files with 1,092 additions and 585 deletions.
19 changes: 14 additions & 5 deletions basic/approx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@
"source": [
"import torch\n",
"import sys\n",
"sys.path.append('../src')\n",
"\n",
"sys.path.append(\"../src\")\n",
"from relaxit.distributions.LogisticNormalSoftmax import LogisticNormalSoftmax\n",
"from relaxit.distributions.approx import lognorm_approximation_fn, dirichlet_approximation_fn\n",
"from relaxit.distributions.approx import (\n",
" lognorm_approximation_fn,\n",
" dirichlet_approximation_fn,\n",
")\n",
"from pyro.distributions import Dirichlet"
]
},
Expand All @@ -44,29 +48,34 @@
"# Compute midpoints of the triangle sides\n",
"MIDPOINTS = [(CORNERS[(i + 1) % 3] + CORNERS[(i + 2) % 3]) / 2.0 for i in range(3)]\n",
"\n",
"def xy2bc(xy, tol=1.e-3):\n",
"\n",
"def xy2bc(xy, tol=1.0e-3):\n",
" \"\"\"Converts 2D Cartesian coordinates to barycentric.\"\"\"\n",
" s = [(CORNERS[i] - MIDPOINTS[i]).dot(xy - MIDPOINTS[i]) / 0.75 for i in range(3)]\n",
" return np.clip(s, tol, 1.0 - tol)\n",
"\n",
"\n",
"def argmin_norm(sample, positions):\n",
" \"\"\"Finds the index of the closest point in positions to sample.\"\"\"\n",
" return np.argmin(np.sum(np.square(sample - positions), axis=1))\n",
"\n",
"\n",
"def refine_triangulation(subdiv=5):\n",
" \"\"\"Refines the triangulation and returns the refined mesh.\"\"\"\n",
" refiner = tri.UniformTriRefiner(TRIANGLE)\n",
" return refiner.refine_triangulation(subdiv=subdiv)\n",
"\n",
"\n",
"def plot_contours(trimesh, pvals, nlevels=200, **kwargs):\n",
" \"\"\"Plots the contours of the probability values on the triangulated mesh.\"\"\"\n",
" plt.tricontourf(trimesh, pvals, nlevels, **kwargs)\n",
" plt.axis('equal')\n",
" plt.axis(\"equal\")\n",
" plt.xlim(0, 1)\n",
" plt.ylim(0, 0.75**0.5)\n",
" plt.axis('off')\n",
" plt.axis(\"off\")\n",
" plt.show()\n",
"\n",
"\n",
"def sample_and_plot(distribution, nlevels=200, subdiv=5, num_samples=100000, **kwargs):\n",
" \"\"\"Samples from a given distribution and plots the contours.\"\"\"\n",
" trimesh = refine_triangulation(subdiv)\n",
Expand Down
Loading

0 comments on commit 7adc7c6

Please sign in to comment.