Skip to content

Commit

Permalink
documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
quattro committed Apr 11, 2024
1 parent 4224aee commit 01cdd48
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
9 changes: 9 additions & 0 deletions docs/api/samplers.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,20 @@ simple abstract class definition, [`traceax.AbstractSampler`][] using that subcl


::: traceax.NormalSampler
options:
members:
- __init__

---

::: traceax.SphereSampler
options:
members:
- __init__

---

::: traceax.RademacherSampler
options:
members:
- __init__
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ plugins:
extensions:
- pydantic: {schema: true}
- docs/scripts/extension.py:DynamicDocstrings:
paths: [ traceax._estimators ]
paths: [ traceax._estimators, traceax._samplers ]
# general options
allow_inspection: true
show_bases: true
Expand Down
22 changes: 21 additions & 1 deletion src/traceax/_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Num[Array, "n k"]:
- `key`: a jax PRNG key used as the random key.
- `n`: the size of the leading dimension.
- `k`: the size of the trailing dimension.
- `dtype`: the numerical type of generated samples (e.g., `float`, `int`, `complex`, etc.)
**Returns**:
Expand Down Expand Up @@ -73,6 +72,12 @@ def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]:
return rdm.normal(key, (n, k), self.dtype)


NormalSampler.__init__.__doc__ = r"""**Arguments:**
- `dtype`: numeric representation for sampled test-vectors. Default is `float`.
"""


class SphereSampler(AbstractSampler, strict=True):
r"""Sphere distribution sampler.
Expand All @@ -96,10 +101,19 @@ def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]:
return jnp.sqrt(n) * (samples / jnp.linalg.norm(samples, axis=0))


SphereSampler.__init__.__doc__ = r"""**Arguments:**
- `dtype`: numeric representation for sampled test-vectors. Default is `float`.
"""


class RademacherSampler(AbstractSampler, strict=True):
r"""Rademacher distribution sampler.
Generates samples $X_{ij} \sim \mathcal{U}(-1, +1)$ for $i \in [n]$ and $j \in [k]$.
!!! Note
Supports integer, float, and complex-valued types.
"""

dtype: DTypeLike = eqx.field(converter=canonicalize_dtype, default=int)
Expand All @@ -110,3 +124,9 @@ def __check_init__(self):

def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Num[Array, "n k"]:
return rdm.rademacher(key, (n, k), self.dtype)


RademacherSampler.__init__.__doc__ = r"""**Arguments:**
- `dtype`: numeric representation for sampled test-vectors. Default is `int`.
"""

0 comments on commit 01cdd48

Please sign in to comment.