Skip to content

Commit 6ee5a7e

Browse files
committed
fix maxvar initialisation
1 parent 432933e commit 6ee5a7e

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

elfi/methods/bo/acquisition.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -327,16 +327,20 @@ class MaxVar(AcquisitionBase):
327327
328328
"""
329329

330-
def __init__(self, quantile_eps=.01, *args, **opts):
330+
def __init__(self, model, prior, quantile_eps=.01, **opts):
331331
"""Initialise MaxVar.
332332
333333
Parameters
334334
----------
335+
model : elfi.GPyRegression
336+
Gaussian process model used to calculate the unnormalised approximate likelihood.
337+
prior : scipy-like distribution
338+
Prior distribution.
335339
quantile_eps : int, optional
336340
Quantile of the observed discrepancies used in setting the ABC threshold.
337341
338342
"""
339-
super(MaxVar, self).__init__(*args, **opts)
343+
super(MaxVar, self).__init__(model, prior=prior, **opts)
340344
self.name = 'max_var'
341345
self.label_fn = 'Variance of the Unnormalised Approximate Posterior'
342346
self.quantile_eps = quantile_eps
@@ -492,13 +496,16 @@ class RandMaxVar(MaxVar):
492496
493497
"""
494498

495-
def __init__(self, quantile_eps=.01, sampler='nuts', n_samples=50, warmup=None,
496-
limit_faulty_init=1000, init_from_prior=False, sigma_proposals=None,
497-
*args, **opts):
499+
def __init__(self, model, prior, quantile_eps=.01, sampler='nuts', n_samples=50, warmup=None,
500+
limit_faulty_init=1000, init_from_prior=False, sigma_proposals=None, **opts):
498501
"""Initialise RandMaxVar.
499502
500503
Parameters
501504
----------
505+
model : elfi.GPyRegression
506+
Gaussian process model used to calculate the unnormalised approximate likelihood.
507+
prior : scipy-like distribution
508+
Prior distribution.
502509
quantile_eps : int, optional
503510
Quantile of the observed discrepancies used in setting the ABC threshold.
504511
sampler : string, optional
@@ -517,7 +524,7 @@ def __init__(self, quantile_eps=.01, sampler='nuts', n_samples=50, warmup=None,
517524
Markov Chain sampler. Defaults to 1/10 of surrogate model bound lengths.
518525
519526
"""
520-
super(RandMaxVar, self).__init__(quantile_eps, *args, **opts)
527+
super(RandMaxVar, self).__init__(model, prior, quantile_eps, **opts)
521528
self.name = 'rand_max_var'
522529
self.name_sampler = sampler
523530
self._n_samples = n_samples
@@ -648,13 +655,17 @@ class ExpIntVar(MaxVar):
648655
649656
"""
650657

651-
def __init__(self, quantile_eps=.01, integration='grid', d_grid=.2,
658+
def __init__(self, model, prior, quantile_eps=.01, integration='grid', d_grid=.2,
652659
n_samples_imp=100, iter_imp=2, sampler='nuts', n_samples=2000,
653-
sigma_proposals=None, *args, **opts):
660+
sigma_proposals=None, **opts):
654661
"""Initialise ExpIntVar.
655662
656663
Parameters
657664
----------
665+
model : elfi.GPyRegression
666+
Gaussian process model used to calculate the approximate unnormalised likelihood.
667+
prior : scipy-like distribution
668+
Prior distribution.
658669
quantile_eps : int, optional
659670
Quantile of the observed discrepancies used in setting the discrepancy threshold.
660671
integration : str, optional
@@ -680,7 +691,7 @@ def __init__(self, quantile_eps=.01, integration='grid', d_grid=.2,
680691
Markov Chain sampler. Defaults to 1/10 of surrogate model bound lengths.
681692
682693
"""
683-
super(ExpIntVar, self).__init__(quantile_eps, *args, **opts)
694+
super(ExpIntVar, self).__init__(model, prior, quantile_eps, **opts)
684695
self.name = 'exp_int_var'
685696
self.label_fn = 'Expected Loss'
686697
self._integration = integration

0 commit comments

Comments
 (0)