From 1d001bf0ab36417ebdf87567cd72fad6209d749e Mon Sep 17 00:00:00 2001 From: "hasnain3257@gmail.com" Date: Fri, 24 Nov 2023 12:33:35 +0100 Subject: [PATCH] Update RVTransform --- pymc_experimental/utils/prior.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pymc_experimental/utils/prior.py b/pymc_experimental/utils/prior.py index 962b01bc0..30d4e9507 100644 --- a/pymc_experimental/utils/prior.py +++ b/pymc_experimental/utils/prior.py @@ -19,12 +19,12 @@ import numpy as np import pymc as pm import pytensor.tensor as pt -from pymc.logprob.transforms import RVTransform +from pymc.logprob.transforms import Transform class ParamCfg(TypedDict): name: str - transform: Optional[RVTransform] + transform: Optional[Transform] dims: Optional[Union[str, Tuple[str]]] @@ -44,14 +44,14 @@ class FlatInfo(TypedDict): info: List[VarInfo] -def _arg_to_param_cfg(key, value: Optional[Union[ParamCfg, RVTransform, str, Tuple]] = None): +def _arg_to_param_cfg(key, value: Optional[Union[ParamCfg, Transform, str, Tuple]] = None): if value is None: cfg = ParamCfg(name=key, transform=None, dims=None) elif isinstance(value, Tuple): cfg = ParamCfg(name=key, transform=None, dims=value) elif isinstance(value, str): cfg = ParamCfg(name=value, transform=None, dims=None) - elif isinstance(value, RVTransform): + elif isinstance(value, Transform): cfg = ParamCfg(name=key, transform=value, dims=None) else: cfg = value.copy() @@ -62,7 +62,7 @@ def _arg_to_param_cfg(key, value: Optional[Union[ParamCfg, RVTransform, str, Tup def _parse_args( - var_names: Sequence[str], **kwargs: Union[ParamCfg, RVTransform, str, Tuple] + var_names: Sequence[str], **kwargs: Union[ParamCfg, Transform, str, Tuple] ) -> Dict[str, ParamCfg]: results = dict() for var in var_names: @@ -133,7 +133,7 @@ def prior_from_idata( name="trace_prior_", *, var_names: Sequence[str] = (), - **kwargs: Union[ParamCfg, RVTransform, str, Tuple] + **kwargs: Union[ParamCfg, Transform, str, Tuple] ) -> Dict[str, pt.TensorVariable]: """ Create a prior from posterior using MvNormal approximation. @@ -153,7 +153,7 @@ def prior_from_idata( Inference data with posterior group var_names: Sequence[str] names of variables to take as is from the posterior - kwargs: Union[ParamCfg, RVTransform, str, Tuple] + kwargs: Union[ParamCfg, Transform, str, Tuple] names of variables with additional configuration, see more in Examples Examples