Skip to content

Commit

Permalink
added function for experiment automation
Browse files Browse the repository at this point in the history
  • Loading branch information
majoma7 committed Aug 16, 2024
1 parent 3021006 commit 161dbc3
Show file tree
Hide file tree
Showing 7 changed files with 591 additions and 41 deletions.
103 changes: 99 additions & 4 deletions ddopnew/_modidx.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion ddopnew/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_release_tag(dataset_type, version, token=None):
else:
release_tag = f"{dataset_type}_{version}"

print(f"Filtered release tags: {release_tags_filtered}")
logging.debug(f"Filtered release tags: {release_tags_filtered}")
return release_tag

def get_dataset_url(dataset_type, dataset_number, release_tag, token=None):
Expand Down
17 changes: 3 additions & 14 deletions ddopnew/experiment_functions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/30_experiment_functions/10_experiment_functions.ipynb.

# %% auto 0
__all__ = ['EarlyStoppingHandler', 'calculate_score', 'log_info', 'update_best', 'save_agent', 'select_agent', 'test_agent',
'run_test_episode', 'run_experiment']
__all__ = ['EarlyStoppingHandler', 'calculate_score', 'log_info', 'update_best', 'save_agent', 'test_agent', 'run_test_episode',
'run_experiment']

# %% ../nbs/30_experiment_functions/10_experiment_functions.ipynb 3
from abc import ABC, abstractmethod
Expand All @@ -15,7 +15,6 @@

from .envs.base import BaseEnvironment
from .agents.base import BaseAgent
from .agents.class_names import AGENT_CLASSES

import importlib

Expand Down Expand Up @@ -160,17 +159,7 @@ def save_agent(agent: BaseAgent, # Any agent inheriting from BaseAgent
save_dir = f"{experiment_dir}/saved_models/best"
agent.save(save_dir)

# %% ../nbs/30_experiment_functions/10_experiment_functions.ipynb 9
def select_agent(agent_name: str) -> type: #
""" Select an agent class from a list of agent names and return the class"""
if agent_name in AGENT_CLASSES:
module_path, class_name = AGENT_CLASSES[agent_name].rsplit(".", 1)
module = importlib.import_module(module_path)
return getattr(module, class_name)
else:
raise ValueError(f"Unknown agent name: {agent_name}")

# %% ../nbs/30_experiment_functions/10_experiment_functions.ipynb 11
# %% ../nbs/30_experiment_functions/10_experiment_functions.ipynb 10
def test_agent(agent: BaseAgent,
env: BaseEnvironment,
return_dataset = False,
Expand Down
214 changes: 213 additions & 1 deletion ddopnew/meta_experiment_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb.

# %% auto 0
__all__ = ['select_agent']
__all__ = ['prep_experiment', 'init_wandb', 'track_libraries_and_git', 'import_config', 'transfer_lag_window_to_env',
'get_ddop_data', 'download_data', 'set_indices', 'set_up_env', 'set_up_earlystoppinghandler', 'select_agent']

# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 3
from abc import ABC, abstractmethod
Expand All @@ -10,8 +11,14 @@
from datetime import datetime
import numpy as np
import sys
import os
import yaml

from .tracking import get_git_hash, get_library_version
from .agents.class_names import AGENT_CLASSES
from .dataloaders.tabular import XYDataLoader
from .datasets import DatasetLoader
from .experiment_functions import EarlyStoppingHandler

import wandb

Expand All @@ -23,7 +30,212 @@
# Think about how to handle mushroom integration.
from mushroom_rl.core import Core

# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 5
def prep_experiment(
project_name: str,
libraries_to_track: List[str] = ["ddopnew"],
config_train_name: str = "config_train",
config_agent_name: str = "config_agent",
config_env_name: str = "config_env",
):
""" First stpes to always execute when starting an experiment (using wandb for tracking)"""

init_wandb(project_name)
track_libraries_and_git(libraries_to_track)

config_train = import_config(config_train_name)
config_agent = import_config(config_agent_name) # General config file containing all agent parameters
config_env = import_config(config_env_name)

AgentClass = select_agent(config_train["agent"]) # Select agent class and import dynamically
agent_name = config_train["agent"]
config_agent = config_agent[config_train["agent"]] # Get parameters for specific agent

transfer_lag_window_to_env(config_env, config_agent)

wandb.config.update(config_train)
wandb.config.update(config_agent)
wandb.config.update(config_env)

return config_train, config_agent, config_env, AgentClass, agent_name

# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 6
def init_wandb(project_name: str): #

""" init wandb """

wandb.init(
project=project_name,
name = f"{project_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
)

# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 7
def track_libraries_and_git( libraries_to_track: List[str],
tracking: bool = True,
tracking_tool = "wandb", # Currenty only wandb is supported
) -> None:

"""
Track the versions of the libraries and the git hash of the repository.
"""

for lib in libraries_to_track:
version = get_library_version(lib, tracking=tracking, tracking_tool=tracking_tool)
logging.info(f"{lib}: {version}")
git_hash = get_git_hash(".", tracking=tracking, tracking_tool=tracking_tool)
logging.info(f"Git hash: {git_hash}")

# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 8
def import_config( filename: str, # Name of the file, must be a yaml file
path: str = None # Optional path to the file if it is not in the current directory
) -> Dict:

"""
Import a config file in YAML format
"""

# Check if filename has an extension
if '.' in filename:
extension = filename.split(".")[-1]
else:
extension = ''

if not extension:
filename += ".yaml"
elif extension not in ["yaml", "yml"]:
raise ValueError("The configuration file must have a .yaml or .yml extension.")


if path is not None:
full_path = os.path.join(path, filename)
else:
full_path = filename


# Check if the file exists
if not os.path.exists(full_path):
raise FileNotFoundError(f"The configuration file '{full_path}' does not exist.")

with open(full_path, "r") as stream:
try:
config = yaml.safe_load(stream)
except yaml.YAMLError as exc:
raise yaml.YAMLError(f"Error parsing YAML file '{full_path}': {exc}")

logging.info(f"Configuration file '{filename}' successfully loaded.")
logging.debug(f"Configuration: {config}")

return config

# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 9
def transfer_lag_window_to_env(config_env: Dict, #
config_agent: Dict
) -> None:

"""
Transfer the lag window from the agent configuration to the environment configuration
"""

if "lag_window" in config_agent.keys():
if isinstance(config_agent["lag_window"], int):
config_env["lag_window_params"]["lag_window"] = config_agent["lag_window"]
else:
raise ValueError("The lag window must be an integer.")
del config_agent["lag_window"]
else:
logging.warning("No lag window specified in the agent configuration. Keeping value from env config")

# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 11
def get_ddop_data(
config_env: Dict,
overwrite: bool = False
) -> Tuple:

""" Standard function to load data provided by the ddop package """

data = download_data(config_env, overwrite)

val_index_start, test_index_start = set_indices(config_env, data[0])

return data, val_index_start, test_index_start


# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 12
def download_data( config_env: Dict,
overwrite: bool = False #
) -> Tuple:

""" Download standard dataset from ddop repository using the DatasetLoader class """

datasetloader = DatasetLoader()

data = datasetloader.load_dataset(
dataset_type = config_env["dataset_type"],
dataset_number = config_env["dataset_number"],
overwrite=False)

data_tuple = data["data_raw_features"], data["data_raw_target"]

return data_tuple

# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 13
def set_indices(config_env: Dict, #
X: np.ndarray
) -> Tuple:

""" Set the indices for the validation and test set """

val_index_start = len(X) - config_env["size_val"] - config_env["size_test"]
test_index_start = len(X) - config_env["size_test"]

return val_index_start, test_index_start

# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 15
def set_up_env(
env_class,
raw_data: Tuple, #
val_index_start: int,
test_index_start: int,
config_env: Dict,
postprocessors: List,
) -> object:

""" Set up the environment """

dataloader = XYDataLoader( X = raw_data[0],
Y = raw_data[1],
val_index_start = val_index_start,
test_index_start = test_index_start,
lag_window_params = config_env["lag_window_params"],
normalize_features = {'normalize': config_env["normalize_features"], 'ignore_one_hot': True})

environment = env_class(
dataloader = dataloader,
postprocessors = postprocessors,
**config_env["env_kwargs"]
)

return environment

# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 17
def set_up_earlystoppinghandler(config_train: Dict) -> object: #

""" Set up the early stopping handler """

# check if config_train has either early_stopping_patience or early_stopping_warmup
if "early_stopping_patience" in config_train or "early_stopping_warmup" in config_train:
warmup = config_train["early_stopping_warmup"] if "early_stopping_warmup" in config_train else 0
patience = config_train["early_stopping_patience"] if "early_stopping_patience" in config_train else 0
earlystoppinghandler = EarlyStoppingHandler(warmup=warmup, patience=warmup)
else:
earlystoppinghandler = None

return earlystoppinghandler

# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 19
def select_agent(agent_name: str) -> type: #
""" Select an agent class from a list of agent names and return the class"""
if agent_name in AGENT_CLASSES:
Expand Down
19 changes: 0 additions & 19 deletions nbs/30_experiment_functions/10_experiment_functions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
"\n",
"from ddopnew.envs.base import BaseEnvironment\n",
"from ddopnew.agents.base import BaseAgent\n",
"from ddopnew.agents.class_names import AGENT_CLASSES\n",
"\n",
"import importlib\n",
"\n",
Expand Down Expand Up @@ -331,24 +330,6 @@
" agent.save(save_dir)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"\n",
"def select_agent(agent_name: str) -> type: #\n",
" \"\"\" Select an agent class from a list of agent names and return the class\"\"\"\n",
" if agent_name in AGENT_CLASSES:\n",
" module_path, class_name = AGENT_CLASSES[agent_name].rsplit(\".\", 1)\n",
" module = importlib.import_module(module_path)\n",
" return getattr(module, class_name)\n",
" else:\n",
" raise ValueError(f\"Unknown agent name: {agent_name}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
Loading

0 comments on commit 161dbc3

Please sign in to comment.