Skip to content

Commit 427a594

Browse files
committed
some fixes
1 parent 35dc705 commit 427a594

8 files changed

+59
-18
lines changed

diffrax/_integrate.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,6 @@ def cond_fun(state):
340340

341341
def body_fun_aux(state):
342342
state = _handle_static(state)
343-
344343
#
345344
# Actually do some differential equation solving! Make numerical steps, adapt
346345
# step sizes, all that jazz.
@@ -1105,7 +1104,11 @@ def _promote(yi):
11051104
terms = MultiTerm(*terms)
11061105

11071106
if path_state is None:
1108-
path_state = terms.init(t0, t1, y0, args)
1107+
path_state = jax.tree.map(
1108+
lambda term: term.init(t0, t1, y0, args),
1109+
terms,
1110+
is_leaf=lambda x: isinstance(x, AbstractTerm),
1111+
)
11091112

11101113
# Error checking for term compatibility
11111114
_assert_term_compatible(
@@ -1252,7 +1255,11 @@ def _subsaveat_direction_fn(x):
12521255

12531256
if path_state is None:
12541257
passed_path_state = False
1255-
path_state = terms.init(t0, tnext, y0, args)
1258+
path_state = jax.tree.map(
1259+
lambda term: term.init(t0, tnext, y0, args),
1260+
terms,
1261+
is_leaf=lambda x: isinstance(x, AbstractTerm),
1262+
)
12561263
else:
12571264
passed_path_state = True
12581265

diffrax/_solver/foster_langevin_srk.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def init(
260260
evaluation of grad_f.
261261
"""
262262
drift, diffusion = terms.terms
263+
drift_path, diffusion_path = path_state
263264
(
264265
gamma_drift,
265266
u_drift,
@@ -271,7 +272,7 @@ def init(
271272
# is this the only solver class that has `init` depend on the path state?
272273
# feels irksome to change everything for one class, but I'm going to make
273274
# `init` now depend on path state for the sake of generality
274-
h, _ = drift.contr(t0, t1, path_state)
275+
h, _ = drift.contr(t0, t1, drift_path)
275276
x0, v0 = y0
276277

277278
gamma = broadcast_underdamped_langevin_arg(gamma_drift, x0, "gamma")
@@ -390,7 +391,6 @@ def step(
390391
drift, diffusion = terms.terms
391392
drift_path, diffusion_path = path_state
392393

393-
394394
h, drift_path = drift.contr(t0, t1, drift_path)
395395
h_prev = st.h
396396
tay: PyTree[_Coeffs] = st.taylor_coeffs

diffrax/_solver/milstein.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def init(
6969
) -> _SolverState:
7070
return None
7171

72+
# TODO, a bunch of these solvers have tuple requirements, we can type the
73+
# _PathState to be the same pytree.
7274
def step(
7375
self,
7476
terms: MultiTerm[
@@ -84,9 +86,10 @@ def step(
8486
) -> tuple[Y, _ErrorEstimate, DenseInfo, _SolverState, _PathState, RESULTS]:
8587
del solver_state, made_jump
8688
drift, diffusion = terms.terms
87-
# should these be same path state?
88-
dt, _ = drift.contr(t0, t1, path_state)
89-
dw, path_state = diffusion.contr(t0, t1, path_state)
89+
drift_path, diffusion_path = path_state
90+
91+
dt, drift_path = drift.contr(t0, t1, drift_path)
92+
dw, diffusion_path = diffusion.contr(t0, t1, diffusion_path)
9093

9194
f0_prod = drift.vf_prod(t0, y0, args, dt)
9295
g0_prod = diffusion.vf_prod(t0, y0, args, dw)
@@ -98,7 +101,14 @@ def _to_jvp(_y0):
98101
y1 = (y0**ω + f0_prod**ω + g0_prod**ω + 0.5 * v0_prod**ω).ω
99102

100103
dense_info = dict(y0=y0, y1=y1)
101-
return y1, None, dense_info, None, path_state, RESULTS.successful
104+
return (
105+
y1,
106+
None,
107+
dense_info,
108+
None,
109+
(drift_path, diffusion_path),
110+
RESULTS.successful,
111+
)
102112

103113
def func(
104114
self,
@@ -167,8 +177,9 @@ def step(
167177
) -> tuple[Y, _ErrorEstimate, DenseInfo, _SolverState, _PathState, RESULTS]:
168178
del solver_state, made_jump
169179
drift, diffusion = terms.terms
170-
Δt, path_state = drift.contr(t0, t1, path_state)
171-
Δw, path_state = diffusion.contr(t0, t1, path_state)
180+
drift_path, diffusion_path = path_state
181+
Δt, drift_path = drift.contr(t0, t1, drift_path)
182+
Δw, diffusion_path = diffusion.contr(t0, t1, diffusion_path)
172183

173184
#
174185
# So this is a bit involved, largely because of the generality that the rest of
@@ -379,7 +390,14 @@ def _dot(_, _v0):
379390
#
380391

381392
dense_info = dict(y0=y0, y1=y1)
382-
return y1, None, dense_info, None, path_state, RESULTS.successful
393+
return (
394+
y1,
395+
None,
396+
dense_info,
397+
None,
398+
(drift_path, diffusion_path),
399+
RESULTS.successful,
400+
)
383401

384402
def func(
385403
self,

diffrax/_solver/semi_implicit_euler.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,24 @@ def step(
6262
del solver_state, made_jump
6363

6464
term_1, term_2 = terms
65+
path_state1, path_state2 = path_state
6566
y0_1, y0_2 = y0
6667

67-
control1, path_state = term_1.contr(t0, t1, path_state)
68-
control2, path_state = term_2.contr(t0, t1, path_state)
68+
control1, path_state1 = term_1.contr(t0, t1, path_state1)
69+
control2, path_state2 = term_2.contr(t0, t1, path_state2)
6970
y1_1 = (y0_1**ω + term_1.vf_prod(t0, y0_2, args, control1) ** ω).ω
7071
y1_2 = (y0_2**ω + term_2.vf_prod(t0, y1_1, args, control2) ** ω).ω
7172

7273
y1 = (y1_1, y1_2)
7374
dense_info = dict(y0=y0, y1=y1)
74-
return y1, None, dense_info, None, path_state, RESULTS.successful
75+
return (
76+
y1,
77+
None,
78+
dense_info,
79+
None,
80+
(path_state1, path_state2),
81+
RESULTS.successful,
82+
)
7583

7684
def func(
7785
self,

diffrax/_solver/srk.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,14 @@ def compute_and_insert_kg_j(_w_kgs_in, _levylist_kgs_in):
663663

664664
y1 = (y0**ω + drift_result**ω + diffusion_result**ω).ω
665665
dense_info = dict(y0=y0, y1=y1)
666-
return y1, error, dense_info, None, (drift_path, diffusion_path), RESULTS.successful
666+
return (
667+
y1,
668+
error,
669+
dense_info,
670+
None,
671+
(drift_path, diffusion_path),
672+
RESULTS.successful,
673+
)
667674

668675
def func(
669676
self,

diffrax/_term.py

-1
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,6 @@ def contr(
675675
control_state: PyTree,
676676
**kwargs,
677677
) -> tuple[tuple[PyTree[ArrayLike], ...], tuple[PyTree, ...]]:
678-
679678
contrs = [
680679
term.contr(t0, t1, state, **kwargs)
681680
for term, state in zip(self.terms, control_state)

test/test_solver.py

+3
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ def test_everything_pytree(implicit, vf_expensive, adaptive):
205205
class Term(diffrax.AbstractTerm):
206206
coeff: float
207207

208+
def init(self, t0, t1, y0, args):
209+
return None
210+
208211
def vf(self, t, y, args):
209212
return {"f": -self.coeff * y["y"]}
210213

test/test_underdamped_langevin.py

-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def test_shape(solver, dtype):
9191
sde = get_pytree_uld(t0, t1, dtype)
9292
bm = sde.get_bm(jr.key(5678), diffrax.SpaceTimeTimeLevyArea, tol=0.2)
9393
terms = sde.get_terms(bm)
94-
print(terms)
9594

9695
sol = diffeqsolve(
9796
terms, solver, t0, t1, dt0=dt0, y0=sde.y0, args=None, saveat=saveat

0 commit comments

Comments
 (0)