Skip to content

Commit

Permalink
Merge pull request #3652 from rpgoldman/fix-context-stack
Browse files Browse the repository at this point in the history
Fix context stack
  • Loading branch information
lucianopaz authored Nov 27, 2019
2 parents 9c4b740 + 55e6f59 commit ed55be2
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 68 deletions.
3 changes: 2 additions & 1 deletion pymc3/data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict, List, Any
from copy import copy
import io
import os
Expand Down Expand Up @@ -232,7 +233,7 @@ class Minibatch(tt.TensorVariable):
>>> assert x.eval().shape == (2, 20, 20, 40, 10)
"""

RNG = collections.defaultdict(list)
RNG = collections.defaultdict(list) # type: Dict[str, List[Any]]

@theano.configparser.change_flags(compute_test_value='raise')
def __init__(self, data, batch_size=128, dtype=None, broadcastable=None, name='Minibatch',
Expand Down
17 changes: 4 additions & 13 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..memoize import memoize
from ..model import (
Model, get_named_nodes_and_relations, FreeRV,
ObservedRV, MultiObservedRV, Context, InitContextMeta
ObservedRV, MultiObservedRV, ContextMeta
)
from ..vartypes import string_types, theano_constant
from .shape_utils import (
Expand Down Expand Up @@ -449,23 +449,14 @@ def random(self, point=None, size=None, **kwargs):
"Define a custom random method and pass it as kwarg random")


class _DrawValuesContext(Context, metaclass=InitContextMeta):
class _DrawValuesContext(metaclass=ContextMeta, context_class='_DrawValuesContext'):
""" A context manager class used while drawing values with draw_values
"""

def __new__(cls, *args, **kwargs):
# resolves the parent instance
instance = super().__new__(cls)
if cls.get_contexts():
potential_parent = cls.get_contexts()[-1]
# We have to make sure that the context is a _DrawValuesContext
# and not a Model
if isinstance(potential_parent, _DrawValuesContext):
instance._parent = potential_parent
else:
instance._parent = None
else:
instance._parent = None
instance._parent = cls.get_context(error_if_none=False)
return instance

def __init__(self):
Expand All @@ -485,7 +476,7 @@ def parent(self):
return self._parent


class _DrawValuesContextBlocker(_DrawValuesContext, metaclass=InitContextMeta):
class _DrawValuesContextBlocker(_DrawValuesContext):
"""
Context manager that starts a new drawn variables context disregarding all
parent contexts. This can be used inside a random method to ensure that
Expand Down
186 changes: 134 additions & 52 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import itertools
import threading
import warnings
from typing import Optional
from typing import Optional, TypeVar, Type, List, Union, TYPE_CHECKING, Any, cast
from sys import modules

import numpy as np
from pandas import Series
Expand Down Expand Up @@ -55,10 +56,10 @@ def __call__(self, *args, **kwargs):
return getattr(self.obj, self.method_name)(*args, **kwargs)


def incorporate_methods(source, destination, methods, default=None,
def incorporate_methods(source, destination, methods,
wrapper=None, override=False):
"""
Add attributes to a destination object which points to
Add attributes to a destination object which point to
methods from from a source object.
Parameters
Expand All @@ -69,8 +70,6 @@ def incorporate_methods(source, destination, methods, default=None,
The destination object for the methods.
methods : list of str
Names of methods to incorporate.
default : object
The value used if the source does not have one of the listed methods.
wrapper : function
An optional function to allow the source method to be
wrapped. Should take the form my_wrapper(source, method_name)
Expand Down Expand Up @@ -162,49 +161,131 @@ def _get_named_nodes_and_relations(graph, parent, leaf_nodes,
node_children.update(temp_tree)
return leaf_nodes, node_parents, node_children

T = TypeVar('T', bound='ContextMeta')

class Context:

class ContextMeta(type):
"""Functionality for objects that put themselves in a context using
the `with` statement.
"""
contexts = threading.local()

def __enter__(self):
type(self).get_contexts().append(self)
# self._theano_config is set in Model.__new__
if hasattr(self, '_theano_config'):
self._old_theano_config = set_theano_conf(self._theano_config)
return self

def __exit__(self, typ, value, traceback):
type(self).get_contexts().pop()
# self._theano_config is set in Model.__new__
if hasattr(self, '_old_theano_config'):
set_theano_conf(self._old_theano_config)

@classmethod
def get_contexts(cls):
# no race-condition here, cls.contexts is a thread-local object
def __new__(cls, name, bases, dct, **kargs): # pylint: disable=unused-argument
"Add __enter__ and __exit__ methods to the class."
def __enter__(self):
self.__class__.context_class.get_contexts().append(self)
# self._theano_config is set in Model.__new__
if hasattr(self, '_theano_config'):
self._old_theano_config = set_theano_conf(self._theano_config)
return self

def __exit__(self, typ, value, traceback): # pylint: disable=unused-argument
self.__class__.context_class.get_contexts().pop()
# self._theano_config is set in Model.__new__
if hasattr(self, '_old_theano_config'):
set_theano_conf(self._old_theano_config)

dct[__enter__.__name__] = __enter__
dct[__exit__.__name__] = __exit__

# We strip off keyword args, per the warning from
# StackExchange:
# DO NOT send "**kargs" to "type.__new__". It won't catch them and
# you'll get a "TypeError: type() takes 1 or 3 arguments" exception.
return super().__new__(cls, name, bases, dct)

# FIXME: is there a more elegant way to automatically add methods to the class that
# are instance methods instead of class methods?
def __init__(cls, name, bases, nmspc, context_class: Optional[Type]=None, **kwargs): # pylint: disable=unused-argument
"""Add ``__enter__`` and ``__exit__`` methods to the new class automatically."""
if context_class is not None:
cls._context_class = context_class
super().__init__(name, bases, nmspc)



def get_context(cls, error_if_none=True) -> Optional[T]:
"""Return the most recently pushed context object of type ``cls``
on the stack, or ``None``. If ``error_if_none`` is True (default),
raise a ``TypeError`` instead of returning ``None``."""
idx = -1
while True:
try:
candidate = cls.get_contexts()[idx] # type: Optional[T]
except IndexError as e:
# Calling code expects to get a TypeError if the entity
# is unfound, and there's too much to fix.
if error_if_none:
raise TypeError("No %s on context stack"%str(cls))
return None
return candidate
idx = idx - 1

def get_contexts(cls) -> List[T]:
"""Return a stack of context instances for the ``context_class``
of ``cls``."""
# This lazily creates the context class's contexts
# thread-local object, as needed. This seems inelegant to me,
# but since the context class is not guaranteed to exist when
# the metaclass is being instantiated, I couldn't figure out a
# better way. [2019/10/11:rpg]

# no race-condition here, contexts is a thread-local object
# be sure not to override contexts in a subclass however!
if not hasattr(cls.contexts, 'stack'):
cls.contexts.stack = []
return cls.contexts.stack

@classmethod
def get_context(cls):
"""Return the deepest context on the stack."""
try:
return cls.get_contexts()[-1]
except IndexError:
raise TypeError("No context on context stack")
context_class = cls.context_class
assert isinstance(context_class, type), \
"Name of context class, %s was not resolvable to a class"%context_class
if not hasattr(context_class, 'contexts'):
context_class.contexts = threading.local()

contexts = context_class.contexts

if not hasattr(contexts, 'stack'):
contexts.stack = []
return contexts.stack

# the following complex property accessor is necessary because the
# context_class may not have been created at the point it is
# specified, so the context_class may be a class *name* rather
# than a class.
@property
def context_class(cls) -> Type:
def resolve_type(c: Union[Type, str]) -> Type:
if isinstance(c, str):
c = getattr(modules[cls.__module__], c)
if isinstance(c, type):
return c
raise ValueError("Cannot resolve context class %s"%c)
assert cls is not None
if isinstance(cls._context_class, str):
cls._context_class = resolve_type(cls._context_class)
if not isinstance(cls._context_class, (str, type)):
raise ValueError("Context class for %s, %s, is not of the right type"%\
(cls.__name__, cls._context_class))
return cls._context_class

# Inherit context class from parent
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.context_class = super().context_class

# Initialize object in its own context...
# Merged from InitContextMeta in the original.
def __call__(cls, *args, **kwargs):
instance = cls.__new__(cls, *args, **kwargs)
with instance: # appends context
instance.__init__(*args, **kwargs)
return instance


def modelcontext(model: Optional['Model']) -> 'Model':
"""return the given model or try to find it in the context if there was
none supplied.
"""
Return the given model or, if none was supplied, try to find one in
the context stack.
"""
if model is None:
return Model.get_context()
model = Model.get_context(error_if_none=False)
if model is None:
raise ValueError("No model on context stack.")
return model


Expand Down Expand Up @@ -292,15 +373,6 @@ def logp_nojact(self):
return logp


class InitContextMeta(type):
"""Metaclass that executes `__init__` of instance in it's context"""
def __call__(cls, *args, **kwargs):
instance = cls.__new__(cls, *args, **kwargs)
with instance: # appends context
instance.__init__(*args, **kwargs)
return instance


def withparent(meth):
"""Helper wrapper that passes calls to parent's instance"""
def wrapped(self, *args, **kwargs):
Expand Down Expand Up @@ -346,11 +418,18 @@ def __setitem__(self, key, value):
' able to determine '
'appropriate logic for it')

def __imul__(self, other):
# Added this because mypy didn't like having __imul__ without __mul__
# This is my best guess about what this should do. I might be happier
# to kill both of these if they are not used.
def __mul__ (self, other) -> 'treelist':
return cast('treelist', list.__mul__(self, other))

def __imul__(self, other) -> 'treelist':
t0 = len(self)
list.__imul__(self, other)
if self.parent is not None:
self.parent.extend(self[t0:])
return self # python spec says should return the result.


class treedict(dict):
Expand Down Expand Up @@ -555,7 +634,7 @@ def _build_joined(self, cost, args, vmap):
return args_joined, theano.clone(cost, replace=replace)


class Model(Context, Factor, WithMemoization, metaclass=InitContextMeta):
class Model(Factor, WithMemoization, metaclass=ContextMeta, context_class='Model'):
"""Encapsulates the variables and likelihood factors of a model.
Model class can be used for creating class based models. To create
Expand Down Expand Up @@ -643,15 +722,18 @@ def __init__(self, mean=0, sigma=1, name='', model=None):
CustomModel(mean=1, name='first')
CustomModel(mean=2, name='second')
"""

if TYPE_CHECKING:
def __enter__(self: 'Model') -> 'Model': ...
def __exit__(self: 'Model', *exc: Any) -> bool: ...

def __new__(cls, *args, **kwargs):
# resolves the parent instance
instance = super().__new__(cls)
if kwargs.get('model') is not None:
instance._parent = kwargs.get('model')
elif cls.get_contexts():
instance._parent = cls.get_contexts()[-1]
else:
instance._parent = None
instance._parent = cls.get_context(error_if_none=False)
theano_config = kwargs.get('theano_config', None)
if theano_config is None or 'compute_test_value' not in theano_config:
theano_config = {'compute_test_value': 'raise'}
Expand Down Expand Up @@ -694,7 +776,7 @@ def root(self):
def isroot(self):
return self.parent is None

@property
@property # type: ignore -- mypy can't handle decorated types.
@memoize(bound=True)
def bijection(self):
vars = inputvars(self.vars)
Expand Down
2 changes: 1 addition & 1 deletion pymc3/tests/test_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_sample_after_set_data(self):
atol=1e-1)

def test_creation_of_data_outside_model_context(self):
with pytest.raises(TypeError) as error:
with pytest.raises((IndexError, TypeError)) as error:
pm.Data('data', [1.1, 2.2, 3.3])
error.match('No model on context stack')

Expand Down
7 changes: 6 additions & 1 deletion pymc3/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,15 @@ def test_setattr_properly_works(self):

def test_context_passes_vars_to_parent_model(self):
with pm.Model() as model:
assert pm.model.modelcontext(None) == model
assert pm.Model.get_context() == model
# a set of variables is created
NewModel()
nm = NewModel()
assert pm.Model.get_context() == model
# another set of variables are created but with prefix 'another'
usermodel2 = NewModel(name='another')
assert pm.Model.get_context() == model
assert usermodel2._parent == model
# you can enter in a context with submodel
with usermodel2:
usermodel2.Var('v3', pm.Normal.dist())
Expand Down
Loading

0 comments on commit ed55be2

Please sign in to comment.