Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update arviz data extracted from MCMC results #485

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Changelog
=========

- Update BOLFI and BOLFIRE to use a shared sample class that returns individual chains in the arviz inference data
- Use kernel copy to avoid pickle issue and allow BOLFI parallelisation with non-default kernel
- Restrict matplotlib version < 3.9 for compatibility with GPy
- Add option to use additive or multiplicative adjustment in any acquisition method
Expand Down
4 changes: 2 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Below is a list of inference methods included in ELFI.
OptimizationResult
Sample
SmcSample
BolfiSample
McmcSample


**Post-processing**
Expand Down Expand Up @@ -244,7 +244,7 @@ Inference API classes
:members:
:inherited-members:

.. autoclass:: BolfiSample
.. autoclass:: McmcSample
:members:
:inherited-members:

Expand Down
6 changes: 3 additions & 3 deletions elfi/methods/inference/bolfi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from elfi.methods.bo.utils import stochastic_optimization
from elfi.methods.inference.parameter_inference import ParameterInference
from elfi.methods.posteriors import BolfiPosterior
from elfi.methods.results import BolfiSample, OptimizationResult
from elfi.methods.results import McmcSample, OptimizationResult
from elfi.methods.utils import arr2d_to_batch, batch_to_arr2d, ceil_to_batch_size, resolve_sigmas
from elfi.model.extensions import ModelPrior

Expand Down Expand Up @@ -507,7 +507,7 @@ def sample(self,

Returns
-------
BolfiSample
McmcSample

"""
if self.state['n_batches'] == 0:
Expand Down Expand Up @@ -588,7 +588,7 @@ def sample(self,
mcmc.gelman_rubin_statistic(chains[:, :, ii]))
self.target_model.is_sampling = False

return BolfiSample(
return McmcSample(
method_name='BOLFI',
chains=chains,
parameter_names=self.target_model.parameter_names,
Expand Down
18 changes: 9 additions & 9 deletions elfi/methods/inference/bolfire.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from elfi.methods.classifier import Classifier, LogisticRegression
from elfi.methods.inference.parameter_inference import ModelBased
from elfi.methods.posteriors import BOLFIREPosterior
from elfi.methods.results import BOLFIRESample
from elfi.methods.results import McmcSample
from elfi.methods.utils import batch_to_arr2d, resolve_sigmas
from elfi.model.extensions import ModelPrior

Expand Down Expand Up @@ -201,7 +201,7 @@ def sample(self,

Returns
-------
BOLFIRESample
McmcSample

"""
# Fit posterior in case not done
Expand Down Expand Up @@ -282,13 +282,13 @@ def sample(self,

self.target_model.is_sampling = False

return BOLFIRESample(method_name='BOLFIRE',
chains=chains,
parameter_names=self.parameter_names,
warmup=warmup,
n_sim=self.state['n_sim'],
seed=self.seed,
*args, **kwargs)
return McmcSample(method_name='BOLFIRE',
chains=chains,
parameter_names=self.parameter_names,
warmup=warmup,
n_sim=self.state['n_sim'],
seed=self.seed,
*args, **kwargs)

def _resolve_marginal(self, marginal, seed_marginal=None):
"""Resolve marginal data."""
Expand Down
47 changes: 10 additions & 37 deletions elfi/methods/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,8 @@ def plot_pairs(self, selector=None, bins=20, axes=None, all=False, **kwargs):
plt.suptitle("Population {}".format(i), fontsize=fontsize)


class BolfiSample(Sample):
"""Container for results from BOLFI."""
class McmcSample(Sample):
"""Container for MCMC results."""

def __init__(self, method_name, chains, parameter_names, warmup, **kwargs):
"""Initialize result.
Expand All @@ -529,7 +529,7 @@ def __init__(self, method_name, chains, parameter_names, warmup, **kwargs):
concatenated = warmed_up.reshape((-1,) + shape[2:])
outputs = dict(zip(parameter_names, concatenated.T))

super(BolfiSample, self).__init__(
super(McmcSample, self).__init__(
method_name=method_name,
outputs=outputs,
parameter_names=parameter_names,
Expand All @@ -538,6 +538,13 @@ def __init__(self, method_name, chains, parameter_names, warmup, **kwargs):
warmup=warmup,
**kwargs)

@property
def idata(self):
"""Convert MCMC chains to arviz InferenceData object."""
warmed_up = self.chains[:, self.warmup:]
sample_chains = dict(zip(self.parameter_names, np.transpose(warmed_up, (2, 0, 1))))
return az.from_dict(sample_chains)

def plot_traces(self, selector=None, axes=None, **kwargs):
"""Plot MCMC traces."""
return vis.plot_traces(self, selector, axes, **kwargs)
Expand Down Expand Up @@ -605,40 +612,6 @@ def compute_ess(self):
return {p: eff_sample_size(self.samples[p]) for p in self.parameter_names}


class BOLFIRESample(Sample):
"""Container for results from BOLFIRE."""

def __init__(self, method_name, chains, parameter_names, warmup, *args, **kwargs):
"""Initialize BOLFIRE result.

Parameters
----------
method_name: str
Name of the inference method.
chains: np.ndarray (n_chains, n_samples, n_parameters)
Chains from sampling, warmup included.
parameter_names: list
List of names in the outputs dict that refer to model parameters.
warmup: int
Number of warmup iterations in chains.

"""
n_chains = chains.shape[0]
warmed_up = chains[:, warmup:, :]
concatenated = warmed_up.reshape((-1,) + chains.shape[2:])
outputs = dict(zip(parameter_names, concatenated.T))

super(BOLFIRESample, self).__init__(
method_name=method_name,
outputs=outputs,
parameter_names=parameter_names,
chains=chains,
n_chains=n_chains,
warmup=warmed_up,
*args, **kwargs
)


class RomcSample(Sample):
"""Container for results from ROMC."""

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ def test_sample():
sample.summary()


def test_bolfi_sample():
def test_mcmc_sample():
n_chains = 3
n_iters = 10
warmup = 5
parameter_names = ['a', 'b']
chains = np.random.random((n_chains, n_iters, len(parameter_names)))

result = elfi.methods.results.BolfiSample(
result = elfi.methods.results.McmcSample(
method_name="TestRes",
chains=chains,
parameter_names=parameter_names,
Expand Down
Loading