Skip to content

Commit

Permalink
Update RVTransform
Browse files Browse the repository at this point in the history
  • Loading branch information
HasnainRaz committed Nov 24, 2023
1 parent 79c4dc1 commit 1d001bf
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions pymc_experimental/utils/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
import numpy as np
import pymc as pm
import pytensor.tensor as pt
from pymc.logprob.transforms import RVTransform
from pymc.logprob.transforms import Transform


class ParamCfg(TypedDict):
name: str
transform: Optional[RVTransform]
transform: Optional[Transform]
dims: Optional[Union[str, Tuple[str]]]


Expand All @@ -44,14 +44,14 @@ class FlatInfo(TypedDict):
info: List[VarInfo]


def _arg_to_param_cfg(key, value: Optional[Union[ParamCfg, RVTransform, str, Tuple]] = None):
def _arg_to_param_cfg(key, value: Optional[Union[ParamCfg, Transform, str, Tuple]] = None):
if value is None:
cfg = ParamCfg(name=key, transform=None, dims=None)
elif isinstance(value, Tuple):
cfg = ParamCfg(name=key, transform=None, dims=value)
elif isinstance(value, str):
cfg = ParamCfg(name=value, transform=None, dims=None)
elif isinstance(value, RVTransform):
elif isinstance(value, Transform):
cfg = ParamCfg(name=key, transform=value, dims=None)
else:
cfg = value.copy()
Expand All @@ -62,7 +62,7 @@ def _arg_to_param_cfg(key, value: Optional[Union[ParamCfg, RVTransform, str, Tup


def _parse_args(
var_names: Sequence[str], **kwargs: Union[ParamCfg, RVTransform, str, Tuple]
var_names: Sequence[str], **kwargs: Union[ParamCfg, Transform, str, Tuple]
) -> Dict[str, ParamCfg]:
results = dict()
for var in var_names:
Expand Down Expand Up @@ -133,7 +133,7 @@ def prior_from_idata(
name="trace_prior_",
*,
var_names: Sequence[str] = (),
**kwargs: Union[ParamCfg, RVTransform, str, Tuple]
**kwargs: Union[ParamCfg, Transform, str, Tuple]
) -> Dict[str, pt.TensorVariable]:
"""
Create a prior from posterior using MvNormal approximation.
Expand All @@ -153,7 +153,7 @@ def prior_from_idata(
Inference data with posterior group
var_names: Sequence[str]
names of variables to take as is from the posterior
kwargs: Union[ParamCfg, RVTransform, str, Tuple]
kwargs: Union[ParamCfg, Transform, str, Tuple]
names of variables with additional configuration, see more in Examples
Examples
Expand Down

0 comments on commit 1d001bf

Please sign in to comment.