From 0a0145e0709f701fb3df237f32823b9f096778a2 Mon Sep 17 00:00:00 2001 From: Egor Baturin <82458209+egoriyaa@users.noreply.github.com> Date: Fri, 7 Feb 2025 13:17:06 +0300 Subject: [PATCH] Rename `PatchTSModel` to `PatchTSTModel` (#601) * fix: rename patchts to patchtst * chore: update changelog * fix: update notebook output * fix: fix name im examples README --------- Co-authored-by: Egor Baturin --- CHANGELOG.md | 2 +- docs/source/api_reference/models.rst | 2 +- etna/models/nn/__init__.py | 2 +- etna/models/nn/{patchts.py => patchtst.py} | 22 +++++++-------- examples/202-NN_examples.ipynb | 28 ++++++++++--------- examples/README.md | 6 ++-- .../test_inference/test_forecast.py | 24 ++++++++-------- .../test_inference/test_predict.py | 20 ++++++------- .../{test_patchts.py => test_patchtst.py} | 20 ++++++------- 9 files changed, 64 insertions(+), 62 deletions(-) rename etna/models/nn/{patchts.py => patchtst.py} (95%) rename tests/test_models/test_nn/{test_patchts.py => test_patchtst.py} (80%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a2990745..9caacc474 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **Breaking:** Rename `TSDataset.index` to `TSDataset.timestamps` ([#593](https://github.com/etna-team/etna/pull/593)) - **Breaking:** Rename `TSDataset.add_columns_from_pandas` to `TSDataset.add_features_from_pandas` ([#593](https://github.com/etna-team/etna/pull/593)) - **Breaking:** Rename `TSDataset.update_columns_from_pandas` to `TSDataset.update_features_from_pandas` ([#593](https://github.com/etna-team/etna/pull/593)) -- +- **Breaking:** Rename `PatchTSModel` to `PatchTSTModel` ([#601](https://github.com/etna-team/etna/pull/601)) - - - Fix `TSDataset.train_test_split` to pass all features to train and test parts ([#545](https://github.com/etna-team/etna/pull/545)) diff --git a/docs/source/api_reference/models.rst b/docs/source/api_reference/models.rst index e9db62263..acc1e78bf 100644 --- a/docs/source/api_reference/models.rst +++ b/docs/source/api_reference/models.rst @@ -80,7 +80,7 @@ Native neural network models: nn.DeepStateModel nn.NBeatsGenericModel nn.NBeatsInterpretableModel - nn.PatchTSModel + nn.PatchTSTModel nn.DeepARModel nn.TFTModel diff --git a/etna/models/nn/__init__.py b/etna/models/nn/__init__.py index 8cfc51f9f..d4fbf61d7 100644 --- a/etna/models/nn/__init__.py +++ b/etna/models/nn/__init__.py @@ -6,7 +6,7 @@ from etna.models.nn.mlp import MLPModel from etna.models.nn.nbeats import NBeatsGenericModel from etna.models.nn.nbeats import NBeatsInterpretableModel - from etna.models.nn.patchts import PatchTSModel + from etna.models.nn.patchtst import PatchTSTModel from etna.models.nn.rnn import RNNModel from etna.models.nn.tft import TFTModel diff --git a/etna/models/nn/patchts.py b/etna/models/nn/patchtst.py similarity index 95% rename from etna/models/nn/patchts.py rename to etna/models/nn/patchtst.py index 1be364853..f70efdc6e 100644 --- a/etna/models/nn/patchts.py +++ b/etna/models/nn/patchtst.py @@ -20,8 +20,8 @@ import torch.nn as nn -class PatchTSBatch(TypedDict): - """Batch specification for PatchTS.""" +class PatchTSTBatch(TypedDict): + """Batch specification for PatchTST.""" encoder_target: "torch.Tensor" decoder_target: "torch.Tensor" @@ -52,8 +52,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(x) -class PatchTSNet(DeepBaseNet): - """PatchTS based Lightning module.""" +class PatchTSTNet(DeepBaseNet): + """PatchTST based Lightning module.""" def __init__( self, @@ -68,7 +68,7 @@ def __init__( loss: "torch.nn.Module", optimizer_params: Optional[dict], ) -> None: - """Init PatchTS. + """Init PatchTST. Parameters ---------- @@ -120,7 +120,7 @@ def __init__( self.lr = lr self.optimizer_params = {} if optimizer_params is None else optimizer_params - def forward(self, x: PatchTSBatch, *args, **kwargs): # type: ignore + def forward(self, x: PatchTSTBatch, *args, **kwargs): # type: ignore """Forward pass. Parameters @@ -157,7 +157,7 @@ def _get_prediction(self, x: torch.Tensor) -> torch.Tensor: return self.projection(y) # (batch_size, 1) - def step(self, batch: PatchTSBatch, *args, **kwargs): # type: ignore + def step(self, batch: PatchTSTBatch, *args, **kwargs): # type: ignore """Step for loss computation for training or validation. Parameters @@ -239,8 +239,8 @@ def configure_optimizers(self) -> "torch.optim.Optimizer": return optimizer -class PatchTSModel(DeepBaseModel): - """PatchTS model using PyTorch layers. For more details read the `paper `_. +class PatchTSTModel(DeepBaseModel): + """PatchTST model using PyTorch layers. For more details read the `paper `_. Model uses only `target` column, other columns will be ignored. @@ -271,7 +271,7 @@ def __init__( val_dataloader_params: Optional[dict] = None, split_params: Optional[dict] = None, ): - """Init PatchTS model. + """Init PatchTST model. Parameters ---------- @@ -327,7 +327,7 @@ def __init__( self.loss = loss if loss is not None else nn.MSELoss() self.optimizer_params = optimizer_params super().__init__( - net=PatchTSNet( + net=PatchTSTNet( encoder_length, patch_len=self.patch_len, stride=self.stride, diff --git a/examples/202-NN_examples.ipynb b/examples/202-NN_examples.ipynb index c17c3c78a..6d14025bc 100644 --- a/examples/202-NN_examples.ipynb +++ b/examples/202-NN_examples.ipynb @@ -36,7 +36,7 @@ " * [MLP](#section_2_5)\n", " * [Deep State Model](#section_2_6)\n", " * [N-BEATS Model](#section_2_7)\n", - " * [PatchTS Model](#section_2_8)\n", + " * [PatchTST Model](#section_2_8)\n", " * [Chronos Model](#section_2_9)\n", " * [Chronos Bolt Model](#section_2_10)\n", " * [TimesFM Model](#section_2_11)" @@ -1043,7 +1043,7 @@ } }, "source": [ - "### 2.3 TFTNative " + "### 2.3 TFT " ] }, { @@ -1152,7 +1152,6 @@ "pycharm": { "name": "#%%\n" }, - "scrolled": false, "tags": [] }, "outputs": [ @@ -2784,7 +2783,7 @@ } }, "source": [ - "### 2.8 PatchTS Model \n", + "### 2.8 PatchTST Model \n", "\n", "Model with transformer encoder that uses patches of timeseries as input words and linear decoder." ] @@ -2801,7 +2800,7 @@ }, "outputs": [], "source": [ - "from etna.models.nn import PatchTSModel" + "from etna.models.nn import PatchTSTModel" ] }, { @@ -2809,6 +2808,7 @@ "execution_count": 56, "id": "cc38238d", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false }, @@ -2951,7 +2951,7 @@ "source": [ "set_seed()\n", "\n", - "model_patchts = PatchTSModel(\n", + "model_patchtst = PatchTSTModel(\n", " decoder_length=HORIZON,\n", " encoder_length=2 * HORIZON,\n", " patch_len=1,\n", @@ -2960,11 +2960,11 @@ " train_batch_size=64,\n", ")\n", "\n", - "pipeline_patchts = Pipeline(\n", - " model=model_patchts, horizon=HORIZON, transforms=[StandardScalerTransform(in_column=\"target\")]\n", + "pipeline_patchtst = Pipeline(\n", + " model=model_patchtst, horizon=HORIZON, transforms=[StandardScalerTransform(in_column=\"target\")]\n", ")\n", "\n", - "metrics_patchts, forecast_patchts, fold_info_patchs = pipeline_patchts.backtest(\n", + "metrics_patchtst, forecast_patchtst, fold_info_patchtst = pipeline_patchtst.backtest(\n", " ts, metrics=metrics, n_folds=3, n_jobs=1\n", ")" ] @@ -2974,6 +2974,7 @@ "execution_count": 57, "id": "6394b96c", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false }, @@ -2987,13 +2988,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "Average SMAPE for PatchTS: 7.583\n" + "Average SMAPE for PatchTST: 7.583\n" ] } ], "source": [ - "score = metrics_patchts[\"SMAPE\"].mean()\n", - "print(f\"Average SMAPE for PatchTS: {score:.3f}\")" + "score = metrics_patchtst[\"SMAPE\"].mean()\n", + "print(f\"Average SMAPE for PatchTST: {score:.3f}\")" ] }, { @@ -3001,6 +3002,7 @@ "execution_count": 58, "id": "a514bd99", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false }, @@ -3022,7 +3024,7 @@ } ], "source": [ - "plot_backtest(forecast_patchts, ts, history_len=20)" + "plot_backtest(forecast_patchtst, ts, history_len=20)" ] }, { diff --git a/examples/README.md b/examples/README.md index dbd30ae7e..5985e05d7 100644 --- a/examples/README.md +++ b/examples/README.md @@ -55,13 +55,13 @@ We have prepared a set of tutorials for an easy introduction: - Loading dataset - Testing models - Baseline - - DeepARNative - - TFTNative + - DeepAR + - TFT - RNN - MLP - Deep State Model - N-BEATS Model - - PatchTS Model + - PatchTST Model - Chronos Model - Chronos Bolt Model - TimesFM Model diff --git a/tests/test_models/test_inference/test_forecast.py b/tests/test_models/test_inference/test_forecast.py index 54ac7ce13..a2d66deeb 100644 --- a/tests/test_models/test_inference/test_forecast.py +++ b/tests/test_models/test_inference/test_forecast.py @@ -38,7 +38,7 @@ from etna.models.nn import MLPModel from etna.models.nn import NBeatsGenericModel from etna.models.nn import NBeatsInterpretableModel -from etna.models.nn import PatchTSModel +from etna.models.nn import PatchTSTModel from etna.models.nn import RNNModel from etna.models.nn import TFTModel from etna.models.nn import TimesFMModel @@ -136,7 +136,7 @@ def test_forecast_in_sample_full_no_target_failed_nans_sklearn(self, model, tran [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( DeepStateModel( ssm=CompositeSSM(seasonal_ssms=[WeeklySeasonalitySSM()]), @@ -302,7 +302,7 @@ def test_forecast_in_sample_full_failed_nans_nn(self, model, transforms, dataset [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( DeepStateModel( ssm=CompositeSSM(seasonal_ssms=[WeeklySeasonalitySSM()]), @@ -434,7 +434,7 @@ def _test_forecast_in_sample_suffix_no_target(ts, model, transforms, num_skip_po [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[2, 3])], @@ -533,7 +533,7 @@ class TestForecastInSampleSuffix: [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[2, 3])], @@ -651,7 +651,7 @@ def _test_forecast_out_sample(ts, model, transforms, prediction_size=5): [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[5, 6])], @@ -729,7 +729,7 @@ def test_forecast_out_sample_datetime_timestamp(self, model, transforms, dataset [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[5, 6])], @@ -881,7 +881,7 @@ def _test_forecast_out_sample_prefix(ts, model, transforms, full_prediction_size [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[5, 6])], @@ -995,7 +995,7 @@ def _test_forecast_out_sample_suffix(ts, model, transforms, full_prediction_size (SeasonalMovingAverageModel(), [], "example_tsds"), (NaiveModel(lag=3), [], "example_tsds"), (DeadlineMovingAverageModel(window=1), [], "example_tsds"), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[5, 6])], @@ -1211,7 +1211,7 @@ def _test_forecast_mixed_in_out_sample(ts, model, transforms, num_skip_points=50 [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[5, 6])], @@ -1343,7 +1343,7 @@ def _test_forecast_subset_segments(self, ts, model, transforms, segments, predic [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[5, 6])], @@ -1455,7 +1455,7 @@ def _test_forecast_new_segments(self, ts, model, transforms, train_segments, pre [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[5, 6])], diff --git a/tests/test_models/test_inference/test_predict.py b/tests/test_models/test_inference/test_predict.py index 93659d9b6..5eb1952a3 100644 --- a/tests/test_models/test_inference/test_predict.py +++ b/tests/test_models/test_inference/test_predict.py @@ -36,7 +36,7 @@ from etna.models.nn import MLPModel from etna.models.nn import NBeatsGenericModel from etna.models.nn import NBeatsInterpretableModel -from etna.models.nn import PatchTSModel +from etna.models.nn import PatchTSTModel from etna.models.nn import RNNModel from etna.models.nn import TFTModel from etna.models.nn import TimesFMModel @@ -134,7 +134,7 @@ def test_predict_in_sample_full_failed_not_enough_context(self, model, transform [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[2, 3])], @@ -240,7 +240,7 @@ def test_predict_in_sample_suffix_datetime_timestamp(self, model, transforms, da [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[2, 3])], @@ -372,7 +372,7 @@ def test_predict_in_sample_suffix_int_timestamp_failed(self, model, transforms, [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[2, 3])], @@ -478,7 +478,7 @@ def test_predict_out_sample(self, model, transforms, dataset_name, request): [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[5, 6])], @@ -612,7 +612,7 @@ def test_predict_out_sample_prefix(self, model, transforms, dataset_name, reques [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[5, 6])], @@ -749,7 +749,7 @@ def test_predict_out_sample_suffix(self, model, transforms, dataset_name, reques [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[5, 6])], @@ -913,7 +913,7 @@ def test_predict_mixed_in_out_sample(self, model, transforms, dataset_name, requ [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[5, 6])], @@ -1054,7 +1054,7 @@ def test_predict_subset_segments(self, model, transforms, dataset_name, request) [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[5, 6])], @@ -1158,7 +1158,7 @@ def test_predict_new_segments(self, model, transforms, dataset_name, request): [], "example_tsds", ), - (PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), + (PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"), ( MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)), [LagTransform(in_column="target", lags=[5, 6])], diff --git a/tests/test_models/test_nn/test_patchts.py b/tests/test_models/test_nn/test_patchtst.py similarity index 80% rename from tests/test_models/test_nn/test_patchts.py rename to tests/test_models/test_nn/test_patchtst.py index 6cef3292e..474c6df4e 100644 --- a/tests/test_models/test_nn/test_patchts.py +++ b/tests/test_models/test_nn/test_patchtst.py @@ -4,8 +4,8 @@ import pytest from etna.metrics import MAE -from etna.models.nn import PatchTSModel -from etna.models.nn.patchts import PatchTSNet +from etna.models.nn import PatchTSTModel +from etna.models.nn.patchtst import PatchTSTNet from etna.transforms import StandardScalerTransform from tests.test_models.utils import assert_model_equals_loaded_original from tests.test_models.utils import assert_sampling_is_valid @@ -15,13 +15,13 @@ "horizon", [8, 13, 15], ) -def test_patchts_model_run_weekly_overfit_with_scaler_small_patch(ts_dataset_weekly_function_with_horizon, horizon): +def test_patchtst_model_run_weekly_overfit_with_scaler_small_patch(ts_dataset_weekly_function_with_horizon, horizon): ts_train, ts_test = ts_dataset_weekly_function_with_horizon(horizon) std = StandardScalerTransform(in_column="target") ts_train.fit_transform([std]) encoder_length = 14 decoder_length = 14 - model = PatchTSModel( + model = PatchTSTModel( encoder_length=encoder_length, decoder_length=decoder_length, patch_len=1, trainer_params=dict(max_epochs=20) ) future = ts_train.make_future(horizon, transforms=[std], tail_steps=encoder_length) @@ -37,13 +37,13 @@ def test_patchts_model_run_weekly_overfit_with_scaler_small_patch(ts_dataset_wee "horizon", [8, 13, 15], ) -def test_patchts_model_run_weekly_overfit_with_scaler_medium_patch(ts_dataset_weekly_function_with_horizon, horizon): +def test_patchtst_model_run_weekly_overfit_with_scaler_medium_patch(ts_dataset_weekly_function_with_horizon, horizon): ts_train, ts_test = ts_dataset_weekly_function_with_horizon(horizon) std = StandardScalerTransform(in_column="target") ts_train.fit_transform([std]) encoder_length = 14 decoder_length = 14 - model = PatchTSModel( + model = PatchTSTModel( encoder_length=encoder_length, decoder_length=decoder_length, trainer_params=dict(max_epochs=20) ) future = ts_train.make_future(horizon, transforms=[std], tail_steps=encoder_length) @@ -56,14 +56,14 @@ def test_patchts_model_run_weekly_overfit_with_scaler_medium_patch(ts_dataset_we @pytest.mark.parametrize("df_name", ["example_make_samples_df", "example_make_samples_df_int_timestamp"]) -def test_patchts_make_samples(df_name, request): +def test_patchtst_make_samples(df_name, request): df = request.getfixturevalue(df_name) module = MagicMock() encoder_length = 8 decoder_length = 4 ts_samples = list( - PatchTSNet.make_samples(module, df=df, encoder_length=encoder_length, decoder_length=decoder_length) + PatchTSTNet.make_samples(module, df=df, encoder_length=encoder_length, decoder_length=decoder_length) ) assert len(ts_samples) == len(df) - encoder_length - decoder_length + 1 @@ -83,12 +83,12 @@ def test_patchts_make_samples(df_name, request): def test_save_load(example_tsds): - model = PatchTSModel(encoder_length=14, decoder_length=14, trainer_params=dict(max_epochs=1)) + model = PatchTSTModel(encoder_length=14, decoder_length=14, trainer_params=dict(max_epochs=1)) assert_model_equals_loaded_original(model=model, ts=example_tsds, transforms=[], horizon=3) def test_params_to_tune(example_tsds): ts = example_tsds - model = PatchTSModel(encoder_length=14, decoder_length=14, trainer_params=dict(max_epochs=1)) + model = PatchTSTModel(encoder_length=14, decoder_length=14, trainer_params=dict(max_epochs=1)) assert len(model.params_to_tune()) > 0 assert_sampling_is_valid(model=model, ts=ts)