-
-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
258 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
# Adaptation with normalizing flows | ||
|
||
**Experimental** | ||
**Experimental and subject to change** |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,179 @@ | ||
# Usage with PyMC models | ||
|
||
This document shows how to use `nutpie` with PyMC models. We will use the | ||
`pymc` package to define a simple model and sample from it using `nutpie`. | ||
|
||
## Installation | ||
|
||
The recommended way to install `pymc` is through the `conda` ecosystem. A good | ||
package manager for conda packages is `pixi`. See for the [pixi | ||
documentation](https://pixi.sh) for instructions on how to install it. | ||
|
||
We create a new project for this example: | ||
|
||
```bash | ||
pixi new pymc-example | ||
``` | ||
|
||
This will create a new directory `pymc-example` with a `pixi.toml` file, that | ||
you can edit to add meta information. | ||
|
||
We then add the `pymc` and `nutpie` packages to the project: | ||
|
||
```bash | ||
cd pymc-example | ||
pixi add pymc nutpie arviz | ||
``` | ||
|
||
You can use Visual Studio Code (VSCode) or JupyterLab to write and run our code. | ||
Both are excellent tools for working with Python and data science projects. | ||
|
||
### Using VSCode | ||
|
||
1. Open VSCode. | ||
2. Open the `pymc-example` directory created earlier. | ||
3. Create a new file named `model.ipynb`. | ||
4. Select the pixi kernel to run the code. | ||
|
||
### Using JupyterLab | ||
|
||
1. Add jupyter labs to the project by running `pixi add jupyterlab`. | ||
1. Open JupyterLab by running `pixi run jupyter lab` in your terminal. | ||
3. Create a new Python notebook. | ||
|
||
## Defining and Sampling a Simple Model | ||
|
||
We will define a simple Bayesian model using `pymc` and sample from it using | ||
`nutpie`. | ||
|
||
### Model Definition | ||
|
||
In your `model.ipypy` file or Jupyter notebook, add the following code: | ||
|
||
```python | ||
import pymc as pm | ||
import nutpie | ||
import pandas as pd | ||
|
||
coords = {"observation": range(3)} | ||
|
||
with pm.Model(coords=coords) as model: | ||
# Prior distributions for the intercept and slope | ||
intercept = pm.Normal("intercept", mu=0, sigma=1) | ||
slope = pm.Normal("slope", mu=0, sigma=1) | ||
|
||
# Likelihood (sampling distribution) of observations | ||
x = [1, 2, 3] | ||
|
||
mu = intercept + slope * x | ||
sigma = pm.HalfNormal("sigma", sigma=1) | ||
y = pm.Normal("y", mu=mu, sigma=sigma, observed=[1, 2, 3], dims="observation") | ||
``` | ||
|
||
### Sampling | ||
|
||
We can now compile the model using the numba backend: | ||
|
||
```python | ||
compiled = nutpie.compile_pymc_model(model) | ||
trace = nutpie.sample(compiled) | ||
``` | ||
|
||
While sampling, nutpie shows a progress bar for each chain. It also includes | ||
information about how each chain is doing: | ||
|
||
- It shows the current number of draws | ||
- The step size of the integrator (very small stepsizes are typically a bad | ||
sign) | ||
- The number of divergences (if there are divergences, that means that nutpie is | ||
probably not sampling the posterior correctly) | ||
- The number of gradient evaluation nutpie uses for each draw. Large numbers | ||
(100 to 1000) are a sign that the parameterization of the model is not ideal, | ||
and the sampler is very inefficient. | ||
|
||
After sampling, this returns an `arviz` InferenceData object that you can use to | ||
analyze the trace. | ||
|
||
For example, we should check the effective sample size: | ||
|
||
```python | ||
import arviz as az | ||
az.ess(trace) | ||
``` | ||
|
||
and have a look at a trace plot: | ||
|
||
```python | ||
az.plot_trace(trace) | ||
``` | ||
|
||
### Choosing the backend | ||
|
||
Right now, we have been using the numba backend. This is the default backend for | ||
`nutpie`, when sampling from pymc models. It tends to have relatively long | ||
compilation times, but samples small models very efficiently. For larger models | ||
the `jax` backend sometimes outperforms `numba`. | ||
|
||
First, we need to install the `jax` package: | ||
|
||
```bash | ||
pixi add jax | ||
``` | ||
|
||
We can select the backend by passing the `backend` argument to the `compile_pymc_model`: | ||
|
||
```python | ||
compiled_jax = nutpie.compiled_pymc_model(model, backend="jax") | ||
trace = nutpie.sample(compiled_jax) | ||
``` | ||
|
||
If you have an nvidia GPU, you can also use the `jax` backend with the `gpu`. We | ||
will have to install the `jaxlib` package with the `cuda` option | ||
|
||
```bash | ||
pixi add jaxlib --build 'cuda12' | ||
``` | ||
|
||
Restart the kernel and check that the GPU is available: | ||
|
||
```python | ||
import jax | ||
|
||
# Should list the cuda device | ||
jax.devices() | ||
``` | ||
|
||
Sampling again, should now use the GPU, which you can observe by checking the | ||
GPU usage with `nvidia-smi` or `nvtop`. | ||
|
||
### Changing the dataset without recompilation | ||
|
||
If you want to use the same model with different datasets, you can modify | ||
datasets after compilation. Since jax does not like changes in shapes, this is | ||
only recommended with the numba backend. | ||
|
||
First, we define the model, but put our dataset in a `pm.Data` structure: | ||
|
||
```python | ||
with pm.Model(): | ||
x = pm.Data("x", [1, 2, 3]) | ||
intercept = pm.Normal("intercept", mu=0, sigma=1) | ||
slope = pm.Normal("slope", mu=0, sigma=1) | ||
mu = intercept + slope * x | ||
sigma = pm.HalfNormal("sigma", sigma=1) | ||
y = pm.Normal("y", mu=mu, sigma=sigma, observed=[1, 2, 3]) | ||
``` | ||
|
||
We can now compile the model: | ||
|
||
```python | ||
compiled = nutpie.compile_pymc_model(model) | ||
trace = nutpie.sample(compiled) | ||
``` | ||
|
||
After compilation, we can change the dataset: | ||
|
||
```python | ||
compiled2 = compiled.with_data({"x": [4, 5, 6]}) | ||
trace2 = nutpie.sample(compiled2) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Nutpie Documentation | ||
|
||
`nutpie` is a high-performance library designed for Bayesian inference, that | ||
provides efficient sampling algorithms for probabilistic models. It can sample | ||
models that are defined in PyMC or Stan (numpyro and custom hand-coded | ||
likelihoods with gradient are coming soon). | ||
|
||
- Faster sampling than either the PyMC or Stan default samplers. (An average | ||
~2x speedup on `posteriordb` compared to Stan) | ||
- All the diagnostic information of PyMC and Stan and some more. | ||
- GPU support for PyMC models through jax. | ||
- A more informative progress bar. | ||
- Access to the incomplete trace during sampling. | ||
- *Experimental* normalizing flow adaptation for more efficient sampling of | ||
difficult posteriors. | ||
|
||
## Quickstart | ||
|
||
Install `nutpie` with pip, uv, pixi, or conda: | ||
|
||
For usage with pymc: | ||
|
||
```bash | ||
# One of | ||
pip install "nutpie[pymc]" | ||
uv add "nutpie[pymc]" | ||
pixi add nutpie pymc numba | ||
conda install -c conda-forge nutpie pymc numba | ||
``` | ||
|
||
And then sample with | ||
|
||
```python | ||
import nutpie | ||
import pymc as pm | ||
|
||
with pm.Model() as model: | ||
mu = pm.Normal("mu", mu=0, sigma=1) | ||
obs = pm.Normal("obs", mu=mu, sigma=1, observed=[1, 2, 3]) | ||
|
||
compiled = nutpie.compile_pymc_model(model) | ||
trace = nutpie.sample(compiled) | ||
``` | ||
|
||
Stan needs access to a compiler toolchain, you can find instructions for those | ||
[here](https://mc-stan.org/docs/cmdstan-guide/installation.html#cpp-toolchain). | ||
You can then install nutpie through pip or uv: | ||
|
||
```bash | ||
# One of | ||
pip install "nutpie[stan]" | ||
uv add "nutpie[stan]" | ||
``` | ||
|
||
```python | ||
import nutpie | ||
|
||
model = """ | ||
data { | ||
int<lower=0> N; | ||
real y[N]; | ||
} | ||
parameters { | ||
real mu; | ||
} | ||
model { | ||
mu ~ normal(0, 1); | ||
y ~ normal(mu, 1); | ||
} | ||
""" | ||
|
||
compiled = ( | ||
nutpie | ||
.compile_stan_model(code=model) | ||
.with_data({"N": 3, "y": [1, 2, 3]}) | ||
) | ||
trace = nutpie.sample(compiled) | ||
``` |