diff --git a/ddopnew/_modidx.py b/ddopnew/_modidx.py index 32ace96..62bc335 100644 --- a/ddopnew/_modidx.py +++ b/ddopnew/_modidx.py @@ -154,7 +154,11 @@ 'ddopnew/agents/newsvendor.py'), 'ddopnew.agents.newsvendor.NewsvendorSAAagent.fit': ( '41_NV_agents/nv_agents.html#newsvendorsaaagent.fit', 'ddopnew/agents/newsvendor.py')}, - 'ddopnew.agents.newsvendor.erm': { 'ddopnew.agents.newsvendor.erm.NVBaseAgent': ( '41_NV_agents/nv_erm_agents.html#nvbaseagent', + 'ddopnew.agents.newsvendor.erm': { 'ddopnew.agents.newsvendor.erm.BaseMetaAgent': ( '41_NV_agents/nv_erm_agents.html#basemetaagent', + 'ddopnew/agents/newsvendor/erm.py'), + 'ddopnew.agents.newsvendor.erm.BaseMetaAgent.set_meta_dataloader': ( '41_NV_agents/nv_erm_agents.html#basemetaagent.set_meta_dataloader', + 'ddopnew/agents/newsvendor/erm.py'), + 'ddopnew.agents.newsvendor.erm.NVBaseAgent': ( '41_NV_agents/nv_erm_agents.html#nvbaseagent', 'ddopnew/agents/newsvendor/erm.py'), 'ddopnew.agents.newsvendor.erm.NVBaseAgent.__init__': ( '41_NV_agents/nv_erm_agents.html#nvbaseagent.__init__', 'ddopnew/agents/newsvendor/erm.py'), @@ -166,12 +170,20 @@ 'ddopnew/agents/newsvendor/erm.py'), 'ddopnew.agents.newsvendor.erm.NewsvendorDLAgent.set_model': ( '41_NV_agents/nv_erm_agents.html#newsvendordlagent.set_model', 'ddopnew/agents/newsvendor/erm.py'), + 'ddopnew.agents.newsvendor.erm.NewsvendorDLMetaAgent': ( '41_NV_agents/nv_erm_agents.html#newsvendordlmetaagent', + 'ddopnew/agents/newsvendor/erm.py'), + 'ddopnew.agents.newsvendor.erm.NewsvendorDLMetaAgent.__init__': ( '41_NV_agents/nv_erm_agents.html#newsvendordlmetaagent.__init__', + 'ddopnew/agents/newsvendor/erm.py'), 'ddopnew.agents.newsvendor.erm.NewsvendorlERMAgent': ( '41_NV_agents/nv_erm_agents.html#newsvendorlermagent', 'ddopnew/agents/newsvendor/erm.py'), 'ddopnew.agents.newsvendor.erm.NewsvendorlERMAgent.__init__': ( '41_NV_agents/nv_erm_agents.html#newsvendorlermagent.__init__', 'ddopnew/agents/newsvendor/erm.py'), 'ddopnew.agents.newsvendor.erm.NewsvendorlERMAgent.set_model': ( '41_NV_agents/nv_erm_agents.html#newsvendorlermagent.set_model', 'ddopnew/agents/newsvendor/erm.py'), + 'ddopnew.agents.newsvendor.erm.NewsvendorlERMMetaAgent': ( '41_NV_agents/nv_erm_agents.html#newsvendorlermmetaagent', + 'ddopnew/agents/newsvendor/erm.py'), + 'ddopnew.agents.newsvendor.erm.NewsvendorlERMMetaAgent.__init__': ( '41_NV_agents/nv_erm_agents.html#newsvendorlermmetaagent.__init__', + 'ddopnew/agents/newsvendor/erm.py'), 'ddopnew.agents.newsvendor.erm.SGDBaseAgent': ( '41_NV_agents/nv_erm_agents.html#sgdbaseagent', 'ddopnew/agents/newsvendor/erm.py'), 'ddopnew.agents.newsvendor.erm.SGDBaseAgent.__init__': ( '41_NV_agents/nv_erm_agents.html#sgdbaseagent.__init__', @@ -626,8 +638,30 @@ 'ddopnew/envs/inventory/single_period.py'), 'ddopnew.envs.inventory.single_period.NewsvendorEnv.__init__': ( '21_envs_inventory/single_period_envs.html#newsvendorenv.__init__', 'ddopnew/envs/inventory/single_period.py'), + 'ddopnew.envs.inventory.single_period.NewsvendorEnv.determine_cost': ( '21_envs_inventory/single_period_envs.html#newsvendorenv.determine_cost', + 'ddopnew/envs/inventory/single_period.py'), 'ddopnew.envs.inventory.single_period.NewsvendorEnv.step_': ( '21_envs_inventory/single_period_envs.html#newsvendorenv.step_', - 'ddopnew/envs/inventory/single_period.py')}, + 'ddopnew/envs/inventory/single_period.py'), + 'ddopnew.envs.inventory.single_period.NewsvendorEnvVariableSL': ( '21_envs_inventory/single_period_envs.html#newsvendorenvvariablesl', + 'ddopnew/envs/inventory/single_period.py'), + 'ddopnew.envs.inventory.single_period.NewsvendorEnvVariableSL.__init__': ( '21_envs_inventory/single_period_envs.html#newsvendorenvvariablesl.__init__', + 'ddopnew/envs/inventory/single_period.py'), + 'ddopnew.envs.inventory.single_period.NewsvendorEnvVariableSL.check_evaluation_metric': ( '21_envs_inventory/single_period_envs.html#newsvendorenvvariablesl.check_evaluation_metric', + 'ddopnew/envs/inventory/single_period.py'), + 'ddopnew.envs.inventory.single_period.NewsvendorEnvVariableSL.check_sl_distribution': ( '21_envs_inventory/single_period_envs.html#newsvendorenvvariablesl.check_sl_distribution', + 'ddopnew/envs/inventory/single_period.py'), + 'ddopnew.envs.inventory.single_period.NewsvendorEnvVariableSL.determine_cost': ( '21_envs_inventory/single_period_envs.html#newsvendorenvvariablesl.determine_cost', + 'ddopnew/envs/inventory/single_period.py'), + 'ddopnew.envs.inventory.single_period.NewsvendorEnvVariableSL.draw_parameter': ( '21_envs_inventory/single_period_envs.html#newsvendorenvvariablesl.draw_parameter', + 'ddopnew/envs/inventory/single_period.py'), + 'ddopnew.envs.inventory.single_period.NewsvendorEnvVariableSL.get_observation': ( '21_envs_inventory/single_period_envs.html#newsvendorenvvariablesl.get_observation', + 'ddopnew/envs/inventory/single_period.py'), + 'ddopnew.envs.inventory.single_period.NewsvendorEnvVariableSL.set_observation_space': ( '21_envs_inventory/single_period_envs.html#newsvendorenvvariablesl.set_observation_space', + 'ddopnew/envs/inventory/single_period.py'), + 'ddopnew.envs.inventory.single_period.NewsvendorEnvVariableSL.set_val_test_sl': ( '21_envs_inventory/single_period_envs.html#newsvendorenvvariablesl.set_val_test_sl', + 'ddopnew/envs/inventory/single_period.py'), + 'ddopnew.envs.inventory.single_period.NewsvendorEnvVariableSL.update_cu_co': ( '21_envs_inventory/single_period_envs.html#newsvendorenvvariablesl.update_cu_co', + 'ddopnew/envs/inventory/single_period.py')}, 'ddopnew.experiment_functions': { 'ddopnew.experiment_functions.EarlyStoppingHandler': ( '30_experiment_functions/experiment_functions.html#earlystoppinghandler', 'ddopnew/experiment_functions.py'), 'ddopnew.experiment_functions.EarlyStoppingHandler.__init__': ( '30_experiment_functions/experiment_functions.html#earlystoppinghandler.__init__', @@ -652,7 +686,9 @@ 'ddopnew/loss_functions.py'), 'ddopnew.loss_functions.quantile_loss': ( '00_utils/loss_functions.html#quantile_loss', 'ddopnew/loss_functions.py')}, - 'ddopnew.meta_experiment_functions': { 'ddopnew.meta_experiment_functions.download_data': ( '30_experiment_functions/meta_experiment_functions.html#download_data', + 'ddopnew.meta_experiment_functions': { 'ddopnew.meta_experiment_functions.clean_up': ( '30_experiment_functions/meta_experiment_functions.html#clean_up', + 'ddopnew/meta_experiment_functions.py'), + 'ddopnew.meta_experiment_functions.download_data': ( '30_experiment_functions/meta_experiment_functions.html#download_data', 'ddopnew/meta_experiment_functions.py'), 'ddopnew.meta_experiment_functions.get_ddop_data': ( '30_experiment_functions/meta_experiment_functions.html#get_ddop_data', 'ddopnew/meta_experiment_functions.py'), @@ -660,6 +696,8 @@ 'ddopnew/meta_experiment_functions.py'), 'ddopnew.meta_experiment_functions.init_wandb': ( '30_experiment_functions/meta_experiment_functions.html#init_wandb', 'ddopnew/meta_experiment_functions.py'), + 'ddopnew.meta_experiment_functions.prep_and_run_test': ( '30_experiment_functions/meta_experiment_functions.html#prep_and_run_test', + 'ddopnew/meta_experiment_functions.py'), 'ddopnew.meta_experiment_functions.prep_experiment': ( '30_experiment_functions/meta_experiment_functions.html#prep_experiment', 'ddopnew/meta_experiment_functions.py'), 'ddopnew.meta_experiment_functions.select_agent': ( '30_experiment_functions/meta_experiment_functions.html#select_agent', @@ -670,18 +708,28 @@ 'ddopnew/meta_experiment_functions.py'), 'ddopnew.meta_experiment_functions.set_up_env': ( '30_experiment_functions/meta_experiment_functions.html#set_up_env', 'ddopnew/meta_experiment_functions.py'), + 'ddopnew.meta_experiment_functions.set_warnings': ( '30_experiment_functions/meta_experiment_functions.html#set_warnings', + 'ddopnew/meta_experiment_functions.py'), 'ddopnew.meta_experiment_functions.track_libraries_and_git': ( '30_experiment_functions/meta_experiment_functions.html#track_libraries_and_git', 'ddopnew/meta_experiment_functions.py'), 'ddopnew.meta_experiment_functions.transfer_lag_window_to_env': ( '30_experiment_functions/meta_experiment_functions.html#transfer_lag_window_to_env', 'ddopnew/meta_experiment_functions.py')}, - 'ddopnew.obsprocessors': { 'ddopnew.obsprocessors.ConvertDictSpace': ( '00_utils/obsprocessors.html#convertdictspace', + 'ddopnew.obsprocessors': { 'ddopnew.obsprocessors.AddParamsToFeatures': ( '00_utils/obsprocessors.html#addparamstofeatures', + 'ddopnew/obsprocessors.py'), + 'ddopnew.obsprocessors.AddParamsToFeatures.__call__': ( '00_utils/obsprocessors.html#addparamstofeatures.__call__', + 'ddopnew/obsprocessors.py'), + 'ddopnew.obsprocessors.AddParamsToFeatures.__init__': ( '00_utils/obsprocessors.html#addparamstofeatures.__init__', + 'ddopnew/obsprocessors.py'), + 'ddopnew.obsprocessors.BaseProcessor': ( '00_utils/obsprocessors.html#baseprocessor', + 'ddopnew/obsprocessors.py'), + 'ddopnew.obsprocessors.BaseProcessor.determine_output_shape': ( '00_utils/obsprocessors.html#baseprocessor.determine_output_shape', + 'ddopnew/obsprocessors.py'), + 'ddopnew.obsprocessors.ConvertDictSpace': ( '00_utils/obsprocessors.html#convertdictspace', 'ddopnew/obsprocessors.py'), 'ddopnew.obsprocessors.ConvertDictSpace.__call__': ( '00_utils/obsprocessors.html#convertdictspace.__call__', 'ddopnew/obsprocessors.py'), 'ddopnew.obsprocessors.ConvertDictSpace.__init__': ( '00_utils/obsprocessors.html#convertdictspace.__init__', 'ddopnew/obsprocessors.py'), - 'ddopnew.obsprocessors.ConvertDictSpace.determine_output_shape': ( '00_utils/obsprocessors.html#convertdictspace.determine_output_shape', - 'ddopnew/obsprocessors.py'), 'ddopnew.obsprocessors.FlattenTimeDimNumpy': ( '00_utils/obsprocessors.html#flattentimedimnumpy', 'ddopnew/obsprocessors.py'), 'ddopnew.obsprocessors.FlattenTimeDimNumpy.__call__': ( '00_utils/obsprocessors.html#flattentimedimnumpy.__call__', @@ -710,6 +758,8 @@ 'ddopnew/torch_utils/loss_functions.py'), 'ddopnew.torch_utils.loss_functions.TorchQuantileLoss.__init__': ( '00_utils/torch_loss_functions.html#torchquantileloss.__init__', 'ddopnew/torch_utils/loss_functions.py'), + 'ddopnew.torch_utils.loss_functions.TorchQuantileLoss.convert_quantile': ( '00_utils/torch_loss_functions.html#torchquantileloss.convert_quantile', + 'ddopnew/torch_utils/loss_functions.py'), 'ddopnew.torch_utils.loss_functions.TorchQuantileLoss.forward': ( '00_utils/torch_loss_functions.html#torchquantileloss.forward', 'ddopnew/torch_utils/loss_functions.py'), 'ddopnew.torch_utils.loss_functions.quantile_loss': ( '00_utils/torch_loss_functions.html#quantile_loss', @@ -738,6 +788,11 @@ 'ddopnew/utils.py'), 'ddopnew.utils.DatasetWrapper.__init__': ('00_utils/utils.html#datasetwrapper.__init__', 'ddopnew/utils.py'), 'ddopnew.utils.DatasetWrapper.__len__': ('00_utils/utils.html#datasetwrapper.__len__', 'ddopnew/utils.py'), + 'ddopnew.utils.DatasetWrapperMeta': ('00_utils/utils.html#datasetwrappermeta', 'ddopnew/utils.py'), + 'ddopnew.utils.DatasetWrapperMeta.__getitem__': ( '00_utils/utils.html#datasetwrappermeta.__getitem__', + 'ddopnew/utils.py'), + 'ddopnew.utils.DatasetWrapperMeta.__init__': ( '00_utils/utils.html#datasetwrappermeta.__init__', + 'ddopnew/utils.py'), 'ddopnew.utils.MDPInfo': ('00_utils/utils.html#mdpinfo', 'ddopnew/utils.py'), 'ddopnew.utils.MDPInfo.__init__': ('00_utils/utils.html#mdpinfo.__init__', 'ddopnew/utils.py'), 'ddopnew.utils.MDPInfo.shape': ('00_utils/utils.html#mdpinfo.shape', 'ddopnew/utils.py'), diff --git a/ddopnew/agents/base.py b/ddopnew/agents/base.py index ed0ba9e..1907327 100644 --- a/ddopnew/agents/base.py +++ b/ddopnew/agents/base.py @@ -48,10 +48,16 @@ def draw_action(self, observation: np.ndarray) -> np.ndarray: # Internal logic of the agent to be implemented in draw_action_ method. """ - observation = self.add_batch_dim(observation) + batch_added = False + if not isinstance(observation, dict): + observation = self.add_batch_dim(observation) + batch_added = True for obsprocessor in self.obsprocessors: observation = obsprocessor(observation) + if not isinstance(observation, dict) and not batch_added: + observation = self.add_batch_dim(observation) + batch_added = True action = self.draw_action_(observation) diff --git a/ddopnew/agents/class_names.py b/ddopnew/agents/class_names.py index d41cf81..44fe34b 100644 --- a/ddopnew/agents/class_names.py +++ b/ddopnew/agents/class_names.py @@ -6,11 +6,20 @@ # %% ../../nbs/40_base_agents/10_AGENT_CLASSES.ipynb 3 AGENT_CLASSES = { "RandomAgent": "ddopnew.agents.saa.SAA", + "SAA": "ddopnew.agents.newsvendor.saa.NewsvendorSAAagent", "wSAA": "ddopnew.agents.newsvendor.saa.NewsvendorRFwSAAagent", "RFwSAA": "ddopnew.agents.newsvendor.saa.NewsvendorRFwSAAagent", + "lERM": "ddopnew.agents.newsvendor.erm.NewsvendorlERMAgent", "DLNV": "ddopnew.agents.newsvendor.erm.NewsvendorDLAgent", + "DLNVRNN": "ddopnew.agents.newsvendor.erm.NewsvendorDLRNNAgent", + "DLNVTransformer": "ddopnew.agents.newsvendor.erm.NewsvendorDLTransformerAgent", + + "lERMMeta": "ddopnew.agents.newsvendor.erm.NewsvendorlERMMetaAgent", + "DLNVMeta": "ddopnew.agents.newsvendor.erm.NewsvendorDLMetaAgent", + "DLNVRNNMeta": "ddopnew.agents.newsvendor.erm.NewsvendorDLRNNMetaAgent", + "DLNVTransformerMeta": "ddopnew.agents.newsvendor.erm.NewsvendorDLTransformerMetaAgent", "SAC": "ddopnew.agents.rl.sac.SACAgent", "SACRNN": "ddopnew.agents.rl.sac.SACRNNAgent", diff --git a/ddopnew/agents/newsvendor/erm.py b/ddopnew/agents/newsvendor/erm.py index d26a73a..49c37cb 100644 --- a/ddopnew/agents/newsvendor/erm.py +++ b/ddopnew/agents/newsvendor/erm.py @@ -1,7 +1,8 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/41_NV_agents/11_NV_erm_agents.ipynb. # %% auto 0 -__all__ = ['SGDBaseAgent', 'NVBaseAgent', 'NewsvendorlERMAgent', 'NewsvendorDLAgent'] +__all__ = ['SGDBaseAgent', 'NVBaseAgent', 'NewsvendorlERMAgent', 'NewsvendorDLAgent', 'BaseMetaAgent', 'NewsvendorlERMMetaAgent', + 'NewsvendorDLMetaAgent'] # %% ../../../nbs/41_NV_agents/11_NV_erm_agents.ipynb 3 import logging @@ -14,7 +15,7 @@ from ...envs.base import BaseEnvironment from ..base import BaseAgent -from ...utils import MDPInfo, Parameter, DatasetWrapper +from ...utils import MDPInfo, Parameter, DatasetWrapper, DatasetWrapperMeta from ...torch_utils.loss_functions import TorchQuantileLoss from ...torch_utils.obsprocessors import FlattenTimeDim @@ -55,7 +56,9 @@ def __init__(self, self.device = self.set_device(device) self.set_dataloader(dataloader, dataloader_params) + self.set_model(input_shape, output_shape) + self.loss_function_params=None # default self.set_loss_function() self.set_optimizer(optimizer_params) self.set_learning_rate_scheduler(learning_rate_scheduler) @@ -89,8 +92,12 @@ def set_dataloader(self, Set the dataloader for the agent by wrapping it into a Torch Dataset """ - dataset = DatasetWrapper(dataloader) - self.dataloader = torch.utils.data.DataLoader(dataset, **dataloader_params) + + # check if class already have a dataloader + if not hasattr(self, 'dataloader'): + + dataset = DatasetWrapper(dataloader) + self.dataloader = torch.utils.data.DataLoader(dataset, **dataloader_params) @abstractmethod def set_loss_function(self): @@ -105,18 +112,21 @@ def set_model(self, input_shape: Tuple, output_shape: Tuple): def set_optimizer(self, optimizer_params: dict): # dict with keys: optimizer, lr, weight_decay """ Set the optimizer for the model """ - optimizer = optimizer_params["optimizer"] - optimizer_params_copy = optimizer_params.copy() - del optimizer_params_copy["optimizer"] - - if optimizer == "Adam": - self.optimizer = torch.optim.Adam(self.model.parameters(), **optimizer_params_copy) - elif optimizer == "SGD": - self.optimizer = torch.optim.SGD(self.model.parameters(), **optimizer_params_copy) - elif optimizer == "RMSprop": - self.optimizer = torch.optim.RMSprop(self.model.parameters(), **optimizer_params_copy) - else: - raise ValueError(f"Optimizer {optimizer} not supported") + + if not hasattr(self, 'optimizer'): + + optimizer = optimizer_params["optimizer"] + optimizer_params_copy = optimizer_params.copy() + del optimizer_params_copy["optimizer"] + + if optimizer == "Adam": + self.optimizer = torch.optim.Adam(self.model.parameters(), **optimizer_params_copy) + elif optimizer == "SGD": + self.optimizer = torch.optim.SGD(self.model.parameters(), **optimizer_params_copy) + elif optimizer == "RMSprop": + self.optimizer = torch.optim.RMSprop(self.model.parameters(), **optimizer_params_copy) + else: + raise ValueError(f"Optimizer {optimizer} not supported") def set_learning_rate_scheduler(self, learning_rate_scheduler: None = None): # """ Set learning rate scheudler (can be None) """ @@ -135,7 +145,11 @@ def fit_epoch(self): for i, output in enumerate(self.dataloader): - X, y = output + if len(output)==3: + X, y, loss_function_params = output + else: + X, y = output + loss_function_params = None # convert X and y to float32 X = X.type(torch.float32) @@ -150,10 +164,12 @@ def fit_epoch(self): y_pred = self.model(X) - if self.loss_function_params==None: - loss = self.loss_function(y_pred, y) + if loss_function_params is not None: + loss = self.loss_function(y_pred, y, **loss_function_params) + elif self.loss_function_params is not None: + loss = self.loss_function(y_pred, y, **self.loss_function_params) else: - loss = self.loss_function(y_pred, y, **self.loss_function_params) # TODO: add reduction param when defining loss function + loss = self.loss_function(y_pred, y) loss.backward() self.optimizer.step() @@ -286,7 +302,6 @@ def __init__(self, agent_name: str | None = None, ): - cu = self.convert_to_numpy_array(cu) co = self.convert_to_numpy_array(co) @@ -306,12 +321,13 @@ def __init__(self, device=device, agent_name=agent_name ) + def set_loss_function(self): - + """Set the loss function for the model to the quantile loss. For training the model uses quantile loss and not the pinball loss with specific cu and co values to ensure similar scale of the feedback signal during training.""" - + self.loss_function_params = {"quantile": self.sl} self.loss_function = TorchQuantileLoss(reduction="mean") @@ -448,3 +464,144 @@ def set_model(self, input_shape, output_shape): from ddopnew.approximators import MLP self.model = MLP(input_size=input_size, output_size=output_size, **self.model_params) + +# %% ../../../nbs/41_NV_agents/11_NV_erm_agents.ipynb 35 +class BaseMetaAgent(): + + def set_meta_dataloader( + self, + dataloader: BaseDataLoader, + dataloader_params, # dict with keys: batch_size, shuffle + draw_parameter_function: callable, # function to draw parameters from distribution + distribution: str, # distribution for params during training + bounds_low: Union[int, float], # lower bound for params during training + bounds_high: Union[int, float], # upper bound for params during training + obsprocessor: callable, # function to process observations + parameter_names: List[str] = None, # names of parameters + ) -> None: + + """ """ + + # check if class already have a dataloader + + print("setting meta datloader") + + dataset = DatasetWrapperMeta( + dataloader = dataloader, + draw_parameter_function = draw_parameter_function, + distribution = distribution, + bounds_low = bounds_low, + bounds_high = bounds_high, + obsprocessor = obsprocessor, + parameter_names = parameter_names, + ) + + self.dataloader = torch.utils.data.DataLoader(dataset, **dataloader_params) + +# %% ../../../nbs/41_NV_agents/11_NV_erm_agents.ipynb 36 +class NewsvendorlERMMetaAgent(NewsvendorlERMAgent, BaseMetaAgent): + + """ + Newsvendor agent implementing Empirical Risk Minimization (ERM) approach + based on a linear (regression) model. In addition to the features, the agent + also gets the sl as input to be able to forecast the optimal order quantity + for different sl values. Depending on the training pipeline, this model can be + adapted to become a full meta-learning algorithm cross products and cross sls. + + """ + + def __init__(self, + # Parameters for meta Agent + dataset_meta_params: dict, # Parameters for meta dataloader + + # Parameters for lERM agent + environment_info: MDPInfo, + dataloader: BaseDataLoader, + cu: np.ndarray | Parameter, + co: np.ndarray | Parameter, + input_shape: Tuple, + output_shape: Tuple, + optimizer_params: dict | None = None, # default: {"optimizer": "Adam", "lr": 0.01, "weight_decay": 0.0} + learning_rate_scheduler = None, # TODO: add base class for learning rate scheduler for typing + model_params: dict | None = None, # default: {"relu_output": False} + dataloader_params: dict | None = None, # default: {"batch_size": 32, "shuffle": True} + obsprocessors: list | None = None, # default: [] + torch_obsprocessors: list | None = None, # default: [FlattenTimeDim(allow_2d=False)] + device: str = "cpu", # "cuda" or "cpu" + agent_name: str | None = "lERMMeta" + ): + + self.set_meta_dataloader(dataloader, dataloader_params, **dataset_meta_params) + + super().__init__( + environment_info=environment_info, + dataloader=dataloader, + cu=cu, + co=co, + input_shape=input_shape, + output_shape=output_shape, + optimizer_params=optimizer_params, + learning_rate_scheduler=learning_rate_scheduler, + model_params=model_params, + dataloader_params=dataloader_params, + obsprocessors=obsprocessors, + torch_obsprocessors=torch_obsprocessors, + device=device, + agent_name=agent_name + ) + +# %% ../../../nbs/41_NV_agents/11_NV_erm_agents.ipynb 37 +class NewsvendorDLMetaAgent(NewsvendorDLAgent, BaseMetaAgent): + + """ + Newsvendor agent implementing Empirical Risk Minimization (ERM) approach + based on a Neural Network. In addition to the features, the agent + also gets the sl as input to be able to forecast the optimal order quantity + for different sl values. Depending on the training pipeline, this model can be + adapted to become a full meta-learning algorithm cross products and cross sls. + + """ + + def __init__(self, + # Parameters for meta Agent + dataset_meta_params: dict, # Parameters for meta dataloader + + environment_info: MDPInfo, + dataloader: BaseDataLoader, + cu: np.ndarray | Parameter, + co: np.ndarray | Parameter, + input_shape: Tuple, + output_shape: Tuple, + learning_rate_scheduler = None, # TODO: add base class for learning rate scheduler for typing + + # parameters in yaml file + optimizer_params: dict | None = None, # default: {"optimizer": "Adam", "lr": 0.01, "weight_decay": 0.0} + model_params: dict | None = None, # default: {"hidden_layers": [64, 64], "drop_prob": 0.0, "batch_norm": False, "relu_output": False} + dataloader_params: dict | None = None, # default: {"batch_size": 32, "shuffle": True} + device: str = "cpu", # "cuda" or "cpu" + + obsprocessors: list | None = None, # default: [] + torch_obsprocessors: list | None = None, # default: [FlattenTimeDim(allow_2d=False)] + agent_name: str | None = "DLNV", + ): + + self.set_meta_dataloader(dataloader, dataloader_params, **dataset_meta_params) + + super().__init__( + environment_info=environment_info, + dataloader=dataloader, + cu=cu, + co=co, + input_shape=input_shape, + output_shape=output_shape, + learning_rate_scheduler=learning_rate_scheduler, + + optimizer_params=optimizer_params, + model_params=model_params, + dataloader_params=dataloader_params, + device=device, + + obsprocessors=obsprocessors, + torch_obsprocessors=torch_obsprocessors, + agent_name=agent_name + ) diff --git a/ddopnew/agents/rl/sac.py b/ddopnew/agents/rl/sac.py index 8c74008..725c919 100644 --- a/ddopnew/agents/rl/sac.py +++ b/ddopnew/agents/rl/sac.py @@ -28,6 +28,7 @@ import torch import torch.nn.functional as F from torchinfo import summary +from IPython import get_ipython from copy import deepcopy @@ -166,7 +167,10 @@ def __init__(self, else: input_tensor = torch.randn(batch_dim, *actor_mu_params["input_shape"]).to(self.device) input_tuple = (input_tensor,) - print(summary(self.actor, input_data=input_tuple, device=self.device)) + if get_ipython() is not None: + print(summary(self.actor, input_data=input_tuple, device=self.device)) + else: + summary(self.actor, input_data=input_tuple, device=self.device) time.sleep(0.2) logging.info("################################################################################") @@ -183,7 +187,11 @@ def __init__(self, state_mlp_sample = torch.randn(batch_dim, *critic_params["input_shape"][0][1]).to(self.device) state_sample = torch.cat((state_sample, state_mlp_sample), dim=1) input_tuple = (state_sample, action_sample) - print(summary(self.critic, input_data=input_tuple, device=self.device)) + if get_ipython() is not None: + print(summary(self.critic, input_data=input_tuple, device=self.device)) + else: + summary(self.critic, input_data=input_tuple, device=self.device) + # print(summary(self.critic, input_data=input_tuple, device=self.device)) def get_network_list(self, set_actor_critic_attributes: bool = True): """ Get the list of networks in the agent for the save and load functions @@ -207,7 +215,7 @@ def get_network_list(self, set_actor_critic_attributes: bool = True): def predict_(self, observation: np.ndarray) -> np.ndarray: # """ Do one forward pass of the model directly and return the prediction. Apply tanh as implemented for the SAC actor in mushroom_rl""" - + # make observation torch tensor device = next(self.actor.parameters()).device observation = torch.tensor(observation, dtype=torch.float32).to(device) diff --git a/ddopnew/envs/base.py b/ddopnew/envs/base.py index ee3460d..77d60d3 100644 --- a/ddopnew/envs/base.py +++ b/ddopnew/envs/base.py @@ -3,7 +3,7 @@ # %% auto 0 __all__ = ['BaseEnvironment'] -# %% ../../nbs/20_base_env/10_base_env.ipynb 3 +# %% ../../nbs/20_base_env/10_base_env.ipynb 4 import gymnasium as gym from abc import ABC, abstractmethod from typing import Union, List @@ -12,7 +12,7 @@ from ..utils import MDPInfo, Parameter, set_param import time -# %% ../../nbs/20_base_env/10_base_env.ipynb 4 +# %% ../../nbs/20_base_env/10_base_env.ipynb 5 class BaseEnvironment(gym.Env, ABC): """ diff --git a/ddopnew/envs/inventory/multi_period.py b/ddopnew/envs/inventory/multi_period.py index d67bd07..04d3864 100644 --- a/ddopnew/envs/inventory/multi_period.py +++ b/ddopnew/envs/inventory/multi_period.py @@ -198,7 +198,7 @@ def get_observation(self): return observation, Y_item - + def reset(self, start_index: int | str = None, # index to start from state: np.ndarray = None # initial state diff --git a/ddopnew/envs/inventory/single_period.py b/ddopnew/envs/inventory/single_period.py index 05ad47f..b57055a 100644 --- a/ddopnew/envs/inventory/single_period.py +++ b/ddopnew/envs/inventory/single_period.py @@ -1,15 +1,15 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/21_envs_inventory/20_single_period_envs.ipynb. # %% auto 0 -__all__ = ['NewsvendorEnv'] +__all__ = ['NewsvendorEnv', 'NewsvendorEnvVariableSL'] # %% ../../../nbs/21_envs_inventory/20_single_period_envs.ipynb 3 from abc import ABC, abstractmethod -from typing import Union, Tuple +from typing import Union, Tuple, Literal from ...utils import Parameter, MDPInfo from ...dataloaders.base import BaseDataLoader -from ...loss_functions import pinball_loss +from ...loss_functions import pinball_loss, quantile_loss from .base import BaseInventoryEnv import gymnasium as gym @@ -43,6 +43,7 @@ def __init__(self, self.print=False num_SKUs = dataloader.num_units if num_SKUs is None else num_SKUs + if not isinstance(num_SKUs, int): raise ValueError("num_SKUs must be an integer.") @@ -82,7 +83,7 @@ def step_(self, if action.ndim == 2 and action.shape[0] == 1: action = np.squeeze(action, axis=0) # Remove the first dimension - cost_per_SKU = pinball_loss(self.demand, action, self.underage_cost, self.overage_cost) + cost_per_SKU = self.determine_cost(action) reward = -np.sum(cost_per_SKU) # negative because we want to minimize the cost terminated = False # in this problem there is no termination condition @@ -103,7 +104,6 @@ def step_(self, observation, self.demand = self.get_observation() - return observation, reward, terminated, truncated, info else: @@ -117,3 +117,175 @@ def step_(self, time.sleep(3) return observation, reward, terminated, truncated, info + + def determine_cost(self, action: np.ndarray) -> np.ndarray: + """ + Determine the cost per SKU given the action taken. The cost is the sum of underage and overage costs. + """ + # Compute the cost per SKU + return pinball_loss(self.demand, action, self.underage_cost, self.overage_cost) + +# %% ../../../nbs/21_envs_inventory/20_single_period_envs.ipynb 13 +class NewsvendorEnvVariableSL(NewsvendorEnv, ABC): + def __init__(self, + + # Additional parameters: + sl_bound_low: Union[np.ndarray, Parameter, int, float] = 0.1, # lower bound of the service level during training + sl_bound_high: Union[np.ndarray, Parameter, int, float] = 0.9, # upper bound of the service level during training + sl_distribution: Literal["fixed", "uniform"] = "fixed", # distribution of the random service level during training, if fixed then the service level is fixed to sl_test_val + evaluation_metric: Literal["pinball_loss", "quantile_loss"] = "quantile_loss", # quantile loss is the generic quantile loss (independent of cost levels) while pinball loss uses the specific under- and overage costs + sl_test_val: Union[np.ndarray, Parameter, int, float] = None, # service level during test and validation, alternatively use cu and co + + underage_cost: Union[np.ndarray, Parameter, int, float] = 1, # underage cost per unit + overage_cost: Union[np.ndarray, Parameter, int, float] = 1, # overage cost per unit + q_bound_low: Union[np.ndarray, Parameter, int, float] = 0, # lower bound of the order quantity + q_bound_high: Union[np.ndarray, Parameter, int, float] = np.inf, # upper bound of the order quantity + dataloader: BaseDataLoader = None, # dataloader + num_SKUs: Union[int] = None, # if None it will be inferred from the DataLoader + gamma: float = 1, # discount factor + horizon_train: int | str = "use_all_data", # if "use_all_data" then horizon is inferred from the DataLoader + postprocessors: list[object] | None = None, # default is empty list + mode: str = "train", # Initial mode (train, val, test) of the environment + return_truncation: str = True # whether to return a truncated condition in step function + ) -> None: + + self.set_param("sl_bound_low", sl_bound_low, shape=(1,), new=True) + self.set_param("sl_bound_high", sl_bound_high, shape=(1,), new=True) + self.evaluation_metric = evaluation_metric + self.check_evaluation_metric + self.sl_distribution = sl_distribution + self.check_sl_distribution + + super().__init__(underage_cost=underage_cost, + overage_cost=overage_cost, + q_bound_low=q_bound_low, + q_bound_high=q_bound_high, + dataloader=dataloader, + num_SKUs=num_SKUs, + gamma=gamma, + horizon_train=horizon_train, + postprocessors=postprocessors, + mode=mode, + return_truncation=return_truncation) + + if sl_test_val is not None: + if self.underage_cost is None and self.overage_cost is None: + self.set_param("sl", sl_test_val, shape=(num_SKUs[0],), new=True) + else: + raise ValueError("sl_test_val can only be used when underage_cost and overage_cost are None.") + else: + if self.underage_cost is None or self.overage_cost is None: + raise ValueError("Either sl_test_val or underage_cost and overage_cost must be provided.") + sl = self.underage_cost / (self.underage_cost + self.overage_cost) + self.set_param("sl", sl, shape=(self.num_SKUs[0],), new=True) + + def determine_cost(self, action: np.ndarray) -> np.ndarray: + """ + Determine the cost per SKU given the action taken. The cost is the sum of underage and overage costs. + """ + + # Compute the cost per SKU + if self.mode == "train": # during training only the service level is relevant + return quantile_loss(self.demand, action, self.sl_period) + else: + if self.evaluation_metric == "pinball_loss": + return pinball_loss(self.demand, action, self.underage_cost, self.overage_cost) + elif self.evaluation_metric == "quantile_loss": + return quantile_loss(self.demand, action, self.sl) + + def set_observation_space(self, + shape: tuple, # shape of the dataloader features + low: Union[np.ndarray, float] = -np.inf, # lower bound of the observation space + high: Union[np.ndarray, float] = np.inf, # upper bound of the observation space + samples_dim_included = True # whether the first dimension of the shape input is the number of samples + ) -> None: + + ''' + Set the observation space of the environment. + This is a standard function for simple observation spaces. For more complex observation spaces, + this function should be overwritten. Note that it is assumped that the first dimension + is n_samples that is not relevant for the observation space. + + ''' + + # To handle cases when no external information is available (e.g., parametric NV) + + if shape is None: + self.observation_space = None + + spaces = {} + if isinstance(shape, tuple): + if samples_dim_included: + shape = shape[1:] # assumed that the first dimension is the number of samples + spaces["features"] = gym.spaces.Box(low=low, high=high, shape=shape, dtype=np.float32) + + elif feature_shape is None: + pass + + else: + raise ValueError("Shape for features must be a tuple or None") + + spaces["service_level"] = gym.spaces.Box(low=0, high=1, shape=(self.num_SKUs[0],), dtype=np.float32) + + self.observation_space = gym.spaces.Dict(spaces) + + @staticmethod # staticmethod such that the dataloader can also use the funciton + def draw_parameter(distribution, sl_bound_low, sl_bound_high, samples): + + if distribution == "fixed": + sl = np.random.uniform(sl_bound_low, sl_bound_high, size=(samples,)) + elif distribution == "uniform": + sl = np.random.uniform(sl_bound_low, sl_bound_high, size=(samples,)) + else: + raise ValueError("sl_distribution not recognized.") + + return sl + + def get_observation(self): + + """ + Return the current observation. This function is for the simple case where the observation + is only an x,y pair. For more complex observations, this function should be overwritten. + """ + + X_item, Y_item = self.dataloader[self.index] + + if self.mode == "train": + sl = self.draw_parameter(self.sl_distribution, self.sl_bound_low, self.sl_bound_high, samples = self.num_SKUs[0]) + else: + sl = self.sl.copy() # evaluate on fixed sls + + self.sl_period = sl # store the service level to assess the action + + return {"features": X_item, "service_level": sl}, Y_item + + def check_evaluation_metric(self): + if self.evaluation_metric not in ["pinball_loss", "quantile_loss"]: + raise ValueError("evaluation_metric must be either 'pinball_loss' or 'quantile_loss'.") + if self.evaluation_metric == "pinball_loss" and (self.underage_cost is None or self.overage_cost is None): + raise ValueError("Underage and overage costs must be provided for pinball loss.") + if self.evaluation_metric == "quantile_loss" and (self.sl_test_val is None): + raise ValueError("sl_test_val must be provided for quantile loss.") + + def check_sl_distribution(self): + if self.sl_distribution not in ["fixed", "uniform"]: + raise ValueError("sl_distribution must be 'uniform' or 'fixed'.") + + def set_val_test_sl(self, sl_test_val): + self.set_param("sl", sl_test_val, shape=(self.num_SKUs[0],), new=False) + + def update_cu_co(self, cu=None, co=None): + + if not hasattr(self, "underage_cost") or not hasattr(self, "overage_cost"): + logging.warning("Underage and overage costs were not set previously, setting them as new parameters.") + self.set_param("underage_cost", cu, shape=(self.num_SKUs[0],), new=True) + self.set_param("overage_cost", co, shape=(self.num_SKUs[0],), new=True) + + if cu is not None: + self.set_param("underage_cost", cu, shape=(self.num_SKUs[0],), new=False) + if co is not None: + self.set_param("overage_cost", co, shape=(self.num_SKUs[0],), new=False) + + sl = self.underage_cost / (self.underage_cost + self.overage_cost) + + self.set_param("sl", sl, shape=(self.num_SKUs[0],), new=False) diff --git a/ddopnew/experiment_functions.py b/ddopnew/experiment_functions.py index 49b7ad0..bab5ff0 100644 --- a/ddopnew/experiment_functions.py +++ b/ddopnew/experiment_functions.py @@ -226,7 +226,7 @@ def run_test_episode( env: BaseEnvironment, # Any environment inheriting from logging.debug("truncated: %s", truncated) sample = (obs, action, reward, next_obs, terminated, truncated) # unlike mushroom do not include policy_state - + obs = next_obs dataset.append((sample, info)) @@ -334,6 +334,7 @@ def run_experiment( agent: BaseAgent, stop = False if stop: + log_info(R, J, n_epochs-epoch-1, tracking, "val") logging.info(f"Early stopping after {epoch+1} epochs") break @@ -390,6 +391,7 @@ def run_experiment( agent: BaseAgent, stop = False if stop: + log_info(R, J, n_epochs-epoch-1, tracking, "val") logging.info(f"Early stopping after {epoch+1} epochs") break diff --git a/ddopnew/meta_experiment_functions.py b/ddopnew/meta_experiment_functions.py index 9f1b826..d17a22c 100644 --- a/ddopnew/meta_experiment_functions.py +++ b/ddopnew/meta_experiment_functions.py @@ -1,8 +1,9 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb. # %% auto 0 -__all__ = ['prep_experiment', 'init_wandb', 'track_libraries_and_git', 'import_config', 'transfer_lag_window_to_env', - 'get_ddop_data', 'download_data', 'set_indices', 'set_up_env', 'set_up_earlystoppinghandler', 'select_agent'] +__all__ = ['set_warnings', 'prep_experiment', 'init_wandb', 'track_libraries_and_git', 'import_config', + 'transfer_lag_window_to_env', 'get_ddop_data', 'download_data', 'set_indices', 'set_up_env', + 'set_up_earlystoppinghandler', 'prep_and_run_test', 'clean_up', 'select_agent'] # %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 3 from abc import ABC, abstractmethod @@ -13,15 +14,19 @@ import sys import os import yaml +import pickle +import warnings +import torch from .tracking import get_git_hash, get_library_version from .agents.class_names import AGENT_CLASSES from .dataloaders.tabular import XYDataLoader from .datasets import DatasetLoader -from .experiment_functions import EarlyStoppingHandler +from .experiment_functions import EarlyStoppingHandler, test_agent import wandb +import gc import importlib @@ -31,6 +36,16 @@ from mushroom_rl.core import Core # %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 5 +def set_warnings (logging_level): + + """ Set warnings to be ignored for the given logging level or higher.""" + + if logging.getLogger().isEnabledFor(logging_level): + warnings.filterwarnings("ignore", category=UserWarning, message=".*Box bound precision lowered by casting to float32.*") + warnings.filterwarnings("ignore", category=UserWarning, message=".*TypedStorage is deprecated.*") + warnings.filterwarnings("ignore", category=FutureWarning, message=".*You are using `torch.load` with `weights_only=False`.*") + +# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 7 def prep_experiment( project_name: str, libraries_to_track: List[str] = ["ddopnew"], @@ -59,7 +74,7 @@ def prep_experiment( return config_train, config_agent, config_env, AgentClass, agent_name -# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 6 +# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 8 def init_wandb(project_name: str): # """ init wandb """ @@ -69,7 +84,7 @@ def init_wandb(project_name: str): # name = f"{project_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" ) -# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 7 +# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 9 def track_libraries_and_git( libraries_to_track: List[str], tracking: bool = True, tracking_tool = "wandb", # Currenty only wandb is supported @@ -86,7 +101,7 @@ def track_libraries_and_git( libraries_to_track: List[str], git_hash = get_git_hash(".", tracking=tracking, tracking_tool=tracking_tool) logging.info(f"Git hash: {git_hash}") -# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 8 +# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 10 def import_config( filename: str, # Name of the file, must be a yaml file path: str = None # Optional path to the file if it is not in the current directory ) -> Dict: @@ -129,7 +144,7 @@ def import_config( filename: str, # Name of the file, must be a yaml file return config -# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 9 +# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 11 def transfer_lag_window_to_env(config_env: Dict, # config_agent: Dict ) -> None: @@ -148,7 +163,7 @@ def transfer_lag_window_to_env(config_env: Dict, # else: logging.warning("No lag window specified in the agent configuration. Keeping value from env config") -# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 11 +# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 13 def get_ddop_data( config_env: Dict, overwrite: bool = False @@ -163,7 +178,7 @@ def get_ddop_data( return data, val_index_start, test_index_start -# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 12 +# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 14 def download_data( config_env: Dict, overwrite: bool = False # ) -> Tuple: @@ -181,7 +196,7 @@ def download_data( config_env: Dict, return data_tuple -# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 13 +# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 15 def set_indices(config_env: Dict, # X: np.ndarray ) -> Tuple: @@ -193,7 +208,7 @@ def set_indices(config_env: Dict, # return val_index_start, test_index_start -# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 15 +# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 17 def set_up_env( env_class, raw_data: Tuple, # @@ -220,7 +235,7 @@ def set_up_env( return environment -# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 17 +# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 19 def set_up_earlystoppinghandler(config_train: Dict) -> object: # """ Set up the early stopping handler """ @@ -229,13 +244,95 @@ def set_up_earlystoppinghandler(config_train: Dict) -> object: # if "early_stopping_patience" in config_train or "early_stopping_warmup" in config_train: warmup = config_train["early_stopping_warmup"] if "early_stopping_warmup" in config_train else 0 patience = config_train["early_stopping_patience"] if "early_stopping_patience" in config_train else 0 - earlystoppinghandler = EarlyStoppingHandler(warmup=warmup, patience=warmup) + + earlystoppinghandler = EarlyStoppingHandler(warmup=warmup, patience=patience) else: earlystoppinghandler = None return earlystoppinghandler -# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 19 +# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 21 +def prep_and_run_test( + agent, + environment, + agent_dir: str, + save_dataset: bool = True, + dataset_dir: str = None, + tracking = "wandb"): + + """ + Test the agent in the environment. + """ + + if save_dataset: + if dataset_dir is None: + raise ValueError("If save_dataset is True, dataset_dir must be specified.") + + # load parameters of agent + agent.load(agent_dir) + + # Set agent and environment to test mode + agent.eval() + environment.test() + + # Run test episode + output = test_agent( + agent, + environment, + return_dataset=save_dataset, + tracking=tracking + ) + + # Save dataset + if save_dataset: + + R, J, dataset = output + + if not os.path.exists(dataset_dir): + os.mkdir(dataset_dir) + else: + raise ValueError("Path to save dataset already exists") # it should never exist since run_id is usually part or path and unique + + dir = os.path.join(dataset_dir, "dataset_test.pkl") + + with open (os.path.join(dir), "wb") as f: + pickle.dump(dataset, f) + + artifact = wandb.Artifact("transition_test_set", type="dataset") + + artifact.add_file(os.path.join(dir)) + + wandb.run.log_artifact(artifact) + + else: + + R, J = output + + logging.info(f"final evaluation on test set: R = {np.round(R, 10)} J = {np.round(J, 10)}") + + + +# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 23 +def clean_up(agent, environment): + + """ Clean up agent and environment to free up GPU memory """ + + # Delete agent and environment to free up GPU memory + del agent + del environment + + # Force garbage collection + gc.collect() + + # Clear GPU cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + wandb.finish() + + return None, None + +# %% ../nbs/30_experiment_functions/20_meta_experiment_functions.ipynb 27 def select_agent(agent_name: str) -> type: # """ Select an agent class from a list of agent names and return the class""" if agent_name in AGENT_CLASSES: diff --git a/ddopnew/obsprocessors.py b/ddopnew/obsprocessors.py index 611ae7f..0d556db 100644 --- a/ddopnew/obsprocessors.py +++ b/ddopnew/obsprocessors.py @@ -1,7 +1,7 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_utils/11_obsprocessors.ipynb. # %% auto 0 -__all__ = ['FlattenTimeDimNumpy', 'ConvertDictSpace'] +__all__ = ['BaseProcessor', 'FlattenTimeDimNumpy', 'ConvertDictSpace', 'AddParamsToFeatures'] # %% ../nbs/00_utils/11_obsprocessors.ipynb 3 from typing import Union, Optional, List, Tuple, Dict @@ -14,6 +14,27 @@ import torch.nn.functional as F # %% ../nbs/00_utils/11_obsprocessors.ipynb 4 +class BaseProcessor(): + + def determine_output_shape(self, + sample_input: Dict, # + flat: bool = False # if the flattend output shape should be returned + ) -> Tuple | List: + + """ + Determine the output shape based on the input dictionary. + """ + + if flat: + output = self.__call__(sample_input, flatten=True) + else: + output = self.__call__(sample_input, flatten=False) + if isinstance(output, list): + return [output_element.shape for output_element in output] + else: + return output.shape + +# %% ../nbs/00_utils/11_obsprocessors.ipynb 5 class FlattenTimeDimNumpy(): """ @@ -100,8 +121,8 @@ def __call__(self, return output -# %% ../nbs/00_utils/11_obsprocessors.ipynb 8 -class ConvertDictSpace(): +# %% ../nbs/00_utils/11_obsprocessors.ipynb 9 +class ConvertDictSpace(BaseProcessor): """ @@ -176,21 +197,75 @@ def __call__(self, return np.concatenate(obs, axis=0) - def determine_output_shape(self, - sample_input: Dict, # - flat: bool = False # if the flattend output shape should be returned - ) -> Tuple | List: +# %% ../nbs/00_utils/11_obsprocessors.ipynb 10 +class AddParamsToFeatures(BaseProcessor): + + """ + + A utility class to process a dictionary of numpy arrays, with options to preserve or flatten the time dimension. + + Note, this class is only used to preprocess output from the environment without batch dimension. + + """ + + def __init__(self, + keep_time_dim: Optional[bool] = False, #If time timension should be flattened as well. + hybrid: Optional[bool] = False, # If the param dim should be added as separate vector or concatenated to the features. + ): + + self.keep_time_dim = keep_time_dim + self.hybrid = hybrid + + if not keep_time_dim and hybrid: + raise ValueError("For flattened vector, hybrid should be be merged with features directy.") + + + def __call__(self, + input: Dict, # Observation as dict of with numpy arrays + flatten: bool = False, # whether to flatten composite spaces (non-composite spaces will depend on self.keep_time_dim) + ) -> List[np.ndarray] | np.ndarray: """ - Determine the output shape based on the input dictionary. + Process the input dictionary by converting it to a numpy array. """ - if flat: - output = self.__call__(sample_input, flatten=True) - else: - output = self.__call__(sample_input, flatten=False) - if isinstance(output, list): - return [output[0].shape, output[1].shape] - else: - return output.shape + # print("input to processor: ", input) + input = input.copy() + features = input["features"] if self.keep_time_dim else input["features"].flatten() + del input["features"] + + + if self.hybrid: + obs_1d = [] # features or time X features + obs_2d = [] # time X features + obs.append(input["features"]) + for counter, (key, value) in enumerate(input.items()): + if not isinstance(value, np.ndarray): + raise TypeError(f"Expected input to be a dictionary of numpy arrays, but got {type(value)} instead.") + + if value.ndim == 1: + if features.ndim == 1: + features = np.concatenate([features, value]) + else: + if self.hybrid: + raise NotImplementedError("Hybrid not implemented yet.") + # expand value to 2d by copy time dimension + else: + value = np.expand_dims(value, axis=0) + value = np.repeat(value, features.shape[0], axis=0) + features = np.concatenate([features, value.flatten()]) + else: + if value.shape == features.shape: + features = np.concatenate([features, value.flatten()]) + else: + raise ValueError(f"Expected input to have the same shape as features, but got {value.shape} instead (feature shape: {features.shape}).") + + if self.hybrid: + if flatten: + raise NotImplementedError("Hybrid not implemented yet.") + else: + raise NotImplementedError("Hybrid not implemented yet.") + else: + return features + diff --git a/ddopnew/torch_utils/loss_functions.py b/ddopnew/torch_utils/loss_functions.py index 8ca2732..17ed565 100644 --- a/ddopnew/torch_utils/loss_functions.py +++ b/ddopnew/torch_utils/loss_functions.py @@ -14,6 +14,8 @@ import torch.nn.functional as F from torch.nn.modules.loss import _Loss +import warnings + # %% ../../nbs/00_utils/20_torch_loss_functions.ipynb 4 def quantile_loss( input: torch.Tensor, @@ -31,9 +33,14 @@ def quantile_loss( ) expanded_input, expanded_target = torch.broadcast_tensors(input, target) - + + # print(expanded_input.size(), expanded_target.size(), quantile.size()) + # print(quantile) + loss = torch.max((expanded_target - expanded_input) * quantile, (expanded_input - expanded_target) * (1 - quantile)) + # print(losks.size()) + if reduction == 'mean': return loss.mean() elif reduction == 'sum': @@ -64,19 +71,30 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, quantile: Parameter """ - if isinstance(quantile, Parameter): - quantile = quantile.get_value() - - quantile = torch.tensor(quantile, dtype=input.dtype, device=input.device) + quantile = self.convert_quantile(quantile, input_dtype=input.dtype, device=input.device) - if not (target.shape[1] == input.shape[1] == quantile.shape[0]): + if not (target.shape == input.shape == quantile.shape): warnings.warn( - f"Mismatch in dimensions: target dimension 2 size ({target.size(2)}), input dimension 2 size ({input.size(2)}), " - f"and quantile dimension 1 size ({quantile.size(1)}) must be the same. " + f"Mismatch in dimensions: target dimension ({target.shape}), input dimension ({input.shape}), " + f"and quantile dimension ({quantile.shape}) must be the same. " "This will likely lead to incorrect results due to broadcasting. " "Please ensure they have the same size.", stacklevel=2, ) return quantile_loss(input, target, quantile, reduction=self.reduction) + + def convert_quantile(self, quantile: Parameter | np.ndarray, input_dtype: torch.dtype = torch.float32, device: torch.device = torch.device('cpu')) -> torch.Tensor: + + if isinstance(quantile, Parameter): + quantile = quantile.get_value() + elif isinstance(quantile, np.ndarray): + quantile = torch.tensor(quantile, dtype=input_dtype, device=device) + elif isinstance(quantile, torch.Tensor): + # ensure dtype and device are the same as the input tensor + quantile = quantile.to(dtype=input_dtype, device=device) + else: + raise ValueError(f"quantile must be of type Parameter, np.ndarray, or torch.Tensor, but got {type(quantile)}") + + return quantile diff --git a/ddopnew/utils.py b/ddopnew/utils.py index b733b1f..ba85200 100644 --- a/ddopnew/utils.py +++ b/ddopnew/utils.py @@ -1,7 +1,8 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_utils/00_utils.ipynb. # %% auto 0 -__all__ = ['check_parameter_types', 'Parameter', 'MDPInfo', 'DatasetWrapper', 'merge_dictionaries', 'set_param'] +__all__ = ['check_parameter_types', 'Parameter', 'MDPInfo', 'DatasetWrapper', 'DatasetWrapperMeta', 'merge_dictionaries', + 'set_param'] # %% ../nbs/00_utils/00_utils.ipynb 3 from torch.utils.data import Dataset @@ -203,6 +204,62 @@ def __len__(self): raise ValueError("Dataset type must be either 'train', 'val' or 'test'") # %% ../nbs/00_utils/00_utils.ipynb 24 +class DatasetWrapperMeta(DatasetWrapper): + """ + This class is used to wrap a Pytorch Dataset around the ddopnew dataloader + to enable the usage of the Pytorch Dataloader during training. This way, + agents that are trained using Pytorch without interacting with the environment + can directly train on the data generated by the dataloader. + + """ + + def __init__(self, + dataloader: BaseDataLoader, # Any dataloader that inherits from BaseDataLoader + draw_parameter_function: callable = None, # function to draw parameters from distribution + distribution: Literal["fixed", "uniform"] | List = "fixed", # distribution for params during training, can be List for multiple parameters + parameter_names: List[str] = None, # names of the parameters + bounds_low: Union[int, float] | List = 0, # lower bound for params during training, can be List for multiple parameters + bounds_high: Union[int, float] | List = 1, # upper bound for params during training, can be List for multiple parameters + obsprocessor: callable = None # processor to combine features and parameters + ): + + if isinstance(distribution, list) or isinstance(bounds_low, list) or isinstance(bounds_high, list): + raise NotImplementedError("Multiple parameters not yet implemented") + if obsprocessor is None: + raise ValueError("Obsprocessor must be provided") + + self.distribution = [distribution] + self.bounds_low = [bounds_low] + self.bounds_high = [bounds_high] + + self.dataloader = dataloader + + self.draw_parameter = draw_parameter_function + self.obsprocessor = obsprocessor + + self.parameter_names = parameter_names + + def __getitem__(self, idx): + """ + Get the item at the provided idx. + + """ + + features, demand = self.dataloader[idx] + params = {} + for i in range(len(self.distribution)): + param = self.draw_parameter(self.distribution[0], self.bounds_low[0], self.bounds_high[0], samples=1) # idx always gets a single sample + params[self.parameter_names[i]] = param + + obs = params.copy() + obs["features"] = features + + obs = self.obsprocessor(obs) + + # create tuple of items + return obs, demand, params + +# %% ../nbs/00_utils/00_utils.ipynb 27 def merge_dictionaries(dict1, dict2): """ Merge two dictionaries. If a key is found in both dictionaries, raise a KeyError. """ for key in dict2: @@ -213,7 +270,7 @@ def merge_dictionaries(dict1, dict2): merged_dict = {**dict1, **dict2} return merged_dict -# %% ../nbs/00_utils/00_utils.ipynb 26 +# %% ../nbs/00_utils/00_utils.ipynb 29 def set_param(obj, name: str, # name of the parameter (will become the attribute name) input: Parameter | int | float | np.ndarray | List | None , # input value of the parameter diff --git a/nbs/00_utils/00_utils.ipynb b/nbs/00_utils/00_utils.ipynb index 9157219..58e70b6 100644 --- a/nbs/00_utils/00_utils.ipynb +++ b/nbs/00_utils/00_utils.ipynb @@ -202,7 +202,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L29){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L32){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## Parameter\n", "\n", @@ -225,7 +225,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L29){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L32){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## Parameter\n", "\n", @@ -265,7 +265,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L48){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L51){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### Parameter.__call__\n", "\n", @@ -276,7 +276,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L48){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L51){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### Parameter.__call__\n", "\n", @@ -304,7 +304,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L55){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L58){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### Parameter.get_value\n", "\n", @@ -315,7 +315,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L55){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L58){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### Parameter.get_value\n", "\n", @@ -343,7 +343,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L64){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L67){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### Parameter.set_value\n", "\n", @@ -368,7 +368,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L64){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L67){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### Parameter.set_value\n", "\n", @@ -410,7 +410,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L101){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L104){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### Parameter.shape\n", "\n", @@ -421,7 +421,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L101){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L104){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### Parameter.shape\n", "\n", @@ -449,7 +449,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L108){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L111){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### Parameter.size\n", "\n", @@ -460,7 +460,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L108){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L111){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### Parameter.size\n", "\n", @@ -573,7 +573,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L115){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L118){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## MDPInfo\n", "\n", @@ -600,7 +600,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L115){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L118){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## MDPInfo\n", "\n", @@ -644,7 +644,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L141){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L144){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MDPInfo.size\n", "\n", @@ -655,7 +655,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L141){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L144){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MDPInfo.size\n", "\n", @@ -683,7 +683,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L149){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L152){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MDPInfo.shape\n", "\n", @@ -694,7 +694,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L149){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L152){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### MDPInfo.shape\n", "\n", @@ -777,7 +777,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L157){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L160){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## DatasetWrapper\n", "\n", @@ -795,7 +795,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L157){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L160){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## DatasetWrapper\n", "\n", @@ -830,7 +830,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L175){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L178){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DatasetWrapper.__getitem__\n", "\n", @@ -841,7 +841,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L175){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L178){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DatasetWrapper.__getitem__\n", "\n", @@ -869,7 +869,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L185){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L188){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DatasetWrapper.__len__\n", "\n", @@ -881,7 +881,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L185){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L188){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DatasetWrapper.__len__\n", "\n", @@ -900,6 +900,82 @@ "show_doc(DatasetWrapper.__len__)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "class DatasetWrapperMeta(DatasetWrapper):\n", + " \"\"\"\n", + " This class is used to wrap a Pytorch Dataset around the ddopnew dataloader\n", + " to enable the usage of the Pytorch Dataloader during training. This way,\n", + " agents that are trained using Pytorch without interacting with the environment\n", + " can directly train on the data generated by the dataloader.\n", + " \n", + " \"\"\"\n", + "\n", + " def __init__(self, \n", + " dataloader: BaseDataLoader, # Any dataloader that inherits from BaseDataLoader\n", + " draw_parameter_function: callable = None, # function to draw parameters from distribution\n", + " distribution: Literal[\"fixed\", \"uniform\"] | List = \"fixed\", # distribution for params during training, can be List for multiple parameters\n", + " parameter_names: List[str] = None, # names of the parameters\n", + " bounds_low: Union[int, float] | List = 0, # lower bound for params during training, can be List for multiple parameters\n", + " bounds_high: Union[int, float] | List = 1, # upper bound for params during training, can be List for multiple parameters\n", + " obsprocessor: callable = None # processor to combine features and parameters\n", + " ):\n", + "\n", + " if isinstance(distribution, list) or isinstance(bounds_low, list) or isinstance(bounds_high, list):\n", + " raise NotImplementedError(\"Multiple parameters not yet implemented\")\n", + " if obsprocessor is None:\n", + " raise ValueError(\"Obsprocessor must be provided\")\n", + " \n", + " self.distribution = [distribution]\n", + " self.bounds_low = [bounds_low]\n", + " self.bounds_high = [bounds_high]\n", + " \n", + " self.dataloader = dataloader\n", + "\n", + " self.draw_parameter = draw_parameter_function\n", + " self.obsprocessor = obsprocessor\n", + "\n", + " self.parameter_names = parameter_names\n", + " \n", + " def __getitem__(self, idx):\n", + " \"\"\"\n", + " Get the item at the provided idx.\n", + "\n", + " \"\"\"\n", + "\n", + " features, demand = self.dataloader[idx]\n", + " params = {}\n", + " for i in range(len(self.distribution)):\n", + " param = self.draw_parameter(self.distribution[0], self.bounds_low[0], self.bounds_high[0], samples=1) # idx always gets a single sample\n", + " params[self.parameter_names[i]] = param\n", + " \n", + " obs = params.copy()\n", + " obs[\"features\"] = features\n", + "\n", + " obs = self.obsprocessor(obs)\n", + "\n", + " # create tuple of items\n", + " return obs, demand, params" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -928,7 +1004,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L204){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L260){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## merge_dictionaries\n", "\n", @@ -939,7 +1015,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L204){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L260){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## merge_dictionaries\n", "\n", @@ -1030,10 +1106,13 @@ "text/markdown": [ "---\n", "\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L271){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "## set_param\n", "\n", - "> set_param (obj, name:str, input:__main__.Parameter|float|numpy.ndarray,\n", - "> shape:tuple=(1,), new:bool=False)\n", + "> set_param (obj, name:str,\n", + "> input:Union[__main__.Parameter,int,float,numpy.ndarray,List,No\n", + "> neType], shape:tuple=(1,), new:bool=False)\n", "\n", "*Set a parameter for the class. It converts scalar values to numpy arrays and ensures that\n", "environment parameters are either of the Parameter class of Numpy arrays. If new is set to True, \n", @@ -1044,17 +1123,20 @@ "| -- | -------- | ----------- | ----------- |\n", "| obj | | | |\n", "| name | str | | name of the parameter (will become the attribute name) |\n", - "| input | __main__.Parameter \\| float \\| numpy.ndarray | | input value of the parameter |\n", + "| input | Union | | input value of the parameter |\n", "| shape | tuple | (1,) | shape of the parameter |\n", "| new | bool | False | whether to create a new parameter or update an existing one |" ], "text/plain": [ "---\n", "\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/utils.py#L271){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "\n", "## set_param\n", "\n", - "> set_param (obj, name:str, input:__main__.Parameter|float|numpy.ndarray,\n", - "> shape:tuple=(1,), new:bool=False)\n", + "> set_param (obj, name:str,\n", + "> input:Union[__main__.Parameter,int,float,numpy.ndarray,List,No\n", + "> neType], shape:tuple=(1,), new:bool=False)\n", "\n", "*Set a parameter for the class. It converts scalar values to numpy arrays and ensures that\n", "environment parameters are either of the Parameter class of Numpy arrays. If new is set to True, \n", @@ -1065,7 +1147,7 @@ "| -- | -------- | ----------- | ----------- |\n", "| obj | | | |\n", "| name | str | | name of the parameter (will become the attribute name) |\n", - "| input | __main__.Parameter \\| float \\| numpy.ndarray | | input value of the parameter |\n", + "| input | Union | | input value of the parameter |\n", "| shape | tuple | (1,) | shape of the parameter |\n", "| new | bool | False | whether to create a new parameter or update an existing one |" ] diff --git a/nbs/00_utils/11_obsprocessors.ipynb b/nbs/00_utils/11_obsprocessors.ipynb index 1c1b1ca..8d4995f 100644 --- a/nbs/00_utils/11_obsprocessors.ipynb +++ b/nbs/00_utils/11_obsprocessors.ipynb @@ -46,6 +46,35 @@ "import torch.nn.functional as F" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "class BaseProcessor():\n", + "\n", + " def determine_output_shape(self,\n", + " sample_input: Dict, #\n", + " flat: bool = False # if the flattend output shape should be returned\n", + " ) -> Tuple | List:\n", + "\n", + " \"\"\"\n", + " Determine the output shape based on the input dictionary.\n", + " \"\"\"\n", + "\n", + " if flat:\n", + " output = self.__call__(sample_input, flatten=True)\n", + " else:\n", + " output = self.__call__(sample_input, flatten=False)\n", + " if isinstance(output, list):\n", + " return [output_element.shape for output_element in output]\n", + " else:\n", + " return output.shape" + ] + }, { "cell_type": "code", "execution_count": null, @@ -151,7 +180,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/obsprocessors.py#L17){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/obsprocessors.py#L38){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## FlattenTimeDimNumpy\n", "\n", @@ -170,7 +199,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/obsprocessors.py#L17){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/obsprocessors.py#L38){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## FlattenTimeDimNumpy\n", "\n", @@ -206,7 +235,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/obsprocessors.py#L32){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/obsprocessors.py#L53){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### FlattenTimeDimNumpy.check_input\n", "\n", @@ -221,7 +250,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/obsprocessors.py#L32){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/obsprocessors.py#L53){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### FlattenTimeDimNumpy.check_input\n", "\n", @@ -253,7 +282,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/obsprocessors.py#L72){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/obsprocessors.py#L93){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### FlattenTimeDimNumpy.__call__\n", "\n", @@ -265,7 +294,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/obsprocessors.py#L72){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/obsprocessors.py#L93){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### FlattenTimeDimNumpy.__call__\n", "\n", @@ -291,7 +320,7 @@ "outputs": [], "source": [ "#| export\n", - "class ConvertDictSpace():\n", + "class ConvertDictSpace(BaseProcessor):\n", "\n", " \"\"\" \n", "\n", @@ -364,26 +393,87 @@ " else:\n", " return np.concatenate(obs, axis=1)\n", "\n", - " return np.concatenate(obs, axis=0)\n", + " return np.concatenate(obs, axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class AddParamsToFeatures(BaseProcessor):\n", "\n", - " def determine_output_shape(self,\n", - " sample_input: Dict, #\n", - " flat: bool = False # if the flattend output shape should be returned\n", - " ) -> Tuple | List:\n", + " \"\"\" \n", + "\n", + " A utility class to process a dictionary of numpy arrays, with options to preserve or flatten the time dimension.\n", + "\n", + " Note, this class is only used to preprocess output from the environment without batch dimension.\n", + " \n", + " \"\"\"\n", + "\n", + " def __init__(self,\n", + " keep_time_dim: Optional[bool] = False, #If time timension should be flattened as well.\n", + " hybrid: Optional[bool] = False, # If the param dim should be added as separate vector or concatenated to the features.\n", + " ):\n", + "\n", + " self.keep_time_dim = keep_time_dim\n", + " self.hybrid = hybrid\n", + "\n", + " if not keep_time_dim and hybrid:\n", + " raise ValueError(\"For flattened vector, hybrid should be be merged with features directy.\")\n", + "\n", + "\n", + " def __call__(self, \n", + " input: Dict, # Observation as dict of with numpy arrays\n", + " flatten: bool = False, # whether to flatten composite spaces (non-composite spaces will depend on self.keep_time_dim)\n", + " ) -> List[np.ndarray] | np.ndarray: \n", "\n", " \"\"\"\n", - " Determine the output shape based on the input dictionary.\n", + " Process the input dictionary by converting it to a numpy array.\n", " \"\"\"\n", "\n", - " if flat:\n", - " output = self.__call__(sample_input, flatten=True)\n", - " else:\n", - " output = self.__call__(sample_input, flatten=False)\n", - " if isinstance(output, list):\n", - " return [output[0].shape, output[1].shape]\n", - " else:\n", - " return output.shape\n", - " " + " # print(\"input to processor: \", input)\n", + " input = input.copy()\n", + " features = input[\"features\"] if self.keep_time_dim else input[\"features\"].flatten()\n", + " del input[\"features\"]\n", + "\n", + "\n", + " if self.hybrid:\n", + " obs_1d = [] # features or time X features\n", + " obs_2d = [] # time X features\n", + " obs.append(input[\"features\"])\n", + " \n", + " for counter, (key, value) in enumerate(input.items()):\n", + " if not isinstance(value, np.ndarray):\n", + " raise TypeError(f\"Expected input to be a dictionary of numpy arrays, but got {type(value)} instead.\")\n", + " \n", + " if value.ndim == 1:\n", + " if features.ndim == 1:\n", + " features = np.concatenate([features, value])\n", + " else:\n", + " if self.hybrid:\n", + " raise NotImplementedError(\"Hybrid not implemented yet.\")\n", + " # expand value to 2d by copy time dimension\n", + " else:\n", + " value = np.expand_dims(value, axis=0)\n", + " value = np.repeat(value, features.shape[0], axis=0)\n", + " features = np.concatenate([features, value.flatten()])\n", + " else:\n", + " if value.shape == features.shape:\n", + " features = np.concatenate([features, value.flatten()])\n", + " else:\n", + " raise ValueError(f\"Expected input to have the same shape as features, but got {value.shape} instead (feature shape: {features.shape}).\")\n", + " \n", + " if self.hybrid:\n", + " if flatten:\n", + " raise NotImplementedError(\"Hybrid not implemented yet.\")\n", + " else:\n", + " raise NotImplementedError(\"Hybrid not implemented yet.\")\n", + " else: \n", + " return features\n", + " " ] }, { diff --git a/nbs/00_utils/20_torch_loss_functions.ipynb b/nbs/00_utils/20_torch_loss_functions.ipynb index 052a649..9a85dab 100644 --- a/nbs/00_utils/20_torch_loss_functions.ipynb +++ b/nbs/00_utils/20_torch_loss_functions.ipynb @@ -44,7 +44,9 @@ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", - "from torch.nn.modules.loss import _Loss" + "from torch.nn.modules.loss import _Loss\n", + "\n", + "import warnings" ] }, { @@ -71,9 +73,14 @@ " )\n", "\n", " expanded_input, expanded_target = torch.broadcast_tensors(input, target)\n", - " \n", + "\n", + " # print(expanded_input.size(), expanded_target.size(), quantile.size())\n", + " # print(quantile)\n", + "\n", " loss = torch.max((expanded_target - expanded_input) * quantile, (expanded_input - expanded_target) * (1 - quantile))\n", "\n", + " # print(losks.size())\n", + "\n", " if reduction == 'mean':\n", " return loss.mean()\n", " elif reduction == 'sum':\n", @@ -112,23 +119,53 @@ "\n", " \"\"\"\n", "\n", - " if isinstance(quantile, Parameter):\n", - " quantile = quantile.get_value()\n", - " \n", - " quantile = torch.tensor(quantile, dtype=input.dtype, device=input.device)\n", + " quantile = self.convert_quantile(quantile, input_dtype=input.dtype, device=input.device)\n", " \n", - " if not (target.shape[1] == input.shape[1] == quantile.shape[0]):\n", + " if not (target.shape == input.shape == quantile.shape):\n", " warnings.warn(\n", - " f\"Mismatch in dimensions: target dimension 2 size ({target.size(2)}), input dimension 2 size ({input.size(2)}), \"\n", - " f\"and quantile dimension 1 size ({quantile.size(1)}) must be the same. \"\n", + " f\"Mismatch in dimensions: target dimension ({target.shape}), input dimension ({input.shape}), \"\n", + " f\"and quantile dimension ({quantile.shape}) must be the same. \"\n", " \"This will likely lead to incorrect results due to broadcasting. \"\n", " \"Please ensure they have the same size.\",\n", " stacklevel=2,\n", " )\n", "\n", - " return quantile_loss(input, target, quantile, reduction=self.reduction)\n" + " return quantile_loss(input, target, quantile, reduction=self.reduction)\n", + " \n", + " def convert_quantile(self, quantile: Parameter | np.ndarray, input_dtype: torch.dtype = torch.float32, device: torch.device = torch.device('cpu')) -> torch.Tensor:\n", + " \n", + " if isinstance(quantile, Parameter):\n", + " quantile = quantile.get_value()\n", + " elif isinstance(quantile, np.ndarray):\n", + " quantile = torch.tensor(quantile, dtype=input_dtype, device=device)\n", + " elif isinstance(quantile, torch.Tensor):\n", + " # ensure dtype and device are the same as the input tensor\n", + " quantile = quantile.to(dtype=input_dtype, device=device)\n", + " else:\n", + " raise ValueError(f\"quantile must be of type Parameter, np.ndarray, or torch.Tensor, but got {type(quantile)}\")\n", + "\n", + " return quantile\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -139,7 +176,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/torch_utils/loss_functions.py#L47){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/torch_utils/loss_functions.py#L53){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## TorchQuantileLoss\n", "\n", @@ -158,7 +195,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/torch_utils/loss_functions.py#L47){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/torch_utils/loss_functions.py#L53){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## TorchQuantileLoss\n", "\n", @@ -194,7 +231,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/torch_utils/loss_functions.py#L60){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/torch_utils/loss_functions.py#L66){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### TorchQuantileLoss.forward\n", "\n", @@ -214,7 +251,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/torch_utils/loss_functions.py#L60){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/torch_utils/loss_functions.py#L66){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### TorchQuantileLoss.forward\n", "\n", diff --git a/nbs/20_base_env/10_base_env.ipynb b/nbs/20_base_env/10_base_env.ipynb index 93fcefa..e74ba6e 100644 --- a/nbs/20_base_env/10_base_env.ipynb +++ b/nbs/20_base_env/10_base_env.ipynb @@ -28,6 +28,11 @@ "from nbdev.showdoc import *" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, { "cell_type": "code", "execution_count": null, diff --git a/nbs/21_envs_inventory/10_base_inventory_env.ipynb b/nbs/21_envs_inventory/10_base_inventory_env.ipynb index d59c33d..47e16e9 100644 --- a/nbs/21_envs_inventory/10_base_inventory_env.ipynb +++ b/nbs/21_envs_inventory/10_base_inventory_env.ipynb @@ -183,7 +183,11 @@ "\n", "> BaseInventoryEnv (mdp_info:ddopnew.utils.MDPInfo,\n", "> postprocessors:list[object]|None=None,\n", - "> mode:str='train', return_truncation:str=True)\n", + "> mode:str='train', return_truncation:str=True, dataloade\n", + "> r:ddopnew.dataloaders.base.BaseDataLoader=None,\n", + "> horizon_train:int=100, underage_cost:Union[numpy.ndarra\n", + "> y,ddopnew.utils.Parameter,int,float]=1, overage_cost:Un\n", + "> ion[numpy.ndarray,ddopnew.utils.Parameter,int,float]=0)\n", "\n", "*Base class for inventory management environments. This class inherits from BaseEnvironment.*\n", "\n", @@ -193,6 +197,10 @@ "| postprocessors | list[object] \\| None | None | default is empty list |\n", "| mode | str | train | Initial mode (train, val, test) of the environment |\n", "| return_truncation | str | True | whether to return a truncated condition in step function |\n", + "| dataloader | BaseDataLoader | None | dataloader for the environment |\n", + "| horizon_train | int | 100 | horizon for training mode |\n", + "| underage_cost | Union | 1 | underage cost per unit |\n", + "| overage_cost | Union | 0 | overage cost per unit (zero in most cases) |\n", "| **Returns** | **None** | | |" ], "text/plain": [ @@ -204,7 +212,11 @@ "\n", "> BaseInventoryEnv (mdp_info:ddopnew.utils.MDPInfo,\n", "> postprocessors:list[object]|None=None,\n", - "> mode:str='train', return_truncation:str=True)\n", + "> mode:str='train', return_truncation:str=True, dataloade\n", + "> r:ddopnew.dataloaders.base.BaseDataLoader=None,\n", + "> horizon_train:int=100, underage_cost:Union[numpy.ndarra\n", + "> y,ddopnew.utils.Parameter,int,float]=1, overage_cost:Un\n", + "> ion[numpy.ndarray,ddopnew.utils.Parameter,int,float]=0)\n", "\n", "*Base class for inventory management environments. This class inherits from BaseEnvironment.*\n", "\n", @@ -214,6 +226,10 @@ "| postprocessors | list[object] \\| None | None | default is empty list |\n", "| mode | str | train | Initial mode (train, val, test) of the environment |\n", "| return_truncation | str | True | whether to return a truncated condition in step function |\n", + "| dataloader | BaseDataLoader | None | dataloader for the environment |\n", + "| horizon_train | int | 100 | horizon for training mode |\n", + "| underage_cost | Union | 1 | underage cost per unit |\n", + "| overage_cost | Union | 0 | overage cost per unit (zero in most cases) |\n", "| **Returns** | **None** | | |" ] }, @@ -236,7 +252,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/base.py#L36){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/base.py#L50){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### BaseInventoryEnv.set_observation_space\n", "\n", @@ -262,7 +278,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/base.py#L36){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/base.py#L50){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### BaseInventoryEnv.set_observation_space\n", "\n", @@ -305,7 +321,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/base.py#L65){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/base.py#L79){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### BaseInventoryEnv.set_action_space\n", "\n", @@ -330,7 +346,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/base.py#L65){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/base.py#L79){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### BaseInventoryEnv.set_action_space\n", "\n", @@ -372,7 +388,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/base.py#L100){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/base.py#L114){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### BaseInventoryEnv.reset\n", "\n", @@ -392,7 +408,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/base.py#L100){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/base.py#L114){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### BaseInventoryEnv.reset\n", "\n", @@ -429,7 +445,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/base.py#L87){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/base.py#L101){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### BaseInventoryEnv.get_observation\n", "\n", @@ -441,7 +457,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/base.py#L87){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/base.py#L101){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### BaseInventoryEnv.get_observation\n", "\n", diff --git a/nbs/21_envs_inventory/20_single_period_envs.ipynb b/nbs/21_envs_inventory/20_single_period_envs.ipynb index ed594ad..f5ae701 100644 --- a/nbs/21_envs_inventory/20_single_period_envs.ipynb +++ b/nbs/21_envs_inventory/20_single_period_envs.ipynb @@ -36,11 +36,11 @@ "source": [ "#| export\n", "from abc import ABC, abstractmethod\n", - "from typing import Union, Tuple\n", + "from typing import Union, Tuple, Literal\n", "\n", "from ddopnew.utils import Parameter, MDPInfo\n", "from ddopnew.dataloaders.base import BaseDataLoader\n", - "from ddopnew.loss_functions import pinball_loss\n", + "from ddopnew.loss_functions import pinball_loss, quantile_loss\n", "from ddopnew.envs.inventory.base import BaseInventoryEnv\n", "\n", "import gymnasium as gym\n", @@ -81,6 +81,7 @@ " self.print=False\n", "\n", " num_SKUs = dataloader.num_units if num_SKUs is None else num_SKUs\n", + "\n", " if not isinstance(num_SKUs, int):\n", " raise ValueError(\"num_SKUs must be an integer.\")\n", " \n", @@ -120,7 +121,7 @@ " if action.ndim == 2 and action.shape[0] == 1:\n", " action = np.squeeze(action, axis=0) # Remove the first dimension\n", "\n", - " cost_per_SKU = pinball_loss(self.demand, action, self.underage_cost, self.overage_cost)\n", + " cost_per_SKU = self.determine_cost(action)\n", " reward = -np.sum(cost_per_SKU) # negative because we want to minimize the cost\n", "\n", " terminated = False # in this problem there is no termination condition\n", @@ -141,7 +142,6 @@ "\n", " observation, self.demand = self.get_observation()\n", "\n", - "\n", " return observation, reward, terminated, truncated, info\n", " \n", " else:\n", @@ -154,7 +154,14 @@ " print(\"next demand:\", self.demand)\n", " time.sleep(3)\n", "\n", - " return observation, reward, terminated, truncated, info" + " return observation, reward, terminated, truncated, info\n", + "\n", + " def determine_cost(self, action: np.ndarray) -> np.ndarray:\n", + " \"\"\"\n", + " Determine the cost per SKU given the action taken. The cost is the sum of underage and overage costs.\n", + " \"\"\"\n", + " # Compute the cost per SKU\n", + " return pinball_loss(self.demand, action, self.underage_cost, self.overage_cost)" ] }, { @@ -260,13 +267,13 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/single_period.py#L71){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/single_period.py#L70){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### NewsvendorEnv.step_\n", "\n", "> NewsvendorEnv.step_ (action:numpy.ndarray)\n", "\n", - "*Step function implementing the Newsvendor logic. Note that the dataloader will return an observation and a demad,\n", + "*Step function implementing the Newsvendor logic. Note that the dataloader will return an observation and a demand,\n", "which will be relevant in the next period. The observation will be returned directly, while the demand will be \n", "temporarily stored under self.demand and used in the next step.*\n", "\n", @@ -278,13 +285,13 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/single_period.py#L71){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/inventory/single_period.py#L70){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### NewsvendorEnv.step_\n", "\n", "> NewsvendorEnv.step_ (action:numpy.ndarray)\n", "\n", - "*Step function implementing the Newsvendor logic. Note that the dataloader will return an observation and a demad,\n", + "*Step function implementing the Newsvendor logic. Note that the dataloader will return an observation and a demand,\n", "which will be relevant in the next period. The observation will be returned directly, while the demand will be \n", "temporarily stored under self.demand and used in the next step.*\n", "\n", @@ -319,21 +326,25 @@ "name": "stdout", "output_type": "stream", "text": [ - "[2]\n" - ] - }, - { - "ename": "TypeError", - "evalue": "only integer scalar arrays can be converted to a scalar index", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[7], line 16\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtruncated:\u001b[39m\u001b[38;5;124m\"\u001b[39m, truncated)\n\u001b[1;32m 14\u001b[0m dataloader \u001b[38;5;241m=\u001b[39m NormalDistributionDataLoader(mean\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m4\u001b[39m, \u001b[38;5;241m3\u001b[39m], std\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m], num_units\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m---> 16\u001b[0m test_env \u001b[38;5;241m=\u001b[39m \u001b[43mNewsvendorEnv\u001b[49m\u001b[43m(\u001b[49m\u001b[43munderage_cost\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moverage_cost\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataloader\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhorizon_train\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 18\u001b[0m obs \u001b[38;5;241m=\u001b[39m test_env\u001b[38;5;241m.\u001b[39mreset(start_index\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m##### RESET #####\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "Cell \u001b[0;32mIn[4], line 34\u001b[0m, in \u001b[0;36mNewsvendorEnv.__init__\u001b[0;34m(self, underage_cost, overage_cost, q_bound_low, q_bound_high, dataloader, num_SKUs, gamma, horizon_train, postprocessors, mode, return_truncation)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mset_param(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnum_SKUs\u001b[39m\u001b[38;5;124m\"\u001b[39m, num_SKUs, new\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 32\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_SKUs)\n\u001b[0;32m---> 34\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mset_param\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mq_bound_low\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mq_bound_low\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshape\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_SKUs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnew\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 35\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mset_param(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mq_bound_high\u001b[39m\u001b[38;5;124m\"\u001b[39m, q_bound_high, shape\u001b[38;5;241m=\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_SKUs,), new\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mset_observation_space(dataloader\u001b[38;5;241m.\u001b[39mX_shape)\n", - "File \u001b[0;32m~/Documents/02_PhD/Other_python_projects/00_ddop_new/ddopnew/ddopnew/envs/base.py:73\u001b[0m, in \u001b[0;36mBaseEnvironment.set_param\u001b[0;34m(self, name, input, shape, new)\u001b[0m\n\u001b[1;32m 70\u001b[0m param \u001b[38;5;241m=\u001b[39m \u001b[38;5;28minput\u001b[39m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28minput\u001b[39m, (\u001b[38;5;28mint\u001b[39m, \u001b[38;5;28mfloat\u001b[39m)):\n\u001b[0;32m---> 73\u001b[0m param \u001b[38;5;241m=\u001b[39m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfull\u001b[49m\u001b[43m(\u001b[49m\u001b[43mshape\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28minput\u001b[39m, \u001b[38;5;28mlist\u001b[39m):\n\u001b[1;32m 76\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray(\u001b[38;5;28minput\u001b[39m)\n", - "File \u001b[0;32m~/miniforge3/envs/inventory_gym_2/lib/python3.11/site-packages/numpy/core/numeric.py:329\u001b[0m, in \u001b[0;36mfull\u001b[0;34m(shape, fill_value, dtype, order, like)\u001b[0m\n\u001b[1;32m 327\u001b[0m fill_value \u001b[38;5;241m=\u001b[39m asarray(fill_value)\n\u001b[1;32m 328\u001b[0m dtype \u001b[38;5;241m=\u001b[39m fill_value\u001b[38;5;241m.\u001b[39mdtype\n\u001b[0;32m--> 329\u001b[0m a \u001b[38;5;241m=\u001b[39m \u001b[43mempty\u001b[49m\u001b[43m(\u001b[49m\u001b[43mshape\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43morder\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 330\u001b[0m multiarray\u001b[38;5;241m.\u001b[39mcopyto(a, fill_value, casting\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124munsafe\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 331\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m a\n", - "\u001b[0;31mTypeError\u001b[0m: only integer scalar arrays can be converted to a scalar index" + "##### RESET #####\n", + "determining cost\n", + "##### STEP: 1 #####\n", + "reward: -3.082737612950405\n", + "info: {'demand': array([3.03025524, 0.45675355]), 'action': array([1.5745691, 1.2702793], dtype=float32), 'cost_per_SKU': array([1.45568614, 1.62705147])}\n", + "next observation: None\n", + "truncated: False\n", + "determining cost\n", + "##### STEP: 2 #####\n", + "reward: -1.1137939333845048\n", + "info: {'demand': array([3.15054544, 1.6558931 ]), 'action': array([3.4893591, 1.2197266], dtype=float32), 'cost_per_SKU': array([0.6776274 , 0.43616654])}\n", + "next observation: None\n", + "truncated: False\n", + "determining cost\n", + "##### STEP: 3 #####\n", + "reward: -6.56132299073508\n", + "info: {'demand': array([5.11224062, 0. ]), 'action': array([1.6151099, 1.5320961], dtype=float32), 'cost_per_SKU': array([3.4971307, 3.0641923])}\n", + "next observation: None\n", + "truncated: True\n" ] } ], @@ -377,7 +388,67 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "#################### RESET ####################\n", + "#################### RUN IN TRAIN MODE ####################\n", + "determining cost\n", + "##### STEP: 1 #####\n", + "reward: -0.626389269353475\n", + "info: {'demand': array([0.41801109, 0.41814421]), 'action': array([1.1567007 , 0.16109976], dtype=float32), 'cost_per_SKU': array([0.36934482, 0.25704445])}\n", + "next observation: [0.51654708 0.67238019]\n", + "truncated: False\n", + "determining cost\n", + "##### STEP: 2 #####\n", + "reward: -1.0822899643361983\n", + "info: {'demand': array([0.61617324, 0.52211535]), 'action': array([0.19506973, 1.8444883 ], dtype=float32), 'cost_per_SKU': array([0.42110351, 0.66118645])}\n", + "next observation: [0.71467365 0.37996181]\n", + "truncated: False\n", + "determining cost\n", + "##### STEP: 3 #####\n", + "reward: -1.420132036516529\n", + "info: {'demand': array([0.45242345, 0.60924132]), 'action': array([2.9406414 , 0.96128744], dtype=float32), 'cost_per_SKU': array([1.24410898, 0.17602306])}\n", + "next observation: [0.78011439 1. ]\n", + "truncated: True\n", + "#################### RUN IN VAL MODE ####################\n", + "determining cost\n", + "##### STEP: 1 #####\n", + "reward: -0.3543074271338995\n", + "info: {'demand': array([0. , 0.16760013]), 'action': array([0.5284514, 0.3477636], dtype=float32), 'cost_per_SKU': array([0.26422569, 0.09008174])}\n", + "next observation: [0. 0.59527916]\n", + "truncated: True\n", + "#################### RUN IN TEST MODE ####################\n", + "determining cost\n", + "##### STEP: 1 #####\n", + "reward: -0.32474637093475456\n", + "info: {'demand': array([0.3316407 , 0.33063685]), 'action': array([0.637311 , 0.15872562], dtype=float32), 'cost_per_SKU': array([0.15283514, 0.17191123])}\n", + "next observation: [1. 0.71807281]\n", + "truncated: True\n", + "#################### RUN IN TRAIN MODE AGAIN ####################\n", + "determining cost\n", + "##### STEP: 1 #####\n", + "reward: -0.37947741548097014\n", + "info: {'demand': array([0.41801109, 0.41814421]), 'action': array([0.06180833, 0.46469352], dtype=float32), 'cost_per_SKU': array([0.35620276, 0.02327465])}\n", + "next observation: [0.51654708 0.67238019]\n", + "truncated: False\n", + "determining cost\n", + "##### STEP: 2 #####\n", + "reward: -0.7171106756228236\n", + "info: {'demand': array([0.61617324, 0.52211535]), 'action': array([0.86906, 1.70345], dtype=float32), 'cost_per_SKU': array([0.12644337, 0.59066731])}\n", + "next observation: [0.71467365 0.37996181]\n", + "truncated: False\n", + "determining cost\n", + "##### STEP: 3 #####\n", + "reward: -0.9650588692783887\n", + "info: {'demand': array([0.45242345, 0.60924132]), 'action': array([0.1302417, 1.8949956], dtype=float32), 'cost_per_SKU': array([0.32218174, 0.64287713])}\n", + "next observation: [0.78011439 1. ]\n", + "truncated: True\n" + ] + } + ], "source": [ "from sklearn.datasets import make_regression\n", "from sklearn.preprocessing import MinMaxScaler\n", @@ -414,174 +485,186 @@ "run_test_loop(test_env)\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Newsvendor Env that can provide a variable service level\n", + "\n", + "> Static inventory environment where a decision only affects the next period (Newsvendor problem),\n", + "> but with a variable service level (random during training, fixed during testing)" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "# #| export\n", - "# class NewsvendorVariableSLEnv(NewsvendorEnv, ABC):\n", - "# \"\"\"\n", - "\n", - "# \"\"\"\n", - "# def __init__(self,\n", - "# underage_cost: Union[np.ndarray, Parameter] = np.array([1]),\n", - "# overage_cost: Union[np.ndarray, Parameter] = np.array([1]),\n", - "# q_bound_low: Union[np.ndarray, Parameter] = np.array([0]),\n", - "# q_bound_high: Union[np.ndarray, Parameter] = np.array([np.inf]),\n", - "# dataloader: BaseDataLoader = None,\n", - "# gamma: float = 1,\n", - "# horizon: int = 100,\n", - "\n", - "# low_sl: np.ndarray = np.array([0.1]),\n", - "# high_sl: np.ndarray = np.array([0.9]),\n", - "\n", - "# ) -> None:\n", - " \n", - "# super().__init__( \n", - "# underage_cost=underage_cost,\n", - "# overage_cost=overage_cost,\n", - "# q_bound_low=q_bound_low,\n", - "# q_bound_high=q_bound_high,\n", - "# dataloader=dataloader,\n", - "# gamma=gamma,\n", - "# horizon=horizon,\n", - "# )\n", - "\n", - "# self.low_sl = set_env_parameter(low_sl, self.num_SKUs)\n", - "# self.high_sl = set_env_parameter(high_sl, self.num_SKUs)\n", - " \n", - "# def set_observation_space(self,\n", - "# shape: tuple,\n", - "# low: Union[np.ndarray, float] = -np.inf,\n", - "# high: Union[np.ndarray, float] = np.inf) -> None:\n", + "#| export\n", + "class NewsvendorEnvVariableSL(NewsvendorEnv, ABC):\n", + " def __init__(self,\n", + "\n", + " # Additional parameters:\n", + " sl_bound_low: Union[np.ndarray, Parameter, int, float] = 0.1, # lower bound of the service level during training\n", + " sl_bound_high: Union[np.ndarray, Parameter, int, float] = 0.9, # upper bound of the service level during training\n", + " sl_distribution: Literal[\"fixed\", \"uniform\"] = \"fixed\", # distribution of the random service level during training, if fixed then the service level is fixed to sl_test_val\n", + " evaluation_metric: Literal[\"pinball_loss\", \"quantile_loss\"] = \"quantile_loss\", # quantile loss is the generic quantile loss (independent of cost levels) while pinball loss uses the specific under- and overage costs\n", + " sl_test_val: Union[np.ndarray, Parameter, int, float] = None, # service level during test and validation, alternatively use cu and co\n", + "\n", + " underage_cost: Union[np.ndarray, Parameter, int, float] = 1, # underage cost per unit\n", + " overage_cost: Union[np.ndarray, Parameter, int, float] = 1, # overage cost per unit\n", + " q_bound_low: Union[np.ndarray, Parameter, int, float] = 0, # lower bound of the order quantity\n", + " q_bound_high: Union[np.ndarray, Parameter, int, float] = np.inf, # upper bound of the order quantity\n", + " dataloader: BaseDataLoader = None, # dataloader\n", + " num_SKUs: Union[int] = None, # if None it will be inferred from the DataLoader\n", + " gamma: float = 1, # discount factor\n", + " horizon_train: int | str = \"use_all_data\", # if \"use_all_data\" then horizon is inferred from the DataLoader\n", + " postprocessors: list[object] | None = None, # default is empty list\n", + " mode: str = \"train\", # Initial mode (train, val, test) of the environment\n", + " return_truncation: str = True # whether to return a truncated condition in step function\n", + " ) -> None:\n", + "\n", + " self.set_param(\"sl_bound_low\", sl_bound_low, shape=(1,), new=True)\n", + " self.set_param(\"sl_bound_high\", sl_bound_high, shape=(1,), new=True)\n", + " self.evaluation_metric = evaluation_metric\n", + " self.check_evaluation_metric\n", + " self.sl_distribution = sl_distribution\n", + " self.check_sl_distribution\n", + "\n", + " super().__init__(underage_cost=underage_cost,\n", + " overage_cost=overage_cost,\n", + " q_bound_low=q_bound_low,\n", + " q_bound_high=q_bound_high,\n", + " dataloader=dataloader,\n", + " num_SKUs=num_SKUs,\n", + " gamma=gamma,\n", + " horizon_train=horizon_train,\n", + " postprocessors=postprocessors,\n", + " mode=mode,\n", + " return_truncation=return_truncation)\n", + "\n", + " if sl_test_val is not None:\n", + " if self.underage_cost is None and self.overage_cost is None:\n", + " self.set_param(\"sl\", sl_test_val, shape=(num_SKUs[0],), new=True)\n", + " else:\n", + " raise ValueError(\"sl_test_val can only be used when underage_cost and overage_cost are None.\")\n", + " else:\n", + " if self.underage_cost is None or self.overage_cost is None:\n", + " raise ValueError(\"Either sl_test_val or underage_cost and overage_cost must be provided.\")\n", + " sl = self.underage_cost / (self.underage_cost + self.overage_cost)\n", + " self.set_param(\"sl\", sl, shape=(self.num_SKUs[0],), new=True)\n", + "\n", + " def determine_cost(self, action: np.ndarray) -> np.ndarray:\n", + " \"\"\"\n", + " Determine the cost per SKU given the action taken. The cost is the sum of underage and overage costs.\n", + " \"\"\"\n", + "\n", + " # Compute the cost per SKU\n", + " if self.mode == \"train\": # during training only the service level is relevant\n", + " return quantile_loss(self.demand, action, self.sl_period)\n", + " else:\n", + " if self.evaluation_metric == \"pinball_loss\":\n", + " return pinball_loss(self.demand, action, self.underage_cost, self.overage_cost)\n", + " elif self.evaluation_metric == \"quantile_loss\":\n", + " return quantile_loss(self.demand, action, self.sl)\n", + "\n", + " def set_observation_space(self,\n", + " shape: tuple, # shape of the dataloader features\n", + " low: Union[np.ndarray, float] = -np.inf, # lower bound of the observation space\n", + " high: Union[np.ndarray, float] = np.inf, # upper bound of the observation space\n", + " samples_dim_included = True # whether the first dimension of the shape input is the number of samples\n", + " ) -> None:\n", " \n", - "# '''\n", - "# Set the observation space of the environment.\n", - "# '''\n", + " '''\n", + " Set the observation space of the environment.\n", + " This is a standard function for simple observation spaces. For more complex observation spaces,\n", + " this function should be overwritten. Note that it is assumped that the first dimension\n", + " is n_samples that is not relevant for the observation space.\n", "\n", - "# ### THIS MAKES NO SENSE:\n", + " '''\n", "\n", - "# # if shape is not None:\n", - "# # if not isinstance(shape, tuple):\n", - "# # raise ValueError(\"Shape must be a tuple.\")\n", - " \n", - "# # shape = shape[1:]\n", + " # To handle cases when no external information is available (e.g., parametric NV)\n", " \n", - "# # self.observation_space = gym.spaces.Dict({\n", - "# # 'X': gym.spaces.Box(low=low, high=high, shape=shape, dtype=np.float32),\n", - "# # 'sl': gym.spaces.Box(low=0, high=1, shape=(self.num_SKUs,), dtype=np.float32)\n", - "# # })\n", - "# # else:\n", - "# # self.observation_space = gym.spaces.Dict({\n", - "# # 'sl': gym.spaces.Box(low=0, high=1, shape=(self.num_SKUs,), dtype=np.float32)\n", - "# # })\n", - "\n", - "# def get_observation(self):\n", - "# \"\"\"\n", - "# Return the current observation.\n", - "# \"\"\"\n", + " if shape is None:\n", + " self.observation_space = None\n", + "\n", + " spaces = {}\n", + " if isinstance(shape, tuple):\n", + " if samples_dim_included:\n", + " shape = shape[1:] # assumed that the first dimension is the number of samples\n", + " spaces[\"features\"] = gym.spaces.Box(low=low, high=high, shape=shape, dtype=np.float32)\n", " \n", - "# X_item, Y_item = self.dataloader[self.index]\n", + " elif feature_shape is None:\n", + " pass\n", "\n", - "# underage_cost, overage_cost, sl = self.draw_service_level()\n", + " else:\n", + " raise ValueError(\"Shape for features must be a tuple or None\")\n", "\n", - "# self.underage_cost.set_value(underage_cost, (self.num_SKUs,))\n", - "# self.overage_cost.set_value(overage_cost, (self.num_SKUs,))\n", + " spaces[\"service_level\"] = gym.spaces.Box(low=0, high=1, shape=(self.num_SKUs[0],), dtype=np.float32)\n", "\n", - "# if X_item is not None:\n", - "# obs = {'X': X_item, 'sl': sl}\n", - "# else:\n", - "# obs = {'sl': sl}\n", + " self.observation_space = gym.spaces.Dict(spaces)\n", "\n", - "# return obs, Y_item\n", - " \n", - "# def draw_service_level(self):\n", + " @staticmethod # staticmethod such that the dataloader can also use the funciton\n", + " def draw_parameter(distribution, sl_bound_low, sl_bound_high, samples):\n", + " \n", + " if distribution == \"fixed\":\n", + " sl = np.random.uniform(sl_bound_low, sl_bound_high, size=(samples,))\n", + " elif distribution == \"uniform\":\n", + " sl = np.random.uniform(sl_bound_low, sl_bound_high, size=(samples,))\n", + " else:\n", + " raise ValueError(\"sl_distribution not recognized.\")\n", + " \n", + " return sl\n", + "\n", + " def get_observation(self):\n", " \n", - "# sl = np.random.uniform(self.low_sl, self.high_sl, self.num_SKUs)\n", + " \"\"\"\n", + " Return the current observation. This function is for the simple case where the observation\n", + " is only an x,y pair. For more complex observations, this function should be overwritten.\n", + " \"\"\"\n", + " \n", + " X_item, Y_item = self.dataloader[self.index]\n", + " \n", + " if self.mode == \"train\":\n", + " sl = self.draw_parameter(self.sl_distribution, self.sl_bound_low, self.sl_bound_high, samples = self.num_SKUs[0])\n", + " else:\n", + " sl = self.sl.copy() # evaluate on fixed sls\n", + " \n", + " self.sl_period = sl # store the service level to assess the action\n", "\n", - "# overage_cost = np.ones_like(sl)\n", - "# underage_cost = np.ones_like(sl)\n", + " return {\"features\": X_item, \"service_level\": sl}, Y_item\n", "\n", - "# # # Calculate underage_cost where sl >= 0.5\n", - "# underage_cost = np.where(sl < 0.5, sl / (1 - sl), underage_cost)\n", + " def check_evaluation_metric(self):\n", + " if self.evaluation_metric not in [\"pinball_loss\", \"quantile_loss\"]:\n", + " raise ValueError(\"evaluation_metric must be either 'pinball_loss' or 'quantile_loss'.\")\n", + " if self.evaluation_metric == \"pinball_loss\" and (self.underage_cost is None or self.overage_cost is None):\n", + " raise ValueError(\"Underage and overage costs must be provided for pinball loss.\")\n", + " if self.evaluation_metric == \"quantile_loss\" and (self.sl_test_val is None):\n", + " raise ValueError(\"sl_test_val must be provided for quantile loss.\")\n", + " \n", + " def check_sl_distribution(self):\n", + " if self.sl_distribution not in [\"fixed\", \"uniform\"]:\n", + " raise ValueError(\"sl_distribution must be 'uniform' or 'fixed'.\")\n", + "\n", + " def set_val_test_sl(self, sl_test_val):\n", + " self.set_param(\"sl\", sl_test_val, shape=(self.num_SKUs[0],), new=False)\n", + "\n", + " def update_cu_co(self, cu=None, co=None):\n", + "\n", + " if not hasattr(self, \"underage_cost\") or not hasattr(self, \"overage_cost\"):\n", + " logging.warning(\"Underage and overage costs were not set previously, setting them as new parameters.\")\n", + " self.set_param(\"underage_cost\", cu, shape=(self.num_SKUs[0],), new=True)\n", + " self.set_param(\"overage_cost\", co, shape=(self.num_SKUs[0],), new=True)\n", " \n", - "# # Calculate overage_cost where sl < 0.5\n", - "# overage_cost = np.where(sl >= 0.5, 1 / sl -1, overage_cost)\n", + " if cu is not None:\n", + " self.set_param(\"underage_cost\", cu, shape=(self.num_SKUs[0],), new=False)\n", + " if co is not None:\n", + " self.set_param(\"overage_cost\", co, shape=(self.num_SKUs[0],), new=False)\n", " \n", - "# return underage_cost, overage_cost, sl" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# test_env = NewsvendorVariableSLEnv(underage_cost=Parameter(np.array([1,1]), shape = (2,)), overage_cost=Parameter(np.array([0.5,0.5]), shape = (2,)), dataloader=dataloader, horizon=3)\n", - "\n", - "# print(test_env.observation_space)\n", - "# print(test_env.observation_space.sample())\n", - "\n", - "# obs = test_env.reset(start_index=0)\n", - "# print(\"##### RESET #####\")\n", - "# print(\"obs:\", obs)\n", - "\n", - "# truncated = False\n", - "# while not truncated:\n", - "# action = test_env.action_space.sample()\n", - "# obs, reward, terminated, truncated, info = test_env.step(action)\n", - "# print(\"##### STEP: \", test_env.index, \"#####\")\n", - "# print(\"reward:\", reward)\n", - "# print(\"info:\", info)\n", - "# print(\"obs:\", obs)\n", - "# print(\"truncated:\", truncated)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Simple Example with synthetic data:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# from sklearn.datasets import make_regression\n", - "# from sklearn.preprocessing import MinMaxScaler\n", - "\n", - "# from ddopnew.dataloaders.tabular import XYDataLoader\n", - "\n", - "# # create a simple dataset bounded between 0 and 1\n", - "# X, Y = make_regression(n_samples=100, n_features=2, n_targets=2, noise=0.1)\n", - "# scaler = MinMaxScaler()\n", - "# X = scaler.fit_transform(X)\n", - "# Y = scaler.fit_transform(Y)\n", - "\n", - "# dataloader = XYDataLoader(X, Y)\n", - "# test_env = NewsvendorVariableSLEnv(underage_cost=Parameter(np.array([1,1]), shape = (2,)), overage_cost=Parameter(np.array([0.5,0.5]), shape = (2,)), dataloader=dataloader, horizon=len(dataloader))\n", - "\n", - "# print(test_env.observation_space)\n", - "# print(test_env.observation_space.sample())\n", - "\n", - "# obs = test_env.reset(start_index=0)\n", - "# print(\"##### RESET #####\")\n", - "# print(\"obs:\", obs)\n", - "\n", - "# truncated = False\n", - "# while not truncated:\n", - "# action = test_env.action_space.sample()\n", - "# obs, reward, terminated, truncated, info = test_env.step(action)\n", - "# print(\"##### STEP: \", test_env.index, \"#####\")\n", - "# print(\"reward:\", reward)\n", - "# print(\"info:\", info)\n", - "# print(\"obs:\", obs)\n", - "# print(\"truncated:\", truncated)" + " sl = self.underage_cost / (self.underage_cost + self.overage_cost)\n", + "\n", + " self.set_param(\"sl\", sl, shape=(self.num_SKUs[0],), new=False)" ] }, { diff --git a/nbs/21_envs_inventory/30_multi_period_envs.ipynb b/nbs/21_envs_inventory/30_multi_period_envs.ipynb index 272cc59..fff32f1 100644 --- a/nbs/21_envs_inventory/30_multi_period_envs.ipynb +++ b/nbs/21_envs_inventory/30_multi_period_envs.ipynb @@ -236,7 +236,7 @@ "\n", " return observation, Y_item\n", "\n", - "\n", + " \n", " def reset(self,\n", " start_index: int | str = None, # index to start from\n", " state: np.ndarray = None # initial state\n", diff --git a/nbs/30_experiment_functions/10_experiment_functions.ipynb b/nbs/30_experiment_functions/10_experiment_functions.ipynb index 3bab4e7..9a7878e 100644 --- a/nbs/30_experiment_functions/10_experiment_functions.ipynb +++ b/nbs/30_experiment_functions/10_experiment_functions.ipynb @@ -130,7 +130,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L28){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L27){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## EarlyStoppingHandler\n", "\n", @@ -153,7 +153,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L28){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L27){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## EarlyStoppingHandler\n", "\n", @@ -193,7 +193,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L54){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L56){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### EarlyStoppingHandler.add_result\n", "\n", @@ -210,7 +210,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L54){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L56){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### EarlyStoppingHandler.add_result\n", "\n", @@ -413,7 +413,7 @@ " logging.debug(\"truncated: %s\", truncated)\n", "\n", " sample = (obs, action, reward, next_obs, terminated, truncated) # unlike mushroom do not include policy_state\n", - "\n", + " \n", " obs = next_obs\n", " \n", " dataset.append((sample, info))\n", @@ -521,6 +521,7 @@ " stop = False\n", "\n", " if stop:\n", + " log_info(R, J, n_epochs-epoch-1, tracking, \"val\")\n", " logging.info(f\"Early stopping after {epoch+1} epochs\")\n", " break\n", " \n", @@ -577,6 +578,7 @@ " stop = False\n", "\n", " if stop:\n", + " log_info(R, J, n_epochs-epoch-1, tracking, \"val\")\n", " logging.info(f\"Early stopping after {epoch+1} epochs\")\n", " break\n", " \n", @@ -606,7 +608,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L256){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L248){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### run_experiment\n", "\n", @@ -640,7 +642,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L256){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L248){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### run_experiment\n", "\n", @@ -721,7 +723,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L174){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L166){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### test_agent\n", "\n", @@ -742,7 +744,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L174){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L166){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### test_agent\n", "\n", @@ -780,7 +782,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L194){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### run_test_episode\n", "\n", @@ -801,7 +803,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/experiment_functions.py#L194){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### run_test_episode\n", "\n", @@ -845,7 +847,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "R: -5.299441731286187, J: -5.285677168806183\n" + "R: -3.1622384231698444, J: -3.150371113545027\n" ] } ], diff --git a/nbs/30_experiment_functions/20_meta_experiment_functions.ipynb b/nbs/30_experiment_functions/20_meta_experiment_functions.ipynb index e952719..67e89b9 100644 --- a/nbs/30_experiment_functions/20_meta_experiment_functions.ipynb +++ b/nbs/30_experiment_functions/20_meta_experiment_functions.ipynb @@ -44,15 +44,19 @@ "import sys\n", "import os\n", "import yaml\n", + "import pickle\n", + "import warnings\n", + "import torch\n", "\n", "from ddopnew.tracking import get_git_hash, get_library_version\n", "from ddopnew.agents.class_names import AGENT_CLASSES\n", "from ddopnew.dataloaders.tabular import XYDataLoader\n", "from ddopnew.datasets import DatasetLoader\n", - "from ddopnew.experiment_functions import EarlyStoppingHandler\n", + "from ddopnew.experiment_functions import EarlyStoppingHandler, test_agent\n", "\n", "import wandb\n", "\n", + "import gc\n", "\n", "import importlib\n", "\n", @@ -62,6 +66,33 @@ "from mushroom_rl.core import Core" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Warnings\n", + "\n", + "> Some warnings are irrelevant for this library" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "def set_warnings (logging_level):\n", + "\n", + " \"\"\" Set warnings to be ignored for the given logging level or higher.\"\"\"\n", + "\n", + " if logging.getLogger().isEnabledFor(logging_level):\n", + " warnings.filterwarnings(\"ignore\", category=UserWarning, message=\".*Box bound precision lowered by casting to float32.*\")\n", + " warnings.filterwarnings(\"ignore\", category=UserWarning, message=\".*TypedStorage is deprecated.*\")\n", + " warnings.filterwarnings(\"ignore\", category=FutureWarning, message=\".*You are using `torch.load` with `weights_only=False`.*\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -370,13 +401,142 @@ " if \"early_stopping_patience\" in config_train or \"early_stopping_warmup\" in config_train:\n", " warmup = config_train[\"early_stopping_warmup\"] if \"early_stopping_warmup\" in config_train else 0\n", " patience = config_train[\"early_stopping_patience\"] if \"early_stopping_patience\" in config_train else 0\n", - " earlystoppinghandler = EarlyStoppingHandler(warmup=warmup, patience=warmup)\n", + "\n", + " earlystoppinghandler = EarlyStoppingHandler(warmup=warmup, patience=patience)\n", " else:\n", " earlystoppinghandler = None\n", "\n", " return earlystoppinghandler" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Testing\n", + "\n", + "> Some functions to test the final model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "def prep_and_run_test(\n", + " agent,\n", + " environment,\n", + " agent_dir: str,\n", + " save_dataset: bool = True,\n", + " dataset_dir: str = None,\n", + " tracking = \"wandb\"):\n", + " \n", + " \"\"\"\n", + " Test the agent in the environment.\n", + " \"\"\"\n", + "\n", + " if save_dataset:\n", + " if dataset_dir is None:\n", + " raise ValueError(\"If save_dataset is True, dataset_dir must be specified.\")\n", + "\n", + " # load parameters of agent\n", + " agent.load(agent_dir)\n", + "\n", + " # Set agent and environment to test mode\n", + " agent.eval()\n", + " environment.test()\n", + "\n", + " # Run test episode\n", + " output = test_agent(\n", + " agent,\n", + " environment,\n", + " return_dataset=save_dataset,\n", + " tracking=tracking\n", + " )\n", + "\n", + " # Save dataset\n", + " if save_dataset:\n", + "\n", + " R, J, dataset = output\n", + "\n", + " if not os.path.exists(dataset_dir):\n", + " os.mkdir(dataset_dir)\n", + " else:\n", + " raise ValueError(\"Path to save dataset already exists\") # it should never exist since run_id is usually part or path and unique\n", + " \n", + " dir = os.path.join(dataset_dir, \"dataset_test.pkl\")\n", + "\n", + " with open (os.path.join(dir), \"wb\") as f:\n", + " pickle.dump(dataset, f)\n", + "\n", + " artifact = wandb.Artifact(\"transition_test_set\", type=\"dataset\")\n", + "\n", + " artifact.add_file(os.path.join(dir))\n", + "\n", + " wandb.run.log_artifact(artifact)\n", + " \n", + " else:\n", + "\n", + " R, J = output\n", + "\n", + " logging.info(f\"final evaluation on test set: R = {np.round(R, 10)} J = {np.round(J, 10)}\")\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Clean-up\n", + "\n", + "> Function to clean-up the experiment script" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "def clean_up(agent, environment):\n", + "\n", + " \"\"\" Clean up agent and environment to free up GPU memory \"\"\"\n", + " \n", + " # Delete agent and environment to free up GPU memory\n", + " del agent\n", + " del environment\n", + "\n", + " # Force garbage collection\n", + " gc.collect()\n", + "\n", + " # Clear GPU cache\n", + " if torch.cuda.is_available():\n", + " torch.cuda.empty_cache()\n", + "\n", + " wandb.finish()\n", + "\n", + " return None, None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/nbs/40_base_agents/10_AGENT_CLASSES.ipynb b/nbs/40_base_agents/10_AGENT_CLASSES.ipynb index bbe7f17..b51a05c 100644 --- a/nbs/40_base_agents/10_AGENT_CLASSES.ipynb +++ b/nbs/40_base_agents/10_AGENT_CLASSES.ipynb @@ -37,11 +37,20 @@ "#| export\n", "AGENT_CLASSES = {\n", " \"RandomAgent\": \"ddopnew.agents.saa.SAA\",\n", + "\n", " \"SAA\": \"ddopnew.agents.newsvendor.saa.NewsvendorSAAagent\",\n", " \"wSAA\": \"ddopnew.agents.newsvendor.saa.NewsvendorRFwSAAagent\",\n", " \"RFwSAA\": \"ddopnew.agents.newsvendor.saa.NewsvendorRFwSAAagent\",\n", + "\n", " \"lERM\": \"ddopnew.agents.newsvendor.erm.NewsvendorlERMAgent\",\n", " \"DLNV\": \"ddopnew.agents.newsvendor.erm.NewsvendorDLAgent\",\n", + " \"DLNVRNN\": \"ddopnew.agents.newsvendor.erm.NewsvendorDLRNNAgent\",\n", + " \"DLNVTransformer\": \"ddopnew.agents.newsvendor.erm.NewsvendorDLTransformerAgent\",\n", + "\n", + " \"lERMMeta\": \"ddopnew.agents.newsvendor.erm.NewsvendorlERMMetaAgent\",\n", + " \"DLNVMeta\": \"ddopnew.agents.newsvendor.erm.NewsvendorDLMetaAgent\",\n", + " \"DLNVRNNMeta\": \"ddopnew.agents.newsvendor.erm.NewsvendorDLRNNMetaAgent\",\n", + " \"DLNVTransformerMeta\": \"ddopnew.agents.newsvendor.erm.NewsvendorDLTransformerMetaAgent\",\n", "\n", " \"SAC\": \"ddopnew.agents.rl.sac.SACAgent\",\n", " \"SACRNN\": \"ddopnew.agents.rl.sac.SACRNNAgent\",\n", diff --git a/nbs/40_base_agents/10_base_agents.ipynb b/nbs/40_base_agents/10_base_agents.ipynb index e93c9df..036ac90 100644 --- a/nbs/40_base_agents/10_base_agents.ipynb +++ b/nbs/40_base_agents/10_base_agents.ipynb @@ -99,10 +99,16 @@ " Internal logic of the agent to be implemented in draw_action_ method.\n", " \"\"\"\n", "\n", - " observation = self.add_batch_dim(observation)\n", + " batch_added = False\n", + " if not isinstance(observation, dict):\n", + " observation = self.add_batch_dim(observation)\n", + " batch_added = True\n", "\n", " for obsprocessor in self.obsprocessors:\n", " observation = obsprocessor(observation)\n", + " if not isinstance(observation, dict) and not batch_added:\n", + " observation = self.add_batch_dim(observation)\n", + " batch_added = True\n", "\n", " action = self.draw_action_(observation)\n", " \n", diff --git a/nbs/41_NV_agents/11_NV_erm_agents.ipynb b/nbs/41_NV_agents/11_NV_erm_agents.ipynb index 101db0d..d85d55d 100644 --- a/nbs/41_NV_agents/11_NV_erm_agents.ipynb +++ b/nbs/41_NV_agents/11_NV_erm_agents.ipynb @@ -46,7 +46,7 @@ "\n", "from ddopnew.envs.base import BaseEnvironment\n", "from ddopnew.agents.base import BaseAgent\n", - "from ddopnew.utils import MDPInfo, Parameter, DatasetWrapper\n", + "from ddopnew.utils import MDPInfo, Parameter, DatasetWrapper, DatasetWrapperMeta\n", "from ddopnew.torch_utils.loss_functions import TorchQuantileLoss\n", "from ddopnew.torch_utils.obsprocessors import FlattenTimeDim\n", "\n", @@ -95,7 +95,9 @@ " self.device = self.set_device(device)\n", " \n", " self.set_dataloader(dataloader, dataloader_params)\n", + "\n", " self.set_model(input_shape, output_shape)\n", + " self.loss_function_params=None # default\n", " self.set_loss_function()\n", " self.set_optimizer(optimizer_params)\n", " self.set_learning_rate_scheduler(learning_rate_scheduler)\n", @@ -129,8 +131,12 @@ " Set the dataloader for the agent by wrapping it into a Torch Dataset\n", " \n", " \"\"\"\n", - " dataset = DatasetWrapper(dataloader)\n", - " self.dataloader = torch.utils.data.DataLoader(dataset, **dataloader_params)\n", + "\n", + " # check if class already have a dataloader\n", + " if not hasattr(self, 'dataloader'):\n", + "\n", + " dataset = DatasetWrapper(dataloader)\n", + " self.dataloader = torch.utils.data.DataLoader(dataset, **dataloader_params)\n", "\n", " @abstractmethod\n", " def set_loss_function(self):\n", @@ -145,18 +151,21 @@ " def set_optimizer(self, optimizer_params: dict): # dict with keys: optimizer, lr, weight_decay\n", " \n", " \"\"\" Set the optimizer for the model \"\"\"\n", - " optimizer = optimizer_params[\"optimizer\"]\n", - " optimizer_params_copy = optimizer_params.copy()\n", - " del optimizer_params_copy[\"optimizer\"]\n", - "\n", - " if optimizer == \"Adam\":\n", - " self.optimizer = torch.optim.Adam(self.model.parameters(), **optimizer_params_copy)\n", - " elif optimizer == \"SGD\":\n", - " self.optimizer = torch.optim.SGD(self.model.parameters(), **optimizer_params_copy)\n", - " elif optimizer == \"RMSprop\":\n", - " self.optimizer = torch.optim.RMSprop(self.model.parameters(), **optimizer_params_copy)\n", - " else:\n", - " raise ValueError(f\"Optimizer {optimizer} not supported\")\n", + "\n", + " if not hasattr(self, 'optimizer'):\n", + " \n", + " optimizer = optimizer_params[\"optimizer\"]\n", + " optimizer_params_copy = optimizer_params.copy()\n", + " del optimizer_params_copy[\"optimizer\"]\n", + "\n", + " if optimizer == \"Adam\":\n", + " self.optimizer = torch.optim.Adam(self.model.parameters(), **optimizer_params_copy)\n", + " elif optimizer == \"SGD\":\n", + " self.optimizer = torch.optim.SGD(self.model.parameters(), **optimizer_params_copy)\n", + " elif optimizer == \"RMSprop\":\n", + " self.optimizer = torch.optim.RMSprop(self.model.parameters(), **optimizer_params_copy)\n", + " else:\n", + " raise ValueError(f\"Optimizer {optimizer} not supported\")\n", " \n", " def set_learning_rate_scheduler(self, learning_rate_scheduler: None = None): #\n", " \"\"\" Set learning rate scheudler (can be None) \"\"\"\n", @@ -175,7 +184,11 @@ "\n", " for i, output in enumerate(self.dataloader):\n", " \n", - " X, y = output\n", + " if len(output)==3:\n", + " X, y, loss_function_params = output\n", + " else:\n", + " X, y = output\n", + " loss_function_params = None\n", "\n", " # convert X and y to float32\n", " X = X.type(torch.float32)\n", @@ -190,10 +203,12 @@ "\n", " y_pred = self.model(X)\n", "\n", - " if self.loss_function_params==None:\n", - " loss = self.loss_function(y_pred, y)\n", + " if loss_function_params is not None:\n", + " loss = self.loss_function(y_pred, y, **loss_function_params)\n", + " elif self.loss_function_params is not None:\n", + " loss = self.loss_function(y_pred, y, **self.loss_function_params)\n", " else:\n", - " loss = self.loss_function(y_pred, y, **self.loss_function_params) # TODO: add reduction param when defining loss function\n", + " loss = self.loss_function(y_pred, y)\n", "\n", " loss.backward()\n", " self.optimizer.step()\n", @@ -312,7 +327,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L26){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L27){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## SGDBaseAgent\n", "\n", @@ -345,7 +360,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L26){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L27){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## SGDBaseAgent\n", "\n", @@ -422,7 +437,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L81){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L86){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_dataloader\n", "\n", @@ -441,7 +456,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L81){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L86){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_dataloader\n", "\n", @@ -477,7 +492,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L94){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L103){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_loss_function\n", "\n", @@ -488,7 +503,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L94){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L103){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_loss_function\n", "\n", @@ -516,7 +531,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L99){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L108){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_model\n", "\n", @@ -527,7 +542,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L99){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L108){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_model\n", "\n", @@ -555,7 +570,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L103){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L112){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_optimizer\n", "\n", @@ -570,7 +585,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L103){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L112){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_optimizer\n", "\n", @@ -602,7 +617,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L119){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L131){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_learning_rate_scheduler\n", "\n", @@ -619,7 +634,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L119){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L131){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_learning_rate_scheduler\n", "\n", @@ -653,7 +668,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L126){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L138){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.fit_epoch\n", "\n", @@ -664,7 +679,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L126){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L138){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.fit_epoch\n", "\n", @@ -692,7 +707,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L165){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L186){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.draw_action_\n", "\n", @@ -708,7 +723,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L165){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L186){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.draw_action_\n", "\n", @@ -741,7 +756,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L175){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L196){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.predict\n", "\n", @@ -757,7 +772,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L175){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L196){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.predict\n", "\n", @@ -790,7 +805,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L197){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L218){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.train\n", "\n", @@ -801,7 +816,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L197){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L218){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.train\n", "\n", @@ -829,7 +844,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L223){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.eval\n", "\n", @@ -840,7 +855,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L223){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.eval\n", "\n", @@ -868,7 +883,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L207){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L228){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.to\n", "\n", @@ -883,7 +898,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L207){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L228){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.to\n", "\n", @@ -915,7 +930,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L211){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L232){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.save\n", "\n", @@ -931,7 +946,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L211){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L232){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.save\n", "\n", @@ -964,7 +979,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L239){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L260){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.load\n", "\n", @@ -979,7 +994,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L239){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L260){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.load\n", "\n", @@ -1033,7 +1048,6 @@ " agent_name: str | None = None,\n", " ):\n", "\n", - " \n", " cu = self.convert_to_numpy_array(cu)\n", " co = self.convert_to_numpy_array(co)\n", " \n", @@ -1053,12 +1067,13 @@ " device=device,\n", " agent_name=agent_name\n", " ) \n", + " \n", " def set_loss_function(self):\n", - "\n", + " \n", " \"\"\"Set the loss function for the model to the quantile loss. For training\n", " the model uses quantile loss and not the pinball loss with specific cu and \n", " co values to ensure similar scale of the feedback signal during training.\"\"\"\n", - "\n", + " \n", " self.loss_function_params = {\"quantile\": self.sl}\n", " self.loss_function = TorchQuantileLoss(reduction=\"mean\")\n", " \n", @@ -1075,7 +1090,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L263){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L284){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## NVBaseAgent\n", "\n", @@ -1113,7 +1128,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L263){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L284){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## NVBaseAgent\n", "\n", @@ -1168,7 +1183,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L307){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L328){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### NVBaseAgent.set_loss_function\n", "\n", @@ -1181,7 +1196,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L307){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L328){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### NVBaseAgent.set_loss_function\n", "\n", @@ -1283,7 +1298,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L319){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L340){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## NewsvendorlERMAgent\n", "\n", @@ -1324,7 +1339,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L319){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L340){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## NewsvendorlERMAgent\n", "\n", @@ -1395,7 +1410,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L370){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L391){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### NewsvendorlERMAgent.set_model\n", "\n", @@ -1406,7 +1421,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L370){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L391){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### NewsvendorlERMAgent.set_model\n", "\n", @@ -1440,21 +1455,568 @@ "name": "stdout", "output_type": "stream", "text": [ - "-18.726206754896214 -17.81568786174066\n" + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "-22.005190308678777 -21.025734475368946\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n", + "determining cost\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 2/2 [00:00<00:00, 53.21it/s]" + " 0%| | 0/2 [00:00 None: \n", + "\n", + " \"\"\" \"\"\"\n", + "\n", + " # check if class already have a dataloader\n", + "\n", + " print(\"setting meta datloader\")\n", + "\n", + " dataset = DatasetWrapperMeta(\n", + " dataloader = dataloader,\n", + " draw_parameter_function = draw_parameter_function,\n", + " distribution = distribution,\n", + " bounds_low = bounds_low,\n", + " bounds_high = bounds_high,\n", + " obsprocessor = obsprocessor,\n", + " parameter_names = parameter_names,\n", + " )\n", + "\n", + " self.dataloader = torch.utils.data.DataLoader(dataset, **dataloader_params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "class NewsvendorlERMMetaAgent(NewsvendorlERMAgent, BaseMetaAgent):\n", + "\n", + " \"\"\"\n", + " Newsvendor agent implementing Empirical Risk Minimization (ERM) approach \n", + " based on a linear (regression) model. In addition to the features, the agent\n", + " also gets the sl as input to be able to forecast the optimal order quantity\n", + " for different sl values. Depending on the training pipeline, this model can be \n", + " adapted to become a full meta-learning algorithm cross products and cross sls.\n", + "\n", + " \"\"\"\n", + "\n", + " def __init__(self, \n", + " # Parameters for meta Agent\n", + " dataset_meta_params: dict, # Parameters for meta dataloader\n", + "\n", + " # Parameters for lERM agent\n", + " environment_info: MDPInfo,\n", + " dataloader: BaseDataLoader,\n", + " cu: np.ndarray | Parameter,\n", + " co: np.ndarray | Parameter,\n", + " input_shape: Tuple,\n", + " output_shape: Tuple,\n", + " optimizer_params: dict | None = None, # default: {\"optimizer\": \"Adam\", \"lr\": 0.01, \"weight_decay\": 0.0}\n", + " learning_rate_scheduler = None, # TODO: add base class for learning rate scheduler for typing\n", + " model_params: dict | None = None, # default: {\"relu_output\": False}\n", + " dataloader_params: dict | None = None, # default: {\"batch_size\": 32, \"shuffle\": True}\n", + " obsprocessors: list | None = None, # default: []\n", + " torch_obsprocessors: list | None = None, # default: [FlattenTimeDim(allow_2d=False)]\n", + " device: str = \"cpu\", # \"cuda\" or \"cpu\"\n", + " agent_name: str | None = \"lERMMeta\"\n", + " ):\n", + "\n", + " self.set_meta_dataloader(dataloader, dataloader_params, **dataset_meta_params)\n", + "\n", + " super().__init__(\n", + " environment_info=environment_info,\n", + " dataloader=dataloader,\n", + " cu=cu,\n", + " co=co,\n", + " input_shape=input_shape,\n", + " output_shape=output_shape,\n", + " optimizer_params=optimizer_params,\n", + " learning_rate_scheduler=learning_rate_scheduler,\n", + " model_params=model_params,\n", + " dataloader_params=dataloader_params,\n", + " obsprocessors=obsprocessors,\n", + " torch_obsprocessors=torch_obsprocessors,\n", + " device=device,\n", + " agent_name=agent_name\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "class NewsvendorDLMetaAgent(NewsvendorDLAgent, BaseMetaAgent):\n", + "\n", + " \"\"\"\n", + " Newsvendor agent implementing Empirical Risk Minimization (ERM) approach \n", + " based on a Neural Network. In addition to the features, the agent\n", + " also gets the sl as input to be able to forecast the optimal order quantity\n", + " for different sl values. Depending on the training pipeline, this model can be \n", + " adapted to become a full meta-learning algorithm cross products and cross sls.\n", + "\n", + " \"\"\"\n", + "\n", + " def __init__(self, \n", + " # Parameters for meta Agent\n", + " dataset_meta_params: dict, # Parameters for meta dataloader\n", + "\n", + " environment_info: MDPInfo,\n", + " dataloader: BaseDataLoader,\n", + " cu: np.ndarray | Parameter,\n", + " co: np.ndarray | Parameter,\n", + " input_shape: Tuple,\n", + " output_shape: Tuple,\n", + " learning_rate_scheduler = None, # TODO: add base class for learning rate scheduler for typing\n", + " \n", + " # parameters in yaml file\n", + " optimizer_params: dict | None = None, # default: {\"optimizer\": \"Adam\", \"lr\": 0.01, \"weight_decay\": 0.0}\n", + " model_params: dict | None = None, # default: {\"hidden_layers\": [64, 64], \"drop_prob\": 0.0, \"batch_norm\": False, \"relu_output\": False}\n", + " dataloader_params: dict | None = None, # default: {\"batch_size\": 32, \"shuffle\": True}\n", + " device: str = \"cpu\", # \"cuda\" or \"cpu\"\n", + "\n", + " obsprocessors: list | None = None, # default: []\n", + " torch_obsprocessors: list | None = None, # default: [FlattenTimeDim(allow_2d=False)]\n", + " agent_name: str | None = \"DLNV\",\n", + " ):\n", + "\n", + " self.set_meta_dataloader(dataloader, dataloader_params, **dataset_meta_params)\n", + "\n", + " super().__init__(\n", + " environment_info=environment_info,\n", + " dataloader=dataloader,\n", + " cu=cu,\n", + " co=co,\n", + " input_shape=input_shape,\n", + " output_shape=output_shape,\n", + " learning_rate_scheduler=learning_rate_scheduler,\n", + "\n", + " optimizer_params=optimizer_params,\n", + " model_params=model_params,\n", + " dataloader_params=dataloader_params,\n", + " device=device,\n", + "\n", + " obsprocessors=obsprocessors,\n", + " torch_obsprocessors=torch_obsprocessors,\n", + " agent_name=agent_name\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, { "cell_type": "code", "execution_count": null, diff --git a/nbs/51_RL_agents/10_SAC_agents.ipynb b/nbs/51_RL_agents/10_SAC_agents.ipynb index cacbc7f..cd32a40 100644 --- a/nbs/51_RL_agents/10_SAC_agents.ipynb +++ b/nbs/51_RL_agents/10_SAC_agents.ipynb @@ -65,6 +65,7 @@ "import torch\n", "import torch.nn.functional as F\n", "from torchinfo import summary\n", + "from IPython import get_ipython\n", "\n", "from copy import deepcopy\n", "\n", @@ -211,7 +212,10 @@ " else:\n", " input_tensor = torch.randn(batch_dim, *actor_mu_params[\"input_shape\"]).to(self.device)\n", " input_tuple = (input_tensor,)\n", - " print(summary(self.actor, input_data=input_tuple, device=self.device))\n", + " if get_ipython() is not None:\n", + " print(summary(self.actor, input_data=input_tuple, device=self.device))\n", + " else:\n", + " summary(self.actor, input_data=input_tuple, device=self.device)\n", " time.sleep(0.2)\n", "\n", " logging.info(\"################################################################################\")\n", @@ -228,7 +232,11 @@ " state_mlp_sample = torch.randn(batch_dim, *critic_params[\"input_shape\"][0][1]).to(self.device)\n", " state_sample = torch.cat((state_sample, state_mlp_sample), dim=1)\n", " input_tuple = (state_sample, action_sample)\n", - " print(summary(self.critic, input_data=input_tuple, device=self.device))\n", + " if get_ipython() is not None:\n", + " print(summary(self.critic, input_data=input_tuple, device=self.device))\n", + " else:\n", + " summary(self.critic, input_data=input_tuple, device=self.device)\n", + " # print(summary(self.critic, input_data=input_tuple, device=self.device))\n", "\n", " def get_network_list(self, set_actor_critic_attributes: bool = True):\n", " \"\"\" Get the list of networks in the agent for the save and load functions\n", @@ -252,7 +260,7 @@ " def predict_(self, observation: np.ndarray) -> np.ndarray: #\n", " \"\"\" Do one forward pass of the model directly and return the prediction.\n", " Apply tanh as implemented for the SAC actor in mushroom_rl\"\"\"\n", - "\n", + " \n", " # make observation torch tensor\n", " device = next(self.actor.parameters()).device\n", " observation = torch.tensor(observation, dtype=torch.float32).to(device)\n", @@ -475,8 +483,8 @@ "Params size (MB): 0.02\n", "Estimated Total Size (MB): 0.02\n", "==========================================================================================\n", - "-254.0208357996548 -159.8911204707865\n", - "-254.0208357996548 -159.8911204707865\n" + "-609.9476045297464 -385.24070455687774\n", + "-609.9476045297464 -385.24070455687774\n" ] } ], @@ -788,8 +796,8 @@ "Params size (MB): 0.09\n", "Estimated Total Size (MB): 0.09\n", "==========================================================================================\n", - "-420.9567259533292 -266.3596900995244\n", - "-420.9567259533292 -266.3596900995244\n" + "-309.1563886085103 -195.45124971970722\n", + "-309.1563886085103 -195.45124971970722\n" ] } ], diff --git a/nbs/80_datasets/datasets.ipynb b/nbs/80_datasets/datasets.ipynb index 3f380d7..33e1290 100644 --- a/nbs/80_datasets/datasets.ipynb +++ b/nbs/80_datasets/datasets.ipynb @@ -271,7 +271,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/datasets.py#L124){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/datasets.py#L110){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## DatasetLoader\n", "\n", @@ -282,7 +282,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/datasets.py#L124){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/datasets.py#L110){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## DatasetLoader\n", "\n", @@ -310,32 +310,32 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/datasets.py#L138){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/datasets.py#L128){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DatasetLoader.show_dataset_types\n", "\n", - "> DatasetLoader.show_dataset_types (show_num_datasets_per_type=True)\n", + "> DatasetLoader.show_dataset_types (show_num_datasets_per_type=False)\n", "\n", "*Show an overview of all dataset types available in the repository.*\n", "\n", "| | **Type** | **Default** | **Details** |\n", "| -- | -------- | ----------- | ----------- |\n", - "| show_num_datasets_per_type | bool | True | Whether to show the number of datasets per type |" + "| show_num_datasets_per_type | bool | False | Whether to show the number of datasets per type |" ], "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/datasets.py#L138){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/datasets.py#L128){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DatasetLoader.show_dataset_types\n", "\n", - "> DatasetLoader.show_dataset_types (show_num_datasets_per_type=True)\n", + "> DatasetLoader.show_dataset_types (show_num_datasets_per_type=False)\n", "\n", "*Show an overview of all dataset types available in the repository.*\n", "\n", "| | **Type** | **Default** | **Details** |\n", "| -- | -------- | ----------- | ----------- |\n", - "| show_num_datasets_per_type | bool | True | Whether to show the number of datasets per type |" + "| show_num_datasets_per_type | bool | False | Whether to show the number of datasets per type |" ] }, "execution_count": null, @@ -357,12 +357,13 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/datasets.py#L147){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/datasets.py#L145){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DatasetLoader.load_dataset\n", "\n", "> DatasetLoader.load_dataset (dataset_type:str, dataset_number:int,\n", - "> overwrite:bool=False, version:str='latest')\n", + "> overwrite:bool=False, version:str='latest',\n", + "> token:str=None)\n", "\n", "*Load a dataset from the GitHub repository.*\n", "\n", @@ -371,17 +372,19 @@ "| dataset_type | str | | |\n", "| dataset_number | int | | |\n", "| overwrite | bool | False | Whether to overwrite the dataset if it already exists |\n", - "| version | str | latest | Which version of the dataset to load, \"latest\" or a specific version |" + "| version | str | latest | Which version of the dataset to load, \"latest\" or a specific version, |\n", + "| token | str | None | GitHub token to enable more requests (otherwise limited to 60 requests per hour) |" ], "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/datasets.py#L147){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/datasets.py#L145){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### DatasetLoader.load_dataset\n", "\n", "> DatasetLoader.load_dataset (dataset_type:str, dataset_number:int,\n", - "> overwrite:bool=False, version:str='latest')\n", + "> overwrite:bool=False, version:str='latest',\n", + "> token:str=None)\n", "\n", "*Load a dataset from the GitHub repository.*\n", "\n", @@ -390,7 +393,8 @@ "| dataset_type | str | | |\n", "| dataset_number | int | | |\n", "| overwrite | bool | False | Whether to overwrite the dataset if it already exists |\n", - "| version | str | latest | Which version of the dataset to load, \"latest\" or a specific version |" + "| version | str | latest | Which version of the dataset to load, \"latest\" or a specific version, |\n", + "| token | str | None | GitHub token to enable more requests (otherwise limited to 60 requests per hour) |" ] }, "execution_count": null, @@ -443,16 +447,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:Dataset arma_10_10_dataset_1 has already been downloaded.\n", - "WARNING:root:Keeping existing dataset.\n" - ] - } - ], + "outputs": [], "source": [ "download_test = False\n", "\n",