Skip to content

Commit

Permalink
allows for dict in set-param (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
majoma7 authored Oct 24, 2024
1 parent 3ab07ee commit 28faa93
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 10 deletions.
4 changes: 2 additions & 2 deletions ddopai/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# %% ../../nbs/20_environments/20_base_env/10_base_env.ipynb 4
import gymnasium as gym
from abc import ABC, abstractmethod
from typing import Union, List
from typing import Union, List, Dict
import numpy as np

from ..utils import MDPInfo, Parameter, set_param
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(self,

def set_param(self,
name: str, # name of the parameter (will become the attribute name)
input: Parameter | int | float | np.ndarray | List | None, # input value of the parameter
input: Parameter | int | float | np.ndarray | List | Dict | None, # input value of the parameter
shape: tuple = (1,), # shape of the parameter
new: bool = False # whether to create a new parameter or update an existing one
) -> None: #
Expand Down
10 changes: 7 additions & 3 deletions ddopai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# %% ../nbs/00_utils/00_utils.ipynb 3
from torch.utils.data import Dataset
from typing import Union, List, Tuple, Literal
from typing import Union, List, Tuple, Literal, Dict
from gymnasium.spaces import Space
from .dataloaders.base import BaseDataLoader

Expand Down Expand Up @@ -215,7 +215,7 @@ def merge_dictionaries(dict1, dict2):
# %% ../nbs/00_utils/00_utils.ipynb 20
def set_param(obj,
name: str, # name of the parameter (will become the attribute name)
input: Parameter | int | float | np.ndarray | List | None , # input value of the parameter
input: Parameter | int | float | np.ndarray | List | Dict | None , # input value of the parameter
shape: tuple = (1,), # shape of the parameter
new: bool = False, # whether to create a new parameter or update an existing one
):
Expand Down Expand Up @@ -253,8 +253,12 @@ def set_param(obj,
param = np.full(shape, input.item())
else:
raise ValueError("Error in setting parameter. Input array must match the specified shape or be a single-element array")

elif isinstance(input, dict):
param = input

else:
raise TypeError(f"Input must be a Parameter, scalar, or numpy array, got {type(input).__name__} with value {input}")
raise TypeError(f"Input must be a Parameter, scalar, numpy array, list, or dict. Got {type(input).__name__} with value {input}")

# set the parameter
if new:
Expand Down
10 changes: 7 additions & 3 deletions nbs/00_utils/00_utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
"#| export\n",
"\n",
"from torch.utils.data import Dataset\n",
"from typing import Union, List, Tuple, Literal\n",
"from typing import Union, List, Tuple, Literal, Dict\n",
"from gymnasium.spaces import Space\n",
"from ddopai.dataloaders.base import BaseDataLoader\n",
"\n",
Expand Down Expand Up @@ -657,7 +657,7 @@
"#| export\n",
"def set_param(obj,\n",
" name: str, # name of the parameter (will become the attribute name)\n",
" input: Parameter | int | float | np.ndarray | List | None , # input value of the parameter\n",
" input: Parameter | int | float | np.ndarray | List | Dict | None , # input value of the parameter\n",
" shape: tuple = (1,), # shape of the parameter\n",
" new: bool = False, # whether to create a new parameter or update an existing one\n",
" ): \n",
Expand Down Expand Up @@ -695,8 +695,12 @@
" param = np.full(shape, input.item())\n",
" else:\n",
" raise ValueError(\"Error in setting parameter. Input array must match the specified shape or be a single-element array\")\n",
"\n",
" elif isinstance(input, dict):\n",
" param = input\n",
"\n",
" else:\n",
" raise TypeError(f\"Input must be a Parameter, scalar, or numpy array, got {type(input).__name__} with value {input}\")\n",
" raise TypeError(f\"Input must be a Parameter, scalar, numpy array, list, or dict. Got {type(input).__name__} with value {input}\")\n",
"\n",
" # set the parameter\n",
" if new:\n",
Expand Down
4 changes: 2 additions & 2 deletions nbs/20_environments/20_base_env/10_base_env.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"\n",
"import gymnasium as gym\n",
"from abc import ABC, abstractmethod\n",
"from typing import Union, List\n",
"from typing import Union, List, Dict\n",
"import numpy as np\n",
"\n",
"from ddopai.utils import MDPInfo, Parameter, set_param\n",
Expand Down Expand Up @@ -93,7 +93,7 @@
"\n",
" def set_param(self,\n",
" name: str, # name of the parameter (will become the attribute name)\n",
" input: Parameter | int | float | np.ndarray | List | None, # input value of the parameter\n",
" input: Parameter | int | float | np.ndarray | List | Dict | None, # input value of the parameter\n",
" shape: tuple = (1,), # shape of the parameter\n",
" new: bool = False # whether to create a new parameter or update an existing one\n",
" ) -> None: #\n",
Expand Down

0 comments on commit 28faa93

Please sign in to comment.