Skip to content

Commit 13b263a

Browse files
authored
Merge pull request #1517 from pints-team/xnes-cov-matrix-update-fix
Fixed bug in XNES covariance matrix update.
2 parents 4f52588 + 040d3d7 commit 13b263a

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ All notable changes to this project will be documented in this file.
55
## Unreleased
66

77
### Added
8-
- [#1499](https://github.com/pints-team/pints/pull/1499) Added a log-uniform prior class.
98
- [#1505](https://github.com/pints-team/pints/pull/1505) Added notes to `ErrorMeasure` and `LogPDF` to say parameters must be real and continuous.
9+
- [#1499](https://github.com/pints-team/pints/pull/1499) Added a log-uniform prior class.
1010
### Changed
1111
- [#1503](https://github.com/pints-team/pints/pull/1503) Stopped showing time units in controller logs, because the units change depending on the output type (see #1467).
1212
### Deprecated
1313
### Removed
1414
### Fixed
15+
- [#1517](https://github.com/pints-team/pints/pull/1517) Fixed a major bug in the covariance matrix update for xNES.
1516
- [#1505](https://github.com/pints-team/pints/pull/1505) Fixed issues with toy problems that accept invalid inputs.
1617

1718

pints/_optimisers/_xnes.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class XNES(pints.PopulationBasedOptimiser):
3232
3333
.. [2] PyBrain: The Python machine learning library
3434
http://pybrain.org
35+
PyBrain is co-authored by xNES' authors.
36+
3537
"""
3638
def __init__(self, x0, sigma0=None, boundaries=None):
3739
super(XNES, self).__init__(x0, sigma0, boundaries)
@@ -47,7 +49,7 @@ def __init__(self, x0, sigma0=None, boundaries=None):
4749
self._bounded_ids = None # Indices of those xs
4850

4951
# Normalisation / distribution
50-
self._mu = np.array(self._x0) # Mean
52+
self._mu = pints.vector(x0) # Mean
5153
self._A = None # Covariance
5254

5355
# Best solution seen
@@ -106,13 +108,13 @@ def _initialise(self):
106108
d = self._n_parameters
107109
n = self._population_size
108110

109-
# Learning rates
111+
# Learning rates, see Table 1 in [1]
110112
# TODO Allow changing before run() with method call
111113
self._eta_mu = 1
112114
# TODO Allow changing before run() with method call
113115
self._eta_A = 0.6 * (3 + np.log(d)) * d ** -1.5
114116

115-
# Pre-calculated utilities
117+
# Pre-calculated utilities, see Table 1 in [1]
116118
self._us = np.maximum(0, np.log(n / 2 + 1) - np.log(1 + np.arange(n)))
117119
self._us /= np.sum(self._us)
118120
self._us -= 1 / n
@@ -162,10 +164,12 @@ def tell(self, fx):
162164
self._mu += self._eta_mu * np.dot(self._A, Gd)
163165

164166
# Update root of covariance matrix
167+
# Note that this is equation 11 (for the eta-sigma=eta-B case), not the
168+
# more general equations 9&10 version given in Algorithm 1
165169
Gm = np.dot(
166170
np.array([np.outer(z, z).T - self._I for z in self._zs]).T,
167171
self._us)
168-
self._A *= scipy.linalg.expm(np.dot(0.5 * self._eta_A, Gm))
172+
self._A = np.dot(self._A, scipy.linalg.expm(0.5 * self._eta_A * Gm))
169173

170174
# Update f_guessed on the assumption that the lowest value in our
171175
# sample approximates f(mu)

pints/tests/test_opt_controller.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def test_stopping_max_iterations(self):
233233
def test_logging(self):
234234

235235
# Test with logpdf
236+
np.random.seed(1)
236237
r = pints.toy.TwistedGaussianLogPDF(2, 0.01)
237238
x = np.array([0, 1.01])
238239
b = pints.RectangularBoundaries([-0.01, 0.95], [0.01, 1.05])
@@ -255,17 +256,17 @@ def test_logging(self):
255256
self.assertEqual(lines[5][:-3],
256257
'0 3 -4.140462 -4.140462 0:0')
257258
self.assertEqual(lines[6][:-3],
258-
'1 6 -4.140462 -4.140465 0:0')
259+
'1 6 -4.140462 -4.140482 0:0')
259260
self.assertEqual(lines[7][:-3],
260-
'2 11 -4.140462 -4.140462 0:0')
261+
'2 9 -4.140462 -4.140465 0:0')
261262
self.assertEqual(lines[8][:-3],
262-
'3 16 -4.140462 -4.140466 0:0')
263+
'3 14 -4.140462 -4.140462 0:0')
263264
self.assertEqual(lines[9][:-3],
264-
'6 33 -4.140462 -4.140462 0:0')
265+
'6 30 -4.140462 -4.140462 0:0')
265266
self.assertEqual(lines[10][:-3],
266-
'9 51 -4.140462 -4.140462 0:0')
267+
'9 47 -4.140462 -4.140463 0:0')
267268
self.assertEqual(lines[11][:-3],
268-
'10 51 -4.140462 -4.140462 0:0')
269+
'10 47 -4.140462 -4.140463 0:0')
269270
self.assertEqual(
270271
lines[12], 'Halting: Maximum number of iterations (10) reached.')
271272

@@ -448,8 +449,8 @@ def test_post_run_statistics(self):
448449
opt.run()
449450
t_upper = t.time()
450451

451-
self.assertEqual(opt.iterations(), 84)
452-
self.assertEqual(opt.evaluations(), 495)
452+
self.assertEqual(opt.iterations(), 125)
453+
self.assertEqual(opt.evaluations(), 734)
453454

454455
# Time after run is greater than zero
455456
self.assertIsInstance(opt.time(), float)

0 commit comments

Comments
 (0)