Skip to content

Commit

Permalink
Added isolated depedency installation options
Browse files Browse the repository at this point in the history
  • Loading branch information
MSeal committed Feb 4, 2019
1 parent 346f13b commit 7857c8a
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 75 deletions.
5 changes: 3 additions & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ recursive-include papermill *.keep
recursive-include papermill *.txt

include setup.py
include requirements.txt
include requirements-dev.txt
include requirements*.txt
include tox.ini
include pytest.ini
include README.rst
Expand All @@ -27,3 +26,5 @@ prune docs/_build
prune binder
# Scripts
graft scripts
# Test env
prune .tox
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ From the command line:
pip install papermill
```

For all optional io dependencies, you can specify individual bundles
like `s3`, or `azure` -- or use `all`

``` {.sourceCode .bash}
pip install papermill[all]
```

Installing In-Notebook bindings
-------------------------------

Expand Down
14 changes: 14 additions & 0 deletions papermill/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,17 @@ def __init__(self, exec_count, source, ename, evalue, traceback):
message += "\n"

super(PapermillExecutionError, self).__init__(message)


class PapermillOptionalDepedencyException(PapermillException):
"""Raised when an exception is encountered when an optional plugin is missing."""


def missing_dependency_generator(package, dep):
def missing_dep():
raise PapermillOptionalDepedencyException(
"The {package} optional depedency is missing. "
"Please run pip install papermill[{dep}] to install this dependency"
.format(package=package, dep=dep)
)
return missing_dep
89 changes: 48 additions & 41 deletions papermill/iorw.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,32 @@
import requests
import warnings
import entrypoints
import gcsfs

from contextlib import contextmanager

from . import __version__
from .s3 import S3

from .log import logger
from .adl import ADL
from .abs import AzureBlobStore

from .utils import chdir
from .exceptions import PapermillException
from .exceptions import PapermillException, missing_dependency_generator

try:
from .s3 import S3
except ImportError:
S3 = missing_dependency_generator("boto3", "s3")
try:
from .adl import ADL
except ImportError:
ADL = missing_dependency_generator("azure.datalake.store", "azure")
try:
from .abs import AzureBlobStore
except ImportError:
AzureBlobStore = missing_dependency_generator("azure.storage.blob", "azure")
try:
from gcsfs import GCSFileSystem
except ImportError:
GCSFileSystem = missing_dependency_generator("gcsfs", "gcs")

try:
FileNotFoundError
Expand Down Expand Up @@ -170,56 +185,48 @@ def pretty_path(cls, path):


class ADLHandler(object):
_client = None
def __init__(self):
self._client = None

@classmethod
def _get_client(cls):
if cls._client is None:
cls._client = ADL()
return cls._client
def _get_client(self):
if self._client is None:
self._client = ADL()
return self._client

@classmethod
def read(cls, path):
lines = cls._get_client().read(path)
def read(self, path):
lines = self._get_client().read(path)
return "\n".join(lines)

@classmethod
def listdir(cls, path):
return cls._get_client().listdir(path)
def listdir(self, path):
return self._get_client().listdir(path)

@classmethod
def write(cls, buf, path):
return cls._get_client().write(buf, path)
def write(self, buf, path):
return self._get_client().write(buf, path)

@classmethod
def pretty_path(cls, path):
def pretty_path(self, path):
return path


class ABSHandler(object):
_client = None
def __init__(self):
self._client = None

@classmethod
def _get_client(cls):
if cls._client is None:
cls._client = AzureBlobStore()
return cls._client
def _get_client(self):
if self._client is None:
self._client = AzureBlobStore()
return self._client

@classmethod
def read(cls, path):
lines = cls._get_client().read(path)
def read(self, path):
lines = self._get_client().read(path)
return "\n".join(lines)

@classmethod
def listdir(cls, path):
return cls._get_client().listdir(path)
def listdir(self, path):
return self._get_client().listdir(path)

@classmethod
def write(cls, buf, path):
return cls._get_client().write(buf, path)
def write(self, buf, path):
return self._get_client().write(buf, path)

@classmethod
def pretty_path(cls, path):
def pretty_path(self, path):
return path


Expand All @@ -229,7 +236,7 @@ def __init__(self):

def _get_client(self):
if self._client is None:
self._client = gcsfs.GCSFileSystem()
self._client = GCSFileSystem()
return self._client

def read(self, path):
Expand All @@ -252,7 +259,7 @@ def pretty_path(self, path):
papermill_io.register("local", LocalHandler())
papermill_io.register("s3://", S3Handler)
papermill_io.register("adl://", ADLHandler)
papermill_io.register("abs://", ABSHandler)
papermill_io.register("abs://", ABSHandler())
papermill_io.register("http://", HttpHandler)
papermill_io.register("https://", HttpHandler)
papermill_io.register("gs://", GCSHandler())
Expand Down
8 changes: 4 additions & 4 deletions papermill/tests/test_gcs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from papermill.iorw import GCSHandler
from ..iorw import GCSHandler


class MockGCSFileSystem(object):
Expand Down Expand Up @@ -28,7 +28,7 @@ def write(self, data):


def test_gcs_read(mocker):
mocker.patch('gcsfs.GCSFileSystem', MockGCSFileSystem)
mocker.patch('papermill.iorw.GCSFileSystem', MockGCSFileSystem)
gcs_handler = GCSHandler()
client = gcs_handler._get_client()
assert gcs_handler.read('gs://bucket/test.ipynb') == 'default value'
Expand All @@ -38,7 +38,7 @@ def test_gcs_read(mocker):


def test_gcs_write(mocker):
mocker.patch('gcsfs.GCSFileSystem', MockGCSFileSystem)
mocker.patch('papermill.iorw.GCSFileSystem', MockGCSFileSystem)
gcs_handler = GCSHandler()
client = gcs_handler._get_client()
gcs_handler.write('new value', 'gs://bucket/test.ipynb')
Expand All @@ -48,7 +48,7 @@ def test_gcs_write(mocker):


def test_gcs_listdir(mocker):
mocker.patch('gcsfs.GCSFileSystem', MockGCSFileSystem)
mocker.patch('papermill.iorw.GCSFileSystem', MockGCSFileSystem)
gcs_handler = GCSHandler()
client = gcs_handler._get_client()
gcs_handler.listdir('testdir')
Expand Down
16 changes: 6 additions & 10 deletions papermill/tests/test_iorw.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,26 +224,22 @@ class TestADLHandler(unittest.TestCase):
"""

def setUp(self):
self.old_client = ADLHandler._client
self.mock_client = Mock(
self.handler = ADLHandler()
self.handler._client = Mock(
read=Mock(return_value=["foo", "bar", "baz"]),
listdir=Mock(return_value=["foo", "bar", "baz"]),
write=Mock(),
)
ADLHandler._client = self.mock_client

def tearDown(self):
ADLHandler._client = self.old_client

def test_read(self):
self.assertEqual(ADLHandler.read("some_path"), "foo\nbar\nbaz")
self.assertEqual(self.handler.read("some_path"), "foo\nbar\nbaz")

def test_listdir(self):
self.assertEqual(ADLHandler.listdir("some_path"), ["foo", "bar", "baz"])
self.assertEqual(self.handler.listdir("some_path"), ["foo", "bar", "baz"])

def test_write(self):
ADLHandler.write("foo", "bar")
self.mock_client.write.assert_called_once_with("foo", "bar")
self.handler.write("foo", "bar")
self.handler._client.write.assert_called_once_with("foo", "bar")


class TestHttpHandler(unittest.TestCase):
Expand Down
2 changes: 2 additions & 0 deletions requirements-azure.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
azure-datalake-store >= 0.0.30
azure-storage-blob
1 change: 1 addition & 0 deletions requirements-gcs.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
gcsfs
1 change: 1 addition & 0 deletions requirements-s3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
boto3
5 changes: 1 addition & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
ansiwrap
boto3
click
future
futures ; python_version < "3.0"
Expand All @@ -13,7 +12,5 @@ tqdm >= 4.29.1
jupyter_client
pandas
requests
azure-datalake-store >= 0.0.30
azure-storage-blob
entrypoints
gcsfs

33 changes: 19 additions & 14 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,23 @@ def read(fname):
with open(fname, 'rU' if python_2 else 'r') as fhandle:
return fhandle.read()


req_path = os.path.join(here, 'requirements.txt')
required = [req.strip() for req in read(req_path).splitlines() if req.strip()]

test_req_path = os.path.join(here, 'requirements-dev.txt')
test_required = [req.strip() for req in read(test_req_path).splitlines() if req.strip()]
extras_require = {"test": test_required, "dev": test_required}

def read_reqs(fname):
req_path = os.path.join(here, fname)
return [req.strip() for req in read(req_path).splitlines() if req.strip()]

s3_reqs = read_reqs('requirements-s3.txt')
azure_reqs = read_reqs('requirements-azure.txt')
gcs_reqs = read_reqs('requirements-gcs.txt')
all_reqs = s3_reqs + azure_reqs + gcs_reqs
dev_reqs = read_reqs('requirements-dev.txt') + all_reqs
extras_require = {
"test": dev_reqs,
"dev": dev_reqs,
"all": all_reqs,
"s3": s3_reqs,
"azure": azure_reqs,
"gcs": gcs_reqs,
}

# Tox has a weird issue where it can't import pip from it's virtualenv when skipping normal installs
if not bool(int(os.environ.get('SKIP_PIP_CHECK', 0))):
Expand All @@ -70,12 +79,8 @@ def read(fname):
'Your pip version is out of date. Papermill requires pip >= 9.0.1. \n'
'pip {} detected. Please install pip >= 9.0.1.'.format(pip.__version__)
)
except ImportError:
pip_message = (
'No pip detected; we were unable to import pip. \n'
'To use papermill, please install pip >= 9.0.1.'
)
except Exception:
# We only want to optimistically report old versions
pass

if pip_message:
Expand All @@ -100,7 +105,7 @@ def read(fname):
long_description_content_type='text/markdown',
url='https://github.com/nteract/papermill',
packages=['papermill'],
install_requires=required,
install_requires=read_reqs('requirements.txt'),
extras_require=extras_require,
entry_points={'console_scripts': ['papermill = papermill.cli:papermill']},
project_urls={
Expand Down

0 comments on commit 7857c8a

Please sign in to comment.