Skip to content

Commit

Permalink
small changes + formatting with black
Browse files Browse the repository at this point in the history
  • Loading branch information
JuanKo96 committed Aug 30, 2022
1 parent dac6c7e commit 0a3258d
Show file tree
Hide file tree
Showing 6 changed files with 512 additions and 191 deletions.
66 changes: 37 additions & 29 deletions torchdyn/core/defunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,26 @@


class DEFuncBase(nn.Module):
def __init__(self, vector_field:Callable, has_time_arg:bool=True):
def __init__(self, vector_field: Callable, has_time_arg: bool = True):
"""Basic wrapper to ensure call signature compatibility between generic torch Modules and vector fields.
Args:
vector_field (Callable): callable defining the dynamics / vector field / `dxdt` / forcing function
has_time_arg (bool, optional): Internal arg. to indicate whether the callable has `t` in its `__call__'
or `forward` method. Defaults to True.
"""
super().__init__()
self.nfe, self.vf, self.has_time_arg = 0., vector_field, has_time_arg
self.nfe, self.vf, self.has_time_arg = 0.0, vector_field, has_time_arg

def forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
def forward(self, t: Tensor, x: Tensor, args: Dict = {}) -> Tensor:
self.nfe += 1
if self.has_time_arg: return self.vf(t, x, args=args)
else: return self.vf(x)
if self.has_time_arg:
return self.vf(t, x, args=args)
else:
return self.vf(x)


class DEFunc(nn.Module):
def __init__(self, vector_field:Callable, order:int=1):
def __init__(self, vector_field: Callable, order: int = 1):
"""Special vector field wrapper for Neural ODEs.
Handles auxiliary tasks: time ("depth") concatenation, higher-order dynamics and forward propagated integral losses.
Expand All @@ -51,73 +53,79 @@ def __init__(self, vector_field:Callable, order:int=1):
(3) in case of higher-order dynamics, adjusts the vector field forward to recursively compute various orders.
"""
super().__init__()
self.vf, self.nfe, = vector_field, 0.
self.vf, self.nfe, = vector_field, 0.0
self.order, self.integral_loss, self.sensitivity = order, None, None
# identify whether vector field already has time arg

def forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
def forward(self, t: Tensor, x: Tensor, args: Dict = {}) -> Tensor:
self.nfe += 1
# set `t` depth-variable to DepthCat modules
for _, module in self.vf.named_modules():
if hasattr(module, 't'):
if hasattr(module, "t"):
module.t = t

# if-else to handle autograd training with integral loss propagated in x[:, 0]
if (self.integral_loss is not None) and self.sensitivity == 'autograd':
if (self.integral_loss is not None) and self.sensitivity == "autograd":
x_dyn = x[:, 1:]
dlds = self.integral_loss(t, x_dyn)
if len(dlds.shape) == 1: dlds = dlds[:, None]
if self.order > 1: x_dyn = self.horder_forward(t, x_dyn, args)
else: x_dyn = self.vf(t, x_dyn)
if len(dlds.shape) == 1:
dlds = dlds[:, None]
if self.order > 1:
x_dyn = self.horder_forward(t, x_dyn, args)
else:
x_dyn = self.vf(t, x_dyn)
return cat([dlds, x_dyn], 1).to(x_dyn)

# regular forward
else:
if self.order > 1: x = self.higher_order_forward(t, x)
else: x = self.vf(t, x, args=args)
if self.order > 1:
x = self.higher_order_forward(t, x)
else:
x = self.vf(t, x, args=args)
return x

def higher_order_forward(self, t:Tensor, x:Tensor, args:Dict={}) -> Tensor:
def higher_order_forward(self, t: Tensor, x: Tensor, args: Dict = {}) -> Tensor:
x_new = []
size_order = x.size(1) // self.order
for i in range(1, self.order):
x_new.append(x[:, size_order*i : size_order*(i+1)])
x_new.append(x[:, size_order * i : size_order * (i + 1)])
x_new.append(self.vf(t, x))
return cat(x_new, dim=1).to(x)


class SDEFunc(nn.Module):
def __init__(self, f:Callable, g:Callable, order:int=1, noise_type=None, sde_type=None):
def __init__(
self, f: Callable, g: Callable, order: int = 1, noise_type=None, sde_type=None
):
""""Special vector field wrapper for Neural SDEs.
Args:
f (Callable): callable defining the drift
g (Callable): callable defining the diffusion term
order (int, optional): order of the differential equation. Defaults to 1.
"""
super().__init__()
super().__init__()
self.order, self.intloss, self.sensitivity = order, None, None
self.f_func, self.g_func = f, g
self.nfe = 0
self.noise_type = noise_type
self.sde_type = sde_type

def forward(self, t:Tensor, x:Tensor) -> Tensor:
def forward(self, t: Tensor, x: Tensor) -> Tensor:
raise NotImplementedError("Hopefully soon...")
def f(self, t:Tensor, x:Tensor) -> Tensor:

def f(self, t: Tensor, x: Tensor) -> Tensor:
self.nfe += 1
# print(self.f_func)

if 't' not in getfullargspec(self.f_func.forward).args:
if "t" not in getfullargspec(self.f_func.forward).args:
return self.f_func(x)
else:
else:
return self.f_func(t, x)
def g(self, t:Tensor, x:Tensor) -> Tensor:

def g(self, t: Tensor, x: Tensor) -> Tensor:
self.nfe += 1
if 't' not in getfullargspec(self.g_func.forward).args:
if "t" not in getfullargspec(self.g_func.forward).args:
return self.g_func(x)
else:
else:
return self.g_func(t, x)

Loading

0 comments on commit 0a3258d

Please sign in to comment.