From f9b6afdd919f60ddf507026713f1943659a1ef4a Mon Sep 17 00:00:00 2001 From: JuanKo96 Date: Wed, 14 Sep 2022 07:56:27 +0900 Subject: [PATCH] adding test for sdeint and debug the defunc --- test/test_sdeint.py | 121 ++++++++++++++++++++++++++++++++++++++++ torchdyn/core/defunc.py | 27 ++++++--- 2 files changed, 140 insertions(+), 8 deletions(-) create mode 100644 test/test_sdeint.py diff --git a/test/test_sdeint.py b/test/test_sdeint.py new file mode 100644 index 0000000..f9811ae --- /dev/null +++ b/test/test_sdeint.py @@ -0,0 +1,121 @@ +import pytest +from torch import nn +import torch +import torchsde +import numpy as np +from torchdyn.numerics import sdeint +from numpy.testing import assert_almost_equal + + +@pytest.mark.parametrize("solver", ["euler", "milstein_ito"]) +def test_geo_brownian_ito(solver): + torch.manual_seed(0) + np.random.seed(0) + + t0, t1 = 0, 1 + size = (1, 1) + device = "cpu" + + alpha = torch.sigmoid(torch.normal(mean=0.0, std=1.0, size=size)).to(device) + beta = torch.sigmoid(torch.normal(mean=0.0, std=1.0, size=size)).to(device) + x0 = torch.normal(mean=0.0, std=1.1, size=size).to(device) + t_size = 1000 + ts = torch.linspace(t0, t1, t_size).to(device) + + bm = torchsde.BrownianInterval( + t0=t0, t1=t1, size=size, device=device, levy_area_approximation="space-time" + ) + + def get_bm_queries(bm, ts): + bm_increments = torch.stack( + [bm(t0, t1) for t0, t1 in zip(ts[:-1], ts[1:])], dim=0 + ) + bm_queries = torch.cat( + (torch.zeros(1, 1, 1).to(device), torch.cumsum(bm_increments, dim=0)) + ) + return bm_queries + + class SDE(nn.Module): + def __init__(self, alpha, beta): + super().__init__() + self.alpha = nn.Parameter(alpha, requires_grad=True) + self.beta = nn.Parameter(beta, requires_grad=True) + self.noise_type = "diagonal" + self.sde_type = "ito" + + def f(self, t, x): + return self.alpha * x + + def g(self, t, x): + return self.beta * x + + sde = SDE(alpha, beta).to(device) + + with torch.no_grad(): + _, xs_torchdyn = sdeint(sde, x0, ts, solver=solver, bm=bm) + + bm_queries = get_bm_queries(bm, ts) + xs_true = x0.cpu() * np.exp( + (alpha.cpu() - 0.5 * beta.cpu() ** 2) * ts.cpu() + + beta.cpu() * bm_queries[:, 0, 0].cpu() + ) + + assert_almost_equal(xs_true[0][-1], xs_torchdyn[-1], decimal=2) + + +# todo : need to improve sdeint for stratonovich +@pytest.mark.parametrize("solver", ["eulerHeun", "milstein_stratonovich"]) +def test_geo_brownian_stratonovich(solver): + torch.manual_seed(0) + np.random.seed(0) + + t0, t1 = 0, 1 + size = (1, 1) + device = "cpu" + + alpha = torch.sigmoid(torch.normal(mean=0.0, std=1.0, size=size)).to(device) + beta = torch.sigmoid(torch.normal(mean=0.0, std=1.0, size=size)).to(device) + x0 = torch.normal(mean=0.0, std=1.1, size=size).to(device) + t_size = 1000 + ts = torch.linspace(t0, t1, t_size).to(device) + + bm = torchsde.BrownianInterval( + t0=t0, t1=t1, size=size, device=device, levy_area_approximation="space-time" + ) + + def get_bm_queries(bm, ts): + bm_increments = torch.stack( + [bm(t0, t1) for t0, t1 in zip(ts[:-1], ts[1:])], dim=0 + ) + bm_queries = torch.cat( + (torch.zeros(1, 1, 1).to(device), torch.cumsum(bm_increments, dim=0)) + ) + return bm_queries + + class SDE(nn.Module): + def __init__(self, alpha, beta): + super().__init__() + self.alpha = nn.Parameter(alpha, requires_grad=True) + self.beta = nn.Parameter(beta, requires_grad=True) + self.noise_type = "diagonal" + self.sde_type = "stratonovich" + + def f(self, t, x): + return self.alpha * x + + def g(self, t, x): + return self.beta * x + + sde = SDE(alpha, beta).to(device) + + with torch.no_grad(): + _, xs_torchdyn = sdeint(sde, x0, ts, solver=solver, bm=bm) + + bm_queries = get_bm_queries(bm, ts) + xs_true = x0.cpu() * np.exp( + (alpha.cpu() - 0.5 * beta.cpu() ** 2) * ts.cpu() + + beta.cpu() * bm_queries[:, 0, 0].cpu() + ) + + assert_almost_equal(xs_true[0][-1] - xs_torchdyn[-1], 1, decimal=0) + diff --git a/torchdyn/core/defunc.py b/torchdyn/core/defunc.py index 2118dac..2f0ce26 100644 --- a/torchdyn/core/defunc.py +++ b/torchdyn/core/defunc.py @@ -116,16 +116,27 @@ def forward(self, t: Tensor, x: Tensor) -> Tensor: def f(self, t: Tensor, x: Tensor) -> Tensor: self.nfe += 1 - # print(self.f_func) - - if "t" not in getfullargspec(self.f_func.forward).args: - return self.f_func(x) + if issubclass(type(self.f_func), nn.Module): + if "t" not in getfullargspec(self.f_func.forward).args: + return self.f_func(x) + else: + return self.f_func(t, x) else: - return self.f_func(t, x) + if "t" not in getfullargspec(self.f_func).args: + return self.f_func(x) + else: + return self.f_func(t, x) def g(self, t: Tensor, x: Tensor) -> Tensor: self.nfe += 1 - if "t" not in getfullargspec(self.g_func.forward).args: - return self.g_func(x) + if issubclass(type(self.g_func), nn.Module): + + if "t" not in getfullargspec(self.g_func.forward).args: + return self.g_func(x) + else: + return self.g_func(t, x) else: - return self.g_func(t, x) + if "t" not in getfullargspec(self.g_func).args: + return self.g_func(x) + else: + return self.g_func(t, x)