diff --git a/MANIFEST.in b/MANIFEST.in index dde140d5..afb1ce78 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -6,8 +6,7 @@ recursive-include papermill *.keep recursive-include papermill *.txt include setup.py -include requirements.txt -include requirements-dev.txt +include requirements*.txt include tox.ini include pytest.ini include README.rst @@ -27,3 +26,5 @@ prune docs/_build prune binder # Scripts graft scripts +# Test env +prune .tox diff --git a/README.md b/README.md index 509a06c2..029e8d0f 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,13 @@ From the command line: pip install papermill ``` +For all optional io dependencies, you can specify individual bundles +like `s3`, or `azure` -- or use `all` + +``` {.sourceCode .bash} +pip install papermill[all] +``` + Installing In-Notebook bindings ------------------------------- diff --git a/papermill/exceptions.py b/papermill/exceptions.py index 38fb91fc..c212ff37 100644 --- a/papermill/exceptions.py +++ b/papermill/exceptions.py @@ -32,3 +32,17 @@ def __init__(self, exec_count, source, ename, evalue, traceback): message += "\n" super(PapermillExecutionError, self).__init__(message) + + +class PapermillOptionalDepedencyException(PapermillException): + """Raised when an exception is encountered when an optional plugin is missing.""" + + +def missing_dependency_generator(package, dep): + def missing_dep(): + raise PapermillOptionalDepedencyException( + "The {package} optional depedency is missing. " + "Please run pip install papermill[{dep}] to install this dependency" + .format(package=package, dep=dep) + ) + return missing_dep diff --git a/papermill/iorw.py b/papermill/iorw.py index 136ab06b..69c97453 100644 --- a/papermill/iorw.py +++ b/papermill/iorw.py @@ -11,17 +11,32 @@ import requests import warnings import entrypoints -import gcsfs from contextlib import contextmanager from . import __version__ -from .s3 import S3 + from .log import logger -from .adl import ADL -from .abs import AzureBlobStore + from .utils import chdir -from .exceptions import PapermillException +from .exceptions import PapermillException, missing_dependency_generator + +try: + from .s3 import S3 +except ImportError: + S3 = missing_dependency_generator("boto3", "s3") +try: + from .adl import ADL +except ImportError: + ADL = missing_dependency_generator("azure.datalake.store", "azure") +try: + from .abs import AzureBlobStore +except ImportError: + AzureBlobStore = missing_dependency_generator("azure.storage.blob", "azure") +try: + from gcsfs import GCSFileSystem +except ImportError: + GCSFileSystem = missing_dependency_generator("gcsfs", "gcs") try: FileNotFoundError @@ -170,56 +185,48 @@ def pretty_path(cls, path): class ADLHandler(object): - _client = None + def __init__(self): + self._client = None - @classmethod - def _get_client(cls): - if cls._client is None: - cls._client = ADL() - return cls._client + def _get_client(self): + if self._client is None: + self._client = ADL() + return self._client - @classmethod - def read(cls, path): - lines = cls._get_client().read(path) + def read(self, path): + lines = self._get_client().read(path) return "\n".join(lines) - @classmethod - def listdir(cls, path): - return cls._get_client().listdir(path) + def listdir(self, path): + return self._get_client().listdir(path) - @classmethod - def write(cls, buf, path): - return cls._get_client().write(buf, path) + def write(self, buf, path): + return self._get_client().write(buf, path) - @classmethod - def pretty_path(cls, path): + def pretty_path(self, path): return path class ABSHandler(object): - _client = None + def __init__(self): + self._client = None - @classmethod - def _get_client(cls): - if cls._client is None: - cls._client = AzureBlobStore() - return cls._client + def _get_client(self): + if self._client is None: + self._client = AzureBlobStore() + return self._client - @classmethod - def read(cls, path): - lines = cls._get_client().read(path) + def read(self, path): + lines = self._get_client().read(path) return "\n".join(lines) - @classmethod - def listdir(cls, path): - return cls._get_client().listdir(path) + def listdir(self, path): + return self._get_client().listdir(path) - @classmethod - def write(cls, buf, path): - return cls._get_client().write(buf, path) + def write(self, buf, path): + return self._get_client().write(buf, path) - @classmethod - def pretty_path(cls, path): + def pretty_path(self, path): return path @@ -229,7 +236,7 @@ def __init__(self): def _get_client(self): if self._client is None: - self._client = gcsfs.GCSFileSystem() + self._client = GCSFileSystem() return self._client def read(self, path): @@ -252,7 +259,7 @@ def pretty_path(self, path): papermill_io.register("local", LocalHandler()) papermill_io.register("s3://", S3Handler) papermill_io.register("adl://", ADLHandler) -papermill_io.register("abs://", ABSHandler) +papermill_io.register("abs://", ABSHandler()) papermill_io.register("http://", HttpHandler) papermill_io.register("https://", HttpHandler) papermill_io.register("gs://", GCSHandler()) diff --git a/papermill/tests/test_gcs.py b/papermill/tests/test_gcs.py index 8a9cdc6d..76721ea1 100644 --- a/papermill/tests/test_gcs.py +++ b/papermill/tests/test_gcs.py @@ -1,4 +1,4 @@ -from papermill.iorw import GCSHandler +from ..iorw import GCSHandler class MockGCSFileSystem(object): @@ -28,7 +28,7 @@ def write(self, data): def test_gcs_read(mocker): - mocker.patch('gcsfs.GCSFileSystem', MockGCSFileSystem) + mocker.patch('papermill.iorw.GCSFileSystem', MockGCSFileSystem) gcs_handler = GCSHandler() client = gcs_handler._get_client() assert gcs_handler.read('gs://bucket/test.ipynb') == 'default value' @@ -38,7 +38,7 @@ def test_gcs_read(mocker): def test_gcs_write(mocker): - mocker.patch('gcsfs.GCSFileSystem', MockGCSFileSystem) + mocker.patch('papermill.iorw.GCSFileSystem', MockGCSFileSystem) gcs_handler = GCSHandler() client = gcs_handler._get_client() gcs_handler.write('new value', 'gs://bucket/test.ipynb') @@ -48,7 +48,7 @@ def test_gcs_write(mocker): def test_gcs_listdir(mocker): - mocker.patch('gcsfs.GCSFileSystem', MockGCSFileSystem) + mocker.patch('papermill.iorw.GCSFileSystem', MockGCSFileSystem) gcs_handler = GCSHandler() client = gcs_handler._get_client() gcs_handler.listdir('testdir') diff --git a/papermill/tests/test_iorw.py b/papermill/tests/test_iorw.py index 776687b7..5f4ee094 100644 --- a/papermill/tests/test_iorw.py +++ b/papermill/tests/test_iorw.py @@ -224,26 +224,22 @@ class TestADLHandler(unittest.TestCase): """ def setUp(self): - self.old_client = ADLHandler._client - self.mock_client = Mock( + self.handler = ADLHandler() + self.handler._client = Mock( read=Mock(return_value=["foo", "bar", "baz"]), listdir=Mock(return_value=["foo", "bar", "baz"]), write=Mock(), ) - ADLHandler._client = self.mock_client - - def tearDown(self): - ADLHandler._client = self.old_client def test_read(self): - self.assertEqual(ADLHandler.read("some_path"), "foo\nbar\nbaz") + self.assertEqual(self.handler.read("some_path"), "foo\nbar\nbaz") def test_listdir(self): - self.assertEqual(ADLHandler.listdir("some_path"), ["foo", "bar", "baz"]) + self.assertEqual(self.handler.listdir("some_path"), ["foo", "bar", "baz"]) def test_write(self): - ADLHandler.write("foo", "bar") - self.mock_client.write.assert_called_once_with("foo", "bar") + self.handler.write("foo", "bar") + self.handler._client.write.assert_called_once_with("foo", "bar") class TestHttpHandler(unittest.TestCase): diff --git a/requirements-azure.txt b/requirements-azure.txt new file mode 100644 index 00000000..528e1b44 --- /dev/null +++ b/requirements-azure.txt @@ -0,0 +1,2 @@ +azure-datalake-store >= 0.0.30 +azure-storage-blob diff --git a/requirements-gcs.txt b/requirements-gcs.txt new file mode 100644 index 00000000..fc4c1081 --- /dev/null +++ b/requirements-gcs.txt @@ -0,0 +1 @@ +gcsfs diff --git a/requirements-s3.txt b/requirements-s3.txt new file mode 100644 index 00000000..30ddf823 --- /dev/null +++ b/requirements-s3.txt @@ -0,0 +1 @@ +boto3 diff --git a/requirements.txt b/requirements.txt index 6f27da7f..89c987bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ ansiwrap -boto3 click future futures ; python_version < "3.0" @@ -13,7 +12,5 @@ tqdm >= 4.29.1 jupyter_client pandas requests -azure-datalake-store >= 0.0.30 -azure-storage-blob entrypoints -gcsfs + diff --git a/setup.py b/setup.py index 334f67f7..004ffc30 100644 --- a/setup.py +++ b/setup.py @@ -46,14 +46,23 @@ def read(fname): with open(fname, 'rU' if python_2 else 'r') as fhandle: return fhandle.read() - -req_path = os.path.join(here, 'requirements.txt') -required = [req.strip() for req in read(req_path).splitlines() if req.strip()] - -test_req_path = os.path.join(here, 'requirements-dev.txt') -test_required = [req.strip() for req in read(test_req_path).splitlines() if req.strip()] -extras_require = {"test": test_required, "dev": test_required} - +def read_reqs(fname): + req_path = os.path.join(here, fname) + return [req.strip() for req in read(req_path).splitlines() if req.strip()] + +s3_reqs = read_reqs('requirements-s3.txt') +azure_reqs = read_reqs('requirements-azure.txt') +gcs_reqs = read_reqs('requirements-gcs.txt') +all_reqs = s3_reqs + azure_reqs + gcs_reqs +dev_reqs = read_reqs('requirements-dev.txt') + all_reqs +extras_require = { + "test": dev_reqs, + "dev": dev_reqs, + "all": all_reqs, + "s3": s3_reqs, + "azure": azure_reqs, + "gcs": gcs_reqs, +} # Tox has a weird issue where it can't import pip from it's virtualenv when skipping normal installs if not bool(int(os.environ.get('SKIP_PIP_CHECK', 0))): @@ -70,12 +79,8 @@ def read(fname): 'Your pip version is out of date. Papermill requires pip >= 9.0.1. \n' 'pip {} detected. Please install pip >= 9.0.1.'.format(pip.__version__) ) - except ImportError: - pip_message = ( - 'No pip detected; we were unable to import pip. \n' - 'To use papermill, please install pip >= 9.0.1.' - ) except Exception: + # We only want to optimistically report old versions pass if pip_message: @@ -100,7 +105,7 @@ def read(fname): long_description_content_type='text/markdown', url='https://github.com/nteract/papermill', packages=['papermill'], - install_requires=required, + install_requires=read_reqs('requirements.txt'), extras_require=extras_require, entry_points={'console_scripts': ['papermill = papermill.cli:papermill']}, project_urls={