From d98888e4c0c8e1e6e4fe822a983d66818e89ba14 Mon Sep 17 00:00:00 2001 From: Magnus Maichle <110975541+majoma7@users.noreply.github.com> Date: Thu, 26 Sep 2024 18:02:00 +0200 Subject: [PATCH] Minor fixes (#20) * fixed error in selection of time_sku features * fixed error in force save agent * fixed parameter name for dataset * accelerated getitem method * minor fix in predict function * minor fix in predict function * enables out-of-sample experiments * print steps while testing * fixed error in normalization for oos products * fixed error in normalizing oos demand * fixed error in normalization of oos SKUs * fixed error in normalization of oos SKUs * added option for run_experiment to return score * fixed average calc tune! * implemented xgb * renaming --- ddopnew/agents/newsvendor/erm.py | 2 +- ddopnew/dataloaders/tabular.py | 14 ++++++++++ nbs/00_utils/00_utils.ipynb | 4 +++ nbs/00_utils/01_loss_functions.ipynb | 4 +++ nbs/00_utils/10_postprocessors.ipynb | 4 +++ nbs/00_utils/11_obsprocessors.ipynb | 4 +++ nbs/00_utils/20_torch_loss_functions.ipynb | 4 +++ .../12_tabular_dataloaders.ipynb | 14 ++++++++++ .../00_inventory_utils.ipynb | 5 ++++ .../20_single_period_envs.ipynb | 6 ++++- .../30_multi_period_envs.ipynb | 4 +++ .../10_experiment_functions.ipynb | 9 ++++++- .../20_meta_experiment_functions.ipynb | 18 +++---------- nbs/41_NV_agents/10_NV_saa_agents.ipynb | 27 ++++++++++++++++++- nbs/41_NV_agents/11_NV_erm_agents.ipynb | 2 +- 15 files changed, 102 insertions(+), 19 deletions(-) diff --git a/ddopnew/agents/newsvendor/erm.py b/ddopnew/agents/newsvendor/erm.py index 3d92d82..bbdbdf7 100644 --- a/ddopnew/agents/newsvendor/erm.py +++ b/ddopnew/agents/newsvendor/erm.py @@ -219,7 +219,7 @@ def __init__(self, test_batch_size: int = 1024, receive_batch_dim: bool = False, ): - + # Initialize default values for mutable arguments optimizer_params = optimizer_params or {"optimizer": "Adam", "lr": 0.01, "weight_decay": 0.0} dataloader_params = dataloader_params or {"batch_size": 32, "shuffle": True} diff --git a/ddopnew/dataloaders/tabular.py b/ddopnew/dataloaders/tabular.py index 58c7b6a..ca93beb 100644 --- a/ddopnew/dataloaders/tabular.py +++ b/ddopnew/dataloaders/tabular.py @@ -307,6 +307,7 @@ def __init__(self, demand_normalization: Literal['minmax', 'standard', 'no_normalization'] = 'no_normalization', # 'standard' or 'minmax' demand_unit_size: float | None = None, # use same convention as for other dataloaders and enviornments, but here only full decimal values are allowed provide_additional_target: bool = False, # follows ICL convention by providing actual demand to token, with the last token receiving 0 + permutate_inputs: bool = False, # if the inputs shall be permutated during training for meta-learning ): logging.info("Setting main env attributes") @@ -317,6 +318,7 @@ def __init__(self, self.time_features = time_features self.time_SKU_features = time_SKU_features self.mask = mask + self.permutate_inputs = permutate_inputs # convert dtypes to float self.demand = self.demand.astype(float) @@ -1115,6 +1117,18 @@ def __getitem__(self, idx: int): item[:,:,-extra_info:,:] = additional_info + if self.dataset_type == "train": + if self.permutate_inputs: + start_index_to_permutate = len_SKU_features + end_index_to_permutate = item.shape[2] + if self.provide_additional_target: + end_index_to_permutate -= 1 # target shall always be at the end + indices_for_permutation = np.arange(start_index_to_permutate, end_index_to_permutate) + print("original order:", indices_for_permutation) + indices = np.random.permutation(indices_for_permutation) + item[:,:,(start_index_to_permutate):end_index_to_permutate,:] = item[:,:,(indices),:] + print("new order:", indices) + if self.meta_learn_units: if self.dataset_type == "train": if item.shape[-1] == 1: diff --git a/nbs/00_utils/00_utils.ipynb b/nbs/00_utils/00_utils.ipynb index 15d31c5..27baf0b 100644 --- a/nbs/00_utils/00_utils.ipynb +++ b/nbs/00_utils/00_utils.ipynb @@ -1239,6 +1239,10 @@ "display_name": "python3", "language": "python", "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" } }, "nbformat": 4, diff --git a/nbs/00_utils/01_loss_functions.ipynb b/nbs/00_utils/01_loss_functions.ipynb index f84cc0d..00dcdf2 100644 --- a/nbs/00_utils/01_loss_functions.ipynb +++ b/nbs/00_utils/01_loss_functions.ipynb @@ -266,6 +266,10 @@ "display_name": "python3", "language": "python", "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" } }, "nbformat": 4, diff --git a/nbs/00_utils/10_postprocessors.ipynb b/nbs/00_utils/10_postprocessors.ipynb index 045a10a..2c414b8 100644 --- a/nbs/00_utils/10_postprocessors.ipynb +++ b/nbs/00_utils/10_postprocessors.ipynb @@ -642,6 +642,10 @@ "display_name": "python3", "language": "python", "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" } }, "nbformat": 4, diff --git a/nbs/00_utils/11_obsprocessors.ipynb b/nbs/00_utils/11_obsprocessors.ipynb index 6364679..c1993d6 100644 --- a/nbs/00_utils/11_obsprocessors.ipynb +++ b/nbs/00_utils/11_obsprocessors.ipynb @@ -643,6 +643,10 @@ "display_name": "python3", "language": "python", "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" } }, "nbformat": 4, diff --git a/nbs/00_utils/20_torch_loss_functions.ipynb b/nbs/00_utils/20_torch_loss_functions.ipynb index 4f1cf14..8b0b7c4 100644 --- a/nbs/00_utils/20_torch_loss_functions.ipynb +++ b/nbs/00_utils/20_torch_loss_functions.ipynb @@ -534,6 +534,10 @@ "display_name": "python3", "language": "python", "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" } }, "nbformat": 4, diff --git a/nbs/10_dataloaders/12_tabular_dataloaders.ipynb b/nbs/10_dataloaders/12_tabular_dataloaders.ipynb index 189be27..f10491d 100644 --- a/nbs/10_dataloaders/12_tabular_dataloaders.ipynb +++ b/nbs/10_dataloaders/12_tabular_dataloaders.ipynb @@ -896,6 +896,7 @@ " demand_normalization: Literal['minmax', 'standard', 'no_normalization'] = 'no_normalization', # 'standard' or 'minmax'\n", " demand_unit_size: float | None = None, # use same convention as for other dataloaders and enviornments, but here only full decimal values are allowed\n", " provide_additional_target: bool = False, # follows ICL convention by providing actual demand to token, with the last token receiving 0\n", + " permutate_inputs: bool = False, # if the inputs shall be permutated during training for meta-learning\n", " ):\n", " \n", " logging.info(\"Setting main env attributes\")\n", @@ -906,6 +907,7 @@ " self.time_features = time_features\n", " self.time_SKU_features = time_SKU_features\n", " self.mask = mask\n", + " self.permutate_inputs = permutate_inputs\n", "\n", " # convert dtypes to float\n", " self.demand = self.demand.astype(float)\n", @@ -1704,6 +1706,18 @@ "\n", " item[:,:,-extra_info:,:] = additional_info\n", "\n", + " if self.dataset_type == \"train\":\n", + " if self.permutate_inputs:\n", + " start_index_to_permutate = len_SKU_features\n", + " end_index_to_permutate = item.shape[2]\n", + " if self.provide_additional_target:\n", + " end_index_to_permutate -= 1 # target shall always be at the end\n", + " indices_for_permutation = np.arange(start_index_to_permutate, end_index_to_permutate)\n", + " print(\"original order:\", indices_for_permutation)\n", + " indices = np.random.permutation(indices_for_permutation)\n", + " item[:,:,(start_index_to_permutate):end_index_to_permutate,:] = item[:,:,(indices),:]\n", + " print(\"new order:\", indices)\n", + "\n", " if self.meta_learn_units:\n", " if self.dataset_type == \"train\":\n", " if item.shape[-1] == 1:\n", diff --git a/nbs/21_envs_inventory/00_inventory_utils.ipynb b/nbs/21_envs_inventory/00_inventory_utils.ipynb index 2607432..7ea2c6e 100644 --- a/nbs/21_envs_inventory/00_inventory_utils.ipynb +++ b/nbs/21_envs_inventory/00_inventory_utils.ipynb @@ -63,6 +63,7 @@ "source": [ "#| export\n", "class OrderPipeline():\n", + " \n", " \"\"\"\n", " Class to handle the order pipeline in the inventory environments. It is used to keep track of the orders\n", " that are placed. It can account for fixed and variable lead times.\n", @@ -537,6 +538,10 @@ "display_name": "python3", "language": "python", "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" } }, "nbformat": 4, diff --git a/nbs/21_envs_inventory/20_single_period_envs.ipynb b/nbs/21_envs_inventory/20_single_period_envs.ipynb index 811f55a..d9dcd72 100644 --- a/nbs/21_envs_inventory/20_single_period_envs.ipynb +++ b/nbs/21_envs_inventory/20_single_period_envs.ipynb @@ -678,7 +678,7 @@ "\n", " # print(\"env mode:\", self.mode)\n", " # print(\"dataloader mode:\", self.dataloader.dataset_type)\n", - " \n", + "\n", " X_item, Y_item = self.dataloader[self.index]\n", "\n", " # check if any value in X_item is nan.\n", @@ -789,6 +789,10 @@ "display_name": "python3", "language": "python", "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" } }, "nbformat": 4, diff --git a/nbs/21_envs_inventory/30_multi_period_envs.ipynb b/nbs/21_envs_inventory/30_multi_period_envs.ipynb index e9e7ff7..76077ab 100644 --- a/nbs/21_envs_inventory/30_multi_period_envs.ipynb +++ b/nbs/21_envs_inventory/30_multi_period_envs.ipynb @@ -534,6 +534,10 @@ "display_name": "python3", "language": "python", "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" } }, "nbformat": 4, diff --git a/nbs/30_experiment_functions/10_experiment_functions.ipynb b/nbs/30_experiment_functions/10_experiment_functions.ipynb index 0cc54a8..726727d 100644 --- a/nbs/30_experiment_functions/10_experiment_functions.ipynb +++ b/nbs/30_experiment_functions/10_experiment_functions.ipynb @@ -947,7 +947,10 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "\n", + "\n" + ] }, { "cell_type": "code", @@ -983,6 +986,10 @@ "display_name": "python3", "language": "python", "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" } }, "nbformat": 4, diff --git a/nbs/30_experiment_functions/20_meta_experiment_functions.ipynb b/nbs/30_experiment_functions/20_meta_experiment_functions.ipynb index bc626ee..2ff6ecc 100644 --- a/nbs/30_experiment_functions/20_meta_experiment_functions.ipynb +++ b/nbs/30_experiment_functions/20_meta_experiment_functions.ipynb @@ -555,20 +555,6 @@ " return None, None" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "metadata": {}, @@ -706,6 +692,10 @@ "display_name": "python3", "language": "python", "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" } }, "nbformat": 4, diff --git a/nbs/41_NV_agents/10_NV_saa_agents.ipynb b/nbs/41_NV_agents/10_NV_saa_agents.ipynb index 902ae07..7b19bab 100644 --- a/nbs/41_NV_agents/10_NV_saa_agents.ipynb +++ b/nbs/41_NV_agents/10_NV_saa_agents.ipynb @@ -50,7 +50,7 @@ "from ddopnew.obsprocessors import FlattenTimeDimNumpy\n", "\n", "from sklearn.ensemble import RandomForestRegressor\n", - "from sklearn.utils.validation import check_array\n" + "from sklearn.utils.validation import check_array" ] }, { @@ -1654,6 +1654,27 @@ "print(R, J)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -1670,6 +1691,10 @@ "display_name": "python3", "language": "python", "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.0" } }, "nbformat": 4, diff --git a/nbs/41_NV_agents/11_NV_erm_agents.ipynb b/nbs/41_NV_agents/11_NV_erm_agents.ipynb index bc6ab99..5d6cb60 100644 --- a/nbs/41_NV_agents/11_NV_erm_agents.ipynb +++ b/nbs/41_NV_agents/11_NV_erm_agents.ipynb @@ -261,7 +261,7 @@ " test_batch_size: int = 1024,\n", " receive_batch_dim: bool = False,\n", " ):\n", - "\n", + " \n", " # Initialize default values for mutable arguments\n", " optimizer_params = optimizer_params or {\"optimizer\": \"Adam\", \"lr\": 0.01, \"weight_decay\": 0.0}\n", " dataloader_params = dataloader_params or {\"batch_size\": 32, \"shuffle\": True}\n",