Skip to content

Commit

Permalink
Feat: We added the horizon weighing to the distribution losses, simil…
Browse files Browse the repository at this point in the history
…ar to the way point losses like MAE are already implemented. This has proven useful in our applications
  • Loading branch information
mwamsojo committed Dec 18, 2024
1 parent df8c431 commit bb41743
Show file tree
Hide file tree
Showing 31 changed files with 293 additions and 97,424 deletions.
394 changes: 8 additions & 386 deletions nbs/docs/capabilities/03_exogenous_variables.ipynb

Large diffs are not rendered by default.

731 changes: 8 additions & 723 deletions nbs/docs/capabilities/04_hyperparameter_tuning.ipynb

Large diffs are not rendered by default.

199 changes: 4 additions & 195 deletions nbs/docs/capabilities/05_predictInsample.ipynb

Large diffs are not rendered by default.

261 changes: 5 additions & 256 deletions nbs/docs/capabilities/06_save_load_models.ipynb

Large diffs are not rendered by default.

329 changes: 9 additions & 320 deletions nbs/docs/capabilities/07_time_series_scaling.ipynb

Large diffs are not rendered by default.

437 changes: 13 additions & 424 deletions nbs/docs/capabilities/08_cross_validation.ipynb

Large diffs are not rendered by default.

184 changes: 4 additions & 180 deletions nbs/docs/getting-started/02_quickstart.ipynb

Large diffs are not rendered by default.

377 changes: 4 additions & 373 deletions nbs/docs/getting-started/05_datarequirements.ipynb

Large diffs are not rendered by default.

832 changes: 15 additions & 817 deletions nbs/docs/tutorials/01_getting_started_complete.ipynb

Large diffs are not rendered by default.

561 changes: 10 additions & 551 deletions nbs/docs/tutorials/02_cross_validation.ipynb

Large diffs are not rendered by default.

196 changes: 5 additions & 191 deletions nbs/docs/tutorials/03_uncertainty_quantification.ipynb

Large diffs are not rendered by default.

237 changes: 8 additions & 229 deletions nbs/docs/tutorials/04_longhorizon_nhits.ipynb

Large diffs are not rendered by default.

306 changes: 6 additions & 300 deletions nbs/docs/tutorials/05_longhorizon_transformers.ipynb

Large diffs are not rendered by default.

259 changes: 3 additions & 256 deletions nbs/docs/tutorials/06_longhorizon_probabilistic.ipynb

Large diffs are not rendered by default.

217 changes: 5 additions & 212 deletions nbs/docs/tutorials/07_forecasting_tft.ipynb

Large diffs are not rendered by default.

335 changes: 8 additions & 327 deletions nbs/docs/tutorials/08_multivariate_tsmixer.ipynb

Large diffs are not rendered by default.

302 changes: 5 additions & 297 deletions nbs/docs/tutorials/09_hierarchical_forecasting.ipynb

Large diffs are not rendered by default.

291 changes: 5 additions & 286 deletions nbs/docs/tutorials/10_distributed_neuralforecast.ipynb

Large diffs are not rendered by default.

514 changes: 7 additions & 507 deletions nbs/docs/tutorials/11_intermittent_data.ipynb

Large diffs are not rendered by default.

101 changes: 3 additions & 98 deletions nbs/docs/tutorials/12_using_mlflow.ipynb

Large diffs are not rendered by default.

482 changes: 8 additions & 474 deletions nbs/docs/tutorials/13_robust_forecasting.ipynb

Large diffs are not rendered by default.

213 changes: 8 additions & 205 deletions nbs/docs/tutorials/14_interpretable_decompositions.ipynb

Large diffs are not rendered by default.

87,993 changes: 27 additions & 87,966 deletions nbs/docs/tutorials/15_comparing_methods.ipynb

Large diffs are not rendered by default.

834 changes: 7 additions & 827 deletions nbs/docs/tutorials/16_temporal_classification.ipynb

Large diffs are not rendered by default.

473 changes: 7 additions & 466 deletions nbs/docs/tutorials/19_large_datasets.ipynb

Large diffs are not rendered by default.

13 changes: 1 addition & 12 deletions nbs/docs/tutorials/20_conformal_prediction.ipynb

Large diffs are not rendered by default.

153 changes: 4 additions & 149 deletions nbs/docs/use-cases/electricity_peak_forecasting.ipynb

Large diffs are not rendered by default.

377 changes: 7 additions & 370 deletions nbs/docs/use-cases/predictive_maintenance.ipynb

Large diffs are not rendered by default.

84 changes: 58 additions & 26 deletions nbs/losses.pytorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2475,7 +2475,7 @@
"\n",
" \"\"\"\n",
" def __init__(self, distribution, level=[80, 90], quantiles=None,\n",
" num_samples=1000, return_params=False, **distribution_kwargs):\n",
" num_samples=1000, return_params=False, horizon_weight = None, **distribution_kwargs):\n",
" super(DistributionLoss, self).__init__()\n",
"\n",
" qs, self.output_names = level_to_outputs(level)\n",
Expand All @@ -2489,6 +2489,12 @@
" self.quantiles = torch.nn.Parameter(qs, requires_grad=False)\n",
" num_qk = len(self.quantiles)\n",
"\n",
" # Generate a horizon weight tensor from the array\n",
" if horizon_weight is not None:\n",
" horizon_weight = torch.Tensor(horizon_weight.flatten())\n",
" self.horizon_weight = horizon_weight\n",
"\n",
"\n",
" if \"num_pieces\" not in distribution_kwargs:\n",
" num_pieces = 5\n",
" else:\n",
Expand Down Expand Up @@ -2610,36 +2616,62 @@
"\n",
" return samples, sample_mean, quants\n",
"\n",
" def __call__(self,\n",
" y: torch.Tensor,\n",
" distr_args: torch.Tensor,\n",
" mask: Union[torch.Tensor, None] = None):\n",
"\n",
"\n",
" def _compute_weights(self, y, mask):\n",
" \"\"\"\n",
" Compute final weights for each datapoint (based on all weights and all masks)\n",
" Set horizon_weight to a ones[H] tensor if not set.\n",
" If set, check that it has the same length as the horizon in x.\n",
" \"\"\"\n",
" Computes the negative log-likelihood objective function. \n",
" To estimate the following predictive distribution:\n",
" if mask is None:\n",
" mask = torch.ones_like(y, device=y.device)\n",
" else:\n",
" mask = mask.unsqueeze(1) # Add Q dimension.\n",
"\n",
" $$\\mathrm{P}(\\mathbf{y}_{\\\\tau}\\,|\\,\\\\theta) \\\\quad \\mathrm{and} \\\\quad -\\log(\\mathrm{P}(\\mathbf{y}_{\\\\tau}\\,|\\,\\\\theta))$$\n",
"\n",
" where $\\\\theta$ represents the distributions parameters. It aditionally \n",
" summarizes the objective signal using a weighted average using the `mask` tensor. \n",
" # get uniform weights if none\n",
" if self.horizon_weight is None:\n",
" self.horizon_weight = torch.ones(mask.shape[-1])\n",
" else:\n",
" assert mask.shape[-1] == len(self.horizon_weight), \\\n",
" 'horizon_weight must have same length as Y'\n",
" weights = self.horizon_weight.clone()\n",
" weights = torch.ones_like(mask, device=mask.device) * weights.to(mask.device)\n",
" return weights * mask\n",
" \n",
"\n",
" **Parameters**<br>\n",
" `y`: tensor, Actual values.<br>\n",
" `distr_args`: Constructor arguments for the underlying Distribution type.<br>\n",
" `loc`: Optional tensor, of the same shape as the batch_shape + event_shape\n",
" of the resulting distribution.<br>\n",
" `scale`: Optional tensor, of the same shape as the batch_shape+event_shape \n",
" of the resulting distribution.<br>\n",
" `mask`: tensor, Specifies date stamps per serie to consider in loss.<br>\n",
"\n",
" **Returns**<br>\n",
" `loss`: scalar, weighted loss function against which backpropagation will be performed.<br>\n",
" \"\"\"\n",
" # Instantiate Scaled Decoupled Distribution\n",
" distr = self.get_distribution(distr_args=distr_args, **self.distribution_kwargs)\n",
" loss_values = -distr.log_prob(y)\n",
" loss_weights = mask\n",
" return weighted_average(loss_values, weights=loss_weights)"
" def __call__(self,\n",
" y: torch.Tensor,\n",
" distr_args: torch.Tensor,\n",
" mask: Union[torch.Tensor, None] = None):\n",
" \"\"\"\n",
" Computes the negative log-likelihood objective function. \n",
" To estimate the following predictive distribution:\n",
"\n",
" $$\\mathrm{P}(\\mathbf{y}_{\\\\tau}\\,|\\,\\\\theta) \\\\quad \\mathrm{and} \\\\quad -\\log(\\mathrm{P}(\\mathbf{y}_{\\\\tau}\\,|\\,\\\\theta))$$\n",
"\n",
" where $\\\\theta$ represents the distributions parameters. It aditionally \n",
" summarizes the objective signal using a weighted average using the `mask` tensor. \n",
" \n",
" **Parameters**<br>\n",
" `y`: tensor, Actual values.<br>\n",
" `distr_args`: Constructor arguments for the underlying Distribution type.<br>\n",
" `loc`: Optional tensor, of the same shape as the batch_shape + event_shape\n",
" of the resulting distribution.<br>\n",
" `scale`: Optional tensor, of the same shape as the batch_shape+event_shape \n",
" of the resulting distribution.<br>\n",
" `mask`: tensor, Specifies date stamps per serie to consider in loss.<br>\n",
"\n",
" **Returns**<br>\n",
" `loss`: scalar, weighted loss function against which backpropagation will be performed.<br>\n",
" \"\"\"\n",
" # Instantiate Scaled Decoupled Distribution\n",
" distr = self.get_distribution(distr_args=distr_args, **self.distribution_kwargs)\n",
" loss_values = -distr.log_prob(y)\n",
" loss_weights = self._compute_weights(y=y, mask=mask)\n",
" return weighted_average(loss_values, weights=loss_weights)"
]
},
{
Expand Down
2 changes: 2 additions & 0 deletions neuralforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@
'neuralforecast/losses/pytorch.py'),
'neuralforecast.losses.pytorch.DistributionLoss.__init__': ( 'losses.pytorch.html#distributionloss.__init__',
'neuralforecast/losses/pytorch.py'),
'neuralforecast.losses.pytorch.DistributionLoss._compute_weights': ( 'losses.pytorch.html#distributionloss._compute_weights',
'neuralforecast/losses/pytorch.py'),
'neuralforecast.losses.pytorch.DistributionLoss.get_distribution': ( 'losses.pytorch.html#distributionloss.get_distribution',
'neuralforecast/losses/pytorch.py'),
'neuralforecast.losses.pytorch.DistributionLoss.sample': ( 'losses.pytorch.html#distributionloss.sample',
Expand Down
30 changes: 29 additions & 1 deletion neuralforecast/losses/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,6 +1873,7 @@ def __init__(
quantiles=None,
num_samples=1000,
return_params=False,
horizon_weight=None,
**distribution_kwargs,
):
super(DistributionLoss, self).__init__()
Expand All @@ -1888,6 +1889,11 @@ def __init__(
self.quantiles = torch.nn.Parameter(qs, requires_grad=False)
num_qk = len(self.quantiles)

# Generate a horizon weight tensor from the array
if horizon_weight is not None:
horizon_weight = torch.Tensor(horizon_weight.flatten())
self.horizon_weight = horizon_weight

if "num_pieces" not in distribution_kwargs:
num_pieces = 5
else:
Expand Down Expand Up @@ -2011,6 +2017,28 @@ def sample(self, distr_args: torch.Tensor, num_samples: Optional[int] = None):

return samples, sample_mean, quants

def _compute_weights(self, y, mask):
"""
Compute final weights for each datapoint (based on all weights and all masks)
Set horizon_weight to a ones[H] tensor if not set.
If set, check that it has the same length as the horizon in x.
"""
if mask is None:
mask = torch.ones_like(y, device=y.device)
else:
mask = mask.unsqueeze(1) # Add Q dimension.

# get uniform weights if none
if self.horizon_weight is None:
self.horizon_weight = torch.ones(mask.shape[-1])
else:
assert mask.shape[-1] == len(
self.horizon_weight
), "horizon_weight must have same length as Y"
weights = self.horizon_weight.clone()
weights = torch.ones_like(mask, device=mask.device) * weights.to(mask.device)
return weights * mask

def __call__(
self,
y: torch.Tensor,
Expand Down Expand Up @@ -2041,7 +2069,7 @@ def __call__(
# Instantiate Scaled Decoupled Distribution
distr = self.get_distribution(distr_args=distr_args, **self.distribution_kwargs)
loss_values = -distr.log_prob(y)
loss_weights = mask
loss_weights = self._compute_weights(y=y, mask=mask)
return weighted_average(loss_values, weights=loss_weights)

# %% ../../nbs/losses.pytorch.ipynb 74
Expand Down

0 comments on commit bb41743

Please sign in to comment.