diff --git a/docs/api/samplers.md b/docs/api/samplers.md index 2a090ae..a678a96 100644 --- a/docs/api/samplers.md +++ b/docs/api/samplers.md @@ -16,11 +16,20 @@ simple abstract class definition, [`traceax.AbstractSampler`][] using that subcl ::: traceax.NormalSampler + options: + members: + - __init__ --- ::: traceax.SphereSampler + options: + members: + - __init__ --- ::: traceax.RademacherSampler + options: + members: + - __init__ diff --git a/mkdocs.yml b/mkdocs.yml index 57caf8a..49c7691 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -87,7 +87,7 @@ plugins: extensions: - pydantic: {schema: true} - docs/scripts/extension.py:DynamicDocstrings: - paths: [ traceax._estimators ] + paths: [ traceax._estimators, traceax._samplers ] # general options allow_inspection: true show_bases: true diff --git a/src/traceax/_samplers.py b/src/traceax/_samplers.py index df40011..68a214e 100644 --- a/src/traceax/_samplers.py +++ b/src/traceax/_samplers.py @@ -45,7 +45,6 @@ def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Num[Array, "n k"]: - `key`: a jax PRNG key used as the random key. - `n`: the size of the leading dimension. - `k`: the size of the trailing dimension. - - `dtype`: the numerical type of generated samples (e.g., `float`, `int`, `complex`, etc.) **Returns**: @@ -73,6 +72,12 @@ def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: return rdm.normal(key, (n, k), self.dtype) +NormalSampler.__init__.__doc__ = r"""**Arguments:** + +- `dtype`: numeric representation for sampled test-vectors. Default is `float`. +""" + + class SphereSampler(AbstractSampler, strict=True): r"""Sphere distribution sampler. @@ -96,10 +101,19 @@ def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]: return jnp.sqrt(n) * (samples / jnp.linalg.norm(samples, axis=0)) +SphereSampler.__init__.__doc__ = r"""**Arguments:** + +- `dtype`: numeric representation for sampled test-vectors. Default is `float`. +""" + + class RademacherSampler(AbstractSampler, strict=True): r"""Rademacher distribution sampler. Generates samples $X_{ij} \sim \mathcal{U}(-1, +1)$ for $i \in [n]$ and $j \in [k]$. + + !!! Note + Supports integer, float, and complex-valued types. """ dtype: DTypeLike = eqx.field(converter=canonicalize_dtype, default=int) @@ -110,3 +124,9 @@ def __check_init__(self): def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Num[Array, "n k"]: return rdm.rademacher(key, (n, k), self.dtype) + + +RademacherSampler.__init__.__doc__ = r"""**Arguments:** + +- `dtype`: numeric representation for sampled test-vectors. Default is `int`. +"""