Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Support pickling lru_cache on CPython #309

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from __future__ import print_function

import dis
from functools import partial
import functools
import io
import itertools
import logging
Expand Down Expand Up @@ -1067,6 +1067,28 @@ def save_mappingproxy(self, obj):

dispatch[types.MappingProxyType] = save_mappingproxy

# In CPython, functions decorated with functools.lru_cache are actually
# instances of a non-serializable built-in type. We pickle them by pickling
# the underlying function, along with the size of the lru cache. We do
# **not** attempt to pickle the contents of the function's cache.
if hasattr(functools, 'lru_cache'): # pragma: no branch
_lru_cache_instance = functools.lru_cache()(lambda: None)

# PyPy's lru_cache returns a regular function object that closes over
# the cache state. We can't easily treat this specially beacuse
# Pickle's dispatching is purely type-based.
if not isinstance(_lru_cache_instance, types.FunctionType):
# Assume CPython native LRU Cache.
def save_lru_cached_function(self, obj):
self.save_reduce(
_rebuild_lru_cached_function,
(obj.cache_info().maxsize, obj.__wrapped__),
obj=obj,
)

dispatch[type(_lru_cache_instance)] = save_lru_cached_function
del _lru_cache_instance # Remove from class namespace.

"""Special functions for Add-on libraries"""
def inject_addons(self):
"""Plug in system. Register additional pickling functions if modules already loaded"""
Expand Down Expand Up @@ -1395,3 +1417,11 @@ def _is_dynamic(module):
except ImportError:
return True
return False


def _rebuild_lru_cached_function(maxsize, func):
"""Reconstruct a function that was decorated with functools.lru_cache.

The rebuilt function will have an empty cache.
"""
return functools.lru_cache(maxsize)(func)
70 changes: 62 additions & 8 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@

_TEST_GLOBAL_VARIABLE = "default_value"

IS_PYPY = platform.python_implementation() == 'PyPy'


class RaiserOnPickle(object):

Expand Down Expand Up @@ -368,7 +370,7 @@ def test_partial(self):
partial_clone = pickle_depickle(partial_obj, protocol=self.protocol)
self.assertEqual(partial_clone(4), 1)

@pytest.mark.skipif(platform.python_implementation() == 'PyPy',
@pytest.mark.skipif(
reason="Skip numpy and scipy tests on PyPy")
def test_ufunc(self):
# test a numpy ufunc (universal function), which is a C-based function
Expand Down Expand Up @@ -476,7 +478,7 @@ def method(self, x):
self.assertEqual(mod.f(5), mod2.f(5))
self.assertEqual(mod.Foo().method(5), mod2.Foo().method(5))

if platform.python_implementation() != 'PyPy':
if not IS_PYPY:
# XXX: this fails with excessive recursion on PyPy.
mod3 = subprocess_pickle_echo(mod, protocol=self.protocol)
self.assertEqual(mod.x, mod3.x)
Expand Down Expand Up @@ -639,7 +641,7 @@ def test_is_dynamic_module(self):
dynamic_module = types.ModuleType('dynamic_module')
assert _is_dynamic(dynamic_module)

if platform.python_implementation() == 'PyPy':
if IS_PYPY:
import _codecs
assert not _is_dynamic(_codecs)

Expand Down Expand Up @@ -674,8 +676,7 @@ def test_builtin_function(self):
# builtin function from a "regular" module
assert pickle_depickle(mkdir, protocol=self.protocol) is mkdir

@pytest.mark.skipif(platform.python_implementation() == 'PyPy' and
sys.version_info[:2] == (3, 5),
@pytest.mark.skipif(IS_PYPY and sys.version_info[:2] == (3, 5),
reason="bug of pypy3.5 in builtin-type constructors")
def test_builtin_type_constructor(self):
# Due to a bug in pypy3.5, cloudpickling builtin-type constructors
Expand Down Expand Up @@ -750,7 +751,7 @@ def test_builtin_classmethod(self):
# Roundtripping a classmethod_descriptor results in a
# builtin_function_or_method (CPython upstream issue).
assert depickled_clsdict_meth(arg) == clsdict_clsmethod(float, arg)
if platform.python_implementation() == 'PyPy':
if IS_PYPY:
# builtin-classmethods are simple classmethod in PyPy (not
# callable). We test equality of types and the functionality of the
# __func__ attribute instead. We do not test the the identity of
Expand Down Expand Up @@ -781,7 +782,7 @@ def test_builtin_slotmethod(self):
assert depickled_clsdict_meth is clsdict_slotmethod

@pytest.mark.skipif(
platform.python_implementation() == "PyPy" or
IS_PYPY or
sys.version_info[:1] < (3,),
reason="No known staticmethod example in the python 2 / pypy stdlib")
def test_builtin_staticmethod(self):
Expand Down Expand Up @@ -1499,7 +1500,7 @@ class A:
""".format(protocol=self.protocol)
assert_run_python_script(code)

@pytest.mark.skipif(platform.python_implementation() == 'PyPy',
@pytest.mark.skipif(IS_PYPY,
reason="Skip PyPy because memory grows too much")
def test_interactive_remote_function_calls_no_memory_leak(self):
code = """if __name__ == "__main__":
Expand Down Expand Up @@ -1876,6 +1877,59 @@ def __getattr__(self, name):
with pytest.raises(pickle.PicklingError, match='recursion'):
cloudpickle.dumps(a)

@unittest.skipIf(IS_PYPY or not hasattr(functools, "lru_cache"),
"Old versions of Python do not have lru_cache. "
"PyPy's lru_cache is a regular function.")
def test_pickle_lru_cached_function(self):

for maxsize in None, 1, 2:

@functools.lru_cache(maxsize=maxsize)
def func(x, y):
return x + y

# Populate original function's cache.
func(1, 2)

new_func = pickle_depickle(func, protocol=self.protocol)
assert type(new_func) == type(func)

# We don't attempt to pickle the original function's cache, so the
# new function should have an empty cache.
self._expect_cache_info(
new_func.cache_info(),
hits=0,
misses=0,
maxsize=maxsize,
currsize=0,
)

assert new_func(1, 2) == 3

self._expect_cache_info(
new_func.cache_info(),
hits=0,
misses=1,
maxsize=maxsize,
currsize=1,
)

assert new_func(1, 2) == 3

self._expect_cache_info(
new_func.cache_info(),
hits=1,
misses=1,
maxsize=maxsize,
currsize=1,
)

def _expect_cache_info(self, cache_info, hits, misses, maxsize, currsize):
assert cache_info.hits == hits
assert cache_info.misses == misses
assert cache_info.maxsize == maxsize
assert cache_info.currsize == currsize


class Protocol2CloudPickleTest(CloudPickleTest):

Expand Down