Skip to content

Commit

Permalink
fixed mistake in EarlyStoppingHandler setup
Browse files Browse the repository at this point in the history
  • Loading branch information
majoma7 committed Aug 17, 2024
1 parent 79019f1 commit cc7d29f
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 46 deletions.
3 changes: 3 additions & 0 deletions ddopnew/experiment_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def __init__(
self.criteria = criteria
self.direction = direction

print("warmup", warmup)
print("patience", patience)

def add_result(self,
J: float, # Return (discounted rewards) of the last epoch
R: float, # Total rewards of the last epoch
Expand Down
3 changes: 2 additions & 1 deletion ddopnew/meta_experiment_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ def set_up_earlystoppinghandler(config_train: Dict) -> object: #
if "early_stopping_patience" in config_train or "early_stopping_warmup" in config_train:
warmup = config_train["early_stopping_warmup"] if "early_stopping_warmup" in config_train else 0
patience = config_train["early_stopping_patience"] if "early_stopping_patience" in config_train else 0
earlystoppinghandler = EarlyStoppingHandler(warmup=warmup, patience=warmup)

earlystoppinghandler = EarlyStoppingHandler(warmup=warmup, patience=patience)
else:
earlystoppinghandler = None

Expand Down
58 changes: 29 additions & 29 deletions nbs/30_experiment_functions/10_experiment_functions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -30,7 +30,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -57,7 +57,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -122,15 +122,15 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L28){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L27){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## EarlyStoppingHandler\n",
"\n",
Expand All @@ -153,7 +153,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L28){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L27){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## EarlyStoppingHandler\n",
"\n",
Expand All @@ -174,7 +174,7 @@
"| direction | str | max | Whether reward shall be maximized or minimized |"
]
},
"execution_count": null,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -185,15 +185,15 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L54){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L56){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### EarlyStoppingHandler.add_result\n",
"\n",
Expand All @@ -210,7 +210,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L54){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L56){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### EarlyStoppingHandler.add_result\n",
"\n",
Expand All @@ -225,7 +225,7 @@
"| **Returns** | **bool** | |"
]
},
"execution_count": null,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -245,7 +245,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -341,7 +341,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -598,15 +598,15 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L256){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L248){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### run_experiment\n",
"\n",
Expand Down Expand Up @@ -640,7 +640,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L256){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L248){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### run_experiment\n",
"\n",
Expand Down Expand Up @@ -672,7 +672,7 @@
"| eval_step_info | bool | False | |"
]
},
"execution_count": null,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -713,15 +713,15 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L174){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L166){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### test_agent\n",
"\n",
Expand All @@ -742,7 +742,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L174){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L166){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### test_agent\n",
"\n",
Expand All @@ -761,7 +761,7 @@
"| eval_step_info | bool | False | |"
]
},
"execution_count": null,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -772,15 +772,15 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L194){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### run_test_episode\n",
"\n",
Expand All @@ -801,7 +801,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L194){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### run_test_episode\n",
"\n",
Expand All @@ -820,7 +820,7 @@
"| eval_step_info | bool | False | Print step info during evaluation |"
]
},
"execution_count": null,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -838,14 +838,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"R: -5.0652502410719995, J: -5.049058241594489\n"
"R: -3.1622384231698444, J: -3.150371113545027\n"
]
}
],
Expand Down Expand Up @@ -882,7 +882,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand Down
Loading

0 comments on commit cc7d29f

Please sign in to comment.