Skip to content

Commit

Permalink
allowed for set_param to infer current shape if not new
Browse files Browse the repository at this point in the history
  • Loading branch information
majoma7 committed Oct 24, 2024
1 parent b17873a commit 4bb712d
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ To be written.
To make any enviroment compatible with mushroomRL and other agents
defined within ddopai, there are some additional requirements when
defining the environment. Instead of inheriting from `gym.Env`, the
environment should inherit from `ddopai.envs.base.BaseEnvironment`. This
base class provides some additional necessary methods and attributes to
ensure compatibility with the agents. Below are the steps to convert a
Gym environment to a ddopai environment. We strongly recommend you to
also look at the implementation of the NewsvendorEnv
environment should inherit from
[`ddopai.envs.base.BaseEnvironment`](https://opimwue.github.io/ddopai/20_environments/20_base_env/base_env.html#baseenvironment).
This base class provides some additional necessary methods and
attributes to ensure compatibility with the agents. Below are the steps
to convert a Gym environment to a ddopai environment. We strongly
recommend you to also look at the implementation of the NewsvendorEnv
(nbs/20_environments/21_envs_inventory/20_single_period_envs.ipynb) as
an example.

Expand Down
10 changes: 10 additions & 0 deletions ddopai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,19 @@ def set_param(obj,
False, the function will raise an error if the parameter does not exist.
"""


if input is None:
param = None

if not new:
# get current shape of parameter
if not hasattr(self, name):
# if parameter is not a dict, get the shape
raise AttributeError(f"Parameter {name} does not exist")

if not isinstance(getattr(self, name), dict):
shape = getattr(self, name).shape

elif isinstance(input, Parameter):
if input.shape != shape:
raise ValueError("Parameter shape must be equal to the shape specified for this environment parameter")
Expand Down
10 changes: 10 additions & 0 deletions nbs/00_utils/00_utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -668,9 +668,19 @@
" False, the function will raise an error if the parameter does not exist.\n",
" \"\"\"\n",
"\n",
"\n",
" if input is None:\n",
" param = None\n",
"\n",
" if not new:\n",
" # get current shape of parameter\n",
" if not hasattr(self, name):\n",
" # if parameter is not a dict, get the shape\n",
" raise AttributeError(f\"Parameter {name} does not exist\")\n",
"\n",
" if not isinstance(getattr(self, name), dict):\n",
" shape = getattr(self, name).shape\n",
"\n",
" elif isinstance(input, Parameter):\n",
" if input.shape != shape:\n",
" raise ValueError(\"Parameter shape must be equal to the shape specified for this environment parameter\")\n",
Expand Down

0 comments on commit 4bb712d

Please sign in to comment.