diff --git a/docs/api/samplers.md b/docs/api/samplers.md index 6d0dd42..14db65e 100644 --- a/docs/api/samplers.md +++ b/docs/api/samplers.md @@ -1,6 +1,11 @@ # Stochastic Samplers -TBD +`traceax` uses a flexible approach to define how random samples are generated within +[`traceax.AbstractTraceEstimator`][] instances. While this typically wraps a single +jax random call, the varied interfaces for each randomization procedure may differ, +which makes uniformly interfacing with it a bit annoying. As such, we provide a +simple abstract class definition, [`traceax.AbstractSampler`][] using that subclasses +[`Equinox`](https://docs.kidger.site/equinox/) modules. ??? abstract "`traceax.AbstractSampler`" ::: traceax.AbstractSampler @@ -9,7 +14,7 @@ TBD members: - __call__ -# Floating-point Samplers +## Floating-point Samplers ::: traceax.NormalSampler @@ -21,13 +26,10 @@ TBD ::: traceax.RademacherSampler ---- -# Complex-value Samplers +## Complex-value Samplers ::: traceax.ComplexNormalSampler --- ::: traceax.ComplexSphereSampler - ---- diff --git a/src/traceax/_samplers.py b/src/traceax/_samplers.py index aa7da42..12d8496 100644 --- a/src/traceax/_samplers.py +++ b/src/traceax/_samplers.py @@ -22,34 +22,87 @@ class AbstractSampler(eqx.Module, strict=True): + """Abstract base class for all samplers.""" + @abstractmethod def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: + r"""Sample random variates from the underlying distribution as an $n \times k$ + matrix. + + !!! Example + + ```python + sampler = tr.RademacherSampler() + samples = sampler(key, n, k) + ``` + **Arguments:** + + - `key`: a jax PRNG key used as the random key. + - `n`: the size of the leading dimension. + - `k`: the size of the trailing dimension. + + **Returns**: + + An Array of random samples. + """ ... class NormalSampler(AbstractSampler, strict=True): + r"""Standard normal distribution sampler. + + Generates samples $X_{ij} \sim N(0, 1)$ for $i \in [n]$ and $j \in [k]$. + """ + def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: return rdm.normal(key, (n, k)) class SphereSampler(AbstractSampler, strict=True): + r"""Sphere distribution sampler. + + Generates samples $X_1, \dotsc, X_n$ uniformly distributed on the surface of a + $k$ dimensional sphere (i.e. $k-1$-sphere) with radius $\sqrt{n}$. Internally, + this operates by sampling standard normal variates, and then rescaling such that + each $k$-vector $X_i$ has $\lVert X_i \rVert = \sqrt{n}$. + """ + def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: samples = rdm.normal(key, (n, k)) return jnp.sqrt(n) * (samples / jnp.linalg.norm(samples, axis=0)) class ComplexNormalSampler(AbstractSampler, strict=True): + r"""Standard complex normal distribution sampler. + + Generates complex-valued samples $X_{ij} = A_{ij} + i B_{ij}$ where + $A_{ij} \sim N(0, 1)$ and $B_{ij} \sim N(0, 1)$ for $i \in [n]$ and $j \in [k]$. + """ + def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: samples = rdm.normal(key, (n, k)) + 1j * rdm.normal(key, (n, k)) return samples / jnp.sqrt(2) class ComplexSphereSampler(AbstractSampler, strict=True): + r"""Complex sphere distribution sampler. + + Generates complex-valued samples $X_1, \dotsc, X_n$ uniformly distributed on the + surface of a $k$ dimensional complex-valued sphere with radius $\sqrt{n}$. Internally, + this operates by sampling standard complex normal variates, and then rescaling such + that each complex-valued $k$-vector $X_i$ has $\lVert X_i \rVert = \sqrt{n}$. + """ + def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: samples = rdm.normal(key, (n, k)) + 1j * rdm.normal(key, (n, k)) return jnp.sqrt(n) * (samples / jnp.linalg.norm(samples, axis=0)) class RademacherSampler(AbstractSampler, strict=True): + r"""Rademacher distribution sampler. + + Generates samples $X_{ij} \sim \mathcal{U}(-1, +1)$ for $i \in [n]$ and $j \in [k]$. + """ + def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: return rdm.rademacher(key, (n, k))