Skip to content

Commit

Permalink
created new folder with meta experiment functions
Browse files Browse the repository at this point in the history
  • Loading branch information
majoma7 committed Aug 16, 2024
1 parent 1c60e9b commit 3021006
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 2 deletions.
2 changes: 2 additions & 0 deletions ddopnew/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,8 @@
'ddopnew/loss_functions.py'),
'ddopnew.loss_functions.quantile_loss': ( '00_utils/loss_functions.html#quantile_loss',
'ddopnew/loss_functions.py')},
'ddopnew.meta_experiment_functions': { 'ddopnew.meta_experiment_functions.select_agent': ( '30_experiment_functions/meta_experiment_functions.html#select_agent',
'ddopnew/meta_experiment_functions.py')},
'ddopnew.obsprocessors': { 'ddopnew.obsprocessors.ConvertDictSpace': ( '00_utils/obsprocessors.html#convertdictspace',
'ddopnew/obsprocessors.py'),
'ddopnew.obsprocessors.ConvertDictSpace.__call__': ( '00_utils/obsprocessors.html#convertdictspace.__call__',
Expand Down
2 changes: 1 addition & 1 deletion ddopnew/envs/inventory/multi_period.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class MultiPeriodEnv(BaseInventoryEnv, ABC):

"""
XXX.
XXX
"""

def __init__(self,
Expand Down
34 changes: 34 additions & 0 deletions ddopnew/meta_experiment_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb.

# %% auto 0
__all__ = ['select_agent']

# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 3
from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Dict, Literal
import logging
from datetime import datetime
import numpy as np
import sys

from .tracking import get_git_hash, get_library_version

import wandb


import importlib

from tqdm import tqdm, trange

# Think about how to handle mushroom integration.
from mushroom_rl.core import Core

# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 8
def select_agent(agent_name: str) -> type: #
""" Select an agent class from a list of agent names and return the class"""
if agent_name in AGENT_CLASSES:
module_path, class_name = AGENT_CLASSES[agent_name].rsplit(".", 1)
module = importlib.import_module(module_path)
return getattr(module, class_name)
else:
raise ValueError(f"Unknown agent name: {agent_name}")
2 changes: 1 addition & 1 deletion nbs/21_envs_inventory/30_multi_period_envs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"class MultiPeriodEnv(BaseInventoryEnv, ABC):\n",
" \n",
" \"\"\"\n",
" XXX.\n",
" XXX\n",
" \"\"\"\n",
"\n",
" def __init__(self,\n",
Expand Down
174 changes: 174 additions & 0 deletions nbs/30_experiment_functions/20_meta_experiment_functions.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Meta experiment functions\n",
"\n",
"> Very high-level functions to run experiments with minimal code, directly from terminal.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| default_exp meta_experiment_functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"from nbdev.showdoc import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"\n",
"from abc import ABC, abstractmethod\n",
"from typing import Union, List, Tuple, Dict, Literal\n",
"import logging\n",
"from datetime import datetime \n",
"import numpy as np\n",
"import sys\n",
"\n",
"from ddopnew.tracking import get_git_hash, get_library_version\n",
"\n",
"import wandb\n",
"\n",
"\n",
"import importlib\n",
"\n",
"from tqdm import tqdm, trange\n",
"\n",
"# Think about how to handle mushroom integration.\n",
"from mushroom_rl.core import Core"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load files and set-up tracking\n",
"\n",
"> Fist part of experiment: Log into wandb and load config files"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def init_wandb(project_name: str): #\n",
"\n",
" \"\"\" init wandb \"\"\"\n",
"\n",
" wandb.init(\n",
" project=project_name,\n",
" name = f\"{project_name}_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def track_libraries_and_git( libraries_to_track: List[str],\n",
" tracking: bool = True,\n",
" tracking_tool = \"wandb\", # Currenty only wandb is supported\n",
" ) -> None:\n",
" \n",
" \"\"\"\n",
" Track the versions of the libraries and the git hash of the repository.\n",
"\n",
" \"\"\"\n",
"\n",
" for lib in libraries_to_track:\n",
" version = get_library_version(lib, tracking=tracking, tracking_tool=tracking_tool)\n",
" logging.info(f\"{lib}: {version}\")\n",
" git_hash = get_git_hash(\".\", tracking=tracking, tracking_tool=tracking_tool)\n",
" logging.info(f\"Git hash: {git_hash}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Helper functions\n",
"\n",
"> Some functions that are needed to run an experiment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"def select_agent(agent_name: str) -> type: #\n",
" \"\"\" Select an agent class from a list of agent names and return the class\"\"\"\n",
" if agent_name in AGENT_CLASSES:\n",
" module_path, class_name = AGENT_CLASSES[agent_name].rsplit(\".\", 1)\n",
" module = importlib.import_module(module_path)\n",
" return getattr(module, class_name)\n",
" else:\n",
" raise ValueError(f\"Unknown agent name: {agent_name}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"import nbdev; nbdev.nbdev_export()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "python3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

0 comments on commit 3021006

Please sign in to comment.