Skip to content

Commit

Permalink
Fix BaseTrace and MultiTrace typing; remove add_values/`remove_…
Browse files Browse the repository at this point in the history
…values`

The `add_values`/`remove_values` methods accessed attributes
that are not present on `BaseTrace` but only on `NDarray`,
therefore violating the signature.
  • Loading branch information
michaelosthege authored and twiecki committed Jan 14, 2023
1 parent 6c4d4eb commit 6ab0c03
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 112 deletions.
129 changes: 36 additions & 93 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
import warnings

from abc import ABC
from typing import List, Sequence, Tuple, cast
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union, cast

import numpy as np
import pytensor.tensor as at

from pymc.backends.report import SamplerReport
from pymc.model import modelcontext
Expand Down Expand Up @@ -210,18 +209,18 @@ def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
"""Get sampler statistics."""
raise NotImplementedError()

def _slice(self, idx):
def _slice(self, idx: Union[int, slice]):
"""Slice trace object."""
raise NotImplementedError()

def point(self, idx):
def point(self, idx: int) -> Dict[str, np.ndarray]:
"""Return dictionary of point values at `idx` for current chain
with variables names as keys.
"""
raise NotImplementedError()

@property
def stat_names(self):
def stat_names(self) -> Set[str]:
names = set()
for vars in self.sampler_vars or []:
names.update(vars.keys())
Expand Down Expand Up @@ -280,12 +279,10 @@ class MultiTrace:
List of variable names in the trace(s)
"""

def __init__(self, straces):
self._straces = {}
for strace in straces:
if strace.chain in self._straces:
raise ValueError("Chains are not unique.")
self._straces[strace.chain] = strace
def __init__(self, straces: Sequence[BaseTrace]):
if len({t.chain for t in straces}) != len(straces):
raise ValueError("Chains are not unique.")
self._straces = {t.chain: t for t in straces}

self._report = SamplerReport()

Expand All @@ -294,15 +291,15 @@ def __repr__(self):
return template.format(self.__class__.__name__, self.nchains, len(self), len(self.varnames))

@property
def nchains(self):
def nchains(self) -> int:
return len(self._straces)

@property
def chains(self):
def chains(self) -> List[int]:
return list(sorted(self._straces.keys()))

@property
def report(self):
def report(self) -> SamplerReport:
return self._report

def __iter__(self):
Expand Down Expand Up @@ -367,12 +364,12 @@ def __len__(self):
return len(self._straces[chain])

@property
def varnames(self):
def varnames(self) -> List[str]:
chain = self.chains[-1]
return self._straces[chain].varnames

@property
def stat_names(self):
def stat_names(self) -> Set[str]:
if not self._straces:
return set()
sampler_vars = [s.sampler_vars for s in self._straces.values()]
Expand All @@ -386,74 +383,15 @@ def stat_names(self):
names.update(vars.keys())
return names

def add_values(self, vals, overwrite=False) -> None:
"""Add variables to traces.
Parameters
----------
vals: dict (str: array-like)
The keys should be the names of the new variables. The values are expected to be
array-like objects. For traces with more than one chain the length of each value
should match the number of total samples already in the trace `(chains * iterations)`,
otherwise a warning is raised.
overwrite: bool
If `False` (default) a ValueError is raised if the variable already exists.
Change to `True` to overwrite the values of variables
Returns
-------
None.
"""
for k, v in vals.items():
new_var = 1
if k in self.varnames:
if overwrite:
self.varnames.remove(k)
new_var = 0
else:
raise ValueError(f"Variable name {k} already exists.")

self.varnames.append(k)

chains = self._straces
l_samples = len(self) * len(self.chains)
l_v = len(v)
if l_v != l_samples:
warnings.warn(
"The length of the values you are trying to "
"add ({}) does not match the number ({}) of "
"total samples in the trace "
"(chains * iterations)".format(l_v, l_samples)
)

v = np.squeeze(v.reshape(len(chains), len(self), -1))

for idx, chain in enumerate(chains.values()):
if new_var:
dummy = at.as_tensor_variable([], k)
chain.vars.append(dummy)
chain.samples[k] = v[idx]

def remove_values(self, name):
"""remove variables from traces.
Parameters
----------
name: str
Name of the variable to remove. Raises KeyError if the variable is not present
"""
varnames = self.varnames
if name not in varnames:
raise KeyError(f"Unknown variable {name}")
self.varnames.remove(name)
chains = self._straces
for chain in chains.values():
for va in chain.vars:
if va.name == name:
chain.vars.remove(va)
del chain.samples[name]

def get_values(self, varname, burn=0, thin=1, combine=True, chains=None, squeeze=True):
def get_values(
self,
varname: str,
burn: int = 0,
thin: int = 1,
combine: bool = True,
chains: Optional[Union[int, Sequence[int]]] = None,
squeeze: bool = True,
) -> List[np.ndarray]:
"""Get values from traces.
Parameters
Expand All @@ -479,13 +417,20 @@ def get_values(self, varname, burn=0, thin=1, combine=True, chains=None, squeeze
if chains is None:
chains = self.chains
varname = get_var_name(varname)
try:
results = [self._straces[chain].get_values(varname, burn, thin) for chain in chains]
except TypeError: # Single chain passed.
results = [self._straces[chains].get_values(varname, burn, thin)]
if isinstance(chains, int):
chains = [chains]
results = [self._straces[chain].get_values(varname, burn, thin) for chain in chains]
return _squeeze_cat(results, combine, squeeze)

def get_sampler_stats(self, stat_name, burn=0, thin=1, combine=True, chains=None, squeeze=True):
def get_sampler_stats(
self,
stat_name: str,
burn: int = 0,
thin: int = 1,
combine: bool = True,
chains: Optional[Union[int, Sequence[int]]] = None,
squeeze: bool = True,
):
"""Get sampler statistics from the trace.
Parameters
Expand All @@ -508,9 +453,7 @@ def get_sampler_stats(self, stat_name, burn=0, thin=1, combine=True, chains=None

if chains is None:
chains = self.chains
try:
chains = iter(chains)
except TypeError:
if isinstance(chains, int):
chains = [chains]

results = [
Expand All @@ -526,7 +469,7 @@ def _slice(self, slice):
trace._report = self._report._slice(*idxs)
return trace

def point(self, idx, chain=None):
def point(self, idx: int, chain: Optional[int] = None) -> Dict[str, np.ndarray]:
"""Return a dictionary of point values at `idx`.
Parameters
Expand Down
19 changes: 0 additions & 19 deletions pymc/tests/backends/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,25 +124,6 @@ def test_multitrace_nonunique(self):
base.MultiTrace([self.strace0, self.strace1])


class TestMultiTrace_add_remove_values(bf.ModelBackendSampledTestCase):
name = None
backend = ndarray.NDArray
shape = ()

def test_add_values(self):
mtrace = self.mtrace
orig_varnames = list(mtrace.varnames)
name = "new_var"
vals = mtrace[orig_varnames[0]]
mtrace.add_values({name: vals})
assert len(orig_varnames) == len(mtrace.varnames) - 1
assert name in mtrace.varnames
assert np.all(mtrace[orig_varnames[0]] == mtrace[name])
mtrace.remove_values(name)
assert len(orig_varnames) == len(mtrace.varnames)
assert name not in mtrace.varnames


class TestSqueezeCat:
def setup_method(self):
self.x = np.arange(10)
Expand Down

0 comments on commit 6ab0c03

Please sign in to comment.