-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
created new folder with meta experiment functions
- Loading branch information
Showing
5 changed files
with
212 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,7 @@ | |
class MultiPeriodEnv(BaseInventoryEnv, ABC): | ||
|
||
""" | ||
XXX. | ||
XXX | ||
""" | ||
|
||
def __init__(self, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb. | ||
|
||
# %% auto 0 | ||
__all__ = ['select_agent'] | ||
|
||
# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 3 | ||
from abc import ABC, abstractmethod | ||
from typing import Union, List, Tuple, Dict, Literal | ||
import logging | ||
from datetime import datetime | ||
import numpy as np | ||
import sys | ||
|
||
from .tracking import get_git_hash, get_library_version | ||
|
||
import wandb | ||
|
||
|
||
import importlib | ||
|
||
from tqdm import tqdm, trange | ||
|
||
# Think about how to handle mushroom integration. | ||
from mushroom_rl.core import Core | ||
|
||
# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 8 | ||
def select_agent(agent_name: str) -> type: # | ||
""" Select an agent class from a list of agent names and return the class""" | ||
if agent_name in AGENT_CLASSES: | ||
module_path, class_name = AGENT_CLASSES[agent_name].rsplit(".", 1) | ||
module = importlib.import_module(module_path) | ||
return getattr(module, class_name) | ||
else: | ||
raise ValueError(f"Unknown agent name: {agent_name}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
174 changes: 174 additions & 0 deletions
174
nbs/30_experiment_functions/20_meta_experiment_functions.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Meta experiment functions\n", | ||
"\n", | ||
"> Very high-level functions to run experiments with minimal code, directly from terminal.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#| default_exp meta_experiment_functions" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#| hide\n", | ||
"from nbdev.showdoc import *" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#| export\n", | ||
"\n", | ||
"from abc import ABC, abstractmethod\n", | ||
"from typing import Union, List, Tuple, Dict, Literal\n", | ||
"import logging\n", | ||
"from datetime import datetime \n", | ||
"import numpy as np\n", | ||
"import sys\n", | ||
"\n", | ||
"from ddopnew.tracking import get_git_hash, get_library_version\n", | ||
"\n", | ||
"import wandb\n", | ||
"\n", | ||
"\n", | ||
"import importlib\n", | ||
"\n", | ||
"from tqdm import tqdm, trange\n", | ||
"\n", | ||
"# Think about how to handle mushroom integration.\n", | ||
"from mushroom_rl.core import Core" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Load files and set-up tracking\n", | ||
"\n", | ||
"> Fist part of experiment: Log into wandb and load config files" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def init_wandb(project_name: str): #\n", | ||
"\n", | ||
" \"\"\" init wandb \"\"\"\n", | ||
"\n", | ||
" wandb.init(\n", | ||
" project=project_name,\n", | ||
" name = f\"{project_name}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}\"\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def track_libraries_and_git( libraries_to_track: List[str],\n", | ||
" tracking: bool = True,\n", | ||
" tracking_tool = \"wandb\", # Currenty only wandb is supported\n", | ||
" ) -> None:\n", | ||
" \n", | ||
" \"\"\"\n", | ||
" Track the versions of the libraries and the git hash of the repository.\n", | ||
"\n", | ||
" \"\"\"\n", | ||
"\n", | ||
" for lib in libraries_to_track:\n", | ||
" version = get_library_version(lib, tracking=tracking, tracking_tool=tracking_tool)\n", | ||
" logging.info(f\"{lib}: {version}\")\n", | ||
" git_hash = get_git_hash(\".\", tracking=tracking, tracking_tool=tracking_tool)\n", | ||
" logging.info(f\"Git hash: {git_hash}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Helper functions\n", | ||
"\n", | ||
"> Some functions that are needed to run an experiment" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#| export\n", | ||
"def select_agent(agent_name: str) -> type: #\n", | ||
" \"\"\" Select an agent class from a list of agent names and return the class\"\"\"\n", | ||
" if agent_name in AGENT_CLASSES:\n", | ||
" module_path, class_name = AGENT_CLASSES[agent_name].rsplit(\".\", 1)\n", | ||
" module = importlib.import_module(module_path)\n", | ||
" return getattr(module, class_name)\n", | ||
" else:\n", | ||
" raise ValueError(f\"Unknown agent name: {agent_name}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#| hide\n", | ||
"import nbdev; nbdev.nbdev_export()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "python3", | ||
"language": "python", | ||
"name": "python3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |