From 817de9ba243450bbdf3164c1ef4a895f3b058545 Mon Sep 17 00:00:00 2001 From: joycemaalouf Date: Thu, 2 Nov 2023 01:38:15 +0000 Subject: [PATCH 1/2] fixing bugs 323 and 324 --- diffrax/global_interpolation.py | 24 ++++++++++++---- diffrax/misc.py | 4 ++- test/test_global_interpolation.py | 46 +++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 7 deletions(-) diff --git a/diffrax/global_interpolation.py b/diffrax/global_interpolation.py index e24224d6..32c8011b 100644 --- a/diffrax/global_interpolation.py +++ b/diffrax/global_interpolation.py @@ -438,11 +438,15 @@ def _linear_interpolation( if replace_nans_at_start is None: y0 = ys[0] else: - y0 = jnp.broadcast_to(replace_nans_at_start, ys[0].shape) + y0 = jnp.broadcast_to( + jnp.where(jnp.isnan(ys[0]), replace_nans_at_start, ys[0]), ys[0].shape + ) + ys = ys.at[0].set(y0) _, (next_ts, next_ys) = lax.scan( _interpolation_reverse, (ts[-1], ys[-1]), (ts, ys), reverse=True ) + if fill_forward_nans_at_end: next_ys = fill_forward(next_ys) _, ys = lax.scan( @@ -657,18 +661,26 @@ def _backward_hermite_coefficients( ]: ts = left_broadcast_to(ts, ys.shape) + if replace_nans_at_start is None: + y0 = ys[0] + else: + y0 = jnp.broadcast_to( + jnp.where(jnp.isnan(ys[0]), replace_nans_at_start, ys[0]), ys[0].shape + ) + ys = ys.at[0].set(y0) + _, (next_ts, next_ys) = lax.scan( - _interpolation_reverse, (ts[-1], ys[-1]), (ts[1:], ys[1:]), reverse=True + _interpolation_reverse, (ts[-1], ys[-1]), (ts, ys), reverse=True ) if fill_forward_nans_at_end: next_ys = fill_forward(next_ys) + next_ts = next_ts[1:] + next_ys = next_ys[1:] + t0 = ts[0] - if replace_nans_at_start is None: - y0 = ys[0] - else: - y0 = jnp.broadcast_to(replace_nans_at_start, ys[0].shape) + if deriv0 is None: deriv0 = (next_ys[0] - y0) / (next_ts[0] - t0) else: diff --git a/diffrax/misc.py b/diffrax/misc.py index 85ca10a1..31b377c9 100644 --- a/diffrax/misc.py +++ b/diffrax/misc.py @@ -66,7 +66,9 @@ def fill_forward( if replace_nans_at_start is None: y0 = ys[0] else: - y0 = jnp.broadcast_to(replace_nans_at_start, ys[0].shape) + y0 = jnp.broadcast_to( + jnp.where(jnp.isnan(ys[0]), replace_nans_at_start, ys[0]), ys[0].shape + ) _, ys = lax.scan(_fill_forward, y0, ys) return ys diff --git a/test/test_global_interpolation.py b/test/test_global_interpolation.py index cfcce19f..a1b1ad12 100644 --- a/test/test_global_interpolation.py +++ b/test/test_global_interpolation.py @@ -394,3 +394,49 @@ def test_dense_interpolation_vmap(solver, getkey): diffrax.Ralston: 1e-3, }.get(type(solver), 1e-6) assert shaped_allclose(derivs, true_derivs, atol=deriv_tol, rtol=deriv_tol) + + +@pytest.mark.parametrize("mode", ["linear", "rectilinear", "cubic"]) +@pytest.mark.parametrize( + "nans, expected", + [ + ( + jnp.array([0, 3, 4, 6, 9]), + jnp.array([20.0, 1.0, 2.0, 23.0, 24.0, 5.0, 26.0, 7.0, 8.0, 29.0]), + ), + (jnp.arange(0, 10, 1), jnp.arange(20, 30, 1)), + ], +) +@pytest.mark.parametrize("init_nan", [True, False]) +def test_replace_nans_at_start(mode, nans, expected, init_nan): + ts = jnp.linspace(0, 1, 15) + if init_nan: + ys = jnp.empty((15, 10)).at[:].set(jnp.nan) + else: + ys = jrandom.normal(jrandom.PRNGKey(0), (15, 10)) + ys = ys.at[0, :].set(jnp.arange(0, 10, 1)) + nan_ys = ys.at[0, nans].set(jnp.nan) + replace_nans_at_start = jnp.arange(20, 30, 1) + + if mode == "cubic": + coeffs = diffrax.backward_hermite_coefficients( + ts, + nan_ys, + replace_nans_at_start=replace_nans_at_start, + fill_forward_nans_at_end=True, + ) + interp = diffrax.CubicInterpolation(ts, coeffs) + elif mode == "linear": + interp = diffrax.linear_interpolation( + ts, + nan_ys, + replace_nans_at_start=replace_nans_at_start, + fill_forward_nans_at_end=True, + ) + interp = diffrax.LinearInterpolation(ts, interp) + elif mode == "rectilinear": + ts, coeffs = diffrax.rectilinear_interpolation( + ts, nan_ys, replace_nans_at_start=replace_nans_at_start + ) + interp = diffrax.LinearInterpolation(ts, coeffs) + assert shaped_allclose(interp.evaluate(0), expected) From ece99105ba224ed86c7788b543b5475ba58c4e82 Mon Sep 17 00:00:00 2001 From: JadM133 Date: Sun, 5 Nov 2023 01:28:57 +0000 Subject: [PATCH 2/2] fixing flake8 problem and the comment --- diffrax/step_size_controller/adaptive.py | 2 +- test/test_global_interpolation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/diffrax/step_size_controller/adaptive.py b/diffrax/step_size_controller/adaptive.py index 71d2761d..3d67eda4 100644 --- a/diffrax/step_size_controller/adaptive.py +++ b/diffrax/step_size_controller/adaptive.py @@ -424,7 +424,7 @@ def adapt_step_size( # ε_n = atol + norm(y) * rtol with y on the nth step # r_n = norm(y_error) with y_error on the nth step # δ_{n,m} = norm(y_error / (atol + norm(y) * rtol))^(-1) with y_error on the nth - # step and y on the mth step + # step and y on the mth step # β_1 = pcoeff + icoeff + dcoeff # β_2 = -(pcoeff + 2 * dcoeff) # β_3 = dcoeff diff --git a/test/test_global_interpolation.py b/test/test_global_interpolation.py index a1b1ad12..ab2802cb 100644 --- a/test/test_global_interpolation.py +++ b/test/test_global_interpolation.py @@ -411,7 +411,7 @@ def test_dense_interpolation_vmap(solver, getkey): def test_replace_nans_at_start(mode, nans, expected, init_nan): ts = jnp.linspace(0, 1, 15) if init_nan: - ys = jnp.empty((15, 10)).at[:].set(jnp.nan) + ys = jnp.full((15, 10), jnp.nan) else: ys = jrandom.normal(jrandom.PRNGKey(0), (15, 10)) ys = ys.at[0, :].set(jnp.arange(0, 10, 1))