From 067212347bb8b4b3bbc6a78d8e9fc3b213c38f91 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 13 May 2024 19:57:04 +0200 Subject: [PATCH] fix: Allow shared vars to differ in expand and logp --- python/nutpie/compile_pymc.py | 32 +++++++++++++++++++++----------- tests/test_pymc.py | 11 +++++++++++ 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 07c5e1c..5d9a048 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -129,18 +129,20 @@ def update_user_data(user_data, user_data_storage): return np.asarray(user_data) -def make_user_data(func, shared_data): - shared_vars = func.get_shared() +def make_user_data(shared_vars, shared_data): record_dtype = np.dtype( [ ( "shared", [ - ("data", [(var.name, np.uintp) for var in shared_vars]), - ("size", [(var.name, np.uintp) for var in shared_vars]), + ("data", [(var_name, np.uintp) for var_name in shared_vars]), + ("size", [(var_name, np.uintp) for var_name in shared_vars]), ( "shape", - [(var.name, np.uint, (var.ndim,)) for var in shared_vars], + [ + (var_name, np.uint, (var.ndim,)) + for var_name, var in shared_vars.items() + ], ), ], ) @@ -192,16 +194,23 @@ def compile_pymc_model(model: "pm.Model", **kwargs) -> CompiledPyMCModel: shape_info, ) = _make_functions(model) - shared_data = {val.name: val.get_value().copy() for val in logp_fn_pt.get_shared()} + shared_data = {} + shared_vars = {} + seen = set() + for val in [*logp_fn_pt.get_shared(), *expand_fn_pt.get_shared()]: + if val.name in shared_data and val not in seen: + raise ValueError(f"Shared variables must have unique names: {val.name}") + shared_data[val.name] = val.get_value().copy() + shared_vars[val.name] = val + for val in shared_data.values(): val.flags.writeable = False - shared_logp = [var.name for var in logp_fn_pt.get_shared()] - - user_data = make_user_data(logp_fn_pt, shared_data) + user_data = make_user_data(shared_vars, shared_data) + logp_shared_names = [var.name for var in logp_fn_pt.get_shared()] logp_numba_raw, c_sig = _make_c_logp_func( - n_dim, logp_fn, user_data, shared_logp, shared_data + n_dim, logp_fn, user_data, logp_shared_names, shared_data ) with warnings.catch_warnings(): warnings.filterwarnings( @@ -212,8 +221,9 @@ def compile_pymc_model(model: "pm.Model", **kwargs) -> CompiledPyMCModel: logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw) + expand_shared_names = [var.name for var in expand_fn_pt.get_shared()] expand_numba_raw, c_sig_expand = _make_c_expand_func( - n_dim, n_expanded, expand_fn, user_data, shared_expand, shared_data + n_dim, n_expanded, expand_fn, user_data, expand_shared_names, shared_data ) with warnings.catch_warnings(): warnings.filterwarnings( diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 1f6f513..4149d45 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -94,3 +94,14 @@ def test_pymc_model_shared(): compiled3 = compiled.with_data(mu=0.5, sigma=3 * np.ones(4)) with pytest.raises(RuntimeError): nutpie.sample(compiled3, chains=1) + + +def test_missing(): + with pm.Model(coords={"obs": range(4)}) as model: + mu = pm.Normal("mu") + y = pm.Normal("y", mu, observed=[0, -1, 1, np.nan], dims="obs") + + compiled = nutpie.compile_pymc_model(model) + tr = nutpie.sample(compiled, chains=1, seed=1) + print(tr.posterior) + assert hasattr(tr.posterior, "y_unobserved")