diff --git a/lhotse/serialization.py b/lhotse/serialization.py index c63bf1927..50d9bcf8f 100644 --- a/lhotse/serialization.py +++ b/lhotse/serialization.py @@ -13,7 +13,7 @@ import yaml from packaging.version import parse as parse_version -from lhotse.utils import Pathlike, Pipe, SmartOpen, is_module_available, is_valid_url +from lhotse.utils import Pathlike, Pipe, SmartOpen, is_module_available, is_valid_url, replace_bucket_with_profile_name from lhotse.workarounds import gzip_open_robust # TODO: figure out how to use some sort of typing stubs @@ -815,6 +815,82 @@ def handles_special_case(self, identifier: Pathlike) -> bool: def is_applicable(self, identifier: Pathlike) -> bool: return is_valid_url(identifier) + +@lru_cache(1) +def get_lhotse_msc_override_protocols() -> Any: + return os.getenv("LHOTSE_MSC_OVERRIDE_PROTOCOLS", None) + + +@lru_cache(1) +def get_lhotse_msc_profile() -> Any: + return os.getenv("LHOTSE_MSC_PROFILE", None) + + +@lru_cache(1) +def get_lhotse_io_backend() -> Any: + return os.getenv("LHOTSE_IO_BACKEND", None) + + +MSC_PREFIX = "msc" + +class MSCIOBackend(IOBackend): + """ + Uses multi-storage client to download data from object store + """ + + def open(self, identifier: str, mode: str): + """ + Convert identifier if is not prefixed with msc, and use msc.open to access the file + For paths that are prefixed with msc, e.g. msc://profile/path/to/my/object1 + + For paths are yet to migrate to msc-compatible url, e.g. protocol://bucket/path/to/my/object2 + 1. override protocols provided by env LHOTSE_MSC_OVERRIDE_PROTOCOLS to msc: msc://bucket/path/to/my/object2 + 2. override the profile/bucket name by env LHOTSE_MSC_PROFILE if provided: msc://profile/path/to/my/object2, + if bucket name is not provided, then we expect the msc profile name to match with bucket name + """ + + import multistorageclient as msc + + # if url prefixed with msc, then return early + if identifier.startswith(f"{MSC_PREFIX}://"): + return msc.open(identifier, mode) + + # override protocol if provided + lhotse_msc_override_protocols = get_lhotse_msc_override_protocols() + if lhotse_msc_override_protocols: + if "," in lhotse_msc_override_protocols: + override_protocol_list = lhotse_msc_override_protocols.split(",") + else: + override_protocol_list = [lhotse_msc_override_protocols] + for override_protocol in override_protocol_list: + if identifier.startswith(override_protocol): + identifier = identifier.replace(override_protocol, MSC_PREFIX) + break + + # override bucket if provided + lhotse_msc_profile = get_lhotse_msc_profile() + if lhotse_msc_profile: + identifier = replace_bucket_with_profile_name(identifier, lhotse_msc_profile) + + try: + file = msc.open(identifier, mode) + except Exception as e: + print(f"exception: {e}, identifier: {identifier}") + raise e + + return file + + + @classmethod + def is_available(cls) -> bool: + return is_module_available("multistorageclient") + + def handles_special_case(self, identifier: Pathlike) -> bool: + return str(identifier).startswith(f"{MSC_PREFIX}://") + + def is_applicable(self, identifier: Pathlike) -> bool: + return is_valid_url(identifier) + class CompositeIOBackend(IOBackend): """ @@ -938,6 +1014,8 @@ def get_default_io_backend() -> "IOBackend": RedirectIOBackend(), PipeIOBackend(), ] + if MSCIOBackend.is_available(): + backends.append(MSCIOBackend()) if AIStoreIOBackend.is_available(): # Try AIStore before other generalist backends, # but only if it's installed and enabled via AIS_ENDPOINT env var. diff --git a/lhotse/utils.py b/lhotse/utils.py index 2a44a5a04..489debd78 100644 --- a/lhotse/utils.py +++ b/lhotse/utils.py @@ -32,7 +32,7 @@ TypeVar, Union, ) -from urllib.parse import urlparse +from urllib.parse import urlparse, urlunparse import click import numpy as np @@ -1119,3 +1119,9 @@ def build_rng(seed: Union[int, Literal["trng"]]) -> random.Random: def is_dill_enabled() -> bool: return _LHOTSE_DILL_ENABLED or os.environ["LHOTSE_DILL_ENABLED"] + + +def replace_bucket_with_profile_name(identifier, profile_name): + parsed_identifier = urlparse(identifier) + updated_identifier = parsed_identifier._replace(netloc=profile_name) + return urlunparse(updated_identifier) diff --git a/test/test_serialization.py b/test/test_serialization.py index 6d79c2fdd..1943c3195 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -1,5 +1,7 @@ import os from tempfile import NamedTemporaryFile +import sys +import types import pytest @@ -18,7 +20,12 @@ store_manifest, ) from lhotse.lazy import LazyJsonlIterator -from lhotse.serialization import SequentialJsonlWriter, load_manifest_lazy, open_best +from lhotse.serialization import ( + MSCIOBackend, + SequentialJsonlWriter, + load_manifest_lazy, + open_best, +) from lhotse.supervision import AlignmentItem from lhotse.testing.dummies import DummyManifest from lhotse.utils import fastcopy @@ -516,3 +523,87 @@ def test_open_pipe_iter(tmp_path): lines_read.append(l.strip()) assert lines_read == lines + + +@pytest.fixture +def clear_msc_env_caches(): + # Clear caches before each test + from lhotse.serialization import get_lhotse_msc_profile, get_lhotse_msc_override_protocols + get_lhotse_msc_profile.cache_clear() + get_lhotse_msc_override_protocols.cache_clear() + yield + +@pytest.mark.parametrize( + "identifier,expected_output,lhotse_msc_profile", + [ + ("msc://profile/path/to/object", "msc://profile/path/to/object", "profile"), # No change for msc:// prefix + ("s3://bucket/path/to/object", "msc://bucket/path/to/object", ""), # Override only protocol + ("s3://bucket/path", "msc://profile/path", "profile"), # Override protocol and bucket + ], +) +def test_msc_io_backend_url_conversion(monkeypatch, clear_msc_env_caches, identifier, expected_output, lhotse_msc_profile): + # Mock environment variables + monkeypatch.setenv("LHOTSE_MSC_OVERRIDE_PROTOCOLS", "s3") + if lhotse_msc_profile: + monkeypatch.setenv("LHOTSE_MSC_PROFILE", lhotse_msc_profile) + + # Mock multistorageclient.open to capture the transformed URL + class MockMSC: + def open(self, url, mode): + assert url == expected_output + return None + + sys.modules["multistorageclient"] = MockMSC() + sys.modules["multistorageclient"].__spec__ = None + + # Create backend and test URL transformation + backend = MSCIOBackend() + backend.open(identifier, mode="r") + + +@pytest.mark.parametrize( + "protocols", + [ + "s3", # Single protocol + "s3,gs", # Multiple protocols + ], +) +def test_msc_io_backend_multiple_protocols(monkeypatch, clear_msc_env_caches, protocols): + + # Mock environment variables + monkeypatch.setenv("LHOTSE_MSC_OVERRIDE_PROTOCOLS", protocols) + + # Mock multistorageclient.open to capture the transformed URL + class MockMSC: + def open(self, url, mode): + assert url.startswith("msc://") + return None + + sys.modules["multistorageclient"] = MockMSC() + sys.modules["multistorageclient"].__spec__ = None + + # Create backend and test URL transformation + backend = MSCIOBackend() + + # Test with first protocol + backend.open("s3://bucket/path", mode="r") + + if "," in protocols: + # Test with second protocol if multiple + backend.open("gs://bucket/path", mode="r") + + +def test_msc_io_backend_availability(monkeypatch): + from lhotse.serialization import MSCIOBackend + + # Test when multistorageclient is not available + monkeypatch.setitem(sys.modules, "multistorageclient", None) + assert not MSCIOBackend.is_available() + + # Test when multistorageclient is available + class MockMSC: + pass + mock_module = MockMSC() + mock_module.__spec__ = types.SimpleNamespace(name="multistorageclient") + monkeypatch.setitem(sys.modules, "multistorageclient", mock_module) + assert MSCIOBackend.is_available()