From 2324b25d8eaac5aa094a3fe2fed2ae4c564d8de2 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 9 Nov 2024 18:29:39 -0500 Subject: [PATCH 1/5] fix docstrings and reorder super().__init__ --- torchdyn/core/neuralde.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/torchdyn/core/neuralde.py b/torchdyn/core/neuralde.py index 7be9b3b..bcb9a83 100644 --- a/torchdyn/core/neuralde.py +++ b/torchdyn/core/neuralde.py @@ -23,8 +23,6 @@ import torch.nn as nn import torchsde -import warnings - class NeuralODE(ODEProblem, pl.LightningModule): def __init__( @@ -51,12 +49,12 @@ def __init__( In the second case, the Callable is automatically wrapped for consistency solver (Union[str, nn.Module]): order (int, optional): Order of the ODE. Defaults to 1. - atol (float, optional): Absolute tolerance of the solver. Defaults to 1e-4. - rtol (float, optional): Relative tolerance of the solver. Defaults to 1e-4. + atol (float, optional): Absolute tolerance of the solver. Defaults to 1e-3. + rtol (float, optional): Relative tolerance of the solver. Defaults to 1e-3. sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'. solver_adjoint (Union[str, nn.Module, None], optional): ODE solver for the adjoint. Defaults to None. - atol_adjoint (float, optional): Defaults to 1e-6. - rtol_adjoint (float, optional): Defaults to 1e-6. + atol_adjoint (float, optional): Defaults to 1e-4. + rtol_adjoint (float, optional): Defaults to 1e-4. integral_loss (Union[Callable, None], optional): Defaults to None. seminorm (bool, optional): Whether to use seminorms for adaptive stepping in backsolve adjoints. Defaults to False. return_t_eval (bool): Whether to return (t_eval, sol) or only sol. Useful for chaining NeuralODEs in `nn.Sequential`. @@ -171,15 +169,6 @@ def __init__( bm=None, return_t_eval: bool = True, ): - super().__init__( - defunc=SDEFunc(f=drift_func, g=diffusion_func, order=order), - solver=solver, - interpolator=interpolator, - atol=atol, - rtol=rtol, - sensitivity=sensitivity, - ) - """Generic Neural Stochastic Differential Equation. Follows the same design of the `NeuralODE` class. Args: @@ -204,6 +193,15 @@ def __init__( Notes: The current implementation is rougher around the edges compared to `NeuralODE`, and is not guaranteed to have the same features. """ + super().__init__( + defunc=SDEFunc(f=drift_func, g=diffusion_func, order=order), + solver=solver, + interpolator=interpolator, + atol=atol, + rtol=rtol, + sensitivity=sensitivity, + ) + if order != 1: raise NotImplementedError self.defunc.noise_type, self.defunc.sde_type = noise_type, sde_type From 8a093fd9cb2f983bedb621f6755ad888a5d07782 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 9 Nov 2024 18:29:57 -0500 Subject: [PATCH 2/5] remove unneeded import --- torchdyn/core/problems.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdyn/core/problems.py b/torchdyn/core/problems.py index 3157b96..89590f9 100644 --- a/torchdyn/core/problems.py +++ b/torchdyn/core/problems.py @@ -8,7 +8,7 @@ _gather_odefunc_interp_adjoint, ) from torchdyn.numerics.odeint import odeint, odeint_mshooting -from torchdyn.numerics.solvers.ode import str_to_solver, str_to_ms_solver +from torchdyn.numerics.solvers.ode import str_to_solver from torchdyn.core.utils import standardize_vf_call_signature from torchdyn.core.defunc import SDEFunc from torchdyn.numerics import sdeint From b927a6347b8eff9845aceb54dfca389d889ef56c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 9 Nov 2024 18:30:20 -0500 Subject: [PATCH 3/5] use isinstance and define all return paths for odeint --- torchdyn/numerics/odeint.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/torchdyn/numerics/odeint.py b/torchdyn/numerics/odeint.py index 7f10158..65bc76f 100644 --- a/torchdyn/numerics/odeint.py +++ b/torchdyn/numerics/odeint.py @@ -57,9 +57,9 @@ def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, n t_span = -t_span else: f_ = f - if type(t_span) == list: t_span = torch.cat(t_span) + if isinstance(t_span, list): t_span = torch.cat(t_span) # instantiate the solver in case the user has specified preference via a `str` and ensure compatibility of device ~ dtype - if type(solver) == str: + if isinstance(solver, str): solver = str_to_solver(solver, x.dtype) x, t_span = solver.sync_device_dtype(x, t_span) stepping_class = solver.stepping_class @@ -69,15 +69,13 @@ def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, n if verbose: warn("Running interpolation not yet implemented for `tsit5`") interpolator = None - if type(interpolator) == str: + if isinstance(interpolator, str): interpolator = str_to_interp(interpolator, x.dtype) x, t_span = interpolator.sync_device_dtype(x, t_span) - # access parallel integration routines with different t_spans for each sample in `x`. - if len(t_span.shape) > 1: - raise NotImplementedError("Parallel routines not implemented yet, check experimental versions of `torchdyn`") # odeint routine with a single t_span for all samples - elif len(t_span.shape) == 1: + t_span.size + if t_span.dim() == 1: if stepping_class == 'fixed': if atol != odeint.__defaults__[0] or rtol != odeint.__defaults__[1]: warn("Setting tolerances has no effect on fixed-step methods") @@ -89,6 +87,11 @@ def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, n dt = init_step(f, k1, x, t, solver.order, atol, rtol) if len(save_at) > 0: warn("Setting save_at has no effect on adaptive-step methods") return _adaptive_odeint(f_, k1, x, dt, t_span, solver, atol, rtol, args, interpolator, return_all_eval, seminorm) + else: + raise RuntimeError("Invalid stepping class provided") + # access parallel integration routines with different t_spans for each sample in `x`. + else: + raise NotImplementedError("Parallel routines not implemented yet, check experimental versions of `torchdyn`") # TODO (qol) state augmentation for symplectic methods @@ -114,10 +117,10 @@ def odeint_symplectic(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:U f_ = lambda t, x: -f(-t, x) t_span = -t_span else: f_ = f - if type(t_span) == list: t_span = torch.cat(t_span) + if isinstance(t_span, list): t_span = torch.cat(t_span) # instantiate the solver in case the user has specified preference via a `str` and ensure compatibility of device ~ dtype - if type(solver) == str: + if isinstance(solver, str): solver = str_to_solver(solver, x.dtype) x, t_span = solver.sync_device_dtype(x, t_span) stepping_class = solver.stepping_class @@ -167,7 +170,7 @@ def odeint_mshooting(f:Callable, x:Tensor, t_span:Tensor, solver:Union[str, nn.M TODO: At the moment assumes the ODE to NOT be time-varying. An extension is possible by adaptive the step function of a parallel-in-time solvers. """ - if type(solver) == str: + if isinstance(solver, str): solver = str_to_ms_solver(solver) x, t_span = solver.sync_device_dtype(x, t_span) # first-guess B0 of shooting parameters @@ -199,7 +202,7 @@ def odeint_hybrid(f, x, t_span, j_span, solver, callbacks, atol=1e-3, rtol=1e-3, priority (str, optional): Defaults to 'jump'. """ # instantiate the solver in case the user has specified preference via a `str` and ensure compatibility of device ~ dtype - if type(solver) == str: solver = str_to_solver(solver, x.dtype) + if isinstance(solver, str): solver = str_to_solver(solver, x.dtype) x, t_span = solver.sync_device_dtype(x, t_span) x_shape = x.shape ckpt_counter, ckpt_flag, jnum = 0, False, 0 From f801614ca62dfe29bc80c6c82466345a426b3c21 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 9 Nov 2024 18:44:44 -0500 Subject: [PATCH 4/5] ignore pytorch_lightning warnings in tests --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 06cc469..c171581 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,4 +54,6 @@ filterwarnings = [ "ignore:Call to deprecated create function EnumValueDescriptor", # pytorch lightning needs tensorboard which has a conflict with python 3.9 "ignore:Call to deprecated create function FileDescriptor", # pytorch lightning needs tensorboard which has a conflict with python 3.9 "ignore:Call to deprecated create function OneofDescriptor", # pytorch lightning needs tensorboard which has a conflict with python 3.9 + "ignore:.+does not have many workers which may be a bottleneck", + "ignore:The number of training batches" ] From 2ca7d0285d1c9d6f41c1472b57743aef29f7a6fa Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 9 Nov 2024 18:47:38 -0500 Subject: [PATCH 5/5] additional isinstance --- test/models/test_ode.py | 6 +++--- torchdyn/core/problems.py | 2 +- torchdyn/models/hybrid.py | 2 +- torchdyn/numerics/sdeint.py | 4 ++-- torchdyn/numerics/solvers/templates.py | 6 ++++-- 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/test/models/test_ode.py b/test/models/test_ode.py index b5eaa4e..9dbdb80 100644 --- a/test/models/test_ode.py +++ b/test/models/test_ode.py @@ -39,7 +39,7 @@ def test_repr(small_mlp): model = NeuralODE(small_mlp) - assert type(model.__repr__()) == str and 'NFE' in model.__repr__() + assert isinstance(model.__repr__(), str) and 'NFE' in model.__repr__() # TODO: extend to GPU and Multi-GPU @@ -105,14 +105,14 @@ def test_deepcopy(small_mlp, device): model = NeuralODE(small_mlp) x = torch.rand(1, 2) copy_before_forward = copy.deepcopy(model) - assert type(copy_before_forward) == NeuralODE + assert isinstance(copy_before_forward, NeuralODE) # do a forward+backward pass y = model(x) loss = y.sum() loss.backward() copy_after_forward = copy.deepcopy(model) - assert type(copy_after_forward) == NeuralODE + assert isinstance(copy_after_forward, NeuralODE) @pytest.mark.skip(reason='clean up to new API') diff --git a/torchdyn/core/problems.py b/torchdyn/core/problems.py index 89590f9..c6e98c1 100644 --- a/torchdyn/core/problems.py +++ b/torchdyn/core/problems.py @@ -53,7 +53,7 @@ def __init__( """ super().__init__() # instantiate solver at initialization - if type(solver) == str: + if isinstance(solver, str): solver = str_to_solver(solver) if solver_adjoint is None: solver_adjoint = solver diff --git a/torchdyn/models/hybrid.py b/torchdyn/models/hybrid.py index b95958a..eef39f9 100644 --- a/torchdyn/models/hybrid.py +++ b/torchdyn/models/hybrid.py @@ -34,7 +34,7 @@ def __init__(self, flow, jump, out, last_output=True, reverse=False): # either take hidden and element of sequence (e.g RNNCell) # or h, x_t and c (LSTMCell). Custom implementation assumes call # signature of type (x_t, h) and .hidden_size property - if type(jump) == nn.modules.rnn.LSTMCell: + if isinstance(jump, nn.modules.rnn.LSTMCell): self.jump_func = self._jump_latent_cell else: self.jump_func = self._jump_latent diff --git a/torchdyn/numerics/sdeint.py b/torchdyn/numerics/sdeint.py index 1c2ba80..da992ad 100644 --- a/torchdyn/numerics/sdeint.py +++ b/torchdyn/numerics/sdeint.py @@ -49,10 +49,10 @@ def sdeint( # make sde to SDEFunc form? sde = check_sde(sde) - if type(t_span) == list: + if isinstance(t_span, list): t_span = torch.cat(t_span) - if type(solver) == str: + if isinstance(solver, str): solver = sde_str_to_solver(solver, sde, bm, x.dtype) x, t_span = solver.sync_device_dtype(x, t_span) stepping_class = solver.stepping_class diff --git a/torchdyn/numerics/solvers/templates.py b/torchdyn/numerics/solvers/templates.py index f7e5e36..6111443 100644 --- a/torchdyn/numerics/solvers/templates.py +++ b/torchdyn/numerics/solvers/templates.py @@ -69,8 +69,10 @@ def __init__(self, coarse_method, fine_method): from torchdyn.numerics.solvers.ode import str_to_solver super(MultipleShootingDiffeqSolver, self).__init__() - if type(coarse_method) == str: self.coarse_method = str_to_solver(coarse_method) - if type(fine_method) == str: self.fine_method = str_to_solver(fine_method) + if isinstance(coarse_method, str): + self.coarse_method = str_to_solver(coarse_method) + if isinstance(fine_method, str): + self.fine_method = str_to_solver(fine_method) def sync_device_dtype(self, x, t_span): "Ensures `x`, `t_span`, `tableau` and other solver tensors are on the same device with compatible dtypes"