Skip to content

Commit

Permalink
Remove redundant test
Browse files Browse the repository at this point in the history
  • Loading branch information
zaxtax committed Jan 20, 2025
1 parent d813da1 commit 2fcf395
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 30 deletions.
8 changes: 5 additions & 3 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,10 +821,12 @@ def sample_posterior_predictive(
if var_names is not None:
vars_ = [model[x] for x in var_names]
else:
vars_ = model.observed_RVs + observed_dependent_deterministics(model)
observed_vars = model.observed_RVs
if observed_data is not None:
vars_ += [model[x] for x in observed_data if x in model and x not in vars_]
vars_ += observed_dependent_deterministics(model, vars_)
observed_vars += [
model[x] for x in observed_data if x in model and x not in observed_vars
]
vars_ = observed_vars + observed_dependent_deterministics(model, observed_vars)

vars_to_sample = list(get_default_varnames(vars_, include_transformed=False))

Expand Down
28 changes: 1 addition & 27 deletions tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,38 +540,12 @@ def test_normal_scalar_idata(self):
ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False)
assert ppc["a"].shape == (nchains, ndraws)

def test_external_trace(self):
nchains = 2
ndraws = 500
with pm.Model() as model:
mu = pm.Normal("mu", 0.0, 1.0)
a = pm.Normal("a", mu=mu, sigma=1, observed=0.0)
trace = pm.sample(
draws=ndraws,
chains=nchains,
)

# test that trace is used in ppc
with pm.Model() as model_ppc:
mu = pm.Normal("mu", 0.0, 1.0)
a = pm.Normal("a", mu=mu, sigma=1)

ppc = pm.sample_posterior_predictive(
trace=trace, model=model_ppc, return_inferencedata=False
)
assert list(ppc.keys()) == ["a"]

def test_external_trace_det(self):
nchains = 2
ndraws = 500
with pm.Model() as model:
mu = pm.Normal("mu", 0.0, 1.0)
a = pm.Normal("a", mu=mu, sigma=1, observed=0.0)
b = pm.Deterministic("b", a + 1)
trace = pm.sample(
draws=ndraws,
chains=nchains,
)
trace = pm.sample(tune=50, draws=50, chains=1, compute_convergence_checks=False)

# test that trace is used in ppc
with pm.Model() as model_ppc:
Expand Down

0 comments on commit 2fcf395

Please sign in to comment.