Skip to content

Commit

Permalink
chore: core non-jit + new Nuance __repr__ + docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrcia committed Feb 12, 2024
1 parent 8e318f5 commit 49f2eb6
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 27 deletions.
60 changes: 49 additions & 11 deletions nuance/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import jax.numpy as jnp
import numpy as np
from scipy.linalg import block_diag
from tqdm.autonotebook import tqdm
from tqdm import tqdm

from nuance import core, utils
from nuance.nuance import Nuance
from nuance.search_data import SearchData

Expand All @@ -29,8 +28,6 @@ class CombinedNuance:
"""Nuance instance of each dataset, where the linear search must be already ran."""
search_data: SearchData = None
"""SearchData instance of the combined dataset."""
c: float = 12.0
"""The c parameter of the transit model."""

def __post_init__(self):
self._fill_search_data()
Expand Down Expand Up @@ -94,11 +91,49 @@ def eval_model(ms):

self.eval_model = eval_model

def linear_search(self, t0s, Ds, progress=True):
def linear_search(
self,
t0s: np.ndarray,
Ds: np.ndarray,
positive: bool = True,
progress: bool = True,
backend: str = None,
batch_size: int = None,
):
"""Performs the linear search for each dataset. Linear searches are saved as :py:class:`~nuance.SearchData`
within each :py:class:`~nuance.Nuance` dataset.
Parameters
----------
t0s : np.ndarray
array of model epochs
Ds : np.ndarray
array of model durations
positive : bool, optional
wether to force depth to be positive, by default True
progress : bool, optional
wether to show progress bar, by default True
backend : str, optional
backend to use, by default jax.default_backend() (options: "cpu", "gpu").
This affects the linear search function jax-mapping strategy. For more details, see
:py:func:`nuance.core.map_function`
batch_size : int, optional
batch size for parallel evaluation, by default None
Returns
-------
None
"""
for d in self.datasets:
d.linear_search(t0s, Ds, progress=progress)
d.linear_search(
t0s,
Ds,
progress=progress,
backend=backend,
batch_size=batch_size,
positive=positive,
)

def solve(self, t0, D, P):
def solve(self, t0: float, D: float, P: float):
"""Solve the combined model for a given set of parameters.
Parameters
Expand All @@ -121,7 +156,7 @@ def solve(self, t0, D, P):
w, v = self.eval_model(models)
return w, v

def snr(self, t0, D, P):
def snr(self, t0: float, D: float, P: float):
"""SNR of transit linearly solved for epoch `t0` and duration `D` (and period `P` for a periodic transit)
Parameters
Expand All @@ -143,7 +178,7 @@ def snr(self, t0, D, P):
w, v = self.solve(t0, D, P)
return jnp.max(jnp.array([0, w[-1] / jnp.sqrt(v[-1, -1])]))

def periodic_search(self, periods, dphi=0.01):
def periodic_search(self, periods: np.ndarray, progress=True, dphi=0.01):
"""Performs the periodic search
Parameters
Expand Down Expand Up @@ -184,7 +219,10 @@ def _search(p):
snr = np.zeros(n)
params = np.zeros((n, 3))

for i, p in enumerate(tqdm(periods)):
def _progress(x, **kwargs):
return tqdm(x, **kwargs) if progress else x

for i, p in enumerate(_progress(periods)):
snr[i], params[i] = _search(p)

new_search_data = self.search_data.copy()
Expand All @@ -195,7 +233,7 @@ def _search(p):

return new_search_data

def models(self, t0, D, P, split=False):
def models(self, t0: float, D: float, P: float, split=False):
"""Solve the combined model for a given set of parameters.
Parameters
Expand Down
2 changes: 0 additions & 2 deletions nuance/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@ def function(m):
return function


@jax.jit
def transit_protopapas(t, t0, D, P=1e15, c=12):
_t = P * jnp.sin(jnp.pi * (t - t0) / P) / (jnp.pi * D)
return -0.5 * jnp.tanh(c * (_t + 1 / 2)) + 0.5 * jnp.tanh(c * (_t - 1 / 2))


@jax.jit
def transit_box(time, t0, D, P=1e15):
return -((jnp.abs(time - t0) % P) < D / 2).astype(float)

Expand Down
39 changes: 25 additions & 14 deletions nuance/nuance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from scipy.ndimage import minimum_filter1d
from tinygp import GaussianProcess, kernels
from tqdm import tqdm
from tqdm.autonotebook import tqdm

from nuance import DEVICES_COUNT, core, utils
from nuance.search_data import SearchData
Expand Down Expand Up @@ -47,6 +46,8 @@ def __post_init__(self):
if self.model is None:
self.model = partial(core.transit_protopapas, c=12)

self.model = jax.jit(self.model)

assert (self.error is None) ^ (
self.gp is None
), "Either error or gp must be defined"
Expand All @@ -66,6 +67,19 @@ def __post_init__(self):

self.search_data = None

def __repr__(self):
noise = (
f"kernel={self.gp.kernel}"
if self.error is None
else f"error={self.error:.3e}"
)
return f"Nuance(N={len(self.time)}, M={self.X.shape[0]}, {noise}, searched={self.searched})"

@property
def searched(self):
"""Whether the linear search has been performed"""
return self.search_data is not None

@property
def ll0(self) -> float:
"""log-likelihood of data without model
Expand Down Expand Up @@ -161,23 +175,20 @@ def mu(self, time=None):
Parameters
----------
mask : np.ndarray, optional
A boolean mask to apply to the data, by default None.
time : np.ndarray, optional
The time at which to compute the mean model, by default None (uses `self.time`).
Returns
-------
np.ndarray
The mean model of the GP.
Example
-------
>>> mu = model.mu()
"""
if time is None:
time = self.time

@jax.jit
def _mu():
gp = self.gp
_, w, _ = self.eval_model(np.zeros_like(self.time))
_, w, _ = self.eval_model(np.zeros_like(time))
w = w[0:-1]
cond_gp = gp.condition(self.flux - w @ self.X, time).gp
return cond_gp.loc + w @ self.X
Expand All @@ -201,9 +212,9 @@ def models(self, t0: float = None, D: float = None, P: float = 1e15):
list np.ndarray
a list of three np.ndarray:
- linear: linear model
- astro: signal being searched
- noise: noise model
- linear: linear model (using `X`)
- model: model being searched
- noise: noise model (using `GP`)
Example
-------
Expand All @@ -225,7 +236,7 @@ def models(self, t0: float = None, D: float = None, P: float = 1e15):
return self._models(m)

def solve(self, t0: float, D: float, P: float = None):
"""solve linear model (design matrix `Nuance.X`)
"""solve linear model (suing `X`)
Parameters
----------
Expand Down Expand Up @@ -308,7 +319,7 @@ def linear_search(
wether to show progress bar, by default True
backend : str, optional
backend to use, by default jax.default_backend() (options: "cpu", "gpu").
This affect the linear search function jax mapping strategy. For more details, see
This affects the linear search function jax-mapping strategy. For more details, see
:py:func:`nuance.core.map_function`
batch_size : int, optional
batch size for parallel evaluation, by default None
Expand Down Expand Up @@ -376,12 +387,12 @@ def periodic_search(self, periods: np.ndarray, progress=True, dphi=0.01):
dphi: float, optional
the relative step size of the phase grid. For each period, all likelihood quantities along time are
interpolated along a phase grid of resolution `min(1/200, dphi/P))`. The smaller dphi
the finer the grid, and the more resolved the model epoch and period (the the more computationally expensive the
the finer the grid, and the more resolved the model epoch and period (but the more computationally expensive the
periodic search). The default is 0.01.
Returns
-------
:py:class:`nuance.SearchData`
:py:class:`~nuance.SearchData`
search results
"""
new_search_data = self.search_data.copy()
Expand Down

0 comments on commit 49f2eb6

Please sign in to comment.