Skip to content

Commit 29138ed

Browse files
committed
testing work
1 parent 427a594 commit 29138ed

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

benchmarks/stateful_paths.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def _evaluate_leaf(
163163
brownian_motion = diffrax.VirtualBrownianTree(t0, t1, tol=1e-3, shape=(), key=key)
164164
ubp = OldBrownianPath(shape=(), key=key)
165165
new_ubp = diffrax.UnsafeBrownianPath(shape=(), key=key)
166-
new_ubp_pre = diffrax.UnsafeBrownianPath(shape=(), key=key, precompute=True)
166+
new_ubp_pre = diffrax.UnsafeBrownianPath(shape=(), key=key, precompute=ndt + 10)
167167
solver = diffrax.Euler()
168168
terms = diffrax.MultiTerm(
169169
diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, brownian_motion)
@@ -177,7 +177,7 @@ def _evaluate_leaf(
177177
terms_new_precompute = diffrax.MultiTerm(
178178
diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, new_ubp_pre)
179179
)
180-
saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t1, ndt))
180+
saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t1, 1000))
181181

182182

183183
@jax.jit

diffrax/_integrate.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def body_fun_aux(state):
387387

388388
tprev = jnp.minimum(tprev, t1)
389389
tnext = _clip_to_end(tprev, tnext, t1, keep_step)
390-
390+
391391
progress_meter_state = progress_meter.step(
392392
state.progress_meter_state, linear_rescale(t0, tprev, t1)
393393
)
@@ -862,7 +862,7 @@ class SaveAt(eqx.Module): # noqa: F811
862862
t1: bool
863863

864864

865-
# @eqx.filter_jit
865+
@eqx.filter_jit
866866
@eqxi.doc_remove_args("discrete_terminating_event")
867867
def diffeqsolve(
868868
terms: PyTree[AbstractTerm],

0 commit comments

Comments
 (0)