Skip to content

Commit

Permalink
updated documentation for utils
Browse files Browse the repository at this point in the history
  • Loading branch information
majoma7 committed Aug 13, 2024
1 parent 325296d commit a059a78
Show file tree
Hide file tree
Showing 15 changed files with 249 additions and 162 deletions.
16 changes: 8 additions & 8 deletions ddopnew/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,13 +547,13 @@
'ddopnew/loss_functions.py'),
'ddopnew.loss_functions.quantile_loss': ( '00_utils/loss_functions.html#quantile_loss',
'ddopnew/loss_functions.py')},
'ddopnew.obsprocessors': { 'ddopnew.obsprocessors.FlattenTimeDimNumpy': ( '00_utils/preprocessors.html#flattentimedimnumpy',
'ddopnew.obsprocessors': { 'ddopnew.obsprocessors.FlattenTimeDimNumpy': ( '00_utils/obsprocessors.html#flattentimedimnumpy',
'ddopnew/obsprocessors.py'),
'ddopnew.obsprocessors.FlattenTimeDimNumpy.__call__': ( '00_utils/preprocessors.html#flattentimedimnumpy.__call__',
'ddopnew.obsprocessors.FlattenTimeDimNumpy.__call__': ( '00_utils/obsprocessors.html#flattentimedimnumpy.__call__',
'ddopnew/obsprocessors.py'),
'ddopnew.obsprocessors.FlattenTimeDimNumpy.__init__': ( '00_utils/preprocessors.html#flattentimedimnumpy.__init__',
'ddopnew.obsprocessors.FlattenTimeDimNumpy.__init__': ( '00_utils/obsprocessors.html#flattentimedimnumpy.__init__',
'ddopnew/obsprocessors.py'),
'ddopnew.obsprocessors.FlattenTimeDimNumpy.check_input': ( '00_utils/preprocessors.html#flattentimedimnumpy.check_input',
'ddopnew.obsprocessors.FlattenTimeDimNumpy.check_input': ( '00_utils/obsprocessors.html#flattentimedimnumpy.check_input',
'ddopnew/obsprocessors.py')},
'ddopnew.postprocessors': { 'ddopnew.postprocessors.ClipAction': ( '00_utils/postprocessors.html#clipaction',
'ddopnew/postprocessors.py'),
Expand All @@ -579,13 +579,13 @@
'ddopnew/torch_utils/loss_functions.py'),
'ddopnew.torch_utils.loss_functions.quantile_loss': ( '00_utils/torch_loss_functions.html#quantile_loss',
'ddopnew/torch_utils/loss_functions.py')},
'ddopnew.torch_utils.obsprocessors': { 'ddopnew.torch_utils.obsprocessors.FlattenTimeDim': ( '00_utils/torch_pre_processors.html#flattentimedim',
'ddopnew.torch_utils.obsprocessors': { 'ddopnew.torch_utils.obsprocessors.FlattenTimeDim': ( '00_utils/torch_obs_processors.html#flattentimedim',
'ddopnew/torch_utils/obsprocessors.py'),
'ddopnew.torch_utils.obsprocessors.FlattenTimeDim.__call__': ( '00_utils/torch_pre_processors.html#flattentimedim.__call__',
'ddopnew.torch_utils.obsprocessors.FlattenTimeDim.__call__': ( '00_utils/torch_obs_processors.html#flattentimedim.__call__',
'ddopnew/torch_utils/obsprocessors.py'),
'ddopnew.torch_utils.obsprocessors.FlattenTimeDim.__init__': ( '00_utils/torch_pre_processors.html#flattentimedim.__init__',
'ddopnew.torch_utils.obsprocessors.FlattenTimeDim.__init__': ( '00_utils/torch_obs_processors.html#flattentimedim.__init__',
'ddopnew/torch_utils/obsprocessors.py'),
'ddopnew.torch_utils.obsprocessors.FlattenTimeDim.check_input': ( '00_utils/torch_pre_processors.html#flattentimedim.check_input',
'ddopnew.torch_utils.obsprocessors.FlattenTimeDim.check_input': ( '00_utils/torch_obs_processors.html#flattentimedim.check_input',
'ddopnew/torch_utils/obsprocessors.py')},
'ddopnew.torch_utils.preprocessors': { 'ddopnew.torch_utils.preprocessors.FlattenTimeDim': ( '00_utils/torch_pre_processors.html#flattentimedim',
'ddopnew/torch_utils/preprocessors.py'),
Expand Down
3 changes: 0 additions & 3 deletions ddopnew/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,7 @@ def reset_index(self,
if start_index=="random":
if self.mode == "train":
if self.dataloader.len_train is not None and self.dataloader.len_train > self.mdp_info.horizon:
# seed = self.seed_setter.return_seed()
# np.random.seed(seed)
random_index = np.random.choice(self.dataloader.len_train-self.mdp_info.horizon)
# print("reset period:", random_index)
else:
random_index = 0
self.start_index = random_index
Expand Down
6 changes: 3 additions & 3 deletions ddopnew/obsprocessors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_utils/11_preprocessors.ipynb.
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_utils/11_obsprocessors.ipynb.

# %% auto 0
__all__ = ['FlattenTimeDimNumpy']

# %% ../nbs/00_utils/11_preprocessors.ipynb 3
# %% ../nbs/00_utils/11_obsprocessors.ipynb 3
from typing import Union, Optional

import numpy as np
Expand All @@ -13,7 +13,7 @@
import torch.nn as nn
import torch.nn.functional as F

# %% ../nbs/00_utils/11_preprocessors.ipynb 4
# %% ../nbs/00_utils/11_obsprocessors.ipynb 4
class FlattenTimeDimNumpy():

"""
Expand Down
6 changes: 3 additions & 3 deletions ddopnew/postprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.nn.functional as F

# %% ../nbs/00_utils/10_postprocessors.ipynb 4
class ClipAction:
class ClipAction():
"""
A class to clip input values within specified bounds.
If the parameters lower and upper are not specified, no clipping is performed.
Expand Down Expand Up @@ -59,8 +59,8 @@ def __call__(self, input: np.ndarray) -> np.ndarray: #

return output

# %% ../nbs/00_utils/10_postprocessors.ipynb 7
class RoundAction:
# %% ../nbs/00_utils/10_postprocessors.ipynb 8
class RoundAction():
"""
A class to round input values to the nearest specified unit size.
Unit size can be any decimal value like 10, 3, 1, 0.1, 0.03, etc.
Expand Down
6 changes: 3 additions & 3 deletions ddopnew/torch_utils/obsprocessors.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/00_utils/21_torch_pre_processors.ipynb.
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/00_utils/21_torch_obs_processors.ipynb.

# %% auto 0
__all__ = ['FlattenTimeDim']

# %% ../../nbs/00_utils/21_torch_pre_processors.ipynb 3
# %% ../../nbs/00_utils/21_torch_obs_processors.ipynb 3
import numpy as np
from typing import Optional

import torch


# %% ../../nbs/00_utils/21_torch_pre_processors.ipynb 4
# %% ../../nbs/00_utils/21_torch_obs_processors.ipynb 4
class FlattenTimeDim():
"""
Obsprocessor to flatten the time and feature dimension of the input.
Expand Down
9 changes: 5 additions & 4 deletions ddopnew/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
__all__ = ['check_parameter_types', 'Parameter', 'MDPInfo', 'DatasetWrapper']

# %% ../nbs/00_utils/00_utils.ipynb 3
import numpy as np
from torch.utils.data import Dataset
from typing import Union, List, Tuple, Literal
from gymnasium.spaces import Space
from .dataloaders.base import BaseDataLoader

import numpy as np

# %% ../nbs/00_utils/00_utils.ipynb 4
def check_parameter_types(
*args, # any number of parameters to be checked
Expand All @@ -24,7 +25,7 @@ def check_parameter_types(
if not isinstance(arg, parameter_type):
raise TypeError(f"Argument {index+1} of {len(args)} is of type {type(arg).__name__}, expected {parameter_type.__name__}")

# %% ../nbs/00_utils/00_utils.ipynb 8
# %% ../nbs/00_utils/00_utils.ipynb 7
class Parameter():

"""
Expand Down Expand Up @@ -110,7 +111,7 @@ def size(self):
"""
return self._value.size

# %% ../nbs/00_utils/00_utils.ipynb 17
# %% ../nbs/00_utils/00_utils.ipynb 16
class MDPInfo():
"""
This class is used to store the information of the environment.
Expand Down Expand Up @@ -152,7 +153,7 @@ def shape(self):
"""
return self.observation_space.shape + self.action_space.shape

# %% ../nbs/00_utils/00_utils.ipynb 19
# %% ../nbs/00_utils/00_utils.ipynb 20
class DatasetWrapper(Dataset):
"""
This class is used to wrap a Pytorch Dataset around the ddopnew dataloader
Expand Down
Loading

0 comments on commit a059a78

Please sign in to comment.