Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Method TSDataset.tsdataset_idx_slice loses hierarchical structure #618

Merged
merged 6 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **Breaking:** Bump minimum `optuna` version to 4.0 ([#599](https://github.com/etna-team/etna/pull/599))
- **Breaking:** Bump minimum `statsforecast` version to 2.0 ([#599](https://github.com/etna-team/etna/pull/599))
- Optimize performance of exogenous variables addition to the dataset ([#596](https://github.com/etna-team/etna/pull/596))
-
-

### Fixed
- Fix possibility of silent handling of duplicate features when updating dataset with `TSDataset.update_columns_from_pandas` ([#522](https://github.com/etna-team/etna/pull/552))
Expand All @@ -59,6 +59,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **Breaking:** rename `DaylySeasonalitySSM` to `DailySeasonalitySSM` ([#615](https://github.com/etna-team/etna/pull/615))
- Fix `TSDataset.train_test_split` to pass all features to train and test parts ([#545](https://github.com/etna-team/etna/pull/545))
- Fix `ConfigSampler` to handle trials without hash ([#616](https://github.com/etna-team/etna/pull/616))
- Fix method `TSDataset.tsdataset_idx_slice` to not lose hierarchical structure ([#618](https://github.com/etna-team/etna/pull/618))
-

### Removed
- **Breaking:** Remove `FutureMixin`, `OutliersTransform.outliers_timestamps` and `OutliersTransform.original_values` ([#577](https://github.com/etna-team/etna/pull/577))
Expand Down
64 changes: 27 additions & 37 deletions etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from etna.datasets.utils import DataFrameFormat
from etna.datasets.utils import _check_features_in_segments
from etna.datasets.utils import _check_timestamp_param
from etna.datasets.utils import _slice_index_wide_dataframe
from etna.datasets.utils import _TorchDataset
from etna.datasets.utils import apply_alignment
from etna.datasets.utils import get_level_dataframe
Expand Down Expand Up @@ -520,16 +521,25 @@ def tsdataset_idx_slice(self, start_idx: Optional[int] = None, end_idx: Optional
:
TSDataset based on indexing slice.
"""
df_slice = self.df.iloc[start_idx:end_idx].copy(deep=True)
tsdataset_slice = TSDataset(df=df_slice, freq=self.freq)
# can't put known_future into constructor, _check_known_future fails with df_exog=None
tsdataset_slice.known_future = deepcopy(self.known_future)
tsdataset_slice._regressors = deepcopy(self.regressors)
if self.df_exog is not None:
tsdataset_slice.df_exog = self.df_exog.copy(deep=True)
tsdataset_slice._target_components_names = deepcopy(self._target_components_names)
tsdataset_slice._prediction_intervals_names = deepcopy(self._prediction_intervals_names)
return tsdataset_slice
self_df = self.df
self_raw_df = self.raw_df

try:
# we do this to avoid redundant copying of data
self.df = None
self.raw_df = None

ts_slice = deepcopy(self)
ts_slice.df = _slice_index_wide_dataframe(df=self_df, start=start_idx, stop=end_idx, label_indexing=False)
ts_slice.raw_df = _slice_index_wide_dataframe(
df=self_raw_df, start=start_idx, stop=end_idx, label_indexing=False
)

finally:
self.df = self_df
self.raw_df = self_raw_df

return ts_slice

@staticmethod
def _check_known_future(
Expand Down Expand Up @@ -1260,36 +1270,16 @@ def train_test_split(
# we do this to avoid redundant copying of data
self.df = None
self.raw_df = None
train = deepcopy(self)

# we want to make sure it makes only one copy
train_df = self_df.loc[train_start_defined:train_end_defined]
if train_df._is_view or train_df._is_copy is not None:
train.df = train_df.copy()
else:
train.df = train_df

# we want to make sure it makes only one copy
train_raw_df = self_raw_df.loc[train_start_defined:train_end_defined]
if train_raw_df._is_view or train_raw_df._is_copy is not None:
train.raw_df = train_raw_df.copy()
else:
train.raw_df = train_raw_df
train = deepcopy(self)
train.df = _slice_index_wide_dataframe(df=self_df, start=train_start_defined, stop=train_end_defined)
train.raw_df = _slice_index_wide_dataframe(
df=self_raw_df, start=train_start_defined, stop=train_end_defined
)

# we want to make sure it makes only one copy
test = deepcopy(self)
test_df = self_df.loc[test_start_defined:test_end_defined]
if test_df._is_view or test_df._is_copy is not None:
test.df = test_df.copy()
else:
test.df = test_df

# we want to make sure it makes only one copy
test_raw_df = self_raw_df.loc[train_start_defined:test_end_defined]
if test_raw_df._is_view or test_raw_df._is_copy is not None:
test.raw_df = test_raw_df.copy()
else:
test.raw_df = test_raw_df
test.df = _slice_index_wide_dataframe(df=self_df, start=test_start_defined, stop=test_end_defined)
test.raw_df = _slice_index_wide_dataframe(df=self_raw_df, start=train_start_defined, stop=test_end_defined)

finally:
self.df = self_df
Expand Down
17 changes: 17 additions & 0 deletions etna/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,3 +756,20 @@ def _check_features_in_segments(columns: pd.MultiIndex, segments: Optional[List[
raise ValueError(
f"There is a mismatch in feature sets between segments '{compare_segment}' and '{segment}'!"
)


def _slice_index_wide_dataframe(
df: pd.DataFrame,
start: Optional[Union[int, str, pd.Timestamp]] = None,
stop: Optional[Union[int, str, pd.Timestamp]] = None,
label_indexing: bool = True,
) -> pd.DataFrame:
"""Slice index of the dataframe in the wide format with copy."""
indexer = df.loc if label_indexing else df.iloc

# we want to make sure it makes only one copy
df = indexer[start:stop] # type: ignore
if df._is_view or df._is_copy is not None:
df = df.copy(deep=None)

return df
10 changes: 10 additions & 0 deletions tests/test_datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,16 @@ def test_tsdataset_idx_slice_pass_prediction_intervals_to_output(ts_with_predict
)


def test_tsdataset_idx_slice_pass_hierarchical_structure_to_output(product_level_constant_forecast_with_quantiles):
ts = product_level_constant_forecast_with_quantiles
initial_hs = ts.hierarchical_structure
slice_hs = ts.tsdataset_idx_slice(start_idx=1, end_idx=2).hierarchical_structure

assert slice_hs is not None
assert slice_hs.level_names == initial_hs.level_names
assert slice_hs.level_structure == initial_hs.level_structure


def test_to_torch_dataset_without_drop(tsdf_with_exog):
def make_samples(df):
return [{"target": df.target.values, "segment": df["segment"].values[0]}]
Expand Down
15 changes: 15 additions & 0 deletions tests/test_datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from etna.datasets import generate_ar_df
from etna.datasets.utils import DataFrameFormat
from etna.datasets.utils import _check_features_in_segments
from etna.datasets.utils import _slice_index_wide_dataframe
from etna.datasets.utils import _TorchDataset
from etna.datasets.utils import apply_alignment
from etna.datasets.utils import determine_freq
Expand Down Expand Up @@ -1013,3 +1014,17 @@ def test_check_features_in_segments_ok(columns):
)
def test_check_features_in_segments_ok_with_expected_segments(columns):
_check_features_in_segments(columns=columns, segments=[1, 2])


@pytest.mark.parametrize("start, stop", ((0, 4), (4, -1), (-5, -1), (None, 6), (5, None), (None, None)))
def test_slice_index_wide_dataframe_int_idx(df_aligned_datetime, start, stop):
res = _slice_index_wide_dataframe(df=df_aligned_datetime, start=start, stop=stop, label_indexing=False)
pd.testing.assert_frame_equal(res, df_aligned_datetime.iloc[start:stop])


@pytest.mark.parametrize(
"start, stop", (("2020-01-01", "2020-01-04"), (None, "2020-01-10"), ("2020-01-09", None), (None, None))
)
def test_slice_index_wide_dataframe_label_idx(df_aligned_datetime, start, stop):
res = _slice_index_wide_dataframe(df=df_aligned_datetime, start=start, stop=stop, label_indexing=True)
pd.testing.assert_frame_equal(res, df_aligned_datetime.loc[start:stop])