Skip to content

Commit

Permalink
Merged in cbillington/labscript_utils/feature (pull request labscript…
Browse files Browse the repository at this point in the history
…-suite#19)

Utilities for test suites, PY2 constant, minor version bump, smarter check_version()
  • Loading branch information
chrisjbillington committed Aug 15, 2017
2 parents 8ebe212 + b651268 commit 5b93953
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 7 deletions.
59 changes: 52 additions & 7 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
# #
#####################################################################

__version__ = '2.3.1'
__version__ = '2.4.0'


import sys
import os
import traceback

PY2 = sys.version_info[0] == 2

for path in sys.path:
if os.path.exists(os.path.join(path, '.is_labscript_suite_install_dir')):
Expand All @@ -28,17 +31,59 @@
class VersionException(Exception):
pass

def _get_version(module_name):
"""return the version string module.__version__ by importing the module,
and the exc_info for the exception (if any) raised during import, or None
if there was no exception. If the version string is defined prior to the
exception during import, then it will still be returned. Otherwise None
will be returned in its place. This can be useful since having
incompatible versions of packages can itself be the cause of exceptions
during import, so it is preferable to raise a 'wrong version' in addition
to, or instead of the exception that was raised during import"""

def check_version(module_name, at_least, less_than, version=None):
from labscript_utils.brute_import import brute_import

try:
module = __import__(module_name)
exc_info = None
except Exception:
exc_info = sys.exc_info()
# brute_import returns the exception, but if for some reason it's
# different we should return the one we got at the first atttempted
# import:
module, _ = brute_import(module_name)
return getattr(module, '__version__'), exc_info


def _reraise(exc_info):
type, value, traceback = exc_info
# handle python2/3 difference in raising exception
if PY2:
exec('raise type, value, traceback', globals(), locals())
else:
raise value.with_traceback(traceback)


def check_version(module_name, at_least, less_than, version=None):
from distutils.version import LooseVersion

if version is None:
version = __import__(module_name).__version__
version, exc_info = _get_version(module_name)

if version is not None:
at_least_version, less_than_version, installed_version = [LooseVersion(v) for v in [at_least, less_than, version]]
if not at_least_version <= installed_version < less_than_version:
msg = '{module_name} {version} found. {at_least} <= {module_name} < {less_than} required.'.format(**locals())
if exc_info is not None:
msg += '\n\n === In addition, the below exception was raised during import of {}: ===\n\n'.format(module_name)
msg += ''.join(traceback.format_exception(*exc_info))
raise VersionException(msg)

# Correct version string, but failed import:
if exc_info is not None:
_reraise(exc_info)

# Successful import but no version string:
if version is None:
raise ValueError('Invalid version string from package {}: {}'.format(module_name, version))
at_least_version, less_than_version, installed_version = [LooseVersion(v) for v in [at_least, less_than, version]]
if not at_least_version <= installed_version < less_than_version:
raise VersionException(
'{module_name} {version} found. {at_least} <= {module_name} < {less_than} required.'.format(**locals()))

73 changes: 73 additions & 0 deletions brute_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#####################################################################
# #
# brute_import.py #
# #
# Copyright 2017, Chris Billington #
# #
# This file is part of the labscript suite (see #
# http://labscriptsuite.org) and is licensed under the Simplified #
# BSD License. See the license.txt file in the root of the project #
# for the full license. #
# #
#####################################################################


import sys
import os
import types
import imp
import marshal


def _fallback(module_name):
# No module code to execute? Just import the usual way then and return an
# empty module upon exception:
try:
module = __import__(module_name)
return module, None
except Exception:
module = types.ModuleType(module_name)
return module, sys.exc_info()


def brute_import(module_name):
"""Execute a module as if it were being imported, catch exceptions, and
return the (possibly only partially initialised) module object as well as
the exc_info for the exception (or None if there was no exception). This
is useful for say, inspecting the __version__ string of a module that is
failing to import in order to raise a potentially more useful exception if
the module is failing to import *because* it is the wrong version."""

sourcefile, pathname, (_, _, module_type) = imp.find_module(module_name)
module = types.ModuleType(module_name)
sys.modules[module_name] = module

if module_type in [imp.PY_SOURCE, imp.PY_COMPILED]:
module.__file__ = pathname
elif module_type == imp.PKG_DIRECTORY:
module.__path__ = [pathname]
module.__file__ = os.path.join(pathname, '__init__.py')
sourcefile = open(module.__file__)
else:
return _fallback(module_name)

if module_type in [imp.PY_SOURCE, imp.PKG_DIRECTORY]:
code = compile(sourcefile.read(), module.__file__, 'exec', dont_inherit=True)
elif module_type == imp.PY_COMPILED:
if sourcefile.read(4) != imp.get_magic():
# Different python version, we can't execute:
return _fallback(module_name)
# skip timestamp:
_ = sourcefile.read(4)
code = marshal.load(sourcefile)
else:
# Some C extension or something. No code for us to execute.
return _fallback(module_name)

try:
# Execute the module code in its namespace:
exec(code, module.__dict__)
return module, None
except Exception:
exc_info = sys.exc_info()
return module, sys.exc_info()
160 changes: 160 additions & 0 deletions testing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
#####################################################################
# #
# testing_utils.py #
# #
# Copyright 2017, Chris Billington #
# #
# This file is part of the labscript suite (see #
# http://labscriptsuite.org) and is licensed under the Simplified #
# BSD License. See the license.txt file in the root of the project #
# for the full license. #
# #
#####################################################################


from __future__ import print_function
import os
import sys
import time
import threading
import unittest
from labscript_utils import PY2
if PY2:
import Queue as queue
import mock
else:
import queue
import unittest.mock as mock

from unittest import TestCase


class monkeypatch(object):
"""Context manager to temporarily monkeypatch an object attribute with
some mocked attribute"""

def __init__(self, obj, name, mocked_attr):
self.obj = obj
self.name = name
self.real_attr = getattr(obj, name)
self.mocked_attr = mocked_attr

def __enter__(self):
setattr(self.obj, self.name, self.mocked_attr)

def __exit__(self, *args):
setattr(self.obj, self.name, self.real_attr)


class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__


class Any(object):
"""A class whose instances equal any object of the given type or tuple of
types. For use with mock.Mock.assert_called_with when you don't care what
some of the arguments are"""
def __init__(self, types=object):
if isinstance(types, type):
self.types = (types,)
else:
self.types = types

def __eq__(self, other):
return any(isinstance(other, type_) for type_ in self.types)

# Instance of Any() that does not specify type:
ANY = Any()


class ThreadTestCase(TestCase):
"""Test case that runs tests in a new thread whilst providing a mainloop
that allows running scripts in the current thread. Those scripts can then
be tested from the testing thread."""

def __init__(self, *args, **kwargs):
TestCase.__init__(self, *args, **kwargs)
self._thread_return_value = queue.Queue()
self._command_queue = queue.Queue()

def run_script_as_main(self, filepath):
globals_dict = dotdict()
self._command_queue.put([filepath, globals_dict])
return globals_dict

def quit_mainloop(self):
self._command_queue.put([None, None])

def _run(self, *args, **kwargs):
"""Called in a thread to run the tests"""
exception = None
try:
print('about to run')
result = TestCase.run(self, *args, **kwargs)
except:
print('got exception')
self.quit_mainloop()
# Store for re-raising the exception in the calling thread:
exception = sys.exc_info()
result = None
finally:
self._thread_return_value.put([result, exception])

def run(self, *args, **kwargs):
test_thread = threading.Thread(target=self._run, args=args, kwargs=kwargs)
test_thread.start()
self._mainloop()
test_thread.join()
result, exception = self._thread_return_value.get()
if exception is not None:
type, value, traceback = exception
if PY2:
exec('raise type, value, traceback')
else:
raise value.with_traceback(traceback)
return result

def _mainloop(self):
while True:
filepath, globals_dict = self._command_queue.get()
if filepath is None:
break

if PY2:
filepath_native_string = filepath.encode(sys.getfilesystemencoding())
else:
filepath_native_string = filepath

globals_dict['__name__'] ='__main__'
globals_dict['__file__']= os.path.basename(filepath_native_string)

# Save the current working directory before changing it to the
# location of the script:
cwd = os.getcwd()
os.chdir(os.path.dirname(filepath))

# Run the script:
try:
with open(filepath) as f:
code = compile(f.read(), os.path.basename(filepath_native_string),
'exec', dont_inherit=True)
exec(code, globals_dict)
finally:
os.chdir(cwd)

@staticmethod
def wait_for(condition_func, timeout=5,
initial_poll_interval=0.005, max_poll_interval=0.5):
"""Busy wait for a condition to be true. Uses exponential backoff so
it's fast when things are fast and not a complete hog when they're
not"""
poll_interval = initial_poll_interval
start_time = time.time()
while not condition_func():
if time.time() - start_time > timeout:
raise Exception
time.sleep(poll_interval)
poll_interval = min(2*poll_interval, max_poll_interval)

0 comments on commit 5b93953

Please sign in to comment.