Skip to content

Commit

Permalink
Wrap function arguments with pm.Data if they support it.
Browse files Browse the repository at this point in the history
  • Loading branch information
zaxtax committed Jan 12, 2025
1 parent dcc353c commit 7cbbffb
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions pymc_extras/model/model_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from functools import wraps
from inspect import signature

from pymc import Model
import pytensor.tensor as pt

from pymc import Data, Model


def as_model(*model_args, **model_kwargs):
Expand All @@ -9,6 +12,8 @@ def as_model(*model_args, **model_kwargs):
This removes all need to think about context managers and lets you separate creating a generative model from using the model.
Additionally, a coords argument is added to the function so coords can be changed during function invocation
All parameters are wrapped with a `pm.Data` object if the underlying type of the data supports it.
Adapted from `Rob Zinkov's blog post <https://www.zinkov.com/posts/2023-alternative-frontends-pymc/>`_ and inspired by the `sampled <https://github.com/colcarroll/sampled>`_ decorator for PyMC3.
Examples
Expand Down Expand Up @@ -47,8 +52,19 @@ def decorator(f):
@wraps(f)
def make_model(*args, **kwargs):
coords = model_kwargs.pop("coords", {}) | kwargs.pop("coords", {})
sig = signature(f)
ba = sig.bind(*args, **kwargs)
ba.apply_defaults()

with Model(*model_args, coords=coords, **model_kwargs) as m:
f(*args, **kwargs)
for name, v in ba.arguments.items():
# Only wrap pm.Data around values pytensor can process
try:
_ = pt.as_tensor_variable(v)
ba.arguments[name] = Data(name, v)
except (NotImplementedError, TypeError, ValueError):
pass
f(*ba.args, **ba.kwargs)
return m

return make_model
Expand Down

0 comments on commit 7cbbffb

Please sign in to comment.