Skip to content

Commit

Permalink
removed complex-specific samplers and pushed functionality into dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
quattro committed Mar 28, 2024
1 parent 618d8a5 commit 2b209df
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 44 deletions.
9 changes: 0 additions & 9 deletions docs/api/samplers.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ simple abstract class definition, [`traceax.AbstractSampler`][] using that subcl
members:
- __call__

## Floating-point Samplers

::: traceax.NormalSampler

Expand All @@ -25,11 +24,3 @@ simple abstract class definition, [`traceax.AbstractSampler`][] using that subcl
---

::: traceax.RademacherSampler


## Complex-value Samplers
::: traceax.ComplexNormalSampler

---

::: traceax.ComplexSphereSampler
2 changes: 0 additions & 2 deletions src/traceax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
)
from ._samplers import (
AbstractSampler as AbstractSampler,
ComplexNormalSampler as ComplexNormalSampler,
ComplexSphereSampler as ComplexSphereSampler,
NormalSampler as NormalSampler,
RademacherSampler as RademacherSampler,
SphereSampler as SphereSampler,
Expand Down
47 changes: 14 additions & 33 deletions src/traceax/_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
import jax.numpy as jnp
import jax.random as rdm

from jaxtyping import Array, Inexact, PRNGKeyArray
from jaxtyping import Array, DTypeLike, Inexact, Num, PRNGKeyArray


class AbstractSampler(eqx.Module, strict=True):
"""Abstract base class for all samplers."""

@abstractmethod
def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]:
def __call__(self, key: PRNGKeyArray, n: int, k: int, dtype: DTypeLike = float) -> Num[Array, "n k"]:
r"""Sample random variates from the underlying distribution as an $n \times k$
matrix.
Expand All @@ -40,6 +40,7 @@ def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]:
- `key`: a jax PRNG key used as the random key.
- `n`: the size of the leading dimension.
- `k`: the size of the trailing dimension.
- `dtype`: the numerical type of generated samples (e.g., `float`, `int`, `complex`, etc.)
**Returns**:
Expand All @@ -52,10 +53,13 @@ class NormalSampler(AbstractSampler, strict=True):
r"""Standard normal distribution sampler.
Generates samples $X_{ij} \sim N(0, 1)$ for $i \in [n]$ and $j \in [k]$.
!!! Note
Supports float and complex-valued types.
"""

def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]:
return rdm.normal(key, (n, k))
def __call__(self, key: PRNGKeyArray, n: int, k: int, dtype: DTypeLike = float) -> Inexact[Array, "n k"]:
return rdm.normal(key, (n, k), dtype)


class SphereSampler(AbstractSampler, strict=True):
Expand All @@ -65,36 +69,13 @@ class SphereSampler(AbstractSampler, strict=True):
$k$ dimensional sphere (i.e. $k-1$-sphere) with radius $\sqrt{n}$. Internally,
this operates by sampling standard normal variates, and then rescaling such that
each $k$-vector $X_i$ has $\lVert X_i \rVert = \sqrt{n}$.
"""

def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]:
samples = rdm.normal(key, (n, k))
return jnp.sqrt(n) * (samples / jnp.linalg.norm(samples, axis=0))


class ComplexNormalSampler(AbstractSampler, strict=True):
r"""Standard complex normal distribution sampler.
Generates complex-valued samples $X_{ij} = A_{ij} + i B_{ij}$ where
$A_{ij} \sim N(0, 1)$ and $B_{ij} \sim N(0, 1)$ for $i \in [n]$ and $j \in [k]$.
"""

def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]:
samples = rdm.normal(key, (n, k)) + 1j * rdm.normal(key, (n, k))
return samples / jnp.sqrt(2)


class ComplexSphereSampler(AbstractSampler, strict=True):
r"""Complex sphere distribution sampler.
Generates complex-valued samples $X_1, \dotsc, X_n$ uniformly distributed on the
surface of a $k$ dimensional complex-valued sphere with radius $\sqrt{n}$. Internally,
this operates by sampling standard complex normal variates, and then rescaling such
that each complex-valued $k$-vector $X_i$ has $\lVert X_i \rVert = \sqrt{n}$.
!!! Note
Supports float and complex-valued types.
"""

def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]:
samples = rdm.normal(key, (n, k)) + 1j * rdm.normal(key, (n, k))
def __call__(self, key: PRNGKeyArray, n: int, k: int, dtype: DTypeLike = float) -> Inexact[Array, "n k"]:
samples = rdm.normal(key, (n, k), dtype)
return jnp.sqrt(n) * (samples / jnp.linalg.norm(samples, axis=0))


Expand All @@ -104,5 +85,5 @@ class RademacherSampler(AbstractSampler, strict=True):
Generates samples $X_{ij} \sim \mathcal{U}(-1, +1)$ for $i \in [n]$ and $j \in [k]$.
"""

def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]:
return rdm.rademacher(key, (n, k))
def __call__(self, key: PRNGKeyArray, n: int, k: int, dtype: DTypeLike = int) -> Num[Array, "n k"]:
return rdm.rademacher(key, (n, k), dtype)

0 comments on commit 2b209df

Please sign in to comment.