From 7c3eb8e182a3fc66e3284410ccc73f8a103b3cf2 Mon Sep 17 00:00:00 2001 From: Magnus Maichle Date: Thu, 15 Aug 2024 00:09:07 +0200 Subject: [PATCH 1/8] added device test for predict_ function --- ddopnew/agents/rl/mushroom_rl.py | 3 ++- nbs/51_RL_agents/10_mushroom_base_agent.ipynb | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ddopnew/agents/rl/mushroom_rl.py b/ddopnew/agents/rl/mushroom_rl.py index 6e603cc..c329ab5 100644 --- a/ddopnew/agents/rl/mushroom_rl.py +++ b/ddopnew/agents/rl/mushroom_rl.py @@ -128,7 +128,8 @@ def predict_(self, observation: np.ndarray) -> np.ndarray: # """ Do one forward pass of the model directly and return the prediction Overwrite for agents that have additional steps such as SAC""" - observation = torch.tensor(observation, dtype=torch.float32).to(self.device) + device = next(self.actor.parameters()).device + observation = torch.tensor(observation, dtype=torch.float32).to(device) action = self.actor.forward(observation) action = action.cpu().detach().numpy() diff --git a/nbs/51_RL_agents/10_mushroom_base_agent.ipynb b/nbs/51_RL_agents/10_mushroom_base_agent.ipynb index bbfa877..222d0da 100644 --- a/nbs/51_RL_agents/10_mushroom_base_agent.ipynb +++ b/nbs/51_RL_agents/10_mushroom_base_agent.ipynb @@ -168,7 +168,8 @@ " \"\"\" Do one forward pass of the model directly and return the prediction\n", " Overwrite for agents that have additional steps such as SAC\"\"\"\n", "\n", - " observation = torch.tensor(observation, dtype=torch.float32).to(self.device)\n", + " device = next(self.actor.parameters()).device\n", + " observation = torch.tensor(observation, dtype=torch.float32).to(device)\n", " action = self.actor.forward(observation)\n", " action = action.cpu().detach().numpy()\n", "\n", From 3ad3d8a2b023ff0651048a45513e5c8fde8334c0 Mon Sep 17 00:00:00 2001 From: Magnus Maichle Date: Thu, 15 Aug 2024 00:11:39 +0200 Subject: [PATCH 2/8] added device test for predict_ function --- ddopnew/agents/rl/sac.py | 3 ++- nbs/51_RL_agents/10_SAC_agents.ipynb | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ddopnew/agents/rl/sac.py b/ddopnew/agents/rl/sac.py index 1595c55..fc0d108 100644 --- a/ddopnew/agents/rl/sac.py +++ b/ddopnew/agents/rl/sac.py @@ -189,8 +189,9 @@ def predict_(self, observation: np.ndarray) -> np.ndarray: # Apply tanh as implemented for the SAC actor in mushroom_rl""" # make observation torch tensor + device = next(self.actor.parameters()).device + observation = torch.tensor(observation, dtype=torch.float32).to(device) - observation = torch.tensor(observation, dtype=torch.float32).to(self.device) action = self.actor.forward(observation) # print("a before tanh: ", action) action = torch.tanh(action) diff --git a/nbs/51_RL_agents/10_SAC_agents.ipynb b/nbs/51_RL_agents/10_SAC_agents.ipynb index 21eb78c..8c7bf24 100644 --- a/nbs/51_RL_agents/10_SAC_agents.ipynb +++ b/nbs/51_RL_agents/10_SAC_agents.ipynb @@ -229,8 +229,9 @@ " Apply tanh as implemented for the SAC actor in mushroom_rl\"\"\"\n", "\n", " # make observation torch tensor\n", + " device = next(self.actor.parameters()).device\n", + " observation = torch.tensor(observation, dtype=torch.float32).to(device)\n", "\n", - " observation = torch.tensor(observation, dtype=torch.float32).to(self.device)\n", " action = self.actor.forward(observation)\n", " # print(\"a before tanh: \", action)\n", " action = torch.tanh(action)\n", From 7e048ed66e76cf095e63c4d407b5482c21f37a1f Mon Sep 17 00:00:00 2001 From: Magnus Maichle Date: Thu, 15 Aug 2024 01:03:07 +0200 Subject: [PATCH 3/8] use device for summary --- ddopnew/agents/rl/sac.py | 4 ++-- nbs/51_RL_agents/10_SAC_agents.ipynb | 22 ++++++++++++++-------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/ddopnew/agents/rl/sac.py b/ddopnew/agents/rl/sac.py index fc0d108..1596f0c 100644 --- a/ddopnew/agents/rl/sac.py +++ b/ddopnew/agents/rl/sac.py @@ -156,14 +156,14 @@ def __init__(self, logging.info("Actor network (mu network):") if logging.getLogger().isEnabledFor(logging.INFO): input_size = self.add_batch_dimension_for_shape(actor_mu_params["input_shape"]) - print(summary(self.actor, input_size=input_size)) + print(summary(self.actor, input_size=input_size, device=self.device)) time.sleep(0.2) logging.info("################################################################################") logging.info("Critic network:") if logging.getLogger().isEnabledFor(logging.INFO): input_size = self.add_batch_dimension_for_shape(critic_params["input_shape"]) - print(summary(self.critic, input_size=input_size)) + print(summary(self.critic, input_size=input_size, device=self.device)) def get_network_list(self, set_actor_critic_attributes: bool = True): """ Get the list of networks in the agent for the save and load functions diff --git a/nbs/51_RL_agents/10_SAC_agents.ipynb b/nbs/51_RL_agents/10_SAC_agents.ipynb index 8c7bf24..f8efea6 100644 --- a/nbs/51_RL_agents/10_SAC_agents.ipynb +++ b/nbs/51_RL_agents/10_SAC_agents.ipynb @@ -196,14 +196,14 @@ " logging.info(\"Actor network (mu network):\")\n", " if logging.getLogger().isEnabledFor(logging.INFO):\n", " input_size = self.add_batch_dimension_for_shape(actor_mu_params[\"input_shape\"])\n", - " print(summary(self.actor, input_size=input_size))\n", + " print(summary(self.actor, input_size=input_size, device=self.device))\n", " time.sleep(0.2)\n", "\n", " logging.info(\"################################################################################\")\n", " logging.info(\"Critic network:\")\n", " if logging.getLogger().isEnabledFor(logging.INFO):\n", " input_size = self.add_batch_dimension_for_shape(critic_params[\"input_shape\"])\n", - " print(summary(self.critic, input_size=input_size))\n", + " print(summary(self.critic, input_size=input_size, device=self.device))\n", "\n", " def get_network_list(self, set_actor_critic_attributes: bool = True):\n", " \"\"\" Get the list of networks in the agent for the save and load functions\n", @@ -374,7 +374,13 @@ "text": [ "/Users/magnus/miniforge3/envs/inventory_gym_2/lib/python3.11/site-packages/gymnasium/spaces/box.py:130: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n", " gym.logger.warn(f\"Box bound precision lowered by casting to {self.dtype}\")\n", - "INFO:root:Actor network (mu network):\n", + "INFO:root:Actor network (mu network):\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ "/Users/magnus/miniforge3/envs/inventory_gym_2/lib/python3.11/site-packages/torchinfo/torchinfo.py:462: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", " action_fn=lambda data: sys.getsizeof(data.storage()),\n" ] @@ -445,8 +451,8 @@ "Params size (MB): 0.02\n", "Estimated Total Size (MB): 0.02\n", "==========================================================================================\n", - "-379.58216079994116 -239.97422234974908\n", - "-379.58216079994116 -239.97422234974908\n" + "-452.62149851193686 -285.9303368933814\n", + "-452.62149851193686 -285.9303368933814\n" ] } ], @@ -712,7 +718,7 @@ "==========================================================================================\n", "Layer (type:depth-idx) Output Shape Param #\n", "==========================================================================================\n", - "RNNStateAction [1, 1] --\n", + "RNNStateAction -- --\n", "├─RNNMLPHybrid: 1-1 [1, 1] --\n", "│ └─Sequential: 2-1 [1, 6, 64] --\n", "│ │ └─SpecificRNNWrapper: 3-1 [1, 6, 64] 13,248\n", @@ -736,8 +742,8 @@ "Params size (MB): 0.09\n", "Estimated Total Size (MB): 0.09\n", "==========================================================================================\n", - "-314.40625595864066 -198.89509302210004\n", - "-314.40625595864066 -198.89509302210004\n" + "-427.75928171089686 -270.38093617005046\n", + "-427.75928171089686 -270.38093617005046\n" ] } ], From 9921ff533e92c584af3886f7ff9f1e33c97b2516 Mon Sep 17 00:00:00 2001 From: Magnus Maichle Date: Thu, 15 Aug 2024 01:35:11 +0200 Subject: [PATCH 4/8] bring model to device when instantiated --- ddopnew/_modidx.py | 2 + ddopnew/agents/newsvendor/erm.py | 20 ++- nbs/41_NV_agents/11_NV_erm_agents.ipynb | 208 +++++++++++++----------- 3 files changed, 134 insertions(+), 96 deletions(-) diff --git a/ddopnew/_modidx.py b/ddopnew/_modidx.py index 3f3a3e5..f043281 100644 --- a/ddopnew/_modidx.py +++ b/ddopnew/_modidx.py @@ -180,6 +180,8 @@ 'ddopnew/agents/newsvendor/erm.py'), 'ddopnew.agents.newsvendor.erm.SGDBaseAgent.set_dataloader': ( '41_NV_agents/nv_erm_agents.html#sgdbaseagent.set_dataloader', 'ddopnew/agents/newsvendor/erm.py'), + 'ddopnew.agents.newsvendor.erm.SGDBaseAgent.set_device': ( '41_NV_agents/nv_erm_agents.html#sgdbaseagent.set_device', + 'ddopnew/agents/newsvendor/erm.py'), 'ddopnew.agents.newsvendor.erm.SGDBaseAgent.set_learning_rate_scheduler': ( '41_NV_agents/nv_erm_agents.html#sgdbaseagent.set_learning_rate_scheduler', 'ddopnew/agents/newsvendor/erm.py'), 'ddopnew.agents.newsvendor.erm.SGDBaseAgent.set_loss_function': ( '41_NV_agents/nv_erm_agents.html#sgdbaseagent.set_loss_function', diff --git a/ddopnew/agents/newsvendor/erm.py b/ddopnew/agents/newsvendor/erm.py index 437f7ab..d26a73a 100644 --- a/ddopnew/agents/newsvendor/erm.py +++ b/ddopnew/agents/newsvendor/erm.py @@ -52,7 +52,7 @@ def __init__(self, dataloader_params = dataloader_params or {"batch_size": 32, "shuffle": True} self.torch_obsprocessors = torch_obsprocessors or [] - self.device = device + self.device = self.set_device(device) self.set_dataloader(dataloader, dataloader_params) self.set_model(input_shape, output_shape) @@ -62,6 +62,24 @@ def __init__(self, super().__init__(environment_info = environment_info, obsprocessors = obsprocessors, agent_name = agent_name) + self.to(self.device) + + def set_device(self, device: str): + + """ Set the device for the model """ + + if device == "cuda": + if torch.cuda.is_available(): + return "cuda" + else: + logging.warning("CUDA is not available. Using CPU instead.") + return "cpu" + elif device == "cpu": + return "cpu" + else: + raise ValueError(f"Device {device} not currently not supported, use 'cuda' or 'cpu'") + + def set_dataloader(self, dataloader: BaseDataLoader, dataloader_params: dict, # dict with keys: batch_size, shuffle diff --git a/nbs/41_NV_agents/11_NV_erm_agents.ipynb b/nbs/41_NV_agents/11_NV_erm_agents.ipynb index 372ae5b..cc408d2 100644 --- a/nbs/41_NV_agents/11_NV_erm_agents.ipynb +++ b/nbs/41_NV_agents/11_NV_erm_agents.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -30,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -57,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -92,7 +92,7 @@ " dataloader_params = dataloader_params or {\"batch_size\": 32, \"shuffle\": True}\n", " self.torch_obsprocessors = torch_obsprocessors or []\n", "\n", - " self.device = device\n", + " self.device = self.set_device(device)\n", " \n", " self.set_dataloader(dataloader, dataloader_params)\n", " self.set_model(input_shape, output_shape)\n", @@ -102,6 +102,24 @@ "\n", " super().__init__(environment_info = environment_info, obsprocessors = obsprocessors, agent_name = agent_name)\n", "\n", + " self.to(self.device)\n", + "\n", + " def set_device(self, device: str):\n", + "\n", + " \"\"\" Set the device for the model \"\"\"\n", + "\n", + " if device == \"cuda\":\n", + " if torch.cuda.is_available():\n", + " return \"cuda\"\n", + " else:\n", + " logging.warning(\"CUDA is not available. Using CPU instead.\")\n", + " return \"cpu\"\n", + " elif device == \"cpu\":\n", + " return \"cpu\"\n", + " else:\n", + " raise ValueError(f\"Device {device} not currently not supported, use 'cuda' or 'cpu'\")\n", + "\n", + "\n", " def set_dataloader(self,\n", " dataloader: BaseDataLoader,\n", " dataloader_params: dict, # dict with keys: batch_size, shuffle\n", @@ -286,7 +304,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -358,7 +376,7 @@ "| agent_name | str \\| None | None | |" ] }, - "execution_count": null, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -396,7 +414,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -404,7 +422,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L65){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L81){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_dataloader\n", "\n", @@ -423,7 +441,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L65){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L81){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_dataloader\n", "\n", @@ -440,7 +458,7 @@ "| **Returns** | **None** | |" ] }, - "execution_count": null, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -451,7 +469,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -459,7 +477,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L78){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L94){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_loss_function\n", "\n", @@ -470,7 +488,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L78){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L94){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_loss_function\n", "\n", @@ -479,7 +497,7 @@ "*Set loss function for the model*" ] }, - "execution_count": null, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -490,7 +508,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -498,7 +516,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L83){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L99){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_model\n", "\n", @@ -509,7 +527,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L83){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L99){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_model\n", "\n", @@ -518,7 +536,7 @@ "*Set the model for the agent*" ] }, - "execution_count": null, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -529,7 +547,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -537,7 +555,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L87){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L103){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_optimizer\n", "\n", @@ -552,7 +570,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L87){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L103){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_optimizer\n", "\n", @@ -565,7 +583,7 @@ "| optimizer_params | dict | dict with keys: optimizer, lr, weight_decay |" ] }, - "execution_count": null, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -576,7 +594,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -584,7 +602,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L103){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L119){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_learning_rate_scheduler\n", "\n", @@ -601,7 +619,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L103){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L119){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.set_learning_rate_scheduler\n", "\n", @@ -616,7 +634,7 @@ "| learning_rate_scheduler | None | None | |" ] }, - "execution_count": null, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -627,7 +645,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -635,7 +653,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L110){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L126){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.fit_epoch\n", "\n", @@ -646,7 +664,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L110){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L126){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.fit_epoch\n", "\n", @@ -655,7 +673,7 @@ "*Fit the model for one epoch using the dataloader*" ] }, - "execution_count": null, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -666,7 +684,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -674,7 +692,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L149){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L165){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.draw_action_\n", "\n", @@ -690,7 +708,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L149){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L165){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.draw_action_\n", "\n", @@ -704,7 +722,7 @@ "| **Returns** | **ndarray** | |" ] }, - "execution_count": null, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -715,7 +733,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -723,7 +741,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L159){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L175){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.predict\n", "\n", @@ -739,7 +757,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L159){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L175){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.predict\n", "\n", @@ -753,7 +771,7 @@ "| **Returns** | **ndarray** | |" ] }, - "execution_count": null, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -764,7 +782,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -772,7 +790,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L181){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L197){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.train\n", "\n", @@ -783,7 +801,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L181){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L197){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.train\n", "\n", @@ -792,7 +810,7 @@ "*set the internal state of the agent and its model to train*" ] }, - "execution_count": null, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -803,7 +821,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -811,7 +829,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L186){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.eval\n", "\n", @@ -822,7 +840,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L186){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L202){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.eval\n", "\n", @@ -831,7 +849,7 @@ "*set the internal state of the agent and its model to eval*" ] }, - "execution_count": null, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -842,7 +860,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -850,7 +868,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L191){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L207){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.to\n", "\n", @@ -865,7 +883,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L191){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L207){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.to\n", "\n", @@ -878,7 +896,7 @@ "| device | str | |" ] }, - "execution_count": null, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -889,7 +907,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -897,7 +915,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L195){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L211){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.save\n", "\n", @@ -913,7 +931,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L195){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L211){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.save\n", "\n", @@ -927,7 +945,7 @@ "| overwrite | bool | True | Allow overwriting; if False, a FileExistsError will be raised if the file exists. |" ] }, - "execution_count": null, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -938,7 +956,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -946,7 +964,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L223){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L239){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.load\n", "\n", @@ -961,7 +979,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L223){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L239){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### SGDBaseAgent.load\n", "\n", @@ -974,7 +992,7 @@ "| path | str | Only the path to the folder is needed, not the file itself |" ] }, - "execution_count": null, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -985,7 +1003,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -1049,7 +1067,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -1057,7 +1075,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L247){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L263){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## NVBaseAgent\n", "\n", @@ -1095,7 +1113,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L247){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L263){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## NVBaseAgent\n", "\n", @@ -1131,7 +1149,7 @@ "| agent_name | str \\| None | None | |" ] }, - "execution_count": null, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -1142,7 +1160,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -1150,7 +1168,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L291){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L307){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### NVBaseAgent.set_loss_function\n", "\n", @@ -1163,7 +1181,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L291){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L307){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### NVBaseAgent.set_loss_function\n", "\n", @@ -1174,7 +1192,7 @@ "co values to ensure similar scale of the feedback signal during training.*" ] }, - "execution_count": null, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -1185,7 +1203,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -1257,7 +1275,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -1265,7 +1283,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L303){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L319){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## NewsvendorlERMAgent\n", "\n", @@ -1306,7 +1324,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L303){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L319){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## NewsvendorlERMAgent\n", "\n", @@ -1345,7 +1363,7 @@ "| agent_name | str \\| None | lERM | |" ] }, - "execution_count": null, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -1369,7 +1387,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -1377,7 +1395,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L354){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L370){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### NewsvendorlERMAgent.set_model\n", "\n", @@ -1388,7 +1406,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L354){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L370){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### NewsvendorlERMAgent.set_model\n", "\n", @@ -1397,7 +1415,7 @@ "*Set the model for the agent to a linear model*" ] }, - "execution_count": null, + "execution_count": 24, "metadata": {}, "output_type": "execute_result" } @@ -1415,28 +1433,28 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "-19.642909679086138 -18.738758522299577\n" + "-18.726206754896214 -17.81568786174066\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 2/2 [00:00<00:00, 45.80it/s]" + "100%|██████████| 2/2 [00:00<00:00, 53.21it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "-17.47671092500587 -16.66984812958351\n" + "-14.977437276938298 -14.26477751347144\n" ] }, { @@ -1501,7 +1519,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -1577,7 +1595,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -1585,7 +1603,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L367){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L383){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## NewsvendorDLAgent\n", "\n", @@ -1625,7 +1643,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L367){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L383){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "## NewsvendorDLAgent\n", "\n", @@ -1663,7 +1681,7 @@ "| agent_name | str \\| None | DLNV | |" ] }, - "execution_count": null, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -1687,7 +1705,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -1695,7 +1713,7 @@ "text/markdown": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L422){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L438){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### NewsvendorDLAgent.set_model\n", "\n", @@ -1706,7 +1724,7 @@ "text/plain": [ "---\n", "\n", - "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L422){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", + "[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/newsvendor/erm.py#L438){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", "\n", "### NewsvendorDLAgent.set_model\n", "\n", @@ -1715,7 +1733,7 @@ "*Set the model for the agent to an MLP*" ] }, - "execution_count": null, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -1733,28 +1751,28 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "-18.410649699581896 -17.569069829704183\n" + "-16.116648853786373 -15.333838859835726\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 2/2 [00:00<00:00, 34.24it/s]" + "100%|██████████| 2/2 [00:00<00:00, 45.15it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "-15.268177127606906 -14.56018896878008\n" + "-14.678779590466775 -13.980560189243262\n" ] }, { @@ -1813,7 +1831,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ From 96adccec09d57054d8c5ee264a8aff1b95a891f8 Mon Sep 17 00:00:00 2001 From: Magnus Maichle Date: Thu, 15 Aug 2024 01:37:26 +0200 Subject: [PATCH 5/8] put observation to device when predict_ --- ddopnew/agents/rl/sac.py | 3 - nbs/41_NV_agents/11_NV_erm_agents.ipynb | 100 ++++++------- nbs/51_RL_agents/10_SAC_agents.ipynb | 181 +----------------------- 3 files changed, 52 insertions(+), 232 deletions(-) diff --git a/ddopnew/agents/rl/sac.py b/ddopnew/agents/rl/sac.py index 1596f0c..629d559 100644 --- a/ddopnew/agents/rl/sac.py +++ b/ddopnew/agents/rl/sac.py @@ -193,11 +193,8 @@ def predict_(self, observation: np.ndarray) -> np.ndarray: # observation = torch.tensor(observation, dtype=torch.float32).to(device) action = self.actor.forward(observation) - # print("a before tanh: ", action) action = torch.tanh(action) - # print("a after tanh: ", action) action = action * self.agent.policy._delta_a + self.agent.policy._central_a - # print("a after scaling: ", action) action = action.cpu().detach().numpy() return action diff --git a/nbs/41_NV_agents/11_NV_erm_agents.ipynb b/nbs/41_NV_agents/11_NV_erm_agents.ipynb index cc408d2..8739168 100644 --- a/nbs/41_NV_agents/11_NV_erm_agents.ipynb +++ b/nbs/41_NV_agents/11_NV_erm_agents.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -30,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -57,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -304,7 +304,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -376,7 +376,7 @@ "| agent_name | str \\| None | None | |" ] }, - "execution_count": 5, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -414,7 +414,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -458,7 +458,7 @@ "| **Returns** | **None** | |" ] }, - "execution_count": 6, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -469,7 +469,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -497,7 +497,7 @@ "*Set loss function for the model*" ] }, - "execution_count": 7, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -508,7 +508,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -536,7 +536,7 @@ "*Set the model for the agent*" ] }, - "execution_count": 8, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -547,7 +547,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -583,7 +583,7 @@ "| optimizer_params | dict | dict with keys: optimizer, lr, weight_decay |" ] }, - "execution_count": 9, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -594,7 +594,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -634,7 +634,7 @@ "| learning_rate_scheduler | None | None | |" ] }, - "execution_count": 10, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -645,7 +645,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -673,7 +673,7 @@ "*Fit the model for one epoch using the dataloader*" ] }, - "execution_count": 11, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -684,7 +684,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -722,7 +722,7 @@ "| **Returns** | **ndarray** | |" ] }, - "execution_count": 12, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -733,7 +733,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -771,7 +771,7 @@ "| **Returns** | **ndarray** | |" ] }, - "execution_count": 13, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -782,7 +782,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -810,7 +810,7 @@ "*set the internal state of the agent and its model to train*" ] }, - "execution_count": 14, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -821,7 +821,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -849,7 +849,7 @@ "*set the internal state of the agent and its model to eval*" ] }, - "execution_count": 15, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -860,7 +860,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -896,7 +896,7 @@ "| device | str | |" ] }, - "execution_count": 16, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -907,7 +907,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -945,7 +945,7 @@ "| overwrite | bool | True | Allow overwriting; if False, a FileExistsError will be raised if the file exists. |" ] }, - "execution_count": 17, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -956,7 +956,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -992,7 +992,7 @@ "| path | str | Only the path to the folder is needed, not the file itself |" ] }, - "execution_count": 18, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -1003,7 +1003,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1067,7 +1067,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1149,7 +1149,7 @@ "| agent_name | str \\| None | None | |" ] }, - "execution_count": 20, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -1160,7 +1160,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1192,7 +1192,7 @@ "co values to ensure similar scale of the feedback signal during training.*" ] }, - "execution_count": 21, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -1203,7 +1203,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1275,7 +1275,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1363,7 +1363,7 @@ "| agent_name | str \\| None | lERM | |" ] }, - "execution_count": 23, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -1387,7 +1387,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1415,7 +1415,7 @@ "*Set the model for the agent to a linear model*" ] }, - "execution_count": 24, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -1433,7 +1433,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1519,7 +1519,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1595,7 +1595,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1681,7 +1681,7 @@ "| agent_name | str \\| None | DLNV | |" ] }, - "execution_count": 27, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -1705,7 +1705,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1733,7 +1733,7 @@ "*Set the model for the agent to an MLP*" ] }, - "execution_count": 28, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -1751,7 +1751,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1831,7 +1831,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/nbs/51_RL_agents/10_SAC_agents.ipynb b/nbs/51_RL_agents/10_SAC_agents.ipynb index f8efea6..1d619ba 100644 --- a/nbs/51_RL_agents/10_SAC_agents.ipynb +++ b/nbs/51_RL_agents/10_SAC_agents.ipynb @@ -233,11 +233,8 @@ " observation = torch.tensor(observation, dtype=torch.float32).to(device)\n", "\n", " action = self.actor.forward(observation)\n", - " # print(\"a before tanh: \", action)\n", " action = torch.tanh(action)\n", - " # print(\"a after tanh: \", action)\n", " action = action * self.agent.policy._delta_a + self.agent.policy._central_a\n", - " # print(\"a after scaling: \", action)\n", " action = action.cpu().detach().numpy()\n", "\n", " return action" @@ -367,95 +364,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/magnus/miniforge3/envs/inventory_gym_2/lib/python3.11/site-packages/gymnasium/spaces/box.py:130: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n", - " gym.logger.warn(f\"Box bound precision lowered by casting to {self.dtype}\")\n", - "INFO:root:Actor network (mu network):\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/magnus/miniforge3/envs/inventory_gym_2/lib/python3.11/site-packages/torchinfo/torchinfo.py:462: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", - " action_fn=lambda data: sys.getsizeof(data.storage()),\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "MLPActor [1, 1] --\n", - "├─Sequential: 1-1 [1, 1] --\n", - "│ └─Linear: 2-1 [1, 64] 192\n", - "│ └─ReLU: 2-2 [1, 64] --\n", - "│ └─Dropout: 2-3 [1, 64] --\n", - "│ └─Linear: 2-4 [1, 64] 4,160\n", - "│ └─ReLU: 2-5 [1, 64] --\n", - "│ └─Dropout: 2-6 [1, 64] --\n", - "│ └─Linear: 2-7 [1, 1] 65\n", - "│ └─Identity: 2-8 [1, 1] --\n", - "==========================================================================================\n", - "Total params: 4,417\n", - "Trainable params: 4,417\n", - "Non-trainable params: 0\n", - "Total mult-adds (M): 0.00\n", - "==========================================================================================\n", - "Input size (MB): 0.00\n", - "Forward/backward pass size (MB): 0.00\n", - "Params size (MB): 0.02\n", - "Estimated Total Size (MB): 0.02\n", - "==========================================================================================\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:root:################################################################################\n", - "INFO:root:Critic network:\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "MLPStateAction -- --\n", - "├─Sequential: 1-1 [1, 1] --\n", - "│ └─Linear: 2-1 [1, 64] 256\n", - "│ └─ReLU: 2-2 [1, 64] --\n", - "│ └─Dropout: 2-3 [1, 64] --\n", - "│ └─Linear: 2-4 [1, 64] 4,160\n", - "│ └─ReLU: 2-5 [1, 64] --\n", - "│ └─Dropout: 2-6 [1, 64] --\n", - "│ └─Linear: 2-7 [1, 1] 65\n", - "│ └─Identity: 2-8 [1, 1] --\n", - "==========================================================================================\n", - "Total params: 4,481\n", - "Trainable params: 4,481\n", - "Non-trainable params: 0\n", - "Total mult-adds (M): 0.00\n", - "==========================================================================================\n", - "Input size (MB): 0.00\n", - "Forward/backward pass size (MB): 0.00\n", - "Params size (MB): 0.02\n", - "Estimated Total Size (MB): 0.02\n", - "==========================================================================================\n", - "-452.62149851193686 -285.9303368933814\n", - "-452.62149851193686 -285.9303368933814\n" - ] - } - ], + "outputs": [], "source": [ "from ddopnew.envs.inventory import NewsvendorEnv\n", "from ddopnew.dataloaders.tabular import XYDataLoader\n", @@ -660,93 +569,7 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/magnus/miniforge3/envs/inventory_gym_2/lib/python3.11/site-packages/gymnasium/spaces/box.py:130: UserWarning: \u001b[33mWARN: Box bound precision lowered by casting to float32\u001b[0m\n", - " gym.logger.warn(f\"Box bound precision lowered by casting to {self.dtype}\")\n", - "INFO:root:Actor network (mu network):\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "RNNActor [1, 1] --\n", - "├─RNNMLPHybrid: 1-1 [1, 1] --\n", - "│ └─Sequential: 2-1 [1, 6, 64] --\n", - "│ │ └─SpecificRNNWrapper: 3-1 [1, 6, 64] 13,248\n", - "│ │ └─ReLU: 3-2 [1, 6, 64] --\n", - "│ └─Sequential: 2-2 [1, 1] --\n", - "│ │ └─Linear: 3-3 [1, 64] 4,160\n", - "│ │ └─ReLU: 3-4 [1, 64] --\n", - "│ │ └─Dropout: 3-5 [1, 64] --\n", - "│ │ └─Linear: 3-6 [1, 64] 4,160\n", - "│ │ └─ReLU: 3-7 [1, 64] --\n", - "│ │ └─Dropout: 3-8 [1, 64] --\n", - "│ │ └─Linear: 3-9 [1, 1] 65\n", - "==========================================================================================\n", - "Total params: 21,633\n", - "Trainable params: 21,633\n", - "Non-trainable params: 0\n", - "Total mult-adds (M): 0.09\n", - "==========================================================================================\n", - "Input size (MB): 0.00\n", - "Forward/backward pass size (MB): 0.00\n", - "Params size (MB): 0.09\n", - "Estimated Total Size (MB): 0.09\n", - "==========================================================================================\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:root:################################################################################\n", - "INFO:root:Critic network:\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==========================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "==========================================================================================\n", - "RNNStateAction -- --\n", - "├─RNNMLPHybrid: 1-1 [1, 1] --\n", - "│ └─Sequential: 2-1 [1, 6, 64] --\n", - "│ │ └─SpecificRNNWrapper: 3-1 [1, 6, 64] 13,248\n", - "│ │ └─ReLU: 3-2 [1, 6, 64] --\n", - "│ └─Sequential: 2-2 [1, 1] --\n", - "│ │ └─Linear: 3-3 [1, 64] 4,224\n", - "│ │ └─ReLU: 3-4 [1, 64] --\n", - "│ │ └─Dropout: 3-5 [1, 64] --\n", - "│ │ └─Linear: 3-6 [1, 64] 4,160\n", - "│ │ └─ReLU: 3-7 [1, 64] --\n", - "│ │ └─Dropout: 3-8 [1, 64] --\n", - "│ │ └─Linear: 3-9 [1, 1] 65\n", - "==========================================================================================\n", - "Total params: 21,697\n", - "Trainable params: 21,697\n", - "Non-trainable params: 0\n", - "Total mult-adds (M): 0.09\n", - "==========================================================================================\n", - "Input size (MB): 0.00\n", - "Forward/backward pass size (MB): 0.00\n", - "Params size (MB): 0.09\n", - "Estimated Total Size (MB): 0.09\n", - "==========================================================================================\n", - "-427.75928171089686 -270.38093617005046\n", - "-427.75928171089686 -270.38093617005046\n" - ] - } - ], + "outputs": [], "source": [ "from ddopnew.envs.inventory import NewsvendorEnv\n", "from ddopnew.dataloaders.tabular import XYDataLoader\n", From 42ba6e0257ea8f299222a8a64b6dfbe0a6e15dd1 Mon Sep 17 00:00:00 2001 From: Magnus Maichle Date: Thu, 15 Aug 2024 02:01:51 +0200 Subject: [PATCH 6/8] change model train eval directly with agent mode --- ddopnew/agents/rl/mushroom_rl.py | 5 ++++- nbs/51_RL_agents/10_mushroom_base_agent.ipynb | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/ddopnew/agents/rl/mushroom_rl.py b/ddopnew/agents/rl/mushroom_rl.py index c329ab5..edd7ea8 100644 --- a/ddopnew/agents/rl/mushroom_rl.py +++ b/ddopnew/agents/rl/mushroom_rl.py @@ -138,12 +138,15 @@ def predict_(self, observation: np.ndarray) -> np.ndarray: # def train(self): """set the internal state of the agent and its model to train""" self.mode = "train" + for network in self.network_list: + network.train() def eval(self): """set the internal state of the agent and its model to eval""" self.mode = "eval" + for network in self.network_list: + network.eval() - def to(self, device: str): # """Move the model to the specified device""" diff --git a/nbs/51_RL_agents/10_mushroom_base_agent.ipynb b/nbs/51_RL_agents/10_mushroom_base_agent.ipynb index 222d0da..6d29a0d 100644 --- a/nbs/51_RL_agents/10_mushroom_base_agent.ipynb +++ b/nbs/51_RL_agents/10_mushroom_base_agent.ipynb @@ -178,12 +178,15 @@ " def train(self):\n", " \"\"\"set the internal state of the agent and its model to train\"\"\"\n", " self.mode = \"train\"\n", + " for network in self.network_list:\n", + " network.train()\n", "\n", " def eval(self):\n", " \"\"\"set the internal state of the agent and its model to eval\"\"\"\n", " self.mode = \"eval\"\n", + " for network in self.network_list:\n", + " network.eval()\n", " \n", - "\n", " def to(self, device: str): #\n", " \"\"\"Move the model to the specified device\"\"\"\n", "\n", From 27a531edbcb161351cc69652b3fc58ebb6fe3c22 Mon Sep 17 00:00:00 2001 From: Magnus Maichle Date: Thu, 15 Aug 2024 02:15:40 +0200 Subject: [PATCH 7/8] go to train mode again after warmup steps --- ddopnew/experiment_functions.py | 1 + nbs/30_experiment_functions/10_experiment_functions.ipynb | 1 + 2 files changed, 2 insertions(+) diff --git a/ddopnew/experiment_functions.py b/ddopnew/experiment_functions.py index 36eafb1..3dd041d 100644 --- a/ddopnew/experiment_functions.py +++ b/ddopnew/experiment_functions.py @@ -380,6 +380,7 @@ def run_experiment( agent: BaseAgent, for epoch in trange(n_epochs): env.set_return_truncation(False) # For mushroom Core to work, the step function should not return the truncation flag + agent.train() core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit, quiet=True) env.set_return_truncation(True) # Set back to standard gynmasium behavior diff --git a/nbs/30_experiment_functions/10_experiment_functions.ipynb b/nbs/30_experiment_functions/10_experiment_functions.ipynb index ff5b9dc..d7433d3 100644 --- a/nbs/30_experiment_functions/10_experiment_functions.ipynb +++ b/nbs/30_experiment_functions/10_experiment_functions.ipynb @@ -575,6 +575,7 @@ " for epoch in trange(n_epochs):\n", "\n", " env.set_return_truncation(False) # For mushroom Core to work, the step function should not return the truncation flag\n", + " agent.train()\n", " core.learn(n_steps=n_steps, n_steps_per_fit=n_steps_per_fit, quiet=True)\n", " env.set_return_truncation(True) # Set back to standard gynmasium behavior\n", "\n", From 77ae904164ea8f0a9ba10ab7df4c6382a444f469 Mon Sep 17 00:00:00 2001 From: Magnus Maichle Date: Thu, 15 Aug 2024 02:43:27 +0200 Subject: [PATCH 8/8] made data return as dict --- ddopnew/datasets.py | 15 ++-- nbs/80_datasets/datasets.ipynb | 126 ++------------------------------- 2 files changed, 16 insertions(+), 125 deletions(-) diff --git a/ddopnew/datasets.py b/ddopnew/datasets.py index 44b941a..8da4f03 100644 --- a/ddopnew/datasets.py +++ b/ddopnew/datasets.py @@ -14,7 +14,7 @@ import zipfile -# %% ../nbs/80_datasets/datasets.ipynb 7 +# %% ../nbs/80_datasets/datasets.ipynb 6 def get_all_release_tags(token=None): url = "https://api.github.com/repos/opimwue/ddopnew/releases" headers = {'Authorization': f'Bearer {token}'} if token else {} @@ -89,21 +89,24 @@ def unzip_file(zip_file_path, output_dir, delete_zip_file=True): os.remove(zip_file_path) def load_data_from_directory(dir): - data = list() + data = dict() for file in os.listdir(dir): if file.endswith(".csv"): - data.append(pd.read_csv(os.path.join(dir, file))) + key = os.path.splitext(file)[0] + data[key] = pd.read_csv(os.path.join(dir, file)) elif file.endswith(".pkl"): - data.append(pd.read_pickle(os.path.join(dir, file))) + key = os.path.splitext(file)[0] + data[key] = pd.read_pickle(os.path.join(dir, file)) elif file.endswith(".npy"): - data.append(np.load(os.path.join(dir, file))) + key = os.path.splitext(file)[0] + data[key] = np.load(os.path.join(dir, file)) else: raise ValueError(f"File {file} is not a valid file type (csv, pkl, or npy)") return data -# %% ../nbs/80_datasets/datasets.ipynb 9 +# %% ../nbs/80_datasets/datasets.ipynb 8 class DatasetLoader(): """ diff --git a/nbs/80_datasets/datasets.ipynb b/nbs/80_datasets/datasets.ipynb index 0e62fbe..8abdec3 100644 --- a/nbs/80_datasets/datasets.ipynb +++ b/nbs/80_datasets/datasets.ipynb @@ -70,121 +70,6 @@ "## Helper functions to load datasets" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# #| export\n", - "\n", - "# def get_all_release_tags():\n", - "\n", - "# url = f\"https://api.github.com/repos/opimwue/ddopnew/releases\"\n", - "# response = requests.get(url)\n", - " \n", - "# if response.status_code == 200:\n", - "# releases = response.json()\n", - "# tags = [release['tag_name'] for release in releases]\n", - "# return tags\n", - "# else:\n", - "# raise ValueError(f\"Failed to fetch releases: {response.status_code} with message: {response.text}\")\n", - "\n", - "# def get_release_tag(dataset_type, version):\n", - "\n", - "# release_tags = get_all_release_tags()\n", - "# release_tags_filtered = [tag for tag in release_tags if dataset_type in tag]\n", - "\n", - "# if version == \"latest\":\n", - "# release_tags_filtered.sort(key=lambda x: [int(num) if num.isdigit() else num for num in re.findall(r'\\d+|\\D+', x.split('_v')[-1])])\n", - "# release_tag = release_tags_filtered[-1]\n", - "# else:\n", - "# release_tag = f\"{dataset_type}_{version}\"\n", - " \n", - "# return release_tag\n", - " \n", - "# print(f\"Filtered release tags: {release_tags_filtered}\")\n", - "\n", - "\n", - "\n", - "# def get_dataset_url(dataset_type, dataset_number, release_tag):\n", - "# # Define the repository and release tag\n", - "\n", - "# # GitHub API URL for the release\n", - "# api_url = f\"https://api.github.com/repos/opimwue/ddopnew/releases/tags/{release_tag}\"\n", - " \n", - "# # Make the request to the GitHub API\n", - "# response = requests.get(api_url)\n", - "\n", - "# # Check if the request was successful\n", - "# if response.status_code == 200:\n", - "# release_info = response.json()\n", - "# assets = release_info.get(\"assets\", [])\n", - "\n", - "# # get asset where the name contains the f\"{dataset_type}_dataset_{dataset_number}\"\n", - "# assets = [asset for asset in assets if f\"{dataset_type}_dataset_{dataset_number}_\" in asset['name']]\n", - "\n", - "# for asset in assets:\n", - "# logging.debug(f\"Found dataset: {asset['name']}\")\n", - "\n", - "# if len(assets) == 0:\n", - "# raise ValueError(f\"Dataset {dataset_type}_dataset_{dataset_number} not found in release {release_tag}\")\n", - "# elif len(assets) > 1:\n", - "# raise ValueError(f\"Multiple datasets found for {dataset_type}_dataset_{dataset_number} in release {release_tag}\")\n", - "# else:\n", - "# asset = assets[0]\n", - "# return asset['browser_download_url']\n", - "# else:\n", - "# raise ValueError(f\"Failed to fetch release information: {response.status_code}\")\n", - "\n", - "# def get_asset_url(dataset_type, dataset_number, version=\"latest\"):\n", - "\n", - "# release_tag = get_release_tag(dataset_type, version)\n", - "\n", - "# asset_url = get_dataset_url(dataset_type, dataset_number, release_tag)\n", - "\n", - "# return asset_url\n", - "\n", - "\n", - "# def download_file_from_github(url, output_path):\n", - "# response = requests.get(url, stream=True)\n", - "# if response.status_code == 200:\n", - "# with open(output_path, 'wb') as file:\n", - "# for chunk in response.iter_content(chunk_size=1024):\n", - "# if chunk:\n", - "# file.write(chunk)\n", - "# logging.debug(f\"File downloaded successfully: {output_path}\")\n", - "# else:\n", - "# logging.error(f\"Failed to download file: {response.status_code}\")\n", - "\n", - "# def unzip_file(zip_file_path, output_dir, delete_zip_file=True):\n", - "\n", - "# with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:\n", - "# zip_ref.extractall(output_dir)\n", - "\n", - "# if delete_zip_file:\n", - "# os.remove(zip_file_path)\n", - "\n", - "\n", - "# def load_data_from_directory(dir):\n", - "\n", - "# data = list()\n", - "# for file in os.listdir(dir):\n", - "# # check if the file is a csv, pkl, or numpy file\n", - "# # load data into pandas dataframe array with name of the file before the extension\n", - "\n", - "# if file.endswith(\".csv\"):\n", - "# data.append(pd.read_csv(os.path.join(dir, file)))\n", - "# elif file.endswith(\".pkl\"):\n", - "# data.append(pd.read_pickle(os.path.join(dir, file)))\n", - "# elif file.endswith(\".npy\"):\n", - "# data.append(np.load(os.path.join(dir, file)))\n", - "# else:\n", - "# raise ValueError(f\"File {file} is not a valid file type (csv, pkl, or npy)\")\n", - " \n", - "# return data" - ] - }, { "cell_type": "code", "execution_count": null, @@ -266,14 +151,17 @@ " os.remove(zip_file_path)\n", "\n", "def load_data_from_directory(dir):\n", - " data = list()\n", + " data = dict()\n", " for file in os.listdir(dir):\n", " if file.endswith(\".csv\"):\n", - " data.append(pd.read_csv(os.path.join(dir, file)))\n", + " key = os.path.splitext(file)[0]\n", + " data[key] = pd.read_csv(os.path.join(dir, file))\n", " elif file.endswith(\".pkl\"):\n", - " data.append(pd.read_pickle(os.path.join(dir, file)))\n", + " key = os.path.splitext(file)[0]\n", + " data[key] = pd.read_pickle(os.path.join(dir, file))\n", " elif file.endswith(\".npy\"):\n", - " data.append(np.load(os.path.join(dir, file)))\n", + " key = os.path.splitext(file)[0]\n", + " data[key] = np.load(os.path.join(dir, file))\n", " else:\n", " raise ValueError(f\"File {file} is not a valid file type (csv, pkl, or npy)\")\n", " \n",