Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code Improvements #214

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,6 @@ filterwarnings = [
"ignore:Call to deprecated create function EnumValueDescriptor", # pytorch lightning needs tensorboard which has a conflict with python 3.9
"ignore:Call to deprecated create function FileDescriptor", # pytorch lightning needs tensorboard which has a conflict with python 3.9
"ignore:Call to deprecated create function OneofDescriptor", # pytorch lightning needs tensorboard which has a conflict with python 3.9
"ignore:.+does not have many workers which may be a bottleneck",
"ignore:The number of training batches"
]
6 changes: 3 additions & 3 deletions test/models/test_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

def test_repr(small_mlp):
model = NeuralODE(small_mlp)
assert type(model.__repr__()) == str and 'NFE' in model.__repr__()
assert isinstance(model.__repr__(), str) and 'NFE' in model.__repr__()


# TODO: extend to GPU and Multi-GPU
Expand Down Expand Up @@ -105,14 +105,14 @@ def test_deepcopy(small_mlp, device):
model = NeuralODE(small_mlp)
x = torch.rand(1, 2)
copy_before_forward = copy.deepcopy(model)
assert type(copy_before_forward) == NeuralODE
assert isinstance(copy_before_forward, NeuralODE)

# do a forward+backward pass
y = model(x)
loss = y.sum()
loss.backward()
copy_after_forward = copy.deepcopy(model)
assert type(copy_after_forward) == NeuralODE
assert isinstance(copy_after_forward, NeuralODE)


@pytest.mark.skip(reason='clean up to new API')
Expand Down
28 changes: 13 additions & 15 deletions torchdyn/core/neuralde.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import torch.nn as nn
import torchsde

import warnings


class NeuralODE(ODEProblem, pl.LightningModule):
def __init__(
Expand All @@ -51,12 +49,12 @@ def __init__(
In the second case, the Callable is automatically wrapped for consistency
solver (Union[str, nn.Module]):
order (int, optional): Order of the ODE. Defaults to 1.
atol (float, optional): Absolute tolerance of the solver. Defaults to 1e-4.
rtol (float, optional): Relative tolerance of the solver. Defaults to 1e-4.
atol (float, optional): Absolute tolerance of the solver. Defaults to 1e-3.
rtol (float, optional): Relative tolerance of the solver. Defaults to 1e-3.
sensitivity (str, optional): Sensitivity method ['autograd', 'adjoint', 'interpolated_adjoint']. Defaults to 'autograd'.
solver_adjoint (Union[str, nn.Module, None], optional): ODE solver for the adjoint. Defaults to None.
atol_adjoint (float, optional): Defaults to 1e-6.
rtol_adjoint (float, optional): Defaults to 1e-6.
atol_adjoint (float, optional): Defaults to 1e-4.
rtol_adjoint (float, optional): Defaults to 1e-4.
integral_loss (Union[Callable, None], optional): Defaults to None.
seminorm (bool, optional): Whether to use seminorms for adaptive stepping in backsolve adjoints. Defaults to False.
return_t_eval (bool): Whether to return (t_eval, sol) or only sol. Useful for chaining NeuralODEs in `nn.Sequential`.
Expand Down Expand Up @@ -171,15 +169,6 @@ def __init__(
bm=None,
return_t_eval: bool = True,
):
super().__init__(
defunc=SDEFunc(f=drift_func, g=diffusion_func, order=order),
solver=solver,
interpolator=interpolator,
atol=atol,
rtol=rtol,
sensitivity=sensitivity,
)

"""Generic Neural Stochastic Differential Equation. Follows the same design of the `NeuralODE` class.

Args:
Expand All @@ -204,6 +193,15 @@ def __init__(
Notes:
The current implementation is rougher around the edges compared to `NeuralODE`, and is not guaranteed to have the same features.
"""
super().__init__(
defunc=SDEFunc(f=drift_func, g=diffusion_func, order=order),
solver=solver,
interpolator=interpolator,
atol=atol,
rtol=rtol,
sensitivity=sensitivity,
)

if order != 1:
raise NotImplementedError
self.defunc.noise_type, self.defunc.sde_type = noise_type, sde_type
Expand Down
4 changes: 2 additions & 2 deletions torchdyn/core/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
_gather_odefunc_interp_adjoint,
)
from torchdyn.numerics.odeint import odeint, odeint_mshooting
from torchdyn.numerics.solvers.ode import str_to_solver, str_to_ms_solver
from torchdyn.numerics.solvers.ode import str_to_solver
from torchdyn.core.utils import standardize_vf_call_signature
from torchdyn.core.defunc import SDEFunc
from torchdyn.numerics import sdeint
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(
"""
super().__init__()
# instantiate solver at initialization
if type(solver) == str:
if isinstance(solver, str):
solver = str_to_solver(solver)
if solver_adjoint is None:
solver_adjoint = solver
Expand Down
2 changes: 1 addition & 1 deletion torchdyn/models/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, flow, jump, out, last_output=True, reverse=False):
# either take hidden and element of sequence (e.g RNNCell)
# or h, x_t and c (LSTMCell). Custom implementation assumes call
# signature of type (x_t, h) and .hidden_size property
if type(jump) == nn.modules.rnn.LSTMCell:
if isinstance(jump, nn.modules.rnn.LSTMCell):
self.jump_func = self._jump_latent_cell
else:
self.jump_func = self._jump_latent
Expand Down
25 changes: 14 additions & 11 deletions torchdyn/numerics/odeint.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, n
t_span = -t_span
else: f_ = f

if type(t_span) == list: t_span = torch.cat(t_span)
if isinstance(t_span, list): t_span = torch.cat(t_span)
# instantiate the solver in case the user has specified preference via a `str` and ensure compatibility of device ~ dtype
if type(solver) == str:
if isinstance(solver, str):
solver = str_to_solver(solver, x.dtype)
x, t_span = solver.sync_device_dtype(x, t_span)
stepping_class = solver.stepping_class
Expand All @@ -69,15 +69,13 @@ def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, n
if verbose: warn("Running interpolation not yet implemented for `tsit5`")
interpolator = None

if type(interpolator) == str:
if isinstance(interpolator, str):
interpolator = str_to_interp(interpolator, x.dtype)
x, t_span = interpolator.sync_device_dtype(x, t_span)

# access parallel integration routines with different t_spans for each sample in `x`.
if len(t_span.shape) > 1:
raise NotImplementedError("Parallel routines not implemented yet, check experimental versions of `torchdyn`")
# odeint routine with a single t_span for all samples
elif len(t_span.shape) == 1:
t_span.size
if t_span.dim() == 1:
if stepping_class == 'fixed':
if atol != odeint.__defaults__[0] or rtol != odeint.__defaults__[1]:
warn("Setting tolerances has no effect on fixed-step methods")
Expand All @@ -89,6 +87,11 @@ def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, n
dt = init_step(f, k1, x, t, solver.order, atol, rtol)
if len(save_at) > 0: warn("Setting save_at has no effect on adaptive-step methods")
return _adaptive_odeint(f_, k1, x, dt, t_span, solver, atol, rtol, args, interpolator, return_all_eval, seminorm)
else:
raise RuntimeError("Invalid stepping class provided")
# access parallel integration routines with different t_spans for each sample in `x`.
else:
raise NotImplementedError("Parallel routines not implemented yet, check experimental versions of `torchdyn`")


# TODO (qol) state augmentation for symplectic methods
Expand All @@ -114,10 +117,10 @@ def odeint_symplectic(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:U
f_ = lambda t, x: -f(-t, x)
t_span = -t_span
else: f_ = f
if type(t_span) == list: t_span = torch.cat(t_span)
if isinstance(t_span, list): t_span = torch.cat(t_span)

# instantiate the solver in case the user has specified preference via a `str` and ensure compatibility of device ~ dtype
if type(solver) == str:
if isinstance(solver, str):
solver = str_to_solver(solver, x.dtype)
x, t_span = solver.sync_device_dtype(x, t_span)
stepping_class = solver.stepping_class
Expand Down Expand Up @@ -167,7 +170,7 @@ def odeint_mshooting(f:Callable, x:Tensor, t_span:Tensor, solver:Union[str, nn.M
TODO: At the moment assumes the ODE to NOT be time-varying. An extension is possible by adaptive the step
function of a parallel-in-time solvers.
"""
if type(solver) == str:
if isinstance(solver, str):
solver = str_to_ms_solver(solver)
x, t_span = solver.sync_device_dtype(x, t_span)
# first-guess B0 of shooting parameters
Expand Down Expand Up @@ -199,7 +202,7 @@ def odeint_hybrid(f, x, t_span, j_span, solver, callbacks, atol=1e-3, rtol=1e-3,
priority (str, optional): Defaults to 'jump'.
"""
# instantiate the solver in case the user has specified preference via a `str` and ensure compatibility of device ~ dtype
if type(solver) == str: solver = str_to_solver(solver, x.dtype)
if isinstance(solver, str): solver = str_to_solver(solver, x.dtype)
x, t_span = solver.sync_device_dtype(x, t_span)
x_shape = x.shape
ckpt_counter, ckpt_flag, jnum = 0, False, 0
Expand Down
4 changes: 2 additions & 2 deletions torchdyn/numerics/sdeint.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def sdeint(
# make sde to SDEFunc form?
sde = check_sde(sde)

if type(t_span) == list:
if isinstance(t_span, list):
t_span = torch.cat(t_span)

if type(solver) == str:
if isinstance(solver, str):
solver = sde_str_to_solver(solver, sde, bm, x.dtype)
x, t_span = solver.sync_device_dtype(x, t_span)
stepping_class = solver.stepping_class
Expand Down
6 changes: 4 additions & 2 deletions torchdyn/numerics/solvers/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ def __init__(self, coarse_method, fine_method):
from torchdyn.numerics.solvers.ode import str_to_solver

super(MultipleShootingDiffeqSolver, self).__init__()
if type(coarse_method) == str: self.coarse_method = str_to_solver(coarse_method)
if type(fine_method) == str: self.fine_method = str_to_solver(fine_method)
if isinstance(coarse_method, str):
self.coarse_method = str_to_solver(coarse_method)
if isinstance(fine_method, str):
self.fine_method = str_to_solver(fine_method)

def sync_device_dtype(self, x, t_span):
"Ensures `x`, `t_span`, `tableau` and other solver tensors are on the same device with compatible dtypes"
Expand Down
Loading