Skip to content

Commit

Permalink
Merge pull request #216 from renecotyfanboy/better-mocks
Browse files Browse the repository at this point in the history
Clarify how to mock data
  • Loading branch information
renecotyfanboy authored Jan 23, 2025
2 parents 345e975 + 202c7f3 commit 20551f3
Show file tree
Hide file tree
Showing 15 changed files with 326 additions and 86 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/test-and-coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ jobs:

steps:
- uses: actions/checkout@v4
with:
submodules: 'recursive'
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ main.py
doc_notebooks/

.ipynb_checkpoints
*/.ipynb_checkpoints/*
tests/data/
*.ipynb
poetry.lock
# IPython
Expand Down
34 changes: 30 additions & 4 deletions docs/examples/fakeits.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ from jaxspec.model.additive import Powerlaw, Blackbodyrad
from jaxspec.model.multiplicative import Tbabs
from jaxspec.data import ObsConfiguration

obs = ObsConfiguration.from_pha_file('obs_1.pha')
obsconf = ObsConfiguration.from_pha_file('obs_1.pha')
model = Tbabs() * (Powerlaw() + Blackbodyrad())
```

Expand All @@ -46,7 +46,7 @@ And now we can fakeit!
``` python
from jaxspec.data.util import fakeit_for_multiple_parameters

spectra = fakeit_for_multiple_parameters(obs, model, parameters)
spectra = fakeit_for_multiple_parameters(obsconf, model, parameters)
```

Let's plot some of the resulting spectra
Expand All @@ -59,7 +59,7 @@ plt.figure(figsize=(5,4))
for i in range(10):

plt.step(
obs.out_energies[0],
obsconf.out_energies[0],
spectra[i, :],
where="post"
)
Expand All @@ -71,6 +71,32 @@ plt.loglog()

![Some spectra](statics/fakeits.png)

## Using only the instrument

If you don't have any observation you can use as a reference, you can still build a mock [`ObsConfiguration`][jaxspec.data.ObsConfiguration]
using the instrument you want to use.

``` python
from jaxspec.data import ObsConfiguration, Instrument

instrument = Instrument.from_ogip_file(
"instrument.rmf",
arf_path="instrument.arf"
)

obsconf = ObsConfiguration.mock_from_instrument(
instrument,
exposure=1e5,
)
```

Then you can use this [`ObsConfiguration`][jaxspec.data.ObsConfiguration] within `fakeit_for_multiple_parameters` as before.

``` python
spectra = fakeit_for_multiple_parameters(obsconf, model, parameters)
```


## Computing in parallel

Thanks to the amazing [PositionalSharding](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.PositionalSharding)
Expand Down Expand Up @@ -129,7 +155,7 @@ sharded_parameters = jax.device_put(parameters, sharding)
Then we can use these sharded parameters to compute the fakeits in parallel

``` python
fakeit_for_multiple_parameters(obs, model, sharded_parameters, apply_stat=False)
fakeit_for_multiple_parameters(obsconf, model, sharded_parameters, apply_stat=False)
```

!!! info
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "jaxspec"
version = "0.2.0"
version = "0.2.1dev-2"
description = "jaxspec is a bayesian spectral fitting library for X-ray astronomy."
authors = ["sdupourque <sdupourque@irap.omp.eu>"]
license = "MIT"
Expand All @@ -11,7 +11,7 @@ documentation = "https://jaxspec.readthedocs.io/en/latest/"

[tool.poetry.dependencies]
python = ">=3.10,<3.13"
jax = "^0.4.37"
jax = "^0.5.0"
numpy = "<2.0.0"
pandas = "^2.2.0"
astropy = "^6.0.0"
Expand Down
18 changes: 14 additions & 4 deletions src/jaxspec/analysis/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ def plot_ppc(
alpha_envelope: (float, float) = (0.15, 0.25),
style: str | Any = "default",
title: str | None = None,
figsize: tuple[float, float] = (6, 6),
x_lims: tuple[float, float] | None = None,
) -> list[plt.Figure]:
r"""
Plot the posterior predictive distribution of the model. It also features a residual plot, defined using the
Expand All @@ -400,14 +402,17 @@ def plot_ppc(
{(\text{Posterior counts})_{84\%}-(\text{Posterior counts})_{16\%}} $$
Parameters:
percentile: The percentile of the posterior predictive distribution to plot.
n_sigmas: The number of sigmas to plot the envelops.
x_unit: The units of the x-axis. It can be either a string (parsable by astropy.units) or an astropy unit. It must be homogeneous to either a length, a frequency or an energy.
y_type: The type of the y-axis. It can be either "counts", "countrate", "photon_flux" or "photon_flux_density".
plot_background: Whether to plot the background model if it is included in the fit.
plot_components: Whether to plot the components of the model separately.
scale: The axes scaling
alpha_envelope: The transparency range for envelops
style: The style of the plot. It can be either a string or a matplotlib style context.
title: The title of the plot.
figsize: The size of the figure.
x_lims: The limits of the x-axis.
Returns:
A list of matplotlib figures for each observation in the model.
Expand Down Expand Up @@ -436,7 +441,7 @@ def plot_ppc(
fig, ax = plt.subplots(
2,
1,
figsize=(6, 6),
figsize=figsize,
sharex="col",
height_ratios=[0.7, 0.3],
)
Expand Down Expand Up @@ -525,8 +530,10 @@ def plot_ppc(
alpha_envelope=alpha_envelope,
)

name = component_name.split("*")[-1]

legend_plots += component_plot
legend_labels.append(component_name)
legend_labels.append(name)

if self.background_model is not None and plot_background:
# We plot the background only if it is included in the fit, i.e. by subtracting
Expand Down Expand Up @@ -617,6 +624,9 @@ def plot_ppc(
ax[0].set_xscale("log")
ax[0].set_yscale("log")

if x_lims is not None:
ax[0].set_xlim(*x_lims)

fig.align_ylabels()
plt.subplots_adjust(hspace=0.0)
fig.tight_layout()
Expand Down Expand Up @@ -654,7 +664,7 @@ def plot_corner(
"""

consumer = ChainConsumer()
consumer.add_chain(self.to_chain(self.model.to_string()))
consumer.add_chain(self.to_chain("Results"))
consumer.set_plot_config(config)

# Context for default mpl style
Expand Down
76 changes: 68 additions & 8 deletions src/jaxspec/data/util.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
from collections.abc import Mapping
from pathlib import Path
from typing import Literal, TypeVar
from typing import TYPE_CHECKING, Literal, TypeVar

import jax
import jax.numpy as jnp
import numpy as np
import numpyro

from astropy.io import fits
from jax.experimental.sparse import BCOO
from numpyro import handlers

from .._fit._build_model import forward_model
from ..model.abc import SpectralModel
from ..util.online_storage import table_manager
from . import Instrument, ObsConfiguration, Observation

K = TypeVar("K")
V = TypeVar("V")

if TYPE_CHECKING:
from ..data import ObsConfiguration
from ..model.abc import SpectralModel


def load_example_pha(
source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"],
Expand Down Expand Up @@ -124,8 +130,40 @@ def load_example_obsconf(source: Literal["NGC7793_ULX4_PN", "NGC7793_ULX4_ALL"])
raise ValueError(f"{source} not recognized.")


def forward_model_with_multiple_inputs(
model: "SpectralModel",
parameters,
obs_configuration: "ObsConfiguration",
sparse=False,
):
energies = np.asarray(obs_configuration.in_energies)
parameter_dims = next(iter(parameters.values())).shape

def flux_func(p):
return model.photon_flux(p, *energies)

for _ in parameter_dims:
flux_func = jax.vmap(flux_func)

flux_func = jax.jit(flux_func)

if sparse:
# folding.transfer_matrix.data.density > 0.015 is a good criterion to consider sparsify
transfer_matrix = BCOO.from_scipy_sparse(
obs_configuration.transfer_matrix.data.to_scipy_sparse().tocsr()
)

else:
transfer_matrix = np.asarray(obs_configuration.transfer_matrix.data.todense())

expected_counts = jnp.matvec(transfer_matrix, flux_func(parameters))

# The result is clipped at 1e-6 to avoid 0 round-off and diverging likelihoods
return jnp.clip(expected_counts, a_min=1e-6)


def fakeit_for_multiple_parameters(
instrument: ObsConfiguration | list[ObsConfiguration],
obsconfs: ObsConfiguration | list[ObsConfiguration],
model: SpectralModel,
parameters: Mapping[K, V],
rng_key: int = 0,
Expand All @@ -134,23 +172,45 @@ def fakeit_for_multiple_parameters(
):
"""
Convenience function to simulate multiple spectra from a given model and a set of parameters.
This is supposed to be somewhat optimized and can handle multiple parameters at once without blowing
up the memory. The parameters should be passed as a dictionary with the parameter name as the key and
the parameter values as the values, the value can be a scalar or a nd-array.
# Example:
``` python
from jaxspec.data.util import fakeit_for_multiple_parameters
from numpy.random import default_rng
rng = default_rng(42)
size = (10, 30)
parameters = {
"tbabs_1_nh": rng.uniform(0.1, 0.4, size=size),
"powerlaw_1_alpha": rng.uniform(1, 3, size=size),
"powerlaw_1_norm": rng.exponential(10 ** (-0.5), size=size),
"blackbodyrad_1_kT": rng.uniform(0.1, 3.0, size=size),
"blackbodyrad_1_norm": rng.exponential(10 ** (-3), size=size)
}
spectra = fakeit_for_multiple_parameters(obsconf, model, parameters)
```
Parameters:
instrument: The instrumental setup.
obsconfs: The observational setup(s).
model: The model to use.
parameters: The parameters of the model.
rng_key: The random number generator seed.
apply_stat: Whether to apply Poisson statistic on the folded spectra or not.
sparsify_matrix: Whether to sparsify the matrix or not.
"""

instruments = [instrument] if isinstance(instrument, ObsConfiguration) else instrument
obsconf_list = [obsconfs] if isinstance(obsconfs, ObsConfiguration) else obsconfs
fakeits = []

for i, obs in enumerate(instruments):
countrate = jax.vmap(lambda p: forward_model(model, p, instrument, sparse=sparsify_matrix))(
parameters
for i, obsconf in enumerate(obsconf_list):
countrate = forward_model_with_multiple_inputs(
model, parameters, obsconf, sparse=sparsify_matrix
)

if apply_stat:
Expand Down
24 changes: 22 additions & 2 deletions src/jaxspec/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import numpyro

from jax import random
from jax.experimental import mesh_utils
from jax.random import PRNGKey
from jax.sharding import PositionalSharding
from numpyro.contrib.nested_sampling import NestedSampler
from numpyro.distributions import Poisson, TransformedDistribution
from numpyro.infer import AIES, ESS, MCMC, NUTS, Predictive
Expand Down Expand Up @@ -312,14 +314,27 @@ def prior_predictive_coverage(
Check if the prior distribution include the observed data.
"""
key_prior, key_posterior = jax.random.split(key, 2)
n_devices = len(jax.local_devices())
sharding = PositionalSharding(mesh_utils.create_device_mesh((n_devices,)))

# Sample from prior and correct if the number of samples is not a multiple of the number of devices
if num_samples % n_devices != 0:
num_samples = num_samples + n_devices - (num_samples % n_devices)

prior_params = self.prior_samples(key=key_prior, num_samples=num_samples)
posterior_observations = self.mock_observations(prior_params, key=key_posterior)

# Split the parameters on every device
sharded_parameters = jax.device_put(prior_params, sharding)
posterior_observations = self.mock_observations(sharded_parameters, key=key_posterior)

for key, value in self.observation_container.items():
fig, ax = plt.subplots(
nrows=2, ncols=1, sharex=True, figsize=(5, 6), height_ratios=[3, 1]
)

legend_plots = []
legend_labels = []

y_observed, y_observed_low, y_observed_high = _error_bars_for_observed_data(
value.folded_counts.values, 1.0, "ct"
)
Expand All @@ -337,6 +352,11 @@ def prior_predictive_coverage(
ax[0], value.out_energies, posterior_observations["obs_" + key], n_sigmas=3
)

legend_plots.append((true_data_plot,))
legend_labels.append("Observed")
legend_plots += prior_plot
legend_labels.append("Prior Predictive")

# rank = np.vstack((posterior_observations["obs_" + key], value.folded_counts.values)).argsort(axis=0)[-1] / (num_samples) * 100
counts = posterior_observations["obs_" + key]
observed = value.folded_counts.values
Expand All @@ -363,7 +383,7 @@ def prior_predictive_coverage(
ax[1].set_ylim(0, 100)
ax[0].set_xlim(value.out_energies.min(), value.out_energies.max())
ax[0].loglog()
ax[0].legend(loc="upper right")
ax[0].legend(legend_plots, legend_labels)
plt.suptitle(f"Prior Predictive coverage for {key}")
plt.tight_layout()
plt.show()
Expand Down
Loading

0 comments on commit 20551f3

Please sign in to comment.