Skip to content

Commit a1374f9

Browse files
committed
tests + examples
1 parent d24e6f1 commit a1374f9

File tree

6 files changed

+53
-74
lines changed

6 files changed

+53
-74
lines changed

benchmarks/stateful_paths.py

+1
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def step(y, dW):
268268
New UBP + Precompute: 0.002430
269269
Pure Jax: 0.002799
270270
271+
(these are out of date)
271272
Results on A100 GPU:
272273
VBT: 3.881952
273274
Old UBP: 0.337173

diffrax/_adjoint.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,9 @@ def loop(
377377
# "Cannot reverse-mode autodifferentiate when using "
378378
# "`UnsafeBrownianPath`."
379379
# )
380+
# if is_unsafe_sde(terms):
381+
# kind = "lax"
382+
# msg = None
380383
if max_steps is None:
381384
kind = "lax"
382385
msg = (
@@ -836,7 +839,10 @@ def loop(
836839
raise NotImplementedError(
837840
"Cannot use `adjoint=BacksolveAdjoint()` with `saveat=SaveAt(fn=...)`."
838841
)
839-
# is this still true with DirectAdjoint?
842+
# is this still true with DirectBP?
843+
# it seems to give inaccurate results, so not currently, but seems doable
844+
# might just require more careful thinking about path state management
845+
# and more knowledge about continuous adjoints than I have currently
840846
if is_unsafe_sde(terms):
841847
raise ValueError(
842848
"`adjoint=BacksolveAdjoint()` does not support `UnsafeBrownianPath`. "

diffrax/_brownian/base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
BrownianIncrement,
1010
RealScalarLike,
1111
SpaceTimeLevyArea,
12+
SpaceTimeTimeLevyArea
1213
)
1314
from .._path import AbstractPath
1415

@@ -20,7 +21,7 @@
2021
class AbstractBrownianPath(AbstractPath[_Control, _BrownianState]):
2122
"""Abstract base class for all Brownian paths."""
2223

23-
levy_area: AbstractVar[type[Union[BrownianIncrement, SpaceTimeLevyArea]]]
24+
levy_area: AbstractVar[type[Union[BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea]]]
2425

2526
@abc.abstractmethod
2627
def __call__(

diffrax/_brownian/path.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
BrownianIncrement,
1818
levy_tree_transpose,
1919
RealScalarLike,
20+
IntScalarLike,
2021
SpaceTimeLevyArea,
2122
SpaceTimeTimeLevyArea,
2223
Y,
@@ -31,7 +32,7 @@
3132

3233
_Control = Union[PyTree[Array], AbstractBrownianIncrement]
3334
_BrownianState: TypeAlias = Union[
34-
tuple[None, PyTree[Array], int], tuple[PRNGKeyArray, None, None]
35+
tuple[None, PyTree[Array], IntScalarLike], tuple[PRNGKeyArray, None, None]
3536
]
3637

3738

@@ -73,10 +74,10 @@ class DirectBrownianPath(AbstractBrownianPath[_Control, _BrownianState]):
7374
"""
7475

7576
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
77+
key: PRNGKeyArray
7678
levy_area: type[
7779
Union[BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea]
7880
] = eqx.field(static=True)
79-
key: PRNGKeyArray
8081
precompute: Optional[int] = eqx.field(static=True)
8182

8283
def __init__(
@@ -116,7 +117,7 @@ def _generate_noise(
116117
key: PRNGKeyArray,
117118
shape: jax.ShapeDtypeStruct,
118119
max_steps: int,
119-
) -> Float[Array, "levy_dims shape"]:
120+
) -> Float[Array, "..."]:
120121
# TODO: merge into a single jr.normal call
121122
if self.levy_area is SpaceTimeTimeLevyArea:
122123
noise = jr.normal(key, (3, max_steps, *shape.shape), shape.dtype)
@@ -254,7 +255,7 @@ def _evaluate_leaf_precomputed(
254255
Union[BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea]
255256
],
256257
use_levy: bool,
257-
noises: Float[Array, "levy_dims shape"],
258+
noises: Float[Array, "..."],
258259
):
259260
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)
260261
dt = jnp.asarray(t1 - t0, dtype=complex_to_real_dtype(shape.dtype))

examples/underdamped_langevin_example.ipynb

+8-53
Large diffs are not rendered by default.

test/test_adjoint.py

+30-15
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,18 @@ def test_direct_brownian():
237237
final_activation=jnp.tanh,
238238
key=diffusionkey,
239239
)
240+
class Field(eqx.Module):
241+
force: eqx.nn.MLP
242+
243+
def __call__(self, t, y, args):
244+
return self.force(y)
245+
246+
class DiffusionField(eqx.Module):
247+
force: eqx.nn.MLP
248+
249+
def __call__(self, t, y, args):
250+
return lx.DiagonalLinearOperator(self.force(y))
251+
240252
y0 = jr.normal(ykey, (3,))
241253

242254
k1, k2, k3 = jax.random.split(key, 3)
@@ -250,25 +262,25 @@ def test_direct_brownian():
250262
)
251263

252264
vbt_terms = diffrax.MultiTerm(
253-
diffrax.ODETerm(lambda t, y, args: drift_mlp(y)),
265+
diffrax.ODETerm(Field(drift_mlp)),
254266
diffrax.ControlTerm(
255-
lambda t, y, args: lx.DiagonalLinearOperator(diffusion_mlp(y)), vbt
267+
DiffusionField(diffusion_mlp), vbt
256268
),
257269
)
258270
dbp_terms = diffrax.MultiTerm(
259-
diffrax.ODETerm(lambda t, y, args: drift_mlp(y)),
271+
diffrax.ODETerm(Field(drift_mlp)),
260272
diffrax.ControlTerm(
261-
lambda t, y, args: lx.DiagonalLinearOperator(diffusion_mlp(y)), dbp
273+
DiffusionField(diffusion_mlp), dbp
262274
),
263275
)
264276
dbp_pre_terms = diffrax.MultiTerm(
265-
diffrax.ODETerm(lambda t, y, args: drift_mlp(y)),
277+
diffrax.ODETerm(Field(drift_mlp)),
266278
diffrax.ControlTerm(
267-
lambda t, y, args: lx.DiagonalLinearOperator(diffusion_mlp(y)), dbp_pre
279+
DiffusionField(diffusion_mlp), dbp_pre
268280
),
269281
)
270282

271-
solver = diffrax.GeneralShARK()
283+
solver = diffrax.Heun()
272284

273285
y0_args_term0 = (y0, None, vbt_terms)
274286
y0_args_term1 = (y0, None, dbp_terms)
@@ -307,7 +319,7 @@ def _run_finite_diff(y0__args__term, saveat, adjoint):
307319
for t0 in (True, False):
308320
for t1 in (True, False):
309321
for ts in (None, [0.3], [2.0], [9.5], [1.0, 7.0], [0.3, 7.0, 9.5]):
310-
for y0__args__term in (y0_args_term0,):#, y0_args_term1, y0_args_term2):
322+
for i, y0__args__term in enumerate((y0_args_term0, y0_args_term1, y0_args_term2)):
311323
if t0 is False and t1 is False and ts is None:
312324
continue
313325

@@ -329,17 +341,20 @@ def _run_inexact(inexact, saveat, adjoint):
329341
recursive_grads = _run_grad(
330342
inexact, saveat, diffrax.RecursiveCheckpointAdjoint()
331343
)
332-
# backsolve_grads = _run_grad(
333-
# inexact, saveat, diffrax.BacksolveAdjoint()
334-
# )
344+
if i == 0:
345+
backsolve_grads = _run_grad(
346+
inexact, saveat, diffrax.BacksolveAdjoint()
347+
)
348+
assert tree_allclose(fd_grads, backsolve_grads[0], atol=1e-3)
349+
335350
forward_grads = _run_fwd_grad(
336351
inexact, saveat, diffrax.ForwardMode()
337352
)
353+
# TODO: fix via https://github.com/patrick-kidger/equinox/issues/923
338354
# direct_grads = _run_grad(inexact, saveat, diffrax.DirectAdjoint())
339-
# assert tree_allclose(fd_grads, direct_grads[0])
340-
assert tree_allclose(fd_grads, recursive_grads, atol=1e-5)
341-
# assert tree_allclose(fd_grads, backsolve_grads, atol=1e-5)
342-
assert tree_allclose(fd_grads, forward_grads, atol=1e-5)
355+
# assert tree_allclose(fd_grads, direct_grads[0], atol=1e-3)
356+
assert tree_allclose(fd_grads, recursive_grads[0], atol=1e-3)
357+
assert tree_allclose(fd_grads, forward_grads[0], atol=1e-3)
343358

344359

345360
def test_adjoint_seminorm():

0 commit comments

Comments
 (0)