Skip to content

Commit

Permalink
DL models fail when running on cuda without categorical features (#615)
Browse files Browse the repository at this point in the history
* fix bug

* fix for tft

* fix: rename DaylySeasonalitySSM with DailySeasonalitySSM

* chore: update changelog

* run DL notebook

---------

Co-authored-by: Egor Baturin <[email protected]>
  • Loading branch information
egoriyaa and Egor Baturin authored Feb 17, 2025
1 parent c778663 commit 7508561
Show file tree
Hide file tree
Showing 11 changed files with 617 additions and 552 deletions.
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `concat` resolver for `OmegaConf` ([#604](https://github.com/etna-team/etna/pull/604))
- Add 500m `TimesFM` model support ([#605](https://github.com/etna-team/etna/pull/605))
- Add `num_layers`, `use_positional_embedding`, `normalize_target`, `normalize_exog`, `forecast_with_exog_mode` parameters to `TimesFM` model ([#605](https://github.com/etna-team/etna/pull/605))
-
- Add logging loss during DL models training ([#615](https://github.com/etna-team/etna/pull/615))
- Add Python versions 3.11 and 3.12 ([#599](https://github.com/etna-team/etna/pull/599))
-
-
Expand Down Expand Up @@ -55,8 +55,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **Breaking:** Rename `TSDataset.add_columns_from_pandas` to `TSDataset.add_features_from_pandas` ([#593](https://github.com/etna-team/etna/pull/593))
- **Breaking:** Rename `TSDataset.update_columns_from_pandas` to `TSDataset.update_features_from_pandas` ([#593](https://github.com/etna-team/etna/pull/593))
- **Breaking:** Rename `PatchTSModel` to `PatchTSTModel` ([#601](https://github.com/etna-team/etna/pull/601))
-
-
- Fix device of empty tensor for categorical features in DL models ([#615](https://github.com/etna-team/etna/pull/615))
- **Breaking:** rename `DaylySeasonalitySSM` to `DailySeasonalitySSM` ([#615](https://github.com/etna-team/etna/pull/615))
- Fix `TSDataset.train_test_split` to pass all features to train and test parts ([#545](https://github.com/etna-team/etna/pull/545))
-

Expand Down
2 changes: 1 addition & 1 deletion docs/source/api_reference/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ Utilities for :py:class:`~etna.models.nn.deepstate.deepstate.DeepStateModel`
nn.deepstate.LevelSSM
nn.deepstate.LevelTrendSSM
nn.deepstate.SeasonalitySSM
nn.deepstate.DaylySeasonalitySSM
nn.deepstate.DailySeasonalitySSM
nn.deepstate.SeasonalitySSM
nn.deepstate.YearlySeasonalitySSM

Expand Down
4 changes: 2 additions & 2 deletions etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def training_step(self, batch: dict, *args, **kwargs): # type: ignore
loss
"""
loss, true_target, _ = self.step(batch, *args, **kwargs) # type: ignore
self.log("train_loss", loss, on_epoch=True, batch_size=len(true_target))
self.log("train_loss", loss, on_epoch=True, batch_size=len(true_target), prog_bar=True, on_step=False)
return loss

def validation_step(self, batch: dict, *args, **kwargs): # type: ignore
Expand All @@ -470,7 +470,7 @@ def validation_step(self, batch: dict, *args, **kwargs): # type: ignore
loss
"""
loss, true_target, _ = self.step(batch, *args, **kwargs) # type: ignore
self.log("val_loss", loss, on_epoch=True, batch_size=len(true_target))
self.log("val_loss", loss, on_epoch=True, batch_size=len(true_target), prog_bar=True, on_step=False)
return loss


Expand Down
24 changes: 20 additions & 4 deletions etna/models/nn/deepar/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,16 @@ def forward(self, x: DeepARBatch, *args, **kwargs): # type: ignore
encoder_real[:, :, 0] = encoder_real[:, :, 0] / weights.unsqueeze(1)
decoder_real[:, :, 0] = decoder_real[:, :, 0] / weights.unsqueeze(1)

encoder_embeddings = self.embedding(encoder_categorical) if self.embedding is not None else torch.Tensor()
decoder_embeddings = self.embedding(decoder_categorical) if self.embedding is not None else torch.Tensor()
encoder_embeddings = (
self.embedding(encoder_categorical)
if self.embedding is not None
else torch.zeros((encoder_real.shape[0], encoder_real.shape[1], 0), device=encoder_real.device)
)
decoder_embeddings = (
self.embedding(decoder_categorical)
if self.embedding is not None
else torch.zeros((decoder_real.shape[0], decoder_real.shape[1], 0), device=decoder_real.device)
)

encoder_values = torch.concat((encoder_real, encoder_embeddings), dim=2)
decoder_values = torch.concat((decoder_real, decoder_embeddings), dim=2)
Expand Down Expand Up @@ -202,8 +210,16 @@ def step(self, batch: DeepARBatch, *args, **kwargs): # type: ignore
encoder_real[:, :, 0] = encoder_real[:, :, 0] / weights.unsqueeze(1)
decoder_real[:, :, 0] = decoder_real[:, :, 0] / weights.unsqueeze(1)

encoder_embeddings = self.embedding(encoder_categorical) if self.embedding is not None else torch.Tensor()
decoder_embeddings = self.embedding(decoder_categorical) if self.embedding is not None else torch.Tensor()
encoder_embeddings = (
self.embedding(encoder_categorical)
if self.embedding is not None
else torch.zeros((encoder_real.shape[0], encoder_real.shape[1], 0), device=encoder_real.device)
)
decoder_embeddings = (
self.embedding(decoder_categorical)
if self.embedding is not None
else torch.zeros((decoder_real.shape[0], decoder_real.shape[1], 0), device=decoder_real.device)
)

encoder_values = torch.concat((encoder_real, encoder_embeddings), dim=2)
decoder_values = torch.concat((decoder_real, decoder_embeddings), dim=2)
Expand Down
2 changes: 1 addition & 1 deletion etna/models/nn/deepstate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
if SETTINGS.torch_required:
from etna.models.nn.deepstate.linear_dynamic_system import LDS
from etna.models.nn.deepstate.state_space_model import CompositeSSM
from etna.models.nn.deepstate.state_space_model import DaylySeasonalitySSM
from etna.models.nn.deepstate.state_space_model import DailySeasonalitySSM
from etna.models.nn.deepstate.state_space_model import LevelSSM
from etna.models.nn.deepstate.state_space_model import LevelTrendSSM
from etna.models.nn.deepstate.state_space_model import SeasonalitySSM
Expand Down
18 changes: 15 additions & 3 deletions etna/models/nn/deepstate/deepstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,11 @@ def step(self, batch: DeepStateBatch, *args, **kwargs): # type: ignore
:, :, :seq_length
] # (num_models, batch_size, seq_length)

encoder_embeddings = self.embedding(encoder_categorical) if self.embedding is not None else torch.Tensor()
encoder_embeddings = (
self.embedding(encoder_categorical)
if self.embedding is not None
else torch.zeros((encoder_real.shape[0], encoder_real.shape[1], 0), device=encoder_real.device)
)
encoder_values = torch.concat((encoder_real, encoder_embeddings), dim=2)

output, (_, _) = self.RNN(encoder_values) # (batch_size, seq_length, latent_dim)
Expand Down Expand Up @@ -173,8 +177,16 @@ def forward(self, x: DeepStateBatch, *args, **kwargs): # type: ignore
:, :, seq_length:
] # (num_models, batch_size, horizon)

encoder_embeddings = self.embedding(encoder_categorical) if self.embedding is not None else torch.Tensor()
decoder_embeddings = self.embedding(decoder_categorical) if self.embedding is not None else torch.Tensor()
encoder_embeddings = (
self.embedding(encoder_categorical)
if self.embedding is not None
else torch.zeros((encoder_real.shape[0], encoder_real.shape[1], 0), device=encoder_real.device)
)
decoder_embeddings = (
self.embedding(decoder_categorical)
if self.embedding is not None
else torch.zeros((decoder_real.shape[0], decoder_real.shape[1], 0), device=decoder_real.device)
)

encoder_values = torch.concat((encoder_real, encoder_embeddings), dim=2)
decoder_values = torch.concat((decoder_real, decoder_embeddings), dim=2)
Expand Down
2 changes: 1 addition & 1 deletion etna/models/nn/deepstate/state_space_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def get_timestamp_transform(self, x: pd.Timestamp):
return x.weekday()


class DaylySeasonalitySSM(SeasonalitySSM):
class DailySeasonalitySSM(SeasonalitySSM):
"""Class for Daily Seasonality State Space Model.
Note
Expand Down
12 changes: 10 additions & 2 deletions etna/models/nn/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def forward(self, batch: MLPBatch): # type: ignore
decoder_real = batch["decoder_real"].float()
decoder_categorical = batch["decoder_categorical"] # each (batch_size, decoder_length, 1)

decoder_embeddings = self.embedding(decoder_categorical) if self.embedding is not None else torch.Tensor()
decoder_embeddings = (
self.embedding(decoder_categorical)
if self.embedding is not None
else torch.zeros((decoder_real.shape[0], decoder_real.shape[1], 0), device=decoder_real.device)
)

decoder_values = torch.concat((decoder_real, decoder_embeddings), dim=2)

Expand All @@ -126,7 +130,11 @@ def step(self, batch: MLPBatch, *args, **kwargs): # type: ignore
decoder_categorical = batch["decoder_categorical"] # each (batch_size, decoder_length, 1)
decoder_target = batch["decoder_target"].float()

decoder_embeddings = self.embedding(decoder_categorical) if self.embedding is not None else torch.Tensor()
decoder_embeddings = (
self.embedding(decoder_categorical)
if self.embedding is not None
else torch.zeros((decoder_real.shape[0], decoder_real.shape[1], 0), device=decoder_real.device)
)

decoder_values = torch.concat((decoder_real, decoder_embeddings), dim=2)
output = self.mlp(decoder_values)
Expand Down
24 changes: 20 additions & 4 deletions etna/models/nn/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,16 @@ def forward(self, x: RNNBatch, *args, **kwargs): # type: ignore
decoder_target = x["decoder_target"].float() # (batch_size, decoder_length, 1)
decoder_length = decoder_real.shape[1]

encoder_embeddings = self.embedding(encoder_categorical) if self.embedding is not None else torch.Tensor()
decoder_embeddings = self.embedding(decoder_categorical) if self.embedding is not None else torch.Tensor()
encoder_embeddings = (
self.embedding(encoder_categorical)
if self.embedding is not None
else torch.zeros((encoder_real.shape[0], encoder_real.shape[1], 0), device=encoder_real.device)
)
decoder_embeddings = (
self.embedding(decoder_categorical)
if self.embedding is not None
else torch.zeros((decoder_real.shape[0], decoder_real.shape[1], 0), device=decoder_real.device)
)

encoder_values = torch.concat((encoder_real, encoder_embeddings), dim=2)
decoder_values = torch.concat((decoder_real, decoder_embeddings), dim=2)
Expand Down Expand Up @@ -152,8 +160,16 @@ def step(self, batch: RNNBatch, *args, **kwargs): # type: ignore

decoder_length = decoder_real.shape[1]

encoder_embeddings = self.embedding(encoder_categorical) if self.embedding is not None else torch.Tensor()
decoder_embeddings = self.embedding(decoder_categorical) if self.embedding is not None else torch.Tensor()
encoder_embeddings = (
self.embedding(encoder_categorical)
if self.embedding is not None
else torch.zeros((encoder_real.shape[0], encoder_real.shape[1], 0), device=encoder_real.device)
)
decoder_embeddings = (
self.embedding(decoder_categorical)
if self.embedding is not None
else torch.zeros((decoder_real.shape[0], decoder_real.shape[1], 0), device=decoder_real.device)
)

encoder_values = torch.concat((encoder_real, encoder_embeddings), dim=2)
decoder_values = torch.concat((decoder_real, decoder_embeddings), dim=2)
Expand Down
2 changes: 1 addition & 1 deletion etna/models/nn/tft/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def forward(self, x: TFTBatch, *args, **kwargs) -> torch.Tensor:
x=decoder_features
) # (batch_size, decoder_length, hidden_size)
else:
decoder_output = torch.zeros(batch_size, decoder_length, self.hidden_size)
decoder_output = torch.zeros(batch_size, decoder_length, self.hidden_size, device=encoder_output.device)
residual = torch.cat((encoder_output, decoder_output), dim=1)

# Pass encoder and decoder data through LSTM
Expand Down
1,073 changes: 543 additions & 530 deletions examples/202-NN_examples.ipynb

Large diffs are not rendered by default.

0 comments on commit 7508561

Please sign in to comment.