Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateful Controls #559

Open
wants to merge 45 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
51c4c41
intermediate work
lockwo Dec 29, 2024
382d171
solver work
lockwo Dec 31, 2024
10b16bc
benchmark
lockwo Dec 31, 2024
9921407
timing
lockwo Dec 31, 2024
7956b5b
more work
lockwo Dec 31, 2024
7fb9baa
work
lockwo Jan 1, 2025
fd059f7
Merge branch 'patrick-kidger-main' into Owen/control_revamp
lockwo Jan 2, 2025
35dc705
work
lockwo Jan 3, 2025
427a594
some fixes
lockwo Jan 3, 2025
29138ed
testing work
lockwo Jan 3, 2025
7f76cdd
fixes
lockwo Jan 4, 2025
d24e6f1
add test
lockwo Jan 4, 2025
a1374f9
tests + examples
lockwo Jan 5, 2025
919abf9
format
lockwo Jan 5, 2025
12bcf5a
Merge branch 'main' into Owen/control_revamp
lockwo Jan 5, 2025
37640ed
remove todo
lockwo Jan 8, 2025
cc0d4bc
Allowing args into grad_f for ULD
ricor07 Dec 27, 2024
d304d9f
clean up
lockwo Jan 16, 2025
22d00ca
Merge branch 'dev' into Owen/control_revamp
lockwo Jan 16, 2025
1ad8dad
int
lockwo Jan 16, 2025
d0f161c
fix
lockwo Jan 22, 2025
16fedb2
ULD fix
lockwo Jan 22, 2025
9a19d68
more langevin fixes
lockwo Jan 22, 2025
4994982
adjoit
lockwo Jan 24, 2025
80fef54
shorten test
lockwo Jan 27, 2025
1d34946
Test fixes for v0.5.0 + args for langevin
patrick-kidger Jan 26, 2025
b8683f4
Merge branch 'dev' into Owen/control_revamp
lockwo Jan 28, 2025
1067c10
Fix for making vmap over diffeqsolve possible (#578)
LuggiStruggi Jan 28, 2025
5366e65
Tweak test name
patrick-kidger Jan 28, 2025
54e9e77
Update pyproject.toml to meet poetry conventions
joharkit Jan 28, 2025
92fb93d
Fixed a major source of bugs: ControlTerms no longer broadcast.
patrick-kidger Jan 12, 2025
7c2c720
Now using jaxtyping.Real for prettier documentation.
patrick-kidger Jan 12, 2025
583cd6d
Bumped minimum version of Python to 3.10
patrick-kidger Jan 28, 2025
4272270
Investigating if we can drop the typeguard dependency.
patrick-kidger Jan 28, 2025
9236a68
Merge branch 'dev' into Owen/control_revamp
lockwo Jan 30, 2025
96f8bf3
fix merge
lockwo Jan 30, 2025
1946a8b
fix tests
lockwo Feb 3, 2025
b3bb170
fix test2
lockwo Feb 3, 2025
03e5b92
trying larger stepsize
lockwo Feb 4, 2025
766b471
does splitting it up help? (passes locally, but github actions fails)
lockwo Feb 4, 2025
4d22b6f
Merge branch 'patrick-kidger:dev' into dev
lockwo Feb 9, 2025
865846b
Merge branch 'dev' into Owen/control_revamp
lockwo Feb 9, 2025
7865a16
update benchmark
lockwo Feb 9, 2025
20e700d
update jit results
lockwo Feb 9, 2025
e4cd2a3
return jit
lockwo Feb 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 284 additions & 0 deletions benchmarks/stateful_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
import math
from typing import cast, Optional, Union

import diffrax
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import lineax.internal as lxi
from jaxtyping import PRNGKeyArray, PyTree
from lineax.internal import complex_to_real_dtype


class OldBrownianPath(diffrax.AbstractBrownianPath):
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
levy_area: type[
Union[
diffrax.BrownianIncrement,
diffrax.SpaceTimeLevyArea,
diffrax.SpaceTimeTimeLevyArea,
]
] = eqx.field(static=True)
key: PRNGKeyArray
precompute: Optional[int] = eqx.field(static=True)

def __init__(
self,
shape,
key,
levy_area=diffrax.BrownianIncrement,
precompute=None,
):
self.shape = (
jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype())
if diffrax._misc.is_tuple_of_ints(shape)
else shape
)
self.key = key
self.levy_area = levy_area
self.precompute = precompute

if any(
not jnp.issubdtype(x.dtype, jnp.inexact)
for x in jtu.tree_leaves(self.shape)
):
raise ValueError("OldBrownianPath dtypes all have to be floating-point.")

@property
def t0(self):
return -jnp.inf

@property
def t1(self):
return jnp.inf

def init(
self,
t0,
t1,
y0,
args,
):
return None

def __call__(
self,
t0,
brownian_state,
t1=None,
left=True,
use_levy=False,
):
return self.evaluate(t0, t1, left, use_levy), brownian_state

@eqx.filter_jit
def evaluate(
self,
t0,
t1=None,
left=True,
use_levy=False,
):
del left
if t1 is None:
dtype = jnp.result_type(t0)
t1 = t0
t0 = jnp.array(0, dtype)
else:
with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(t0, t1)
t0 = jnp.astype(t0, dtype)
t1 = jnp.astype(t1, dtype)
t0 = eqxi.nondifferentiable(t0, name="t0")
t1 = eqxi.nondifferentiable(t1, name="t1")
t1 = cast(diffrax._custom_types.RealScalarLike, t1)
t0_ = diffrax._misc.force_bitcast_convert_type(t0, jnp.int32)
t1_ = diffrax._misc.force_bitcast_convert_type(t1, jnp.int32)
key = jr.fold_in(self.key, t0_)
key = jr.fold_in(key, t1_)
key = diffrax._misc.split_by_tree(key, self.shape)
out = jtu.tree_map(
lambda key, shape: self._evaluate_leaf(
t0, t1, key, shape, self.levy_area, use_levy
),
key,
self.shape,
)
if use_levy:
out = diffrax._custom_types.levy_tree_transpose(self.shape, out)
assert isinstance(out, self.levy_area)
return out

@staticmethod
def _evaluate_leaf(
t0,
t1,
key,
shape,
levy_area,
use_levy,
):
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)
dt = jnp.asarray(t1 - t0, dtype=complex_to_real_dtype(shape.dtype))

if levy_area is diffrax.SpaceTimeTimeLevyArea:
key_w, key_hh, key_kk = jr.split(key, 3)
w = jr.normal(key_w, shape.shape, shape.dtype) * w_std
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
kk_std = w_std / math.sqrt(720)
kk = jr.normal(key_kk, shape.shape, shape.dtype) * kk_std
levy_val = diffrax.SpaceTimeTimeLevyArea(dt=dt, W=w, H=hh, K=kk)

elif levy_area is diffrax.SpaceTimeLevyArea:
key_w, key_hh = jr.split(key, 2)
w = jr.normal(key_w, shape.shape, shape.dtype) * w_std
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
levy_val = diffrax.SpaceTimeLevyArea(dt=dt, W=w, H=hh)
elif levy_area is diffrax.BrownianIncrement:
w = jr.normal(key, shape.shape, shape.dtype) * w_std
levy_val = diffrax.BrownianIncrement(dt=dt, W=w)
else:
assert False

if use_levy:
return levy_val
return w


# https://github.com/patrick-kidger/diffrax/issues/517
key = jax.random.key(42)
# t0 = 0
# t1 = 100
# y0 = 1.0
# ndt = 4000
# dt = (t1 - t0) / (ndt - 1)
# drift = lambda t, y, args: -y
# diffusion = lambda t, y, args: 0.2
t0 = 0
t1 = 1
y0 = 1.0
ndt = 40010
dt = (t1 - t0) / (ndt - 1)
drift = lambda t, y, args: -y
diffusion = lambda t, y, args: 0.2
# saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t1, ndt))
saveat = diffrax.SaveAt(steps=True)

brownian_motion = diffrax.VirtualBrownianTree(t0, t1, tol=1e-3, shape=(), key=key)
ubp = OldBrownianPath(shape=(), key=key)
new_ubp = diffrax.UnsafeBrownianPath(shape=(), key=key)
new_ubp_pre = diffrax.UnsafeBrownianPath(shape=(), key=key, precompute=ndt + 10)

solver = diffrax.Euler()

terms = diffrax.MultiTerm(
diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, brownian_motion)
)
terms_old = diffrax.MultiTerm(
diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, ubp)
)
terms_new = diffrax.MultiTerm(
diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, new_ubp)
)
terms_new_precompute = diffrax.MultiTerm(
diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, new_ubp_pre)
)


@jax.jit
def diffrax_vbt():
return diffrax.diffeqsolve(
terms, solver, t0, t1, dt0=dt, y0=y0, saveat=saveat, throw=False
).ys


@jax.jit
def diffrax_old():
return diffrax.diffeqsolve(
terms_old, solver, t0, t1, dt0=dt, y0=y0, saveat=saveat, throw=False
).ys


@jax.jit
def diffrax_new():
return diffrax.diffeqsolve(
terms_new, solver, t0, t1, dt0=dt, y0=y0, saveat=saveat, throw=False
).ys


@jax.jit
def diffrax_new_pre():
return diffrax.diffeqsolve(
terms_new_precompute, solver, t0, t1, dt0=dt, y0=y0, saveat=saveat, throw=False
).ys


@jax.jit
def homemade_simu():
dWs = jnp.sqrt(dt) * jax.random.normal(key, (ndt,))

def step(y, dW):
dy = drift(None, y, None) * dt + diffusion(None, y, None) * dW
return y + dy, y

return jax.lax.scan(step, y0, dWs)[-1]


_ = diffrax_vbt().block_until_ready()
_ = diffrax_old().block_until_ready()
_ = diffrax_new().block_until_ready()
_ = diffrax_new_pre().block_until_ready()
_ = homemade_simu().block_until_ready()

from timeit import Timer


num_runs = 10

timer = Timer(stmt="_ = diffrax_vbt().block_until_ready()", globals=globals())
total_time = timer.timeit(number=num_runs)
print(f"VBT: {total_time / num_runs:.6f}")

timer = Timer(stmt="_ = diffrax_old().block_until_ready()", globals=globals())
total_time = timer.timeit(number=num_runs)
print(f"Old UBP: {total_time / num_runs:.6f}")

timer = Timer(stmt="_ = diffrax_new().block_until_ready()", globals=globals())
total_time = timer.timeit(number=num_runs)
print(f"New UBP: {total_time / num_runs:.6f}")

timer = Timer(stmt="_ = diffrax_new_pre().block_until_ready()", globals=globals())
total_time = timer.timeit(number=num_runs)
print(f"New UBP + Precompute: {total_time / num_runs:.6f}")

timer = Timer(stmt="_ = homemade_simu().block_until_ready()", globals=globals())
total_time = timer.timeit(number=num_runs)
print(f"Pure Jax: {total_time / num_runs:.6f}")

"""
Results on Mac M1 CPU:
VBT: 0.204524
Old UBP: 0.017464
New UBP: 0.018535
New UBP + Precompute: 0.002440
Pure Jax: 0.002908

Results on A100 GPU:
VBT: 2.275057
Old UBP: 0.112461
New UBP: 0.126370
New UBP + Precompute: 0.111837
Pure Jax: 0.261937

For small ndt (e.g. 100) the pure jax is faster, but the diffrax overhead
becomes less important as the time increases.

GPU being much slower isn't unsurprising and is a common trend for
small-medium sized SDEs with VFs that are relatively cheap to evaluate
(i.e. not neural networks).
"""
Loading