diff --git a/diffrax/global_interpolation.py b/diffrax/global_interpolation.py index e24224d6..32c8011b 100644 --- a/diffrax/global_interpolation.py +++ b/diffrax/global_interpolation.py @@ -438,11 +438,15 @@ def _linear_interpolation( if replace_nans_at_start is None: y0 = ys[0] else: - y0 = jnp.broadcast_to(replace_nans_at_start, ys[0].shape) + y0 = jnp.broadcast_to( + jnp.where(jnp.isnan(ys[0]), replace_nans_at_start, ys[0]), ys[0].shape + ) + ys = ys.at[0].set(y0) _, (next_ts, next_ys) = lax.scan( _interpolation_reverse, (ts[-1], ys[-1]), (ts, ys), reverse=True ) + if fill_forward_nans_at_end: next_ys = fill_forward(next_ys) _, ys = lax.scan( @@ -657,18 +661,26 @@ def _backward_hermite_coefficients( ]: ts = left_broadcast_to(ts, ys.shape) + if replace_nans_at_start is None: + y0 = ys[0] + else: + y0 = jnp.broadcast_to( + jnp.where(jnp.isnan(ys[0]), replace_nans_at_start, ys[0]), ys[0].shape + ) + ys = ys.at[0].set(y0) + _, (next_ts, next_ys) = lax.scan( - _interpolation_reverse, (ts[-1], ys[-1]), (ts[1:], ys[1:]), reverse=True + _interpolation_reverse, (ts[-1], ys[-1]), (ts, ys), reverse=True ) if fill_forward_nans_at_end: next_ys = fill_forward(next_ys) + next_ts = next_ts[1:] + next_ys = next_ys[1:] + t0 = ts[0] - if replace_nans_at_start is None: - y0 = ys[0] - else: - y0 = jnp.broadcast_to(replace_nans_at_start, ys[0].shape) + if deriv0 is None: deriv0 = (next_ys[0] - y0) / (next_ts[0] - t0) else: diff --git a/diffrax/misc.py b/diffrax/misc.py index 16ed8bc2..37f9fa91 100644 --- a/diffrax/misc.py +++ b/diffrax/misc.py @@ -66,7 +66,9 @@ def fill_forward( if replace_nans_at_start is None: y0 = ys[0] else: - y0 = jnp.broadcast_to(replace_nans_at_start, ys[0].shape) + y0 = jnp.broadcast_to( + jnp.where(jnp.isnan(ys[0]), replace_nans_at_start, ys[0]), ys[0].shape + ) _, ys = lax.scan(_fill_forward, y0, ys) return ys diff --git a/diffrax/step_size_controller/adaptive.py b/diffrax/step_size_controller/adaptive.py index e190b4b7..3d67eda4 100644 --- a/diffrax/step_size_controller/adaptive.py +++ b/diffrax/step_size_controller/adaptive.py @@ -424,7 +424,7 @@ def adapt_step_size( # ε_n = atol + norm(y) * rtol with y on the nth step # r_n = norm(y_error) with y_error on the nth step # δ_{n,m} = norm(y_error / (atol + norm(y) * rtol))^(-1) with y_error on the nth - # step and y on the mth step + # step and y on the mth step # β_1 = pcoeff + icoeff + dcoeff # β_2 = -(pcoeff + 2 * dcoeff) # β_3 = dcoeff diff --git a/test/test_global_interpolation.py b/test/test_global_interpolation.py index cfcce19f..ab2802cb 100644 --- a/test/test_global_interpolation.py +++ b/test/test_global_interpolation.py @@ -394,3 +394,49 @@ def test_dense_interpolation_vmap(solver, getkey): diffrax.Ralston: 1e-3, }.get(type(solver), 1e-6) assert shaped_allclose(derivs, true_derivs, atol=deriv_tol, rtol=deriv_tol) + + +@pytest.mark.parametrize("mode", ["linear", "rectilinear", "cubic"]) +@pytest.mark.parametrize( + "nans, expected", + [ + ( + jnp.array([0, 3, 4, 6, 9]), + jnp.array([20.0, 1.0, 2.0, 23.0, 24.0, 5.0, 26.0, 7.0, 8.0, 29.0]), + ), + (jnp.arange(0, 10, 1), jnp.arange(20, 30, 1)), + ], +) +@pytest.mark.parametrize("init_nan", [True, False]) +def test_replace_nans_at_start(mode, nans, expected, init_nan): + ts = jnp.linspace(0, 1, 15) + if init_nan: + ys = jnp.full((15, 10), jnp.nan) + else: + ys = jrandom.normal(jrandom.PRNGKey(0), (15, 10)) + ys = ys.at[0, :].set(jnp.arange(0, 10, 1)) + nan_ys = ys.at[0, nans].set(jnp.nan) + replace_nans_at_start = jnp.arange(20, 30, 1) + + if mode == "cubic": + coeffs = diffrax.backward_hermite_coefficients( + ts, + nan_ys, + replace_nans_at_start=replace_nans_at_start, + fill_forward_nans_at_end=True, + ) + interp = diffrax.CubicInterpolation(ts, coeffs) + elif mode == "linear": + interp = diffrax.linear_interpolation( + ts, + nan_ys, + replace_nans_at_start=replace_nans_at_start, + fill_forward_nans_at_end=True, + ) + interp = diffrax.LinearInterpolation(ts, interp) + elif mode == "rectilinear": + ts, coeffs = diffrax.rectilinear_interpolation( + ts, nan_ys, replace_nans_at_start=replace_nans_at_start + ) + interp = diffrax.LinearInterpolation(ts, coeffs) + assert shaped_allclose(interp.evaluate(0), expected)