@@ -237,6 +237,18 @@ def test_direct_brownian():
237
237
final_activation = jnp .tanh ,
238
238
key = diffusionkey ,
239
239
)
240
+ class Field (eqx .Module ):
241
+ force : eqx .nn .MLP
242
+
243
+ def __call__ (self , t , y , args ):
244
+ return self .force (y )
245
+
246
+ class DiffusionField (eqx .Module ):
247
+ force : eqx .nn .MLP
248
+
249
+ def __call__ (self , t , y , args ):
250
+ return lx .DiagonalLinearOperator (self .force (y ))
251
+
240
252
y0 = jr .normal (ykey , (3 ,))
241
253
242
254
k1 , k2 , k3 = jax .random .split (key , 3 )
@@ -250,25 +262,25 @@ def test_direct_brownian():
250
262
)
251
263
252
264
vbt_terms = diffrax .MultiTerm (
253
- diffrax .ODETerm (lambda t , y , args : drift_mlp ( y )),
265
+ diffrax .ODETerm (Field ( drift_mlp )),
254
266
diffrax .ControlTerm (
255
- lambda t , y , args : lx . DiagonalLinearOperator (diffusion_mlp ( y ) ), vbt
267
+ DiffusionField (diffusion_mlp ), vbt
256
268
),
257
269
)
258
270
dbp_terms = diffrax .MultiTerm (
259
- diffrax .ODETerm (lambda t , y , args : drift_mlp ( y )),
271
+ diffrax .ODETerm (Field ( drift_mlp )),
260
272
diffrax .ControlTerm (
261
- lambda t , y , args : lx . DiagonalLinearOperator (diffusion_mlp ( y ) ), dbp
273
+ DiffusionField (diffusion_mlp ), dbp
262
274
),
263
275
)
264
276
dbp_pre_terms = diffrax .MultiTerm (
265
- diffrax .ODETerm (lambda t , y , args : drift_mlp ( y )),
277
+ diffrax .ODETerm (Field ( drift_mlp )),
266
278
diffrax .ControlTerm (
267
- lambda t , y , args : lx . DiagonalLinearOperator (diffusion_mlp ( y ) ), dbp_pre
279
+ DiffusionField (diffusion_mlp ), dbp_pre
268
280
),
269
281
)
270
282
271
- solver = diffrax .GeneralShARK ()
283
+ solver = diffrax .Heun ()
272
284
273
285
y0_args_term0 = (y0 , None , vbt_terms )
274
286
y0_args_term1 = (y0 , None , dbp_terms )
@@ -307,7 +319,7 @@ def _run_finite_diff(y0__args__term, saveat, adjoint):
307
319
for t0 in (True , False ):
308
320
for t1 in (True , False ):
309
321
for ts in (None , [0.3 ], [2.0 ], [9.5 ], [1.0 , 7.0 ], [0.3 , 7.0 , 9.5 ]):
310
- for y0__args__term in ( y0_args_term0 ,): #, y0_args_term1, y0_args_term2):
322
+ for i , y0__args__term in enumerate (( y0_args_term0 , y0_args_term1 , y0_args_term2 ) ):
311
323
if t0 is False and t1 is False and ts is None :
312
324
continue
313
325
@@ -329,17 +341,20 @@ def _run_inexact(inexact, saveat, adjoint):
329
341
recursive_grads = _run_grad (
330
342
inexact , saveat , diffrax .RecursiveCheckpointAdjoint ()
331
343
)
332
- # backsolve_grads = _run_grad(
333
- # inexact, saveat, diffrax.BacksolveAdjoint()
334
- # )
344
+ if i == 0 :
345
+ backsolve_grads = _run_grad (
346
+ inexact , saveat , diffrax .BacksolveAdjoint ()
347
+ )
348
+ assert tree_allclose (fd_grads , backsolve_grads [0 ], atol = 1e-3 )
349
+
335
350
forward_grads = _run_fwd_grad (
336
351
inexact , saveat , diffrax .ForwardMode ()
337
352
)
353
+ # TODO: fix via https://github.com/patrick-kidger/equinox/issues/923
338
354
# direct_grads = _run_grad(inexact, saveat, diffrax.DirectAdjoint())
339
- # assert tree_allclose(fd_grads, direct_grads[0])
340
- assert tree_allclose (fd_grads , recursive_grads , atol = 1e-5 )
341
- # assert tree_allclose(fd_grads, backsolve_grads, atol=1e-5)
342
- assert tree_allclose (fd_grads , forward_grads , atol = 1e-5 )
355
+ # assert tree_allclose(fd_grads, direct_grads[0], atol=1e-3)
356
+ assert tree_allclose (fd_grads , recursive_grads [0 ], atol = 1e-3 )
357
+ assert tree_allclose (fd_grads , forward_grads [0 ], atol = 1e-3 )
343
358
344
359
345
360
def test_adjoint_seminorm ():
0 commit comments