Skip to content

Commit

Permalink
documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
quattro committed Apr 10, 2024
1 parent 94c9375 commit f5b2034
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
6 changes: 6 additions & 0 deletions docs/api/estimators.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,9 @@ differentiable and performant [JAX](https://github.com/google/jax) based numeric
options:
members:
- __init__
---

::: traceax.XNysTraceEstimator
options:
members:
- __init__
25 changes: 24 additions & 1 deletion src/traceax/_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int)
class HutchPlusPlusEstimator(AbstractTraceEstimator):
r"""Hutch++ Trace Estimator:
Let $\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}$ be the the _low-rank approximation_
Let $\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}$ be the a _low-rank approximation_
to $\mathbf{A}$, where $\mathbf{Q}$ is the orthonormal basis of $\mathbf{A} \Omega$, for
$\Omega = [\omega_1, \dotsc, \omega_k]$.
Expand Down Expand Up @@ -235,6 +235,21 @@ def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int)


class XNysTraceEstimator(AbstractTraceEstimator):
r"""XNysTrace Trace Estimator:
XNysTrace improves upon XTrace estimator when $\mathbf{A}$ is (negative-) positive-semidefinite, by
performing a [Nyström approximation](https://en.wikipedia.org/wiki/Low-rank_matrix_approximations#Nystr%C3%B6m_approximation),
rather than a randomized SVD (i.e., random projection followed by QR decomposition).
Like, [`traceax.XTraceEstimator`][], the *improved* XNysTrace algorithm (i.e. `improved = True`), ensures
that test-vectors are orthogonalized against the low rank approximation and renormalized.
This improved XNysTrace approach may provide better empirical results compared with the non-orthogonalized version.
As with the Girard-Hutchinson estimator, it requires
$\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$.
"""

sampler: AbstractSampler = SphereSampler()
improved: bool = True

Expand Down Expand Up @@ -282,3 +297,11 @@ def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int)
trace_est = jnp.where(is_nsd, -trace_est, trace_est)

return trace_est, {"std.err": std_err}


XNysTraceEstimator.__init__.__doc__ = r"""**Arguments:**
- `sampler`: the sampling distribution for $\omega$. Default is [`traceax.SphereSampler`][].
- `improved`: whether to use the *improved* XNysTrace estimator, which rescales predicted samples.
Default is `True` (see Notes).
"""

0 comments on commit f5b2034

Please sign in to comment.