Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding weak levy area #496

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
AbstractBrownianIncrement as AbstractBrownianIncrement,
AbstractSpaceTimeLevyArea as AbstractSpaceTimeLevyArea,
AbstractSpaceTimeTimeLevyArea as AbstractSpaceTimeTimeLevyArea,
AbstractWeakSpaceSpaceLevyArea as AbstractWeakSpaceSpaceLevyArea,
BrownianIncrement as BrownianIncrement,
DavieFosterWeakSpaceSpaceLevyArea as DavieFosterWeakSpaceSpaceLevyArea,
DavieWeakSpaceSpaceLevyArea as DavieWeakSpaceSpaceLevyArea,
SpaceTimeLevyArea as SpaceTimeLevyArea,
SpaceTimeTimeLevyArea as SpaceTimeTimeLevyArea,
)
Expand Down
74 changes: 65 additions & 9 deletions diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from .._custom_types import (
AbstractBrownianIncrement,
BrownianIncrement,
DavieFosterWeakSpaceSpaceLevyArea,
DavieWeakSpaceSpaceLevyArea,
levy_tree_transpose,
RealScalarLike,
SpaceTimeLevyArea,
Expand All @@ -27,6 +29,15 @@
from .base import AbstractBrownianPath


_Levy_Areas = Union[
BrownianIncrement,
SpaceTimeLevyArea,
SpaceTimeTimeLevyArea,
DavieWeakSpaceSpaceLevyArea,
DavieFosterWeakSpaceSpaceLevyArea,
]


class UnsafeBrownianPath(AbstractBrownianPath):
"""Brownian simulation that is only suitable for certain cases.

Expand Down Expand Up @@ -62,18 +73,14 @@ class UnsafeBrownianPath(AbstractBrownianPath):
"""

shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
levy_area: type[
Union[BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea]
] = eqx.field(static=True)
levy_area: type[_Levy_Areas] = eqx.field(static=True)
key: PRNGKeyArray

def __init__(
self,
shape: Union[tuple[int, ...], PyTree[jax.ShapeDtypeStruct]],
key: PRNGKeyArray,
levy_area: type[
Union[BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea]
] = BrownianIncrement,
levy_area: type[_Levy_Areas] = BrownianIncrement,
):
self.shape = (
jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype())
Expand Down Expand Up @@ -141,9 +148,7 @@ def _evaluate_leaf(
t1: RealScalarLike,
key,
shape: jax.ShapeDtypeStruct,
levy_area: type[
Union[BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea]
],
levy_area: type[_Levy_Areas],
use_levy: bool,
):
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)
Expand All @@ -158,6 +163,57 @@ def _evaluate_leaf(
kk = jr.normal(key_kk, shape.shape, shape.dtype) * kk_std
levy_val = SpaceTimeTimeLevyArea(dt=dt, W=w, H=hh, K=kk)

elif levy_area is DavieWeakSpaceSpaceLevyArea:
key_w, key_hh, key_b = jr.split(key, 3)
w = jr.normal(key_w, shape.shape, shape.dtype) * w_std
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
if w.ndim == 0 or w.ndim == 1:
a = jnp.zeros_like(w, dtype=shape.dtype)
levy_val = DavieWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a)
else:
b_std = (dt / jnp.sqrt(12)).astype(shape.dtype)
b = (
jr.normal(key_b, shape.shape + shape.shape[-1:], shape.dtype)
* b_std
)
b = b - b.transpose(*range(b.ndim - 2), -1, -2)
a = jnp.expand_dims(hh, -1) * jnp.expand_dims(w, -2) - jnp.expand_dims(
w, -1
) * jnp.expand_dims(hh, -2)
a += b
levy_val = DavieWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a)

elif levy_area is DavieFosterWeakSpaceSpaceLevyArea:
key_w, key_hh, key_b = jr.split(key, 3)
w = jr.normal(key_w, shape.shape, shape.dtype) * w_std
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
if w.ndim == 0 or w.ndim == 1:
a = jnp.zeros_like(w, dtype=shape.dtype)
levy_val = DavieFosterWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a)
else:
tenth_dt = (0.1 * dt).astype(shape.dtype)
hh_squared = hh**2
b_std = jnp.sqrt(
tenth_dt
* (
tenth_dt
+ jnp.expand_dims(hh_squared, -1)
+ jnp.expand_dims(hh_squared, -2)
)
).astype(shape.dtype)
b = (
jr.normal(key_b, shape.shape + shape.shape[-1:], shape.dtype)
* b_std
)
b = b - b.transpose(*range(b.ndim - 2), -1, -2)
a = jnp.expand_dims(hh, -1) * jnp.expand_dims(w, -2) - jnp.expand_dims(
w, -1
) * jnp.expand_dims(hh, -2)
a += b
levy_val = DavieFosterWeakSpaceSpaceLevyArea(dt=dt, W=w, H=hh, A=a)

elif levy_area is SpaceTimeLevyArea:
key_w, key_hh = jr.split(key, 2)
w = jr.normal(key_w, shape.shape, shape.dtype) * w_std
Expand Down
34 changes: 34 additions & 0 deletions diffrax/_custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
Args = PyTree[Any]

BM = PyTree[Shaped[ArrayLike, "?*bm"], "BM"]
Area = PyTree[Shaped[ArrayLike, "?*area"], "Area"]

DenseInfo = dict[str, PyTree[Array]]
DenseInfos = dict[str, PyTree[Shaped[Array, "times-1 ..."]]]
Expand All @@ -72,6 +73,39 @@ class AbstractSpaceTimeLevyArea(AbstractBrownianIncrement):
H: eqx.AbstractVar[BM]


class AbstractWeakSpaceSpaceLevyArea(AbstractBrownianIncrement):
"""
Abstract base class for all weak Space Space Levy Areas.
"""

H: eqx.AbstractVar[BM]
A: eqx.AbstractVar[BM]


class DavieWeakSpaceSpaceLevyArea(AbstractWeakSpaceSpaceLevyArea):
"""
Davie's approximation to weak Space Space Levy Areas.
See (7.4.1) of Foster's thesis.
"""

dt: PyTree[FloatScalarLike, "BM"]
W: BM
H: BM
A: Area


class DavieFosterWeakSpaceSpaceLevyArea(AbstractWeakSpaceSpaceLevyArea):
"""
Davie's approximation to weak Space Space Levy Areas.
See (7.4.2) of Foster's thesis.
"""

dt: PyTree[FloatScalarLike, "BM"]
W: BM
H: BM
A: Area


class AbstractSpaceTimeTimeLevyArea(AbstractSpaceTimeLevyArea):
"""
Abstract base class for all Space Time Time Levy Areas.
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,7 @@ def _promote(yi):
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
# Specific check to not work even if using HalfSolver(Euler())
if isinstance(solver, Euler):
raise ValueError(
warnings.warn(
"An SDE should not be solved with adaptive step sizes with Euler's "
"method, as it may not converge to the correct solution."
)
Expand Down
16 changes: 15 additions & 1 deletion test/test_brownian.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import jax


jax.config.update("jax_enable_x64", True)
import contextlib
import math
from typing import Literal
Expand Down Expand Up @@ -36,12 +40,22 @@ def _make_struct(shape, dtype):
@pytest.mark.parametrize(
"ctr", [diffrax.UnsafeBrownianPath, diffrax.VirtualBrownianTree]
)
@pytest.mark.parametrize("levy_area", _levy_areas)
@pytest.mark.parametrize(
"levy_area",
_levy_areas
+ (diffrax.DavieWeakSpaceSpaceLevyArea, diffrax.DavieFosterWeakSpaceSpaceLevyArea),
)
@pytest.mark.parametrize("use_levy", (False, True))
def test_shape_and_dtype(ctr, levy_area, use_levy, getkey):
t0 = 0.0
t1 = 2.0

if (
issubclass(levy_area, diffrax.AbstractWeakSpaceSpaceLevyArea)
and ctr is diffrax.VirtualBrownianTree
):
return

shapes_dtypes1 = (
((), None),
((0,), None),
Expand Down
Loading