Skip to content

Commit 7f76cdd

Browse files
committed
fixes
1 parent 29138ed commit 7f76cdd

File tree

5 files changed

+53
-22
lines changed

5 files changed

+53
-22
lines changed

diffrax/_adjoint.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -909,9 +909,10 @@ def loop(
909909
throw,
910910
passed_solver_state,
911911
passed_controller_state,
912+
passed_path_state,
912913
**kwargs,
913914
):
914-
del throw, passed_solver_state, passed_controller_state
915+
del throw, passed_solver_state, passed_controller_state, passed_path_state
915916
inner_while_loop = eqx.Partial(_inner_loop, kind="lax")
916917
outer_while_loop = eqx.Partial(_outer_loop, kind="lax")
917918
# Support forward-mode autodiff.

diffrax/_integrate.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
181181
if not vf_type_compatible:
182182
raise ValueError(f"Vector field term {term} is incompatible.")
183183

184+
term_contr_kwargs["control_state"] = term.init(0.0, 0.0, y, args)
184185
contr = ft.partial(term.contr, **term_contr_kwargs)
185186
# Work around https://github.com/google/jax/issues/21825
186187
try:
@@ -387,7 +388,7 @@ def body_fun_aux(state):
387388

388389
tprev = jnp.minimum(tprev, t1)
389390
tnext = _clip_to_end(tprev, tnext, t1, keep_step)
390-
391+
391392
progress_meter_state = progress_meter.step(
392393
state.progress_meter_state, linear_rescale(t0, tprev, t1)
393394
)
@@ -1111,17 +1112,26 @@ def _promote(yi):
11111112
)
11121113

11131114
# Error checking for term compatibility
1115+
1116+
# try:
1117+
# contr_kwargs = jtu.tree_map(
1118+
# lambda _, x, y: jtu.tree_map(
1119+
# lambda a, b: a | {"control_state": b},
1120+
# x,
1121+
# y,
1122+
# is_leaf=lambda v: isinstance(v, dict),
1123+
# ),
1124+
# solver.term_structure,
1125+
# solver.term_compatible_contr_kwargs,
1126+
# path_state,
1127+
# is_leaf=lambda z: isinstance(z, AbstractTerm)
1128+
# and not isinstance(z, MultiTerm),
1129+
# )
1130+
# except Exception as e:
1131+
# raise ValueError("Terms are not compatible with solver!") from e
1132+
11141133
_assert_term_compatible(
1115-
y0,
1116-
args,
1117-
terms,
1118-
solver.term_structure,
1119-
jtu.tree_map(
1120-
lambda x, y: x | {"control_state": y},
1121-
solver.term_compatible_contr_kwargs,
1122-
path_state,
1123-
is_leaf=lambda x: isinstance(x, dict),
1124-
),
1134+
y0, args, terms, solver.term_structure, solver.term_compatible_contr_kwargs
11251135
)
11261136

11271137
if is_sde(terms):

diffrax/_solver/euler_heun.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,9 @@ def step(
6868
del solver_state, made_jump
6969

7070
drift, diffusion = terms.terms
71-
dt, path_state = drift.contr(t0, t1, path_state)
72-
dW, path_state = diffusion.contr(t0, t1, path_state)
71+
drift_path, diffusion_path = path_state
72+
dt, drift_path = drift.contr(t0, t1, drift_path)
73+
dW, diffusion_path = diffusion.contr(t0, t1, diffusion_path)
7374

7475
f0 = drift.vf_prod(t0, y0, args, dt)
7576
g0 = diffusion.vf_prod(t0, y0, args, dW)
@@ -80,7 +81,14 @@ def step(
8081
y1 = (y0**ω + f0**ω + 0.5 * (g0**ω + g_prime**ω)).ω
8182

8283
dense_info = dict(y0=y0, y1=y1)
83-
return y1, None, dense_info, None, path_state, RESULTS.successful
84+
return (
85+
y1,
86+
None,
87+
dense_info,
88+
None,
89+
(drift_path, diffusion_path),
90+
RESULTS.successful,
91+
)
8492

8593
def func(
8694
self,

diffrax/_term.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,9 @@ def is_vf_expensive(
777777
],
778778
args: Args,
779779
) -> bool:
780-
control_struct = eqx.filter_eval_shape(self.contr, t0, t1)
780+
control_struct = eqx.filter_eval_shape(
781+
self.contr, t0, t1, self.term.init(t0, t1, y, args)
782+
)
781783
if sum(c.size for c in jtu.tree_leaves(control_struct)) in (0, 1):
782784
return False
783785
else:

test/test_integrate.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import scipy.stats
1515
from diffrax import ControlTerm, MultiTerm, ODETerm
1616
from equinox.internal import ω
17-
from jaxtyping import Array, ArrayLike, Float
17+
from jaxtyping import Array, ArrayLike, Float, PyTree
1818

1919
from .helpers import (
2020
all_ode_solvers,
@@ -638,6 +638,10 @@ class TestSolver(diffrax.Euler):
638638

639639

640640
def test_term_compatibility_pytree():
641+
class _TestState(eqx.Module):
642+
y: PyTree
643+
state: PyTree
644+
641645
class TestSolver(diffrax.AbstractSolver):
642646
term_structure = {
643647
"a": diffrax.ODETerm,
@@ -661,14 +665,20 @@ def init(self, terms, t0, t1, y0, args, path_state):
661665
return None
662666

663667
def step(self, terms, t0, t1, y0, args, solver_state, made_jump, path_state):
664-
def _step(_term, _y):
665-
control = _term.contr(t0, t1)
666-
return _y + _term.vf_prod(t0, _y, args, control)
668+
def _step(_term, _y, state):
669+
control, new_state = _term.contr(t0, t1, state)
670+
return _TestState(_y + _term.vf_prod(t0, _y, args, control), new_state)
667671

668672
_is_term = lambda x: isinstance(x, diffrax.AbstractTerm)
669-
y1 = jtu.tree_map(_step, terms, y0, is_leaf=_is_term)
673+
output = jtu.tree_map(_step, terms, y0, path_state, is_leaf=_is_term)
674+
y1 = jtu.tree_map(
675+
lambda x: x.y, output, is_leaf=lambda x: isinstance(x, _TestState)
676+
)
677+
path_state = jtu.tree_map(
678+
lambda x: x.state, output, is_leaf=lambda x: isinstance(x, _TestState)
679+
)
670680
dense_info = dict(y0=y0, y1=y1)
671-
return y1, None, dense_info, None, None, diffrax.RESULTS.successful
681+
return y1, None, dense_info, None, path_state, diffrax.RESULTS.successful
672682

673683
def func(self, terms, t0, y0, args):
674684
assert False

0 commit comments

Comments
 (0)