From 2b209df1b7a881cf1ec48da6a9e9bd367109afec Mon Sep 17 00:00:00 2001 From: quattro Date: Thu, 28 Mar 2024 13:56:50 -0700 Subject: [PATCH] removed complex-specific samplers and pushed functionality into dtypes --- docs/api/samplers.md | 9 -------- src/traceax/__init__.py | 2 -- src/traceax/_samplers.py | 47 ++++++++++++---------------------------- 3 files changed, 14 insertions(+), 44 deletions(-) diff --git a/docs/api/samplers.md b/docs/api/samplers.md index 14db65e..2a090ae 100644 --- a/docs/api/samplers.md +++ b/docs/api/samplers.md @@ -14,7 +14,6 @@ simple abstract class definition, [`traceax.AbstractSampler`][] using that subcl members: - __call__ -## Floating-point Samplers ::: traceax.NormalSampler @@ -25,11 +24,3 @@ simple abstract class definition, [`traceax.AbstractSampler`][] using that subcl --- ::: traceax.RademacherSampler - - -## Complex-value Samplers -::: traceax.ComplexNormalSampler - ---- - -::: traceax.ComplexSphereSampler diff --git a/src/traceax/__init__.py b/src/traceax/__init__.py index be60c95..e807885 100644 --- a/src/traceax/__init__.py +++ b/src/traceax/__init__.py @@ -22,8 +22,6 @@ ) from ._samplers import ( AbstractSampler as AbstractSampler, - ComplexNormalSampler as ComplexNormalSampler, - ComplexSphereSampler as ComplexSphereSampler, NormalSampler as NormalSampler, RademacherSampler as RademacherSampler, SphereSampler as SphereSampler, diff --git a/src/traceax/_samplers.py b/src/traceax/_samplers.py index 12d8496..7814ada 100644 --- a/src/traceax/_samplers.py +++ b/src/traceax/_samplers.py @@ -18,14 +18,14 @@ import jax.numpy as jnp import jax.random as rdm -from jaxtyping import Array, Inexact, PRNGKeyArray +from jaxtyping import Array, DTypeLike, Inexact, Num, PRNGKeyArray class AbstractSampler(eqx.Module, strict=True): """Abstract base class for all samplers.""" @abstractmethod - def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: + def __call__(self, key: PRNGKeyArray, n: int, k: int, dtype: DTypeLike = float) -> Num[Array, "n k"]: r"""Sample random variates from the underlying distribution as an $n \times k$ matrix. @@ -40,6 +40,7 @@ def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: - `key`: a jax PRNG key used as the random key. - `n`: the size of the leading dimension. - `k`: the size of the trailing dimension. + - `dtype`: the numerical type of generated samples (e.g., `float`, `int`, `complex`, etc.) **Returns**: @@ -52,10 +53,13 @@ class NormalSampler(AbstractSampler, strict=True): r"""Standard normal distribution sampler. Generates samples $X_{ij} \sim N(0, 1)$ for $i \in [n]$ and $j \in [k]$. + + !!! Note + Supports float and complex-valued types. """ - def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: - return rdm.normal(key, (n, k)) + def __call__(self, key: PRNGKeyArray, n: int, k: int, dtype: DTypeLike = float) -> Inexact[Array, "n k"]: + return rdm.normal(key, (n, k), dtype) class SphereSampler(AbstractSampler, strict=True): @@ -65,36 +69,13 @@ class SphereSampler(AbstractSampler, strict=True): $k$ dimensional sphere (i.e. $k-1$-sphere) with radius $\sqrt{n}$. Internally, this operates by sampling standard normal variates, and then rescaling such that each $k$-vector $X_i$ has $\lVert X_i \rVert = \sqrt{n}$. - """ - - def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: - samples = rdm.normal(key, (n, k)) - return jnp.sqrt(n) * (samples / jnp.linalg.norm(samples, axis=0)) - - -class ComplexNormalSampler(AbstractSampler, strict=True): - r"""Standard complex normal distribution sampler. - - Generates complex-valued samples $X_{ij} = A_{ij} + i B_{ij}$ where - $A_{ij} \sim N(0, 1)$ and $B_{ij} \sim N(0, 1)$ for $i \in [n]$ and $j \in [k]$. - """ - - def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: - samples = rdm.normal(key, (n, k)) + 1j * rdm.normal(key, (n, k)) - return samples / jnp.sqrt(2) - - -class ComplexSphereSampler(AbstractSampler, strict=True): - r"""Complex sphere distribution sampler. - Generates complex-valued samples $X_1, \dotsc, X_n$ uniformly distributed on the - surface of a $k$ dimensional complex-valued sphere with radius $\sqrt{n}$. Internally, - this operates by sampling standard complex normal variates, and then rescaling such - that each complex-valued $k$-vector $X_i$ has $\lVert X_i \rVert = \sqrt{n}$. + !!! Note + Supports float and complex-valued types. """ - def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: - samples = rdm.normal(key, (n, k)) + 1j * rdm.normal(key, (n, k)) + def __call__(self, key: PRNGKeyArray, n: int, k: int, dtype: DTypeLike = float) -> Inexact[Array, "n k"]: + samples = rdm.normal(key, (n, k), dtype) return jnp.sqrt(n) * (samples / jnp.linalg.norm(samples, axis=0)) @@ -104,5 +85,5 @@ class RademacherSampler(AbstractSampler, strict=True): Generates samples $X_{ij} \sim \mathcal{U}(-1, +1)$ for $i \in [n]$ and $j \in [k]$. """ - def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: - return rdm.rademacher(key, (n, k)) + def __call__(self, key: PRNGKeyArray, n: int, k: int, dtype: DTypeLike = int) -> Num[Array, "n k"]: + return rdm.rademacher(key, (n, k), dtype)