Skip to content

Commit

Permalink
Add logic to handle conditional nodes for observed variables
Browse files Browse the repository at this point in the history
  • Loading branch information
zaxtax committed Jan 19, 2025
1 parent e895a5c commit d813da1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
8 changes: 6 additions & 2 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,13 @@ def draw(
return [np.stack(v) for v in drawn_values]


def observed_dependent_deterministics(model: Model):
def observed_dependent_deterministics(model: Model, extra_observeds=None):
"""Find deterministics that depend directly on observed variables."""
if extra_observeds is None:
extra_observeds = []

deterministics = model.deterministics
observed_rvs = set(model.observed_RVs)
observed_rvs = set(model.observed_RVs + extra_observeds)
blockers = model.basic_RVs
return [
deterministic
Expand Down Expand Up @@ -821,6 +824,7 @@ def sample_posterior_predictive(
vars_ = model.observed_RVs + observed_dependent_deterministics(model)
if observed_data is not None:
vars_ += [model[x] for x in observed_data if x in model and x not in vars_]
vars_ += observed_dependent_deterministics(model, vars_)

vars_to_sample = list(get_default_varnames(vars_, include_transformed=False))

Expand Down
5 changes: 2 additions & 3 deletions tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,6 @@ def test_external_trace(self):
)
assert list(ppc.keys()) == ["a"]

@pytest.mark.xfail(reason="Auto-imputation of variables not supported in this setting")
def test_external_trace_det(self):
nchains = 2
ndraws = 500
Expand All @@ -578,12 +577,12 @@ def test_external_trace_det(self):
with pm.Model() as model_ppc:
mu = pm.Normal("mu", 0.0, 1.0)
a = pm.Normal("a", mu=mu, sigma=1)
b = pm.Deterministic("b", a + 1)
c = pm.Deterministic("c", a + 1)

ppc = pm.sample_posterior_predictive(
trace=trace, model=model_ppc, return_inferencedata=False
)
assert list(ppc.keys()) == ["a", "b"]
assert list(ppc.keys()) == ["a", "c"]

def test_normal_vector(self):
with pm.Model() as model:
Expand Down

0 comments on commit d813da1

Please sign in to comment.