Skip to content

Commit

Permalink
Remove nonfunctioning GPJax
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Dec 28, 2024
1 parent 5f84ffb commit 51381c2
Show file tree
Hide file tree
Showing 10 changed files with 908 additions and 112 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ jobs:
- name: Install dependencies
run: |
pip install hatch
- name: Build package
run: |
pip install jaxlib jax
- name: Run tests
run: |
hatch run test:test
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install pypa/build
Expand Down
126 changes: 60 additions & 66 deletions examples/reconciliation.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,37 @@
import chex
import distrax
import gpjax as gpx
import jax
import numpy as np
import optax
import pandas as pd
from einops import rearrange
from jax import Array
from jax import numpy as jnp
from jax import random as jr
from ramsey import NP, train_neural_process
from ramsey.nn import MLP
from statsmodels.tsa.arima_process import arma_generate_sample
from tensorflow_probability.substrates.jax import distributions as tfd

from reconcile.forecast import Forecaster
from reconcile.grouping import Grouping
from reconcile.probabilistic_reconciliation import ProbabilisticReconciliation

jax.config.update("jax_enable_x64", True)


class GPForecaster(Forecaster):
"""Example implementation of a forecaster"""

class NeuralProcessForecaster(Forecaster):
"""Example implementation of a forecaster."""
def __init__(self):
super().__init__()
self._models: list = []
self._xs: jax.Array = None
self._ys: jax.Array = None
self._xs: jax.Array
self._ys: jax.Array

@property
def data(self):
"""Returns the data"""
return self._ys, self._xs
def data(self) -> tuple[Array, Array]:
return self._xs, self._ys

def fit(
self, rng_key: jr.PRNGKey, ys: jax.Array, xs: jax.Array, niter=2000
):
"""Fit a model to each of the time series"""

"""Fit a model to each of the time series."""
self._xs = xs
self._ys = ys
chex.assert_rank([ys, xs], [3, 3])
Expand All @@ -43,70 +40,67 @@ def fit(
p = xs.shape[1]
self._models = [None] * p
for i in np.arange(p):
x, y = xs[:, [i], :], ys[:, [i], :]
x, y = xs[..., i, :], ys[..., i, :]
# fit a model for each time series
opt_posterior, _, D = self._fit_one(rng_key, x, y, niter)
model, params = self._fit_one(rng_key, x, y, niter)
# save the learned parameters and the original data
self._models[i] = opt_posterior, D
self._models[i] = model, params

def _fit_one(self, rng_key, x, y, niter):
# here we use GPs to model the time series
D = gpx.Dataset(X=x.reshape(-1, 1), y=y.reshape(-1, 1))
elbo, q, likelihood = self._model(rng_key, D.n)

negative_elbo = jax.jit(elbo)
optimiser = optax.adam(learning_rate=5e-3)
opt_posterior, history = gpx.fit(
model=q,
objective=negative_elbo,
train_data=D,
optim=optimiser,
num_iters=niter,
key=rng_key,
# here we use neural processes to model the time series
model = self._model()
n_context, n_target = 10, 20
params, _ = train_neural_process(
rng_key,
model,
x=x.reshape(1, -1, 1),
y=y.reshape(1, -1, 1),
n_context=n_context,
n_target=n_target,
n_iter=1000,
batch_size=1,
)
return opt_posterior, history, D
return model, params

@staticmethod
def _model(rng_key, n):
z = jr.uniform(rng_key, (20, 1))
prior = gpx.gps.Prior(
mean_function=gpx.mean_functions.Constant(),
kernel=gpx.kernels.RBF(),
)
likelihood = gpx.likelihoods.Gaussian(num_datapoints=n)
posterior = prior * likelihood
q = gpx.variational_families.CollapsedVariationalGaussian(
posterior=posterior,
inducing_inputs=z,
)
elbo = gpx.objectives.CollapsedELBO(negative=True)
return elbo, q, likelihood
def _model():
def get_neural_process():
dim = 128
np = NP(
decoder=MLP([dim] * 3 + [2]),
latent_encoder=(MLP([dim] * 3), MLP([dim, dim * 2])),
)
return np
neural_process = get_neural_process()
return neural_process

def posterior_predictive(self, rng_key, xs_test: jax.Array):
"""Compute the joint posterior predictive distribution at xs_test."""
chex.assert_rank(xs_test, 3)

q = xs_test.shape[1]
means = [None] * q
covs = [None] * q
scales = [None] * q
for i in np.arange(q):
x_test = xs_test[:, [i], :].reshape(-1, 1)
opt_posterior, D = self._models[i]
_, q, likelihood = self._model(rng_key, D.n)
latent_dist = opt_posterior(x_test, train_data=D)
predictive_dist = opt_posterior.posterior.likelihood(latent_dist)
means[i] = predictive_dist.mean()
cov = predictive_dist.scale_tril
covs[i] = cov.reshape((1, *cov.shape))

# here we stack the means and covariance functions of all
# GP models we used
means = jnp.vstack(means)
covs = jnp.vstack(covs)

# here we use a single distrax distribution to model the predictive
x_context = self._xs[..., i, :]
y_context = self._ys[..., i, :]
x_test = xs_test[..., i, :]

model, params = self._models[i]
predictive_dist = model.apply(
variables=params,
rngs={"sample": rng_key},
x_context=x_context.reshape(1, -1, 1),
y_context=y_context.reshape(1, -1, 1),
x_target=x_test.reshape(1, -1, 1),
)
means[i] = predictive_dist.mean
scales[i] = predictive_dist.scale

means = rearrange(jnp.vstack(means), "b t ... -> ... b t")
scales = rearrange(jnp.vstack(scales), "b t ... -> ... b t")
# posterior of _all_ models
posterior_predictive = distrax.MultivariateNormalTri(means, covs)
posterior_predictive = tfd.MultivariateNormalDiag(means, scales)
return posterior_predictive

def predictive_posterior_probability(
Expand Down Expand Up @@ -147,12 +141,12 @@ def sample_hierarchical_timeseries():
and the second one is a pd.DataFrame of groups
"""

def _group_names():
def _hierarchy():
hierarchy = ["A:10", "A:20", "B:10", "B:20", "B:30"]

return pd.DataFrame.from_dict({"h1": hierarchy})

return _sample_timeseries(100, 5), _group_names()
return _sample_timeseries(100, 5), _hierarchy()


def run():
Expand All @@ -161,7 +155,7 @@ def run():
all_timeseries = grouping.all_timeseries(b)
all_features = jnp.tile(x, [1, all_timeseries.shape[1], 1])

forecaster = GPForecaster()
forecaster = NeuralProcessForecaster()
forecaster.fit(
jr.PRNGKey(1),
all_timeseries[:, :, :90],
Expand Down
44 changes: 20 additions & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,40 @@ name = "probabilistic-reconciliation"
description = "Probabilistic reconciliation of time series forecasts"
authors = [{name = "Simon Dirmeier", email = "[email protected]"}]
readme = "README.md"
license = "Apache-2.0"
homepage = "https://github.com/dirmeier/reconcile"
license = {file = "LICENSE"}
keywords = ["probabilistic reconciliation", "forecasting", "timeseries", "hierarchical time series"]
classifiers=[
classifiers= [
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
requires-python = ">=3.9"
requires-python = ">=3.11"
dependencies = [
"blackjax-nightly>=0.9.6.post127",
"distrax>=0.1.2",
"chex>=0.1.5",
"jaxlib>=0.4.18",
"jax>=0.4.18",
"flax>=0.7.3",
"gpjax>=0.6.9",
"optax>=0.1.3",
"pandas>=1.5.1"
"blackjax>=1.2.4",
"chex>=0.1.8",
"einops>=0.8.0",
"flax>=0.10.2",
"jax>=0.4.38",
"optax>=0.2.4",
"pandas>=1.5.1",
"ramsey>=0.2.1",
"tfp-nightly[jax]>=0.26.0.dev20241227",
]
dynamic = ["version"]

[project.urls]
homepage = "https://github.com/dirmeier/reconcile"

[tool.hatch.build.targets.wheel]
packages = ["reconcile"]

[tool.hatch.version]
path = "reconcile/__init__.py"

[tool.hatch.build.targets.wheel]
packages = ["reconcile"]

[tool.hatch.build.targets.sdist]
exclude = [
"/.github",
Expand All @@ -55,20 +53,18 @@ dependencies = [
"ruff>=0.3.0",
"pytest>=7.2.0",
"pytest-cov>=4.0.0",
"gpjax>=0.5.0",
"statsmodels>=0.13.2"
]

[tool.hatch.envs.test.scripts]
lint = 'ruff check reconcile examples'
test = 'pytest -v --doctest-modules --cov=./reconcile --cov-report=xml reconcile'

[tool.hatch.envs.examples]
dependencies = [
"gpjax>=0.5.0",
"statsmodels>=0.13.2"
]

[tool.hatch.envs.test.scripts]
lint = 'ruff check reconcile examples'
test = 'pytest -v --doctest-modules --cov=./reconcile --cov-report=xml reconcile'

[tool.hatch.envs.examples.scripts]
reconciliation = 'python examples/reconciliation.py'

Expand Down
2 changes: 1 addition & 1 deletion reconcile/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""reconcile: Probabilistic reconciliation of time series forecasts."""

__version__ = "0.1.0"
__version__ = "0.2.0"

from reconcile.forecast import Forecaster
from reconcile.grouping import Grouping
Expand Down
8 changes: 4 additions & 4 deletions reconcile/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jax import numpy as jnp
from jax import random

from examples.reconciliation import GPForecaster
from examples.reconciliation import NeuralProcessForecaster
from reconcile import ProbabilisticReconciliation
from reconcile.grouping import Grouping

Expand Down Expand Up @@ -40,11 +40,11 @@ def reconciliator():
all_timeseries = grouping.all_timeseries(b)
all_features = jnp.tile(x, [1, all_timeseries.shape[1], 1])

forecaster = GPForecaster()
forecaster = NeuralProcessForecaster()
forecaster.fit(
random.PRNGKey(1),
all_timeseries[:, :90, :],
all_features[:, :90, :],
all_timeseries[:, :, :90],
all_features[:, :, :90],
100,
)

Expand Down
6 changes: 3 additions & 3 deletions reconcile/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import abc

import distrax
from jax import Array
from jax import random as jr
from tensorflow_probability.substrates.jax import distributions as tfp


class Forecaster(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -49,7 +49,7 @@ def fit(self, rng_key: jr.PRNGKey, ys: Array, xs: Array) -> None:
@abc.abstractmethod
def posterior_predictive(
self, rng_key: jr.PRNGKey, xs_test: Array
) -> distrax.Distribution:
) -> tfp.Distribution:
"""Computes the posterior predictive distribution at some input points.
Args:
Expand All @@ -61,7 +61,7 @@ def posterior_predictive(
elements as the original training data
Return:
returns a distrax Distribution with batch shape (,P) and event
returns a TFP Distribution with batch shape (,P) and event
shape (,M), such that a single sample has shape (P, M) and
multiple samples have shape (S, P, M)
"""
Expand Down
4 changes: 2 additions & 2 deletions reconcile/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ def summing_matrix(self):

def extract_bottom_timeseries(self, y):
"""Getter for the bottom time series."""
return y[:, self.n_upper_timeseries :, :]
return y[..., self.n_upper_timeseries:, :]

def upper_time_series(self, b):
"""Getter for upper time series."""
y = self.all_timeseries(b)
return y[:, : self.n_upper_timeseries, :]
return y[..., :self.n_upper_timeseries, :]

@staticmethod
def _paste0(a, b):
Expand Down
Loading

0 comments on commit 51381c2

Please sign in to comment.