Skip to content

Commit

Permalink
Rename PatchTSModel to PatchTSTModel (#601)
Browse files Browse the repository at this point in the history
* fix: rename patchts to patchtst

* chore: update changelog

* fix: update notebook output

* fix: fix name im examples README

---------

Co-authored-by: Egor Baturin <[email protected]>
  • Loading branch information
egoriyaa and Egor Baturin authored Feb 7, 2025
1 parent ae4556f commit 0a0145e
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 62 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **Breaking:** Rename `TSDataset.index` to `TSDataset.timestamps` ([#593](https://github.com/etna-team/etna/pull/593))
- **Breaking:** Rename `TSDataset.add_columns_from_pandas` to `TSDataset.add_features_from_pandas` ([#593](https://github.com/etna-team/etna/pull/593))
- **Breaking:** Rename `TSDataset.update_columns_from_pandas` to `TSDataset.update_features_from_pandas` ([#593](https://github.com/etna-team/etna/pull/593))
-
- **Breaking:** Rename `PatchTSModel` to `PatchTSTModel` ([#601](https://github.com/etna-team/etna/pull/601))
-
-
- Fix `TSDataset.train_test_split` to pass all features to train and test parts ([#545](https://github.com/etna-team/etna/pull/545))
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api_reference/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ Native neural network models:
nn.DeepStateModel
nn.NBeatsGenericModel
nn.NBeatsInterpretableModel
nn.PatchTSModel
nn.PatchTSTModel
nn.DeepARModel
nn.TFTModel

Expand Down
2 changes: 1 addition & 1 deletion etna/models/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from etna.models.nn.mlp import MLPModel
from etna.models.nn.nbeats import NBeatsGenericModel
from etna.models.nn.nbeats import NBeatsInterpretableModel
from etna.models.nn.patchts import PatchTSModel
from etna.models.nn.patchtst import PatchTSTModel
from etna.models.nn.rnn import RNNModel
from etna.models.nn.tft import TFTModel

Expand Down
22 changes: 11 additions & 11 deletions etna/models/nn/patchts.py → etna/models/nn/patchtst.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import torch.nn as nn


class PatchTSBatch(TypedDict):
"""Batch specification for PatchTS."""
class PatchTSTBatch(TypedDict):
"""Batch specification for PatchTST."""

encoder_target: "torch.Tensor"
decoder_target: "torch.Tensor"
Expand Down Expand Up @@ -52,8 +52,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(x)


class PatchTSNet(DeepBaseNet):
"""PatchTS based Lightning module."""
class PatchTSTNet(DeepBaseNet):
"""PatchTST based Lightning module."""

def __init__(
self,
Expand All @@ -68,7 +68,7 @@ def __init__(
loss: "torch.nn.Module",
optimizer_params: Optional[dict],
) -> None:
"""Init PatchTS.
"""Init PatchTST.
Parameters
----------
Expand Down Expand Up @@ -120,7 +120,7 @@ def __init__(
self.lr = lr
self.optimizer_params = {} if optimizer_params is None else optimizer_params

def forward(self, x: PatchTSBatch, *args, **kwargs): # type: ignore
def forward(self, x: PatchTSTBatch, *args, **kwargs): # type: ignore
"""Forward pass.
Parameters
Expand Down Expand Up @@ -157,7 +157,7 @@ def _get_prediction(self, x: torch.Tensor) -> torch.Tensor:

return self.projection(y) # (batch_size, 1)

def step(self, batch: PatchTSBatch, *args, **kwargs): # type: ignore
def step(self, batch: PatchTSTBatch, *args, **kwargs): # type: ignore
"""Step for loss computation for training or validation.
Parameters
Expand Down Expand Up @@ -239,8 +239,8 @@ def configure_optimizers(self) -> "torch.optim.Optimizer":
return optimizer


class PatchTSModel(DeepBaseModel):
"""PatchTS model using PyTorch layers. For more details read the `paper <https://arxiv.org/abs/2211.14730>`_.
class PatchTSTModel(DeepBaseModel):
"""PatchTST model using PyTorch layers. For more details read the `paper <https://arxiv.org/abs/2211.14730>`_.
Model uses only `target` column, other columns will be ignored.
Expand Down Expand Up @@ -271,7 +271,7 @@ def __init__(
val_dataloader_params: Optional[dict] = None,
split_params: Optional[dict] = None,
):
"""Init PatchTS model.
"""Init PatchTST model.
Parameters
----------
Expand Down Expand Up @@ -327,7 +327,7 @@ def __init__(
self.loss = loss if loss is not None else nn.MSELoss()
self.optimizer_params = optimizer_params
super().__init__(
net=PatchTSNet(
net=PatchTSTNet(
encoder_length,
patch_len=self.patch_len,
stride=self.stride,
Expand Down
28 changes: 15 additions & 13 deletions examples/202-NN_examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
" * [MLP](#section_2_5)\n",
" * [Deep State Model](#section_2_6)\n",
" * [N-BEATS Model](#section_2_7)\n",
" * [PatchTS Model](#section_2_8)\n",
" * [PatchTST Model](#section_2_8)\n",
" * [Chronos Model](#section_2_9)\n",
" * [Chronos Bolt Model](#section_2_10)\n",
" * [TimesFM Model](#section_2_11)"
Expand Down Expand Up @@ -1043,7 +1043,7 @@
}
},
"source": [
"### 2.3 TFTNative <a class=\"anchor\" id=\"section_2_3\"></a>"
"### 2.3 TFT <a class=\"anchor\" id=\"section_2_3\"></a>"
]
},
{
Expand Down Expand Up @@ -1152,7 +1152,6 @@
"pycharm": {
"name": "#%%\n"
},
"scrolled": false,
"tags": []
},
"outputs": [
Expand Down Expand Up @@ -2784,7 +2783,7 @@
}
},
"source": [
"### 2.8 PatchTS Model <a class=\"anchor\" id=\"section_2_8\"></a>\n",
"### 2.8 PatchTST Model <a class=\"anchor\" id=\"section_2_8\"></a>\n",
"\n",
"Model with transformer encoder that uses patches of timeseries as input words and linear decoder."
]
Expand All @@ -2801,14 +2800,15 @@
},
"outputs": [],
"source": [
"from etna.models.nn import PatchTSModel"
"from etna.models.nn import PatchTSTModel"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "cc38238d",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
Expand Down Expand Up @@ -2951,7 +2951,7 @@
"source": [
"set_seed()\n",
"\n",
"model_patchts = PatchTSModel(\n",
"model_patchtst = PatchTSTModel(\n",
" decoder_length=HORIZON,\n",
" encoder_length=2 * HORIZON,\n",
" patch_len=1,\n",
Expand All @@ -2960,11 +2960,11 @@
" train_batch_size=64,\n",
")\n",
"\n",
"pipeline_patchts = Pipeline(\n",
" model=model_patchts, horizon=HORIZON, transforms=[StandardScalerTransform(in_column=\"target\")]\n",
"pipeline_patchtst = Pipeline(\n",
" model=model_patchtst, horizon=HORIZON, transforms=[StandardScalerTransform(in_column=\"target\")]\n",
")\n",
"\n",
"metrics_patchts, forecast_patchts, fold_info_patchs = pipeline_patchts.backtest(\n",
"metrics_patchtst, forecast_patchtst, fold_info_patchtst = pipeline_patchtst.backtest(\n",
" ts, metrics=metrics, n_folds=3, n_jobs=1\n",
")"
]
Expand All @@ -2974,6 +2974,7 @@
"execution_count": 57,
"id": "6394b96c",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
Expand All @@ -2987,20 +2988,21 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Average SMAPE for PatchTS: 7.583\n"
"Average SMAPE for PatchTST: 7.583\n"
]
}
],
"source": [
"score = metrics_patchts[\"SMAPE\"].mean()\n",
"print(f\"Average SMAPE for PatchTS: {score:.3f}\")"
"score = metrics_patchtst[\"SMAPE\"].mean()\n",
"print(f\"Average SMAPE for PatchTST: {score:.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "a514bd99",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
Expand All @@ -3022,7 +3024,7 @@
}
],
"source": [
"plot_backtest(forecast_patchts, ts, history_len=20)"
"plot_backtest(forecast_patchtst, ts, history_len=20)"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ We have prepared a set of tutorials for an easy introduction:
- Loading dataset
- Testing models
- Baseline
- DeepARNative
- TFTNative
- DeepAR
- TFT
- RNN
- MLP
- Deep State Model
- N-BEATS Model
- PatchTS Model
- PatchTST Model
- Chronos Model
- Chronos Bolt Model
- TimesFM Model
Expand Down
24 changes: 12 additions & 12 deletions tests/test_models/test_inference/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from etna.models.nn import MLPModel
from etna.models.nn import NBeatsGenericModel
from etna.models.nn import NBeatsInterpretableModel
from etna.models.nn import PatchTSModel
from etna.models.nn import PatchTSTModel
from etna.models.nn import RNNModel
from etna.models.nn import TFTModel
from etna.models.nn import TimesFMModel
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_forecast_in_sample_full_no_target_failed_nans_sklearn(self, model, tran
[],
"example_tsds",
),
(PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(
DeepStateModel(
ssm=CompositeSSM(seasonal_ssms=[WeeklySeasonalitySSM()]),
Expand Down Expand Up @@ -302,7 +302,7 @@ def test_forecast_in_sample_full_failed_nans_nn(self, model, transforms, dataset
[],
"example_tsds",
),
(PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(
DeepStateModel(
ssm=CompositeSSM(seasonal_ssms=[WeeklySeasonalitySSM()]),
Expand Down Expand Up @@ -434,7 +434,7 @@ def _test_forecast_in_sample_suffix_no_target(ts, model, transforms, num_skip_po
[],
"example_tsds",
),
(PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(
MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)),
[LagTransform(in_column="target", lags=[2, 3])],
Expand Down Expand Up @@ -533,7 +533,7 @@ class TestForecastInSampleSuffix:
[],
"example_tsds",
),
(PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(
MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)),
[LagTransform(in_column="target", lags=[2, 3])],
Expand Down Expand Up @@ -651,7 +651,7 @@ def _test_forecast_out_sample(ts, model, transforms, prediction_size=5):
[],
"example_tsds",
),
(PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(
MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)),
[LagTransform(in_column="target", lags=[5, 6])],
Expand Down Expand Up @@ -729,7 +729,7 @@ def test_forecast_out_sample_datetime_timestamp(self, model, transforms, dataset
[],
"example_tsds",
),
(PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(
MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)),
[LagTransform(in_column="target", lags=[5, 6])],
Expand Down Expand Up @@ -881,7 +881,7 @@ def _test_forecast_out_sample_prefix(ts, model, transforms, full_prediction_size
[],
"example_tsds",
),
(PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(
MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)),
[LagTransform(in_column="target", lags=[5, 6])],
Expand Down Expand Up @@ -995,7 +995,7 @@ def _test_forecast_out_sample_suffix(ts, model, transforms, full_prediction_size
(SeasonalMovingAverageModel(), [], "example_tsds"),
(NaiveModel(lag=3), [], "example_tsds"),
(DeadlineMovingAverageModel(window=1), [], "example_tsds"),
(PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(
MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)),
[LagTransform(in_column="target", lags=[5, 6])],
Expand Down Expand Up @@ -1211,7 +1211,7 @@ def _test_forecast_mixed_in_out_sample(ts, model, transforms, num_skip_points=50
[],
"example_tsds",
),
(PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(
MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)),
[LagTransform(in_column="target", lags=[5, 6])],
Expand Down Expand Up @@ -1343,7 +1343,7 @@ def _test_forecast_subset_segments(self, ts, model, transforms, segments, predic
[],
"example_tsds",
),
(PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(
MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)),
[LagTransform(in_column="target", lags=[5, 6])],
Expand Down Expand Up @@ -1455,7 +1455,7 @@ def _test_forecast_new_segments(self, ts, model, transforms, train_segments, pre
[],
"example_tsds",
),
(PatchTSModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(PatchTSTModel(encoder_length=7, decoder_length=7, trainer_params=dict(max_epochs=1)), [], "example_tsds"),
(
MLPModel(input_size=2, hidden_size=[10], decoder_length=7, trainer_params=dict(max_epochs=1)),
[LagTransform(in_column="target", lags=[5, 6])],
Expand Down
Loading

0 comments on commit 0a0145e

Please sign in to comment.