@@ -69,6 +69,8 @@ def init(
69
69
) -> _SolverState :
70
70
return None
71
71
72
+ # TODO, a bunch of these solvers have tuple requirements, we can type the
73
+ # _PathState to be the same pytree.
72
74
def step (
73
75
self ,
74
76
terms : MultiTerm [
@@ -84,9 +86,10 @@ def step(
84
86
) -> tuple [Y , _ErrorEstimate , DenseInfo , _SolverState , _PathState , RESULTS ]:
85
87
del solver_state , made_jump
86
88
drift , diffusion = terms .terms
87
- # should these be same path state?
88
- dt , _ = drift .contr (t0 , t1 , path_state )
89
- dw , path_state = diffusion .contr (t0 , t1 , path_state )
89
+ drift_path , diffusion_path = path_state
90
+
91
+ dt , drift_path = drift .contr (t0 , t1 , drift_path )
92
+ dw , diffusion_path = diffusion .contr (t0 , t1 , diffusion_path )
90
93
91
94
f0_prod = drift .vf_prod (t0 , y0 , args , dt )
92
95
g0_prod = diffusion .vf_prod (t0 , y0 , args , dw )
@@ -98,7 +101,14 @@ def _to_jvp(_y0):
98
101
y1 = (y0 ** ω + f0_prod ** ω + g0_prod ** ω + 0.5 * v0_prod ** ω ).ω
99
102
100
103
dense_info = dict (y0 = y0 , y1 = y1 )
101
- return y1 , None , dense_info , None , path_state , RESULTS .successful
104
+ return (
105
+ y1 ,
106
+ None ,
107
+ dense_info ,
108
+ None ,
109
+ (drift_path , diffusion_path ),
110
+ RESULTS .successful ,
111
+ )
102
112
103
113
def func (
104
114
self ,
@@ -167,8 +177,9 @@ def step(
167
177
) -> tuple [Y , _ErrorEstimate , DenseInfo , _SolverState , _PathState , RESULTS ]:
168
178
del solver_state , made_jump
169
179
drift , diffusion = terms .terms
170
- Δt , path_state = drift .contr (t0 , t1 , path_state )
171
- Δw , path_state = diffusion .contr (t0 , t1 , path_state )
180
+ drift_path , diffusion_path = path_state
181
+ Δt , drift_path = drift .contr (t0 , t1 , drift_path )
182
+ Δw , diffusion_path = diffusion .contr (t0 , t1 , diffusion_path )
172
183
173
184
#
174
185
# So this is a bit involved, largely because of the generality that the rest of
@@ -379,7 +390,14 @@ def _dot(_, _v0):
379
390
#
380
391
381
392
dense_info = dict (y0 = y0 , y1 = y1 )
382
- return y1 , None , dense_info , None , path_state , RESULTS .successful
393
+ return (
394
+ y1 ,
395
+ None ,
396
+ dense_info ,
397
+ None ,
398
+ (drift_path , diffusion_path ),
399
+ RESULTS .successful ,
400
+ )
383
401
384
402
def func (
385
403
self ,
0 commit comments