Skip to content

Commit

Permalink
fix: Allow shared vars to differ in expand and logp
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed May 13, 2024
1 parent 8e43498 commit 3632551
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
32 changes: 21 additions & 11 deletions python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,20 @@ def update_user_data(user_data, user_data_storage):
return np.asarray(user_data)


def make_user_data(func, shared_data):
shared_vars = func.get_shared()
def make_user_data(shared_vars, shared_data):
record_dtype = np.dtype(
[
(
"shared",
[
("data", [(var.name, np.uintp) for var in shared_vars]),
("size", [(var.name, np.uintp) for var in shared_vars]),
("data", [(var_name, np.uintp) for var_name in shared_vars]),
("size", [(var_name, np.uintp) for var_name in shared_vars]),
(
"shape",
[(var.name, np.uint, (var.ndim,)) for var in shared_vars],
[
(var_name, np.uint, (var.ndim,))
for var_name, var in shared_vars.items()
],
),
],
)
Expand Down Expand Up @@ -192,16 +194,23 @@ def compile_pymc_model(model: "pm.Model", **kwargs) -> CompiledPyMCModel:
shape_info,
) = _make_functions(model)

shared_data = {val.name: val.get_value().copy() for val in logp_fn_pt.get_shared()}
shared_data = {}
shared_vars = {}
seen = set()
for val in [*logp_fn_pt.get_shared(), *expand_fn_pt.get_shared()]:
if val.name in shared_data and val not in seen:
raise ValueError(f"Shared variables must have unique names: {val.name}")
shared_data[val.name] = val.get_value().copy()
shared_vars[val.name] = val

for val in shared_data.values():
val.flags.writeable = False

shared_logp = [var.name for var in logp_fn_pt.get_shared()]

user_data = make_user_data(logp_fn_pt, shared_data)
user_data = make_user_data(shared_vars, shared_data)

logp_shared_names = [var.name for var in logp_fn_pt.get_shared()]
logp_numba_raw, c_sig = _make_c_logp_func(
n_dim, logp_fn, user_data, shared_logp, shared_data
n_dim, logp_fn, user_data, logp_shared_names, shared_data
)
with warnings.catch_warnings():
warnings.filterwarnings(
Expand All @@ -212,8 +221,9 @@ def compile_pymc_model(model: "pm.Model", **kwargs) -> CompiledPyMCModel:

logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw)

expand_shared_names = [var.name for var in expand_fn_pt.get_shared()]
expand_numba_raw, c_sig_expand = _make_c_expand_func(
n_dim, n_expanded, expand_fn, user_data, shared_expand, shared_data
n_dim, n_expanded, expand_fn, user_data, expand_shared_names, shared_data
)
with warnings.catch_warnings():
warnings.filterwarnings(
Expand Down
11 changes: 11 additions & 0 deletions tests/test_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,14 @@ def test_pymc_model_shared():
compiled3 = compiled.with_data(mu=0.5, sigma=3 * np.ones(4))
with pytest.raises(RuntimeError):
nutpie.sample(compiled3, chains=1)


def test_missing():
with pm.Model(coords={"obs": range(4)}) as model:
mu = pm.Normal("mu")
pm.Normal("y", mu, observed=[0, -1, 1, np.nan], dims="obs")

compiled = nutpie.compile_pymc_model(model)
tr = nutpie.sample(compiled, chains=1, seed=1)
print(tr.posterior)
assert hasattr(tr.posterior, "y_unobserved")

0 comments on commit 3632551

Please sign in to comment.