Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migration to torch distributions and scoringrules integration #70

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e25e7d6
Started migration from TFP to torch distributions and scoringrules in…
MicheleCattaneo Dec 6, 2024
6e6d05d
Fixed shape error in DistributionLossWrapper
MicheleCattaneo Dec 9, 2024
70c8068
Updated poetry lock
MicheleCattaneo Dec 9, 2024
d2f1f20
Fixed tests, updated multibranch net and deep&cross net. Removed py3.…
MicheleCattaneo Dec 10, 2024
ac6edfc
Added truncated normal, gamma and beta distributions
MicheleCattaneo Dec 12, 2024
7ed2076
Added multivariate normal
MicheleCattaneo Dec 13, 2024
01544f9
Started censored normal distribution
MicheleCattaneo Dec 13, 2024
621b3ee
Implementend k-th moment for TruncNormal and variance for CensoredNormal
MicheleCattaneo Dec 16, 2024
5727b8d
Added LogNormal and samples loss wrapper for scoringrules
MicheleCattaneo Dec 16, 2024
799c3ee
Added tests and defense against mc optimization without rsample
MicheleCattaneo Dec 17, 2024
d6a4f21
Changed logic, models always returns a dist, loss function samples fr…
MicheleCattaneo Dec 19, 2024
956956d
Re-introduced callbacks tests
MicheleCattaneo Dec 19, 2024
23b1ba8
Added torch distribution wrapper to unify sample() and sample_n()
MicheleCattaneo Dec 20, 2024
fb9fbff
Increased coverage to py3.12, added a test and small fix on censored …
MicheleCattaneo Jan 7, 2025
5eeb02b
fixed run-tests.yml
MicheleCattaneo Jan 7, 2025
3979611
update poetry up to py3.12
MicheleCattaneo Jan 7, 2025
50e8fa5
Updated README.ipynb
MicheleCattaneo Jan 9, 2025
f6279d9
Update readme
MicheleCattaneo Jan 10, 2025
3df603f
Fixed readme
MicheleCattaneo Jan 10, 2025
65af245
Added named losses
MicheleCattaneo Jan 10, 2025
c75642c
Added normalizer tests back
MicheleCattaneo Jan 10, 2025
e637011
Added metrics tests back
MicheleCattaneo Jan 10, 2025
e7f6321
Added expected_mae metric and fixed metrics to handle distributions
MicheleCattaneo Jan 10, 2025
a2c342d
Small naming and structure refactor
MicheleCattaneo Jan 13, 2025
fbb1dda
Added new parameter rules for distribution losses
MicheleCattaneo Jan 15, 2025
e000535
Added serialization and gradient flow tests for all layers
MicheleCattaneo Jan 15, 2025
dfd02d9
Updated readme
MicheleCattaneo Jan 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10"]
python-version: ["3.10","3.11","3.12"]

steps:
- uses: actions/checkout@v4
Expand Down
278 changes: 206 additions & 72 deletions README.ipynb

Large diffs are not rendered by default.

108 changes: 57 additions & 51 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# mlpp-lib

[![.github/workflows/run-tests.yml](https://github.com/MeteoSwiss/mlpp-lib/actions/workflows/run-tests.yml/badge.svg)](https://github.com/MeteoSwiss/mlpp-lib/actions/workflows/run-tests.yml)
[![pypi](https://img.shields.io/pypi/v/mlpp-lib.svg?colorB=<brightgreen>)](https://pypi.python.org/pypi/mlpp-lib/)

Collection of methods for ML-based postprocessing of weather forecasts.

:warning: **The code in this repository is currently work-in-progress and not recommended for production use.** :warning:
Expand All @@ -20,19 +17,10 @@ import pandas as pd
from mlpp_lib.datasets import DataModule, DataSplitter
```

2024-03-12 11:01:48.532698: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-03-12 11:01:48.594233: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-03-12 11:01:48.595154: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


2024-03-12 11:01:49.442240: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT



```python
LEADTIMES = np.arange(24)
REFTIMES = pd.date_range("2018-01-01", "2018-03-31", freq="24H")
REFTIMES = pd.date_range("2018-01-01", "2018-03-31", freq="24h")
STATIONS = [chr(i) * 3 for i in range(ord("A"), ord("Z"))]
SHAPE = (len(REFTIMES), len(LEADTIMES), len(STATIONS))
DIMS = ["forecast_reference_time", "lead_time", "station"]
Expand Down Expand Up @@ -89,18 +77,18 @@ print(features)

```

<xarray.Dataset>
<xarray.Dataset> Size: 2MB
Dimensions: (forecast_reference_time: 90, lead_time: 24,
station: 25)
Coordinates:
* forecast_reference_time (forecast_reference_time) datetime64[ns] 2018-01...
* lead_time (lead_time) int64 0 1 2 3 4 5 ... 18 19 20 21 22 23
* station (station) <U3 'AAA' 'BBB' 'CCC' ... 'XXX' 'YYY'
* forecast_reference_time (forecast_reference_time) datetime64[ns] 720B 20...
* lead_time (lead_time) int64 192B 0 1 2 3 4 ... 19 20 21 22 23
* station (station) <U3 300B 'AAA' 'BBB' ... 'XXX' 'YYY'
Data variables:
coe:x1 (forecast_reference_time, lead_time, station) float64 ...
coe:x2 (forecast_reference_time, lead_time, station) float64 ...
obs:x3 (forecast_reference_time, lead_time, station) float64 ...
dem:x4 (forecast_reference_time, lead_time, station) float64 ...
coe:x1 (forecast_reference_time, lead_time, station) float64 432kB ...
coe:x2 (forecast_reference_time, lead_time, station) float64 432kB ...
obs:x3 (forecast_reference_time, lead_time, station) float64 432kB ...
dem:x4 (forecast_reference_time, lead_time, station) float64 432kB ...



Expand All @@ -109,16 +97,16 @@ targets = targets_dataset()
print(targets)
```

<xarray.Dataset>
<xarray.Dataset> Size: 865kB
Dimensions: (forecast_reference_time: 90, lead_time: 24,
station: 25)
Coordinates:
* forecast_reference_time (forecast_reference_time) datetime64[ns] 2018-01...
* lead_time (lead_time) int64 0 1 2 3 4 5 ... 18 19 20 21 22 23
* station (station) <U3 'AAA' 'BBB' 'CCC' ... 'XXX' 'YYY'
* forecast_reference_time (forecast_reference_time) datetime64[ns] 720B 20...
* lead_time (lead_time) int64 192B 0 1 2 3 4 ... 19 20 21 22 23
* station (station) <U3 300B 'AAA' 'BBB' ... 'XXX' 'YYY'
Data variables:
obs:y1 (forecast_reference_time, lead_time, station) float64 ...
obs:y2 (forecast_reference_time, lead_time, station) float64 ...
obs:y1 (forecast_reference_time, lead_time, station) float64 432kB ...
obs:y2 (forecast_reference_time, lead_time, station) float64 432kB ...


## Preparing data
Expand Down Expand Up @@ -147,24 +135,40 @@ datamodule = DataModule(
datamodule.setup(stage=None)
```

No normalizer found, data are standardized by default.


## Training
The library builds on top of the tensorflow + keras API and provides some useful methods to quickly build probabilistic models, as well as a collection of probabilistic metrics. Of course, you're free to use tensorflow and tensorflow probability to build your own custom model. MLPP won't get in your way!
The library builds on top of PyTorch + Keras3 API and provides some useful methods to quickly build probabilistic models, while integrating probabilistic metrics thanks to `scoringrules`. Of course, you're free to use torch and torch distributions to build your own custom model. MLPP won't get in your way!

In the following example the model consists of a fully connected layer and a probabilistic layer modelling a normal distribution parametrized by some predicted parameters, which can either be optimized via a closed form CRPS or a sample-based CRPS.

For sample-based losses, the underlying distribution needs to have a reparametrized sampling function. If that was not available, `SampleLossWrapper` will let you know.


```python
from mlpp_lib.models import fully_connected_network
from mlpp_lib.losses import crps_energy
import tensorflow as tf

model: tf.keras.Model = fully_connected_network(
input_shape = datamodule.train.x.shape[1:],
output_size = datamodule.train.y.shape[-1],
hidden_layers = [32, 32],
activations = "relu",
probabilistic_layer = "IndependentNormal"
)
from mlpp_lib.layers import FullyConnectedLayer
from mlpp_lib.models import ProbabilisticModel
from mlpp_lib.losses import DistributionLossWrapper, SampleLossWrapper
from mlpp_lib.probabilistic_layers import BaseDistributionLayer, UniveriateGaussianModule
import scoringrules as sr
import keras

model.compile(loss=crps_energy, optimizer="adam")

encoder = FullyConnectedLayer(hidden_layers=[16,8],
batchnorm=False,
skip_connection=False,
dropout=0.1,
mc_dropout=False,
activations='sigmoid')
prob_layer = BaseDistributionLayer(distribution=UniveriateGaussianModule())

model = ProbabilisticModel(encoder_layer=encoder, probabilistic_layer=prob_layer)

# crps_normal = DistributionLossWrapper(fn=sr.crps_normal) # closed form CRPS
crps_normal = SampleLossWrapper(fn=sr.crps_ensemble, num_samples=100) # sample-based CRPS

model.compile(loss=crps_normal, optimizer=keras.optimizers.Adam(learning_rate=0.1))

history = model.fit(
datamodule.train.x, datamodule.train.y,
Expand All @@ -174,11 +178,6 @@ history = model.fit(
)
```

Epoch 1/2
689/689 [==============================] - 2s 2ms/step - loss: 0.5633 - val_loss: 0.5721

Epoch 2/2
689/689 [==============================] - 1s 2ms/step - loss: 0.5607 - val_loss: 0.5695


## Predictions
Expand All @@ -191,13 +190,20 @@ test_pred_ensemble = datamodule.test.dataset_from_predictions(test_pred_ensemble
print(test_pred_ensemble)
```

<xarray.Dataset>
<xarray.Dataset> Size: 363kB
Dimensions: (realization: 21, forecast_reference_time: 18,
lead_time: 24, station: 5)
Coordinates:
* forecast_reference_time (forecast_reference_time) datetime64[ns] 2018-03...
* lead_time (lead_time) int64 0 1 2 3 4 5 ... 18 19 20 21 22 23
* station (station) <U3 'AAA' 'EEE' 'JJJ' 'PPP' 'RRR'
* realization (realization) int64 0 1 2 3 4 5 ... 16 17 18 19 20
* forecast_reference_time (forecast_reference_time) datetime64[ns] 144B 20...
* lead_time (lead_time) int64 192B 0 1 2 3 4 ... 19 20 21 22 23
* station (station) <U3 60B 'AAA' 'III' 'NNN' 'VVV' 'YYY'
* realization (realization) int64 168B 0 1 2 3 4 ... 17 18 19 20
Data variables:
obs:y1 (realization, forecast_reference_time, lead_time, station) float64 ...
obs:y1 (realization, forecast_reference_time, lead_time, station) float64 363kB ...


## Build the README

```
poetry run jupyter nbconvert --execute --to markdown README.ipynb
```
2 changes: 1 addition & 1 deletion mlpp_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

import os

os.environ["TF_USE_LEGACY_KERAS"] = "1"
os.environ["KERAS_BACKEND"] = "torch"
4 changes: 2 additions & 2 deletions mlpp_lib/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
import properscoring as ps
from tensorflow.keras import callbacks
from keras import callbacks


class EnsembleMetrics(callbacks.Callback):
Expand All @@ -16,7 +16,7 @@ def add_validation_data(self, validation_data) -> None:

def on_epoch_end(self, epoch, logs):
"""Compute a range of probabilistic scores at the end of each epoch."""
y_pred = self.model(self.X_val).sample(self.n_samples)
y_pred = self.model(self.X_val).sample((self.n_samples,))

y_pred = y_pred.numpy()[:, :, 0].T
y_val = np.squeeze(self.y_val)
Expand Down
178 changes: 178 additions & 0 deletions mlpp_lib/custom_distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import torch
from torch.distributions import Distribution, constraints, Normal

class TruncatedNormalDistribution(Distribution):
"""
Implementation of a truncated normal distribution in [a, b] with
differentiable sampling.

Source: The Truncated Normal Distribution, John Burkardt 2023
"""

def __init__(self, mu_bar: torch.Tensor, sigma_bar: torch.Tensor, a: torch.Tensor,b: torch.Tensor):
"""_summary_

Args:
mu_bar (torch.Tensor): The mean of the underlying Normal. It is not the true mean.
sigma_bar (torch.Tensor): The std of the underlying Normal. It is not the true std.
a (torch.Tensor): The left boundary
b (torch.Tensor): The right boundary
"""
self._n = Normal(mu_bar, sigma_bar)
self.mu_bar = mu_bar
self.sigma_bar = sigma_bar
super().__init__()

self.a = a
self.b = b


def icdf(self, p):
# inverse cdf
p_ = self._n.cdf(self.a) + p * (self._n.cdf(self.b) - self._n.cdf(self.a))
return self._n.icdf(p_)

def mean(self) -> torch.Tensor:
"""
Returns:
torch.Tensor: Returns the true mean of the distribution.
"""
alpha = (self.a - self.mu_bar) / self.sigma_bar
beta = (self.b - self.mu_bar) / self.sigma_bar

sn = torch.distributions.Normal(torch.zeros_like(self.mu_bar), torch.ones_like(self.mu_bar))

scale = (torch.exp(sn.log_prob(beta)) - torch.exp(sn.log_prob(alpha)))/(sn.cdf(beta) - sn.cdf(alpha))

return self.mu_bar - self.sigma_bar * scale

def variance(self) -> torch.Tensor:
"""
Returns:
torch.Tensor: Returns the true variance of the distribution.
"""
alpha = (self.a - self.mu_bar) / self.sigma_bar
beta = (self.b - self.mu_bar) / self.sigma_bar

sn = torch.distributions.Normal(torch.zeros_like(self.mu_bar), torch.ones_like(self.mu_bar))

pdf_a = torch.exp(sn.log_prob(alpha))
pdf_b = torch.exp(sn.log_prob(beta))
CDF_a = sn.cdf(alpha)
CDF_b = sn.cdf(beta)

return self.sigma_bar**2 * (1.0 - (beta*pdf_b - alpha*pdf_a)/(CDF_b - CDF_a) - ((pdf_b - pdf_a)/(CDF_b - CDF_a))**2)


def moment(self, k):
# Source: A Recursive Formula for the Moments of a Truncated Univariate Normal Distribution (Eric Orjebin)
if k == -1:
return torch.zeros_like(self.mu_bar)
if k == 0:
return torch.ones_like(self.mu_bar)

alpha = (self.a - self.mu_bar) / self.sigma_bar
beta = (self.b - self.mu_bar) / self.sigma_bar
sn = torch.distributions.Normal(torch.zeros_like(self.mu_bar), torch.ones_like(self.mu_bar))

scale = ((self.b**(k-1) * torch.exp(sn.log_prob(beta)) - self.a**(k-1) * torch.exp(sn.log_prob(alpha))) / (sn.cdf(beta) - sn.cdf(alpha)))

return (k-1)* self.sigma_bar ** 2 * self.moment(k-2) + self.mu_bar * self.moment(k-1) - self.sigma_bar * scale

def sample(self, shape):
return self.rsample(shape)

def rsample(self, shape):
# get some random probability [0,1]
p = torch.distributions.Uniform(torch.zeros_like(self.mu_bar), torch.ones_like(self.sigma_bar)).sample(shape)
# apply the inverse cdf on p
return self.icdf(p)

@property
def arg_constraints(self):
return {
'mu_bar': constraints.real,
'sigma_bar': constraints.positive,
}

@property
def has_rsample(self):
return True

class CensoredNormalDistribution(Distribution):
r"""Implements a censored Normal distribution.
Values of the underlying normal that lie outside the range [a,b]
are assigned to a and b respectively.

.. math::
f_Y(y) =
\begin{cases}
a, & \text{if } y \leq a \\
\sim N(\bar{\mu}, \bar{\sigma}) & \text{if } a < y < b \\
b, & \text{if } y \geq b \\
\end{cases}


"""

def __init__(self, mu_bar: torch.Tensor, sigma_bar: torch.Tensor, a: torch.Tensor,b: torch.Tensor):
"""
Args:
mu_bar (torch.Tensor): The mean of the latent normal distribution
sigma_bar (torch.Tensor): The std of the latend normal distribution
a (torch.Tensor): The lower bound of the distribution.
b (torch.Tensor): The upper bound of the distribution.
"""


self._n = Normal(mu_bar, sigma_bar)
self.mu_bar = mu_bar
self.sigma_bar = sigma_bar
super().__init__()

self.a = a
self.b = b


def mean(self):
alpha = (self.a - self.mu_bar) / self.sigma_bar
beta = (self.b - self.mu_bar) / self.sigma_bar

sn = torch.distributions.Normal(torch.zeros_like(self.mu_bar), torch.ones_like(self.mu_bar))
E_z = TruncatedNormalDistribution(self.mu_bar, self.sigma_bar, self.a, self.b).mean()
return (
self.b * (1-sn.cdf(beta))
+ self.a * sn.cdf(alpha)
+ E_z * (sn.cdf(beta) - sn.cdf(alpha))
)


def variance(self):
# Variance := Var(Y) = E(Y^2) - E(Y)^2
alpha = (self.a - self.mu_bar) / self.sigma_bar
beta = (self.b - self.mu_bar) / self.sigma_bar
sn = torch.distributions.Normal(torch.zeros_like(self.mu_bar), torch.ones_like(self.sigma_bar))
tn = TruncatedNormalDistribution(mu_bar=self.mu_bar, sigma_bar=self.sigma_bar, a=self.a, b=self.b)

# Law of total expectation:
# E(Y^2) = E(Y^2|X>b)*P(X>b) + E(Y^2|X<a)*P(X<a) + E(Y^2 | a<X<b)*P(a<X<b)
# = b^2 * P(X>b) + a^2 * P(X<a) + E(Z^2~TruncNormal(mu,sigma,a,b)) * P(a<X<b)

E_z2 = tn.moment(2) # E(Z^2)
E_y2 = self.b**2 * (1-sn.cdf(beta)) + self.a**2 * sn.cdf(alpha) + E_z2 * (sn.cdf(beta) - sn.cdf(alpha)) # E(Y^2)

return E_y2 - self.mean()**2 # Var(Y)=E(Y^2)-E(Y)^2


def sample(self, shape):
# note: clipping degenerates the gradients.
# Do not use for MC optimization.
s = self._n.sample(shape)
return torch.clip(s, min=self.a, max=self.b)

@property
def arg_constraints(self):
return {
'mu_bar': constraints.real,
'sigma_bar': constraints.positive, # Enforce positive scale
}
Loading
Loading