diff --git a/src/pytroll_watchers/local_watcher.py b/src/pytroll_watchers/local_watcher.py index b6de7f4..80e0426 100644 --- a/src/pytroll_watchers/local_watcher.py +++ b/src/pytroll_watchers/local_watcher.py @@ -8,10 +8,13 @@ from urllib.parse import urlunparse from upath import UPath +from upath._flavour import WrappedFileSystemFlavour from pytroll_watchers.backends.local import listen_to_local_events from pytroll_watchers.publisher import SecurityError, file_publisher_from_generator, parse_metadata +# This is a workaround for a but in universal_pathlib, see +WrappedFileSystemFlavour.protocol_config["netloc_is_anchor"].add("ssh") logger = logging.getLogger(__name__) @@ -72,7 +75,7 @@ def file_generator(directory, observer_type="os", file_pattern=None, protocol=No except ValueError: continue if protocol is not None: - uri = urlunparse((protocol, None, path, None, None, None)) + uri = urlunparse((protocol, None, str(path), None, None, None)) yield UPath(uri, **storage_options), file_metadata else: yield Path(path), file_metadata diff --git a/tests/test_main_interface.py b/tests/test_main_interface.py index 144a7f0..7da22c5 100644 --- a/tests/test_main_interface.py +++ b/tests/test_main_interface.py @@ -1,9 +1,12 @@ """Tests for the gathered publisher functions.""" +import json import logging +from unittest import mock import pytest import yaml +from posttroll.message import Message from posttroll.testing import patched_publisher from pytroll_watchers.local_watcher import file_generator as local_generator from pytroll_watchers.local_watcher import file_publisher as local_publisher @@ -16,6 +19,7 @@ from pytroll_watchers.minio_notification_watcher import file_generator as minio_generator from pytroll_watchers.minio_notification_watcher import file_publisher as minio_publisher from pytroll_watchers.testing import patched_bucket_listener, patched_local_events # noqa +from upath import UPath def test_getting_right_publisher(): @@ -54,6 +58,33 @@ def test_pass_config_to_file_publisher_for_local_backend(tmp_path, patched_local assert str(filename) in msgs[0] +def test_pass_config_to_file_publisher_for_local_backend_with_protocol(tmp_path, patched_local_events, monkeypatch): # noqa + """Test passing a config to create a file publisher from a local fs with protocol.""" + new_fs = mock.Mock() + host = "myhost.pytroll.org" + fs = dict(cls="fsspec.implementations.sftp.SFTPFileSystem", + protocol="sftp", + args=[], + host=host) + new_fs.to_json.return_value = json.dumps(fs) + monkeypatch.setattr(UPath, "fs", new_fs) + local_settings = dict(directory=tmp_path, protocol="ssh", storage_options=dict(host=host)) + publisher_settings = dict(nameservers=False, port=1979) + message_settings = dict(subject="/segment/viirs/l1b/", atype="file", data=dict(sensor="viirs")) + config = dict(backend="local", + fs_config=local_settings, + publisher_config=publisher_settings, + message_config=message_settings) + with patched_publisher() as msgs: + filename = tmp_path / "bla" + with patched_local_events([filename]): + publish_from_config(config) + assert len(msgs) == 1 + msg = Message.decode(msgs[0]) + assert msg.data["path"] == str(filename) + + + def test_pass_config_to_object_publisher_for_minio_backend(patched_bucket_listener): # noqa """Test passing a config to create an objec publisher from minio bucket.""" s3_settings = dict(endpoint_url="someendpoint",