Skip to content

Commit

Permalink
add PyStan3 reloo example [docs] (#1583)
Browse files Browse the repository at this point in the history
* add pystan3 reloo example

* add nest_asyncio

* fix class imports

* fix super call

* update workbook

* update wrappers

* move pystan refitting files

* update sampling wrappers md.

* rename pystan3 to pystan

* update wrapper notebook

* fix mypy

* fix missing import

* black

* update doc references

Co-authored-by: Oriol (ZBook) <[email protected]>
  • Loading branch information
ahartikainen and OriolAbril authored Mar 30, 2021
1 parent 1b5d24f commit b71c83b
Show file tree
Hide file tree
Showing 8 changed files with 693 additions and 91 deletions.
4 changes: 2 additions & 2 deletions arviz/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Sampling wrappers."""
from .base import SamplingWrapper
from .wrap_pystan import PyStanSamplingWrapper
from .wrap_stan import PyStan2SamplingWrapper, PyStanSamplingWrapper

__all__ = ["SamplingWrapper", "PyStanSamplingWrapper"]
__all__ = ["SamplingWrapper", "PyStan2SamplingWrapper", "PyStanSamplingWrapper"]
7 changes: 3 additions & 4 deletions arviz/wrappers/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=too-many-instance-attributes,too-many-arguments
"""Base class for sampling wrappers."""
from xarray import apply_ufunc

Expand Down Expand Up @@ -230,11 +231,9 @@ def check_implemented_methods(self, methods):
if method in supported_methods_1arg:
if self._check_method_is_implemented(method, 1):
continue
else:
not_implemented.append(method)
not_implemented.append(method)
elif method in supported_methods_2args:
if self._check_method_is_implemented(method, 1, 1):
continue
else:
not_implemented.append(method)
not_implemented.append(method)
return not_implemented
70 changes: 61 additions & 9 deletions arviz/wrappers/wrap_pystan.py → arviz/wrappers/wrap_stan.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# pylint: disable=arguments-differ
"""Base class for PyStan wrappers."""
from ..data import from_pystan
from typing import Union

from ..data import from_cmdstanpy, from_pystan
from .base import SamplingWrapper


class PyStanSamplingWrapper(SamplingWrapper):
"""PyStan sampling wrapper base class.
# pylint: disable=abstract-method
class StanSamplingWrapper(SamplingWrapper):
"""Stan sampling wrapper base class.
See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
description. An example of ``PyStanSamplingWrapper`` usage can be found
Expand Down Expand Up @@ -47,16 +50,65 @@ def sel_observations(self, idx):
"""
raise NotImplementedError("sel_observations must be implemented on a model basis")

def sample(self, modified_observed_data):
"""Resample the PyStan model stored in self.model on modified_observed_data."""
fit = self.model.sampling(data=modified_observed_data, **self.sample_kwargs)
return fit

def get_inference_data(self, fit):
"""Convert the fit object returned by ``self.sample`` to InferenceData."""
idata = from_pystan(posterior=fit, **self.idata_kwargs)
if fit.__class__.__name__ == "CmdStanMCMC":
idata = from_cmdstanpy(posterior=fit, **self.idata_kwargs)
else:
idata = from_pystan(posterior=fit, **self.idata_kwargs)
return idata

def log_likelihood__i(self, excluded_obs_log_like, idata__i):
"""Retrieve the log likelihood of the excluded observations from ``idata__i``."""
return idata__i.log_likelihood[excluded_obs_log_like]


class PyStan2SamplingWrapper(StanSamplingWrapper):
"""PyStan (2.x) sampling wrapper base class.
See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
description. An example of ``PyStanSamplingWrapper`` usage can be found
in the :ref:`pystan_refitting` notebook. For usage examples of other wrappers
see the user guide pages on :ref:`wrapper_guide`.
Warnings
--------
Sampling wrappers are an experimental feature in a very early stage. Please use them
with caution.
See Also
--------
SamplingWrapper
"""

def sample(self, modified_observed_data):
"""Resample the PyStan model stored in self.model on modified_observed_data."""
fit = self.model.sampling(data=modified_observed_data, **self.sample_kwargs)
return fit


class PyStanSamplingWrapper(StanSamplingWrapper):
"""PyStan (3.0+) sampling wrapper base class.
See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
description. An example of ``PyStan3SamplingWrapper`` usage can be found
in the :ref:`pystan3_refitting` notebook.
Warnings
--------
Sampling wrappers are an experimental feature in a very early stage. Please use them
with caution.
"""

def sample(self, modified_observed_data):
"""Rebuild and resample the PyStan model on modified_observed_data."""
import stan # pylint: disable=import-error,import-outside-toplevel

self.model: Union[str, stan.Model]
if isinstance(self.model, str):
program_code = self.model
else:
program_code = self.model.program_code
self.model = stan.build(program_code, data=modified_observed_data)
fit = self.model.sample(**self.sample_kwargs)
return fit
1 change: 1 addition & 0 deletions doc/source/api/wrappers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ Experimental feature

SamplingWrapper
PyStanSamplingWrapper
PyStan2SamplingWrapper
460 changes: 460 additions & 0 deletions doc/source/user_guide/pystan2_refitting.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"(pystan_refitting_xr)=\n",
"# Refitting PyStan models with ArviZ (and xarray)\n",
"(pystan2_refitting_xr)=\n",
"# Refitting PyStan (2.x) models with ArviZ (and xarray)\n",
"\n",
"ArviZ is backend agnostic and therefore does not sample directly. In order to take advantage of algorithms that require refitting models several times, ArviZ uses {class}`~arviz.SamplingWrapper`s to convert the API of the sampling backend to a common set of functions. Hence, functions like Leave Future Out Cross Validation can be used in ArviZ independently of the sampling backend used."
]
Expand All @@ -14,7 +14,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Below there is one example of `SamplingWrapper` usage for PyStan."
"Below there is one example of `SamplingWrapper` usage for PyStan (2.x)."
]
},
{
Expand Down Expand Up @@ -124,7 +124,7 @@
" vector[N] y_hat;\n",
" \n",
" for (i in 1:N) {\n",
" // pointwise log likelihood will be calculated outside stan, \n",
" // pointwise log likelihood will be calculated outside Stan, \n",
" // posterior predictive however will be generated here, there are \n",
" // no restrictions on adding more generated quantities\n",
" y_hat[i] = normal_rng(b0 + b1 * x[i], sigma_e);\n",
Expand Down Expand Up @@ -195,7 +195,7 @@
"source": [
"We are now missing the `log_likelihood` group because we have not used the `log_likelihood` argument in `idata_kwargs`. We are doing this to ease the job of the sampling wrapper. Instead of going out of our way to get Stan to calculate the pointwise log likelihood values for each refit and for the excluded observation at every refit, we will compromise and manually write a function to calculate the pointwise log likelihood.\n",
"\n",
"Even though it is not ideal to lose part of the straight out of the box capabilities of PyStan-ArviZ integration, this should generally not be a problem. We are basically moving the pointwise log likelihood calculation from the Stan code to the Python code, in both cases we need to manyally write the function to calculate the pointwise log likelihood.\n",
"Even though it is not ideal to lose part of the straight out of the box capabilities of PyStan-ArviZ integration, this should generally not be a problem. We are basically moving the pointwise log likelihood calculation from the Stan code to the Python code, in both cases we need to manually write the function to calculate the pointwise log likelihood.\n",
"\n",
"Moreover, the Python computation could even be written to be compatible with Dask. Thus it will work even in cases where the large number of observations makes it impossible to store pointwise log likelihood values (with shape `n_samples * n_observations`) in memory."
]
Expand Down Expand Up @@ -3181,14 +3181,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We will create a subclass of {class}`~arviz.SamplingWrapper`. Therefore, instead of having to implement all functions required by {func}`~arviz.reloo` we only have to implement `sel_observations` (we are cloning `sample` and `get_inference_data` from the `PyStanSamplingWrapper` in order to use `apply_ufunc` instead of assuming the log likelihood is calculated within Stan). \n",
"We will create a subclass of {class}`~arviz.SamplingWrapper`. Therefore, instead of having to implement all functions required by {func}`~arviz.reloo` we only have to implement `sel_observations` (we are cloning `sample` and `get_inference_data` from the `PyStan2SamplingWrapper` in order to use `apply_ufunc` instead of assuming the log likelihood is calculated within Stan). \n",
"\n",
"Note that of the 2 outputs of `sel_observations`, `data__i` is a dictionary because it is an argument of `sample` which will pass it as is to `model.sampling`, whereas `data_ex` is a list because it is an argument to `log_likelihood__i` which will pass it as `*data_ex` to `apply_ufunc`. More on `data_ex` and `apply_ufunc` integration below."
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -3202,7 +3202,7 @@
" return data__i, data_ex\n",
" \n",
" def sample(self, modified_observed_data):\n",
" #Cloned from PyStanSamplingWrapper.\n",
" #Cloned from PyStan2SamplingWrapper.\n",
" fit = self.model.sampling(data=modified_observed_data, **self.sample_kwargs)\n",
" return fit\n",
"\n",
Expand Down Expand Up @@ -3265,7 +3265,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We initialize our sampling wrapper. Let's stop and analize each of the arguments. \n",
"We initialize our sampling wrapper. Let's stop and analyze each of the arguments. \n",
"\n",
"We then use the `log_lik_fun` and `posterior_vars` argument to tell the wrapper how to call {func}`~xarray:xarray.apply_ufunc`. `log_lik_fun` is the function to be called, which is then called with the following positional arguments:\n",
"\n",
Expand Down Expand Up @@ -3419,5 +3419,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
219 changes: 154 additions & 65 deletions doc/source/user_guide/pystan_refitting.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion doc/source/user_guide/sampling_wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ whereas the second one externalizes the computation of the pointwise log
likelihood to the user who is expected to write it with xarray/numpy.

```{toctree}
pystan2_refitting
pystan_refitting
pymc3_refitting
numpyro_refitting
pystan_refitting_xr_lik
pystan2_refitting_xr_lik
pymc3_refitting_xr_lik
numpyro_refitting_xr_lik
```

0 comments on commit b71c83b

Please sign in to comment.