Skip to content

Commit

Permalink
Add docs for samplers and improve API
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Feb 2, 2025
1 parent 4290340 commit 3524900
Show file tree
Hide file tree
Showing 15 changed files with 104 additions and 87 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"diffrax>=0.6.1",
"blackjax-nightly>=1.0.0.post17",
"dm-haiku>=0.0.9",
"einops>=0.8.0",
"matplotlib>=3.6.2",
"optax>=0.1.3",
"seaborn>=0.12.2",
Expand Down
2 changes: 1 addition & 1 deletion sbijax/_src/_ne_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def simulate_parameters(
"""
if params is None or len(params) == 0:
diagnostics = None
new_thetas = self.prior_sampler_fn(
new_thetas = self.prior.sample(
seed=rng_key,
sample_shape=(n_simulations,),
)
Expand Down
9 changes: 1 addition & 8 deletions sbijax/_src/_sbi_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import abc

from jax import random as jr

from sbijax._src.util.dataloader import as_batch_iterators


Expand All @@ -16,13 +14,8 @@ def __init__(self, model_fns):
Args:
model_fns: tuple
"""
prior = model_fns[0]()
self.prior_sampler_fn, self.prior_log_density_fn = (
prior.sample,
prior.log_prob,
)
self.prior = model_fns[0]()
self.simulator_fn = model_fns[1]
self._len_theta = len(self.prior_sampler_fn(seed=jr.PRNGKey(123)))

@staticmethod
def as_iterators(
Expand Down
14 changes: 3 additions & 11 deletions sbijax/_src/experimental/aio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,9 @@

def test_aio(prior_simulator_tuple):
y_observed = jnp.array([-1.0, 1.0])
estim = AiO(prior_simulator_tuple, make_simformer_based_score_model(2, jnp.eye(4)))
data, params = None, {}
for i in range(2):
data, _ = estim.simulate_data_and_possibly_append(
jr.PRNGKey(1),
params=params,
observable=y_observed,
data=data,
n_simulations=100,
)
params, info = estim.fit(jr.PRNGKey(2), data=data, n_iter=2)
estim = AiO(prior_simulator_tuple, make_simformer_based_score_model(2, jnp.eye(4), 1, 1))
data, _ = estim.simulate_data(jr.PRNGKey(1), n_simulations=100)
params, info = estim.fit(jr.PRNGKey(2), data=data, n_iter=2)
_ = estim.sample_posterior(
jr.PRNGKey(3),
params,
Expand Down
2 changes: 1 addition & 1 deletion sbijax/_src/experimental/nn/make_score_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def drift_fn(inputs):
solver,
self._time_eps,
self._time_max,
-self._time_delta,
self._time_delta,
(inputs, 0.0),
)
(latents,), (delta_log_likelihood,) = sol.ys
Expand Down
4 changes: 2 additions & 2 deletions sbijax/_src/experimental/npse.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _init_params(self, rng_key, **init_data):
return params

def get_truncated_prior(self, rng_key, params, observable, n_samples):
samp = self.prior_sampler_fn(seed=jr.PRNGKey(0), sample_shape=())
samp = self.prior.sample(seed=jr.PRNGKey(0), sample_shape=())
_, unravel_fn = ravel_pytree(samp)

sample_key, rng_key = jr.split(rng_key)
Expand All @@ -85,7 +85,7 @@ def get_truncated_prior(self, rng_key, params, observable, n_samples):
jax.tree.map(lambda x: x.max(axis=0), posterior_samples),
)
sample_key, rng_key = jr.split(rng_key)
prior_samples = self.prior_sampler_fn(
prior_samples = self.prior.sample(
seed=sample_key, sample_shape=(int(1e6),)
)
min_prior, max_prior = (
Expand Down
6 changes: 2 additions & 4 deletions sbijax/_src/fmpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def sample_posterior(
thetas = None
n_curr = n_samples
n_total_simulations_round = 0
_, unravel_fn = ravel_pytree(self.prior_sampler_fn(seed=jr.PRNGKey(1)))
_, unravel_fn = ravel_pytree(self.prior.sample(seed=jr.PRNGKey(1)))
while n_curr > 0:
n_sim = jnp.minimum(1024, jnp.maximum(1024, n_curr))
n_total_simulations_round += n_sim
Expand All @@ -221,9 +221,7 @@ def sample_posterior(
context=jnp.tile(observable, [n_sim, 1]),
is_training=False,
)
proposal_probs = self.prior_log_density_fn(
jax.vmap(unravel_fn)(proposal)
)
proposal_probs = self.prior.log_prob(jax.vmap(unravel_fn)(proposal))
proposal_accepted = proposal[jnp.isfinite(proposal_probs)]
if thetas is None:
thetas = proposal_accepted
Expand Down
11 changes: 8 additions & 3 deletions sbijax/_src/nass.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ class NASS(NE):
Chen, Yanzhi et al. "Neural Approximate Sufficient Statistics for Implicit Models". ICLR, 2021
"""

def sample_posterior(self, rng_key, params, observable, *args, **kwargs):
raise NotImplementedError()

def __init__(self, model_fns, summary_net):
super().__init__(model_fns, summary_net)

Expand Down Expand Up @@ -220,3 +217,11 @@ def simulate_data(
return super().simulate_data(
rng_key, n_simulations=n_simulations, **kwargs
)

def _simulate_parameters_with_model(
self, rng_key, params, observable, *args, **kwargs
):
raise NotImplementedError()

def sample_posterior(self, rng_key, params, observable, *args, **kwargs):
raise NotImplementedError()
44 changes: 26 additions & 18 deletions sbijax/_src/nle.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,30 +320,17 @@ def _log_likelihood_fn(theta):
theta = jnp.tile(theta, [observable.shape[0], 1])
return part(x=theta)

def _joint_logdensity_fn(theta):
lp_prior = self.prior_log_density_fn(theta)
def _prop_posterior_density(theta):
lp_prior = self.prior.log_prob(theta)
lp = _log_likelihood_fn(theta)
return jnp.sum(lp) + jnp.sum(lp_prior)

if "sampler" in kwargs and kwargs["sampler"] == "slice":

def lp__(theta):
return jax.vmap(_joint_logdensity_fn)(theta)

sampler = kwargs.pop("sampler", None)
else:

def lp__(theta):
return _joint_logdensity_fn(theta)

# take whatever sampler is or per default nuts
sampler = kwargs.pop("sampler", "nuts")

sampler = kwargs.pop("sampler", "nuts")
sampling_fn = getattr(mcmc, "sample_with_" + sampler)
samples = sampling_fn(
rng_key=rng_key,
lp=lp__,
prior=self.prior_sampler_fn,
lp=_prop_posterior_density,
prior=self.prior,
n_chains=n_chains,
n_samples=n_samples,
n_warmup=n_warmup,
Expand All @@ -355,6 +342,27 @@ def lp__(theta):
diagnostics = mcmc_diagnostics(inference_data)
return inference_data, diagnostics

def _simulate_parameters_with_model(
self,
rng_key,
params,
observable,
*,
n_chains=4,
n_samples=2_000,
n_warmup=1_000,
**kwargs,
):
return self.sample_posterior(
rng_key=rng_key,
params=params,
observable=observable,
n_chains=n_chains,
n_samples=n_samples,
n_warmup=n_warmup,
**kwargs,
)

@staticmethod
def plot(inference_data: InferenceData):
arviz.plot_trace(inference_data)
4 changes: 2 additions & 2 deletions sbijax/_src/nn/make_consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from jax import numpy as jnp
from tensorflow_probability.substrates.jax import distributions as tfd

__all__ = ["ConsistencyModel", "make_cm"]

from sbijax._src.nn.make_resnet import _ResnetBlock

__all__ = ["ConsistencyModel", "make_cm"]


# ruff: noqa: PLR0913,D417
class ConsistencyModel(hk.Module):
Expand Down
12 changes: 6 additions & 6 deletions sbijax/_src/nn/make_continuous_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def __init__(self, n_dimension: int, transform: Callable, sigma_min=0.001):
"""
super().__init__()
self._n_dimension = n_dimension
self._score_model = transform
self._score_net = transform
self._base_distribution = distrax.Normal(jnp.zeros(n_dimension), 1.0)
self.sigma_min = sigma_min
self._sigma_min = sigma_min

def __call__(self, method, **kwargs):
"""Aplpy the flow.
Expand All @@ -140,7 +140,7 @@ def sample(self, context, **kwargs):
def ode_fn(time, theta_t):
theta_t = theta_t.reshape(-1, self._n_dimension)
time = jnp.repeat(time, theta_t.shape[0])
ret = self._score_model(
ret = self._score_net(
inputs=theta_t, time=time, context=context, **kwargs
)
return ret.reshape(-1)
Expand All @@ -161,15 +161,15 @@ def loss(self, inputs, context, is_training, **kwargs):
n, _ = inputs.shape
times = jr.uniform(hk.next_rng_key(), shape=(n,))
theta_t = sample_theta_t(
hk.next_rng_key(), inputs, times, self.sigma_min
hk.next_rng_key(), inputs, times, self._sigma_min
)
vs = self._score_model(
vs = self._score_net(
inputs=theta_t,
time=times,
context=context,
is_training=is_training,
)
uts = ut(theta_t, inputs, times, self.sigma_min)
uts = ut(theta_t, inputs, times, self._sigma_min)
loss = jnp.mean(jnp.square(vs - uts))
return loss

Expand Down
32 changes: 24 additions & 8 deletions sbijax/_src/npe.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _fit_model_single_round(
state = optimizer.init(params)

n_round = self.n_round
_, unravel_fn = ravel_pytree(self.prior_sampler_fn(seed=jr.PRNGKey(1)))
_, unravel_fn = ravel_pytree(self.prior.sample(seed=jr.PRNGKey(1)))

if n_round == 0:

Expand Down Expand Up @@ -250,7 +250,7 @@ def _proposal_posterior_log_prob(self, params, rng, n_atoms, theta, y):
params, None, method="log_prob", y=atomic_theta, x=repeated_y
)
log_prob_posterior = log_prob_posterior.reshape(n, n_atoms)
log_prob_prior = self.prior_log_density_fn(atomic_theta)
log_prob_prior = self.prior.log_prob(atomic_theta)
log_prob_prior = log_prob_prior.reshape(n, n_atoms)

unnormalized_log_prob = log_prob_posterior - log_prob_prior
Expand All @@ -262,9 +262,7 @@ def _proposal_posterior_log_prob(self, params, rng, n_atoms, theta, y):

def _validation_loss(self, rng_key, params, val_iter, n_atoms):
if self.n_round == 0:
_, unravel_fn = ravel_pytree(
self.prior_sampler_fn(seed=jr.PRNGKey(1))
)
_, unravel_fn = ravel_pytree(self.prior.sample(seed=jr.PRNGKey(1)))

def loss_fn(rng, **batch):
theta, y = batch["theta"], batch["y"]
Expand Down Expand Up @@ -336,7 +334,7 @@ def sample_posterior(
thetas = None
n_curr = n_samples
n_total_simulations_round = 0
_, unravel_fn = ravel_pytree(self.prior_sampler_fn(seed=jr.PRNGKey(1)))
_, unravel_fn = ravel_pytree(self.prior.sample(seed=jr.PRNGKey(1)))
while n_curr > 0:
n_sim = jnp.minimum(200, jnp.maximum(200, n_curr))
n_total_simulations_round += n_sim
Expand All @@ -351,7 +349,7 @@ def sample_posterior(
if hasattr(self, "_prior_bijectors"):
proposal = jax.vmap(unravel_fn)(proposal)
proposal = self._prior_bijectors.forward(proposal)
proposal_probs = self.prior_log_density_fn(proposal)
proposal_probs = self.prior.log_prob(proposal)
proposal = jax.vmap(lambda x: ravel_pytree(x)[0])(proposal)
else:
proposal_probs = self.prior_log_density_fn(
Expand All @@ -364,7 +362,6 @@ def sample_posterior(
else:
thetas = jnp.vstack([thetas, proposal])
n_curr -= proposal.shape[0]
self.n_total_simulations += n_total_simulations_round

ess = float(thetas.shape[0] / n_total_simulations_round)

Expand All @@ -377,3 +374,22 @@ def reshape(p):
thetas = jax.tree_map(reshape, jax.vmap(unravel_fn)(thetas[:n_samples]))
inference_data = as_inference_data(thetas, jnp.squeeze(observable))
return inference_data, ess

def _simulate_parameters_with_model(
self,
rng_key,
params,
observable,
*,
n_samples=4_000,
check_proposal_probs=True,
**kwargs,
):
return self.sample_posterior(
rng_key=rng_key,
params=params,
observable=observable,
n_samples=n_samples,
check_proposal_probs=check_proposal_probs,
**kwargs,
)
42 changes: 25 additions & 17 deletions sbijax/_src/nre.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,31 +373,18 @@ def _sample_posterior(
part = partial(self.model.apply, params, is_training=False)

def _joint_logdensity_fn(theta):
lp_prior = self.prior_log_density_fn(theta)
lp_prior = self.prior.log_prob(theta)
theta, _ = ravel_pytree(theta)
theta = theta.reshape(observable.shape[0], -1)
lp = part(jnp.concatenate([observable, theta], axis=-1))
return jnp.sum(lp_prior) + jnp.sum(lp)

if "sampler" in kwargs and kwargs["sampler"] == "slice":

def lp__(theta):
return jax.vmap(_joint_logdensity_fn)(theta)

sampler = kwargs.pop("sampler", None)
else:

def lp__(theta):
return _joint_logdensity_fn(theta)

# take whatever sampler is or per default nuts
sampler = kwargs.pop("sampler", "nuts")

sampler = kwargs.pop("sampler", "nuts")
sampling_fn = getattr(mcmc, "sample_with_" + sampler)
samples = sampling_fn(
rng_key=rng_key,
lp=lp__,
prior=self.prior_sampler_fn,
lp=_joint_logdensity_fn,
prior=self.prior,
n_chains=n_chains,
n_samples=n_samples,
n_warmup=n_warmup,
Expand All @@ -408,3 +395,24 @@ def lp__(theta):
inference_data = as_inference_data(samples, jnp.squeeze(observable))
diagnostics = mcmc_diagnostics(inference_data)
return inference_data, diagnostics

def _simulate_parameters_with_model(
self,
rng_key,
params,
observable,
*,
n_chains=4,
n_samples=2_000,
n_warmup=1_000,
**kwargs,
):
return self.sample_posterior(
rng_key=rng_key,
params=params,
observable=observable,
n_samples=n_samples,
n_warmup=n_warmup,
n_chains=n_chains,
**kwargs,
)
Loading

0 comments on commit 3524900

Please sign in to comment.