Skip to content

Commit

Permalink
add_series_dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
elephaint committed Feb 25, 2025
1 parent d4944dc commit 53ffe8a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions nbs/common.base_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -671,12 +671,12 @@
" available_idx = temporal_cols.get_loc('available_mask') \n",
" available_condition = windows[:, :self.input_size, available_idx]\n",
" available_condition = torch.sum(available_condition, axis=(1, -1)) # Sum over time & series dimension\n",
" final_condition = (available_condition > self.data_availability_threshold * self.input_size)\n",
" final_condition = (available_condition > self.data_availability_threshold * self.input_size * self.n_series)\n",
" \n",
" if self.h > 0:\n",
" sample_condition = windows[:, self.input_size:, available_idx]\n",
" sample_condition = torch.sum(sample_condition, axis=(1, -1)) # Sum over time & series dimension\n",
" final_condition = (sample_condition > self.data_availability_threshold * self.h) & (available_condition > self.data_availability_threshold * self.input_size)\n",
" final_condition = (sample_condition > self.data_availability_threshold * self.h * self.n_series) & (available_condition > self.data_availability_threshold * self.input_size * self.n_series)\n",
" \n",
" windows = windows[final_condition]\n",
" \n",
Expand Down
8 changes: 5 additions & 3 deletions neuralforecast/common/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,8 @@ def _create_windows(self, batch, step, w_idxs=None):
available_condition, axis=(1, -1)
) # Sum over time & series dimension
final_condition = (
available_condition > self.data_availability_threshold * self.input_size
available_condition
> self.data_availability_threshold * self.input_size * self.n_series
)

if self.h > 0:
Expand All @@ -689,10 +690,11 @@ def _create_windows(self, batch, step, w_idxs=None):
sample_condition, axis=(1, -1)
) # Sum over time & series dimension
final_condition = (
sample_condition > self.data_availability_threshold * self.h
sample_condition
> self.data_availability_threshold * self.h * self.n_series
) & (
available_condition
> self.data_availability_threshold * self.input_size
> self.data_availability_threshold * self.input_size * self.n_series
)

windows = windows[final_condition]
Expand Down

0 comments on commit 53ffe8a

Please sign in to comment.