Skip to content

Commit

Permalink
allowed in reset_index function for case with no dataloader (#32)
Browse files Browse the repository at this point in the history
majoma7 authored Oct 24, 2024
1 parent d85b2dc commit b17873a
Showing 2 changed files with 16 additions and 8 deletions.
12 changes: 8 additions & 4 deletions ddopai/envs/base.py
Original file line number Diff line number Diff line change
@@ -211,7 +211,8 @@ def get_start_index(self,
return start_index

def reset_index(self,
start_index: Union[int,str]):
start_index: Union[int,str],
) -> bool:

"""
@@ -235,9 +236,12 @@ def reset_index(self,
self.start_index = start_index
else:
raise ValueError("start_index must be an integer or 'random'")

self.max_index = self.dataloader.len_train if self.mode == "train" else self.dataloader.len_val if self.mode == "val" else self.dataloader.len_test
self.max_index -= 1

if self.dataloader.len_train is not None:
self.max_index = self.dataloader.len_train if self.mode == "train" else self.dataloader.len_val if self.mode == "val" else self.dataloader.len_test
self.max_index -= 1
else:
self.max_index = self.start_index+self.mdp_info.horizon
self.max_index_episode = np.minimum(self.max_index, self.start_index+self.mdp_info.horizon)
if self.mode == "test" or self.mode == "val":
self.max_index_episode += 1
12 changes: 8 additions & 4 deletions nbs/20_environments/20_base_env/10_base_env.ipynb
Original file line number Diff line number Diff line change
@@ -253,7 +253,8 @@
" return start_index\n",
"\n",
" def reset_index(self,\n",
" start_index: Union[int,str]):\n",
" start_index: Union[int,str], \n",
" ) -> bool:\n",
"\n",
" \"\"\"\n",
"\n",
@@ -277,9 +278,12 @@
" self.start_index = start_index\n",
" else:\n",
" raise ValueError(\"start_index must be an integer or 'random'\")\n",
"\n",
" self.max_index = self.dataloader.len_train if self.mode == \"train\" else self.dataloader.len_val if self.mode == \"val\" else self.dataloader.len_test\n",
" self.max_index -= 1\n",
" \n",
" if self.dataloader.len_train is not None:\n",
" self.max_index = self.dataloader.len_train if self.mode == \"train\" else self.dataloader.len_val if self.mode == \"val\" else self.dataloader.len_test\n",
" self.max_index -= 1\n",
" else:\n",
" self.max_index = self.start_index+self.mdp_info.horizon\n",
" self.max_index_episode = np.minimum(self.max_index, self.start_index+self.mdp_info.horizon)\n",
" if self.mode == \"test\" or self.mode == \"val\":\n",
" self.max_index_episode += 1\n",

0 comments on commit b17873a

Please sign in to comment.