Skip to content

Commit

Permalink
Wrap function arguments with pm.Data if they support it.
Browse files Browse the repository at this point in the history
  • Loading branch information
zaxtax committed Jan 15, 2025
1 parent dcc353c commit 66e9d7f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
20 changes: 18 additions & 2 deletions pymc_extras/model/model_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from functools import wraps
from inspect import signature

from pymc import Model
import pytensor.tensor as pt

from pymc import Data, Model


def as_model(*model_args, **model_kwargs):
Expand All @@ -9,6 +12,8 @@ def as_model(*model_args, **model_kwargs):
This removes all need to think about context managers and lets you separate creating a generative model from using the model.
Additionally, a coords argument is added to the function so coords can be changed during function invocation
All parameters are wrapped with a `pm.Data` object if the underlying type of the data supports it.
Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.
Examples
Expand Down Expand Up @@ -47,8 +52,19 @@ def decorator(f):
@wraps(f)
def make_model(*args, **kwargs):
coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {})
sig = signature(f)
ba = sig.bind(*args, **kwargs)
ba.apply_defaults()

with Model(*model_args, coords=coords, **model_kwargs) as m:
f(*args, **kwargs)
for name, v in ba.arguments.items():
# Only wrap pm.Data around values pytensor can process
try:
_ = pt.as_tensor_variable(v)
ba.arguments[name] = Data(name, v)
except (NotImplementedError, TypeError, ValueError):
pass
f(*ba.args, **ba.kwargs)
return m

return make_model
Expand Down
9 changes: 9 additions & 0 deletions tests/model/test_model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,14 @@ def model_wrapped2():

mw2 = model_wrapped2(coords=coords)

@pmx.as_model()
def model_wrapped3(mu):
pm.Normal("x", mu, 1.0, dims="obs")

mw3 = model_wrapped3(0.0, coords=coords)
mw4 = model_wrapped3(np.array([np.nan]), coords=coords)

np.testing.assert_equal(model.point_logps(), mw.point_logps())
np.testing.assert_equal(mw.point_logps(), mw2.point_logps())
assert mw3["mu"] in mw3.data_vars
assert "mu" not in mw4

0 comments on commit 66e9d7f

Please sign in to comment.