diff --git a/.gitignore b/.gitignore index 0669572a..39e393db 100644 --- a/.gitignore +++ b/.gitignore @@ -136,5 +136,6 @@ rever/ # VS Code .vscode/ -# Generated requirements.txt +# Generated requirements.txt and uv lock file requirements.txt +uv.lock diff --git a/docs/api/hub.external_auth_client.md b/docs/api/hub.external_client.md similarity index 67% rename from docs/api/hub.external_auth_client.md rename to docs/api/hub.external_client.md index 5b233685..edde82ba 100644 --- a/docs/api/hub.external_auth_client.md +++ b/docs/api/hub.external_client.md @@ -1,4 +1,4 @@ -::: polaris.hub.external_auth_client.ExternalAuthClient +::: polaris.hub.external_client.ExternalAuthClient options: merge_init_into_class: true filters: ["!create_authorization_url", "!fetch_token"] diff --git a/docs/api/hub.storage.md b/docs/api/hub.storage.md new file mode 100644 index 00000000..5ef63a0c --- /dev/null +++ b/docs/api/hub.storage.md @@ -0,0 +1,10 @@ +::: polaris.hub.storage.StorageSession + options: + merge_init_into_class: true + +--- + +::: polaris.hub.storage.S3Store + options: + merge_init_into_class: true +--- diff --git a/env.yml b/env.yml index 0f3f0b1a..52ae6ae5 100644 --- a/env.yml +++ b/env.yml @@ -12,6 +12,9 @@ dependencies: - pydantic-settings >=2 - fsspec - yaspin + - typing-extensions >=4.12.0 + - boto3 >=1.35.0 + # Hub client - authlib @@ -45,6 +48,7 @@ dependencies: - ruff - jupyterlab - ipywidgets + - moto >=5.0.0 # Doc - mkdocs @@ -58,4 +62,4 @@ dependencies: - mike >=1.0.0 - pip: - - fastpdb + - fastpdb diff --git a/mkdocs.yml b/mkdocs.yml index 5f00a4eb..0fd09511 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -39,7 +39,7 @@ nav: - Competiton Evaluation: api/competition.evaluation.md - Hub: - Client: api/hub.client.md - - External Auth Client: api/hub.external_auth_client.md + - External Auth Client: api/hub.external_client.md - PolarisFileSystem: api/hub.polarisfs.md - Additional: - Dataset Factory: api/factory.md @@ -122,9 +122,9 @@ plugins: - mkdocs-jupyter: execute: False remove_tag_config: - remove_cell_tags: [remove_cell] - remove_all_outputs_tags: [remove_output] - remove_input_tags: [remove_input] + remove_cell_tags: [ remove_cell ] + remove_all_outputs_tags: [ remove_output ] + remove_input_tags: [ remove_input ] - mike: version_selector: true diff --git a/polaris/_artifact.py b/polaris/_artifact.py index 609e4fcd..2a108f75 100644 --- a/polaris/_artifact.py +++ b/polaris/_artifact.py @@ -1,5 +1,5 @@ import json -from typing import Dict, Optional, Union +from typing import ClassVar import fsspec from loguru import logger @@ -13,10 +13,11 @@ field_validator, ) from pydantic.alias_generators import to_camel +from typing_extensions import Self -import polaris as po -from polaris.utils.misc import sluggify -from polaris.utils.types import HubOwner, SlugCompatibleStringType +import polaris +from polaris.utils.misc import slugify +from polaris.utils.types import ArtifactUrn, HubOwner, SlugCompatibleStringType, SlugStringType class BaseArtifactModel(BaseModel): @@ -29,30 +30,47 @@ class BaseArtifactModel(BaseModel): Only when uploading to the Hub, some of the attributes are required. Attributes: - name: A slug-compatible name for the dataset. - Together with the owner, this is used by the Hub to uniquely identify the benchmark. - description: A beginner-friendly, short description of the dataset. - tags: A list of tags to categorize the benchmark by. This is used by the hub to search over benchmarks. + name: A slug-compatible name for the artifact. + Together with the owner, this is used by the Hub to uniquely identify the artifact. + description: A beginner-friendly, short description of the artifact. + tags: A list of tags to categorize the artifact by. This is used by the hub to search over artifacts. user_attributes: A dict with additional, textual user attributes. - owner: A slug-compatible name for the owner of the dataset. - If the dataset comes from the Polaris Hub, this is the associated owner (organization or user). - Together with the name, this is used by the Hub to uniquely identify the benchmark. + owner: A slug-compatible name for the owner of the artifact. + If the artifact comes from the Polaris Hub, this is the associated owner (organization or user). + Together with the name, this is used by the Hub to uniquely identify the artifact. polaris_version: The version of the Polaris library that was used to create the artifact. """ + _artifact_type: ClassVar[str] + model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True, arbitrary_types_allowed=True) - name: Optional[SlugCompatibleStringType] = None + # Model attributes + name: SlugCompatibleStringType | None = None description: str = "" tags: list[str] = Field(default_factory=list) - user_attributes: Dict[str, str] = Field(default_factory=dict) - owner: Optional[HubOwner] = None - polaris_version: str = po.__version__ + user_attributes: dict[str, str] = Field(default_factory=dict) + owner: HubOwner | None = None + polaris_version: str = polaris.__version__ @computed_field @property - def artifact_id(self) -> Optional[str]: - return f"{self.owner}/{sluggify(self.name)}" if self.owner and self.name else None + def slug(self) -> SlugStringType | None: + return slugify(self.name) if self.name else None + + @computed_field + @property + def artifact_id(self) -> str | None: + if self.owner and self.slug: + return f"{self.owner}/{self.slug}" + return None + + @computed_field + @property + def urn(self) -> ArtifactUrn | None: + if self.owner and self.slug: + return self.urn_for(self.owner, self.slug) + return None @field_validator("polaris_version") @classmethod @@ -61,7 +79,7 @@ def _validate_version(cls, value: str) -> str: # Make sure it is a valid semantic version Version(value) - current_version = po.__version__ + current_version = polaris.__version__ if value != current_version: logger.info( f"The version of Polaris that was used to create the artifact ({value}) is different " @@ -71,31 +89,35 @@ def _validate_version(cls, value: str) -> str: @field_validator("owner", mode="before") @classmethod - def _validate_owner(cls, value: Union[str, HubOwner, None]): + def _validate_owner(cls, value: str | HubOwner | None): if isinstance(value, str): return HubOwner(slug=value) return value @field_serializer("owner") - def _serialize_owner(self, value: HubOwner) -> Union[str, None]: + def _serialize_owner(self, value: HubOwner) -> str | None: return value.slug if value else None @classmethod - def from_json(cls, path: str): - """Loads a benchmark from a JSON file. + def from_json(cls, path: str) -> Self: + """Loads an artifact from a JSON file. Args: - path: Loads a benchmark specification from a JSON file. + path: Path to a JSON file containing the artifact definition. """ with fsspec.open(path, "r") as f: data = json.load(f) - return cls.model_validate(data) + return cls.model_validate(data) - def to_json(self, path: str): - """Saves the benchmark to a JSON file. + def to_json(self, path: str) -> None: + """Saves an artifact to a JSON file. Args: - path: Saves the benchmark specification to a JSON file. + path: Path to save the artifact definition as JSON. """ with fsspec.open(path, "w") as f: - json.dump(self.model_dump(), f) + f.write(self.model_dump_json()) + + @classmethod + def urn_for(cls, owner: str | HubOwner, name: str) -> ArtifactUrn: + return f"urn:polaris:{cls._artifact_type}:{owner}:{slugify(name)}" diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index b71c91c4..8ac70385 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -1,7 +1,7 @@ import json from hashlib import md5 from itertools import chain -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TypeAlias, Union import fsspec import numpy as np @@ -36,7 +36,7 @@ TaskType, ) -ColumnsType = Union[str, list[str]] +ColumnsType: TypeAlias = str | list[str] class BenchmarkSpecification(BaseArtifactModel, ChecksumMixin): @@ -95,6 +95,8 @@ class BenchmarkSpecification(BaseArtifactModel, ChecksumMixin): For additional meta-data attributes, see the [`BaseArtifactModel`][polaris._artifact.BaseArtifactModel] class. """ + _artifact_type = "benchmark" + # Public attributes # Data dataset: Union[DatasetV1, CompetitionDataset, str, dict[str, Any]] diff --git a/polaris/competition/_competition.py b/polaris/competition/_competition.py index b7d6e444..38f4fa8f 100644 --- a/polaris/competition/_competition.py +++ b/polaris/competition/_competition.py @@ -2,7 +2,7 @@ from typing import Optional from polaris.benchmark import BenchmarkSpecification -from polaris.evaluate._results import CompetitionPredictions +from polaris.evaluate import CompetitionPredictions from polaris.hub.settings import PolarisHubSettings from polaris.utils.types import HubOwner @@ -18,6 +18,8 @@ class CompetitionSpecification(BenchmarkSpecification): end_time: The time at which the competition ends and is no longer interactable. """ + _artifact_type = "competition" + # Additional properties specific to Competitions owner: HubOwner start_time: datetime | None = None diff --git a/polaris/dataset/__init__.py b/polaris/dataset/__init__.py index e82249cc..d7c2eac4 100644 --- a/polaris/dataset/__init__.py +++ b/polaris/dataset/__init__.py @@ -1,7 +1,6 @@ from polaris.dataset._column import ColumnAnnotation, KnownContentType, Modality from polaris.dataset._competition_dataset import CompetitionDataset -from polaris.dataset._dataset import DatasetV1 -from polaris.dataset._dataset import DatasetV1 as Dataset +from polaris.dataset._dataset import DatasetV1, DatasetV1 as Dataset from polaris.dataset._factory import DatasetFactory, create_dataset_from_file, create_dataset_from_files from polaris.dataset._subset import Subset diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index 72753b0e..e3ae59b0 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -1,7 +1,7 @@ import abc import json from pathlib import Path -from typing import Any, Dict, MutableMapping, Optional, Union +from typing import Any, MutableMapping import numpy as np import zarr @@ -21,11 +21,11 @@ from polaris.dataset._column import ColumnAnnotation from polaris.dataset.zarr import MemoryMappedDirectoryStore from polaris.dataset.zarr._utils import load_zarr_group_to_memory -from polaris.hub.polarisfs import PolarisFileSystem from polaris.utils.dict2html import dict2html from polaris.utils.errors import InvalidDatasetError from polaris.utils.types import ( AccessType, + ChecksumStrategy, DatasetIndex, HttpUrlString, HubOwner, @@ -58,7 +58,7 @@ class BaseDataset(BaseArtifactModel, abc.ABC): source: The data source, e.g. a DOI, Github repo or URI. license: The dataset license. Polaris only supports some Creative Commons licenses. See [`SupportedLicenseType`][polaris.utils.types.SupportedLicenseType] for accepted ID values. curation_reference: A reference to the curation process, e.g. a DOI, Github repo or URI. - cache_dir: Where the dataset would be cached if you call the `cache()` method. + _cache_dir: Where the dataset would be cached if you call the `cache()` method. For additional meta-data attributes, see the [`BaseArtifactModel`][polaris._artifact.BaseArtifactModel] class. Raises: @@ -67,24 +67,22 @@ class BaseDataset(BaseArtifactModel, abc.ABC): # Public attributes # Data - default_adapters: Dict[str, Adapter] = Field(default_factory=dict) - zarr_root_path: Optional[str] = None + default_adapters: dict[str, Adapter] = Field(default_factory=dict) + zarr_root_path: str | None = None # Additional meta-data readme: str = "" - annotations: Dict[str, ColumnAnnotation] = Field(default_factory=dict) - source: Optional[HttpUrlString] = None - license: Optional[SupportedLicenseType] = None - curation_reference: Optional[HttpUrlString] = None - - # Config - cache_dir: Optional[Path] = None + annotations: dict[str, ColumnAnnotation] = Field(default_factory=dict) + source: HttpUrlString | None = None + license: SupportedLicenseType | None = None + curation_reference: HttpUrlString | None = None # Private attributes - _zarr_root: Optional[zarr.Group] = PrivateAttr(None) - _zarr_data: Optional[MutableMapping[str, np.ndarray]] = PrivateAttr(None) - _client = PrivateAttr(None) # Optional[PolarisHubClient] + _zarr_root: zarr.Group | None = PrivateAttr(None) + _zarr_data: MutableMapping[str, np.ndarray] | None = PrivateAttr(None) _warn_about_remote_zarr: bool = PrivateAttr(True) + _cache_dir: str | None = PrivateAttr(None) # Where to cache the data to if cache() is called. + _verify_checksum_strategy: ChecksumStrategy = PrivateAttr("verify_unless_zarr") @field_validator("default_adapters", mode="before") def _validate_adapters(cls, value): @@ -96,13 +94,6 @@ def _serialize_adapters(self, value: dict[str, Adapter]): """Serializes the adapters""" return {k: v.name for k, v in value.items()} - @field_serializer("cache_dir", "zarr_root_path") - def _serialize_paths(value): - """Serialize the paths""" - if value is not None: - value = str(value) - return value - @model_validator(mode="after") def _validate_base_dataset_model(self) -> Self: # Verify that all annotations are for columns that exist @@ -125,17 +116,6 @@ def _validate_base_dataset_model(self) -> Self: return self - @property - def client(self): - """The Polaris Hub client used to interact with the Polaris Hub.""" - - # Import it here to prevent circular imports - from polaris.hub.client import PolarisHubClient - - if self._client is None: - self._client = PolarisHubClient() - return self._client - @property def uses_zarr(self) -> bool: """Whether any of the data in this dataset is stored in a Zarr Archive.""" @@ -168,14 +148,16 @@ def zarr_root(self) -> zarr.Group | None: See also [`Dataset.load_to_memory`][polaris.dataset.Dataset.load_to_memory]. """ + from polaris.hub.client import PolarisHubClient + from polaris.hub.storage import StorageSession + if self._zarr_root is not None: return self._zarr_root if self.zarr_root_path is None: return None - # We open the archive in read-only mode if it is saved on the Hub - saved_on_hub = PolarisFileSystem.is_polarisfs_path(self.zarr_root_path) + saved_on_hub = self.zarr_root_path.startswith(StorageSession.polaris_protocol) if self._warn_about_remote_zarr: saved_remote = saved_on_hub or not Path(self.zarr_root_path).exists() @@ -190,7 +172,9 @@ def zarr_root(self) -> zarr.Group | None: try: if saved_on_hub: - self._zarr_root = self.client.open_zarr_file(self.owner, self.name, self.zarr_root_path, "r+") + with PolarisHubClient() as client: + with StorageSession(client, "read", self.urn) as storage: + self._zarr_root = zarr.open_consolidated(storage.extension_store) else: # We use memory mapping by default because our experiments show that it's consistently faster store = MemoryMappedDirectoryStore(self.zarr_root_path) @@ -273,7 +257,7 @@ def get_data( raise NotImplementedError @abc.abstractmethod - def upload_to_hub(self, access: AccessType = "private", owner: Union[HubOwner, str, None] = None): + def upload_to_hub(self, access: AccessType = "private", owner: HubOwner | str | None = None): """Uploads the dataset to the Polaris Hub.""" raise NotImplementedError @@ -316,8 +300,8 @@ def cache(self) -> str: Returns: The path to the cache directory. """ - self.to_json(self.cache_dir, load_zarr_from_new_location=True) - return self.cache_dir + self.to_json(self._cache_dir, load_zarr_from_new_location=True) + return self._cache_dir def size(self) -> tuple[int, int]: return self.n_rows, self.n_columns @@ -350,8 +334,3 @@ def __repr__(self): def __str__(self): return self.__repr__() - - def __del__(self): - """Close the connection of the client""" - if self._client is not None: - self._client.close() diff --git a/polaris/dataset/_competition_dataset.py b/polaris/dataset/_competition_dataset.py index 77217162..8f8e74cb 100644 --- a/polaris/dataset/_competition_dataset.py +++ b/polaris/dataset/_competition_dataset.py @@ -13,6 +13,8 @@ class CompetitionDataset(DatasetV1): of the training data for a given competition. """ + _artifact_type = "competitionDataset" + @model_validator(mode="after") def _validate_model(self) -> Self: """We reject the instantiation of competition datasets which leverage Zarr for the time being""" diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 740446a7..6aa8a621 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -1,8 +1,8 @@ import json -import uuid from hashlib import md5 from pathlib import Path -from typing import Any, ClassVar, List, Literal, Optional, Union +from typing import Any, ClassVar, List, Literal +from uuid import uuid4 import fsspec import numpy as np @@ -10,19 +10,25 @@ import zarr from datamol.utils import fs as dmfs from loguru import logger -from pydantic import PrivateAttr, computed_field, field_validator, model_validator +from pydantic import ( + PrivateAttr, + computed_field, + field_serializer, + field_validator, + model_validator, +) from typing_extensions import Self from polaris.dataset._adapters import Adapter -from polaris.dataset._base import _CACHE_SUBDIR, BaseDataset +from polaris.dataset._base import BaseDataset, _CACHE_SUBDIR from polaris.dataset.zarr import ZarrFileChecksum, compute_zarr_checksum -from polaris.mixins._checksum import ChecksumMixin +from polaris.mixins import ChecksumMixin from polaris.utils.constants import DEFAULT_CACHE_DIR from polaris.utils.errors import InvalidDatasetError from polaris.utils.types import ( AccessType, + ChecksumStrategy, HubOwner, - TimeoutTypes, ZarrConflictResolution, ) @@ -35,7 +41,7 @@ class DatasetV1(BaseDataset, ChecksumMixin): """First version of a Polaris Dataset. Stores datapoints in a Pandas DataFrame and implements _pointer columns_ to support the storage of XXL data - outside of the DataFrame in a Zarr archive. + outside the DataFrame in a Zarr archive. Info: Pointer columns For complex data, such as images, we support storing the content in external blobs of data. @@ -51,25 +57,35 @@ class DatasetV1(BaseDataset, ChecksumMixin): InvalidDatasetError: If the dataset does not conform to the Pydantic data-model specification. """ + _artifact_type = "dataset" + version: ClassVar[Literal[1]] = 1 + # Public attributes # Data table: pd.DataFrame - version: ClassVar[Literal[1]] = 1 - _zarr_md5sum_manifest: List[ZarrFileChecksum] = PrivateAttr(default_factory=list) + # Private attributes + _zarr_md5sum_manifest: list[ZarrFileChecksum] = PrivateAttr(default_factory=list) @field_validator("table", mode="before") - def _validate_table(cls, v): + @classmethod + def _load_table(cls, v) -> pd.DataFrame: """ - If the table is not a dataframe yet, assume it's a path and try load it. - We also make sure that the pandas index is contiguous and starts at 0, and - that all columns are named and unique. + Load from path if not a dataframe """ - # Load from path if not a dataframe - if not isinstance(v, pd.DataFrame): + if isinstance(v, str): if not dmfs.is_file(v) or dmfs.get_extension(v) not in _SUPPORTED_TABLE_EXTENSIONS: raise InvalidDatasetError(f"{v} is not a valid DataFrame or .parquet path.") v = pd.read_parquet(v) + return v + + @field_validator("table") + @classmethod + def _validate_table(cls, v: pd.DataFrame) -> pd.DataFrame: + """ + Make sure that the pandas index is contiguous and starts at 0, and + that all columns are named and unique. + """ # Check if there are any duplicate columns if any(v.columns.duplicated()): raise InvalidDatasetError("The table contains duplicate columns") @@ -90,15 +106,44 @@ def _validate_v1_dataset_model(self) -> Self: "The zarr_root_path should only be specified when there are pointer columns" ) - # Set the default cache dir if none and make sure it exists - if self.cache_dir is None: - dataset_id = self._md5sum if self.has_md5sum else str(uuid.uuid4()) - self.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id - self.cache_dir.mkdir(parents=True, exist_ok=True) + return self + + @model_validator(mode="after") + def _ensure_cache_dir_exists(self) -> Self: + """ + Set the default cache dir if none and make sure it exists + """ + if self._cache_dir is None: + dataset_id = self._md5sum if self.has_md5sum else str(uuid4()) + self._cache_dir = str(Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id) + fs, path = fsspec.url_to_fs(self._cache_dir) + fs.mkdirs(path, exist_ok=True) return self - def _compute_checksum(self) -> str: + @field_validator("default_adapters", mode="before") + def _validate_adapters(cls, value): + """Validate the adapters""" + return {k: Adapter[v] if isinstance(v, str) else v for k, v in value.items()} + + @field_serializer("default_adapters") + def _serialize_adapters(self, value: dict[str, Adapter]) -> dict[str, str]: + """Serializes the adapters""" + return {k: v.name for k, v in value.items()} + + def should_verify_checksum(self, strategy: ChecksumStrategy) -> bool: + """ + Determines whether to verify the checksum of the dataset based on the strategy. + """ + match strategy: + case "ignore": + return False + case "verify": + return True + case "verify_unless_zarr": + return not self.uses_zarr + + def _compute_checksum(self): """Computes a hash of the dataset. This is meant to uniquely identify the dataset and can be used to verify the version. @@ -134,7 +179,7 @@ def zarr_md5sum_manifest(self) -> List[ZarrFileChecksum]: """ if len(self._zarr_md5sum_manifest) == 0 and not self.has_md5sum: # The manifest is set as an instance variable - # as a side-effect of the compute_checksum method + # as a side effect of the compute_checksum method self.md5sum = self._compute_checksum() return self._zarr_md5sum_manifest @@ -197,34 +242,30 @@ def get_data( return arr - def upload_to_hub( - self, - access: Optional[AccessType] = "private", - owner: Union[HubOwner, str, None] = None, - timeout: TimeoutTypes = (10, 200), - ): + def upload_to_hub(self, access: AccessType = "private", owner: HubOwner | str | None = None): """ Very light, convenient wrapper around the [`PolarisHubClient.upload_dataset`][polaris.hub.client.PolarisHubClient.upload_dataset] method. """ - self.client.upload_dataset(self, access=access, owner=owner, timeout=timeout) + from polaris.hub.client import PolarisHubClient + + with PolarisHubClient() as client: + client.upload_dataset(self, owner=owner, access=access) @classmethod def from_json(cls, path: str): - """ - Loads a dataset from a JSON file. + """Loads a dataset from a JSON file. Args: - path: The path to the JSON file to load the dataset from. + path: The path to the JSON file to load the dataset from .ColumnAnnotation """ with fsspec.open(path, "r") as f: data = json.load(f) - data.pop("cache_dir", None) return cls.model_validate(data) def to_json( self, - destination: str, + destination: str | Path, if_exists: ZarrConflictResolution = "replace", load_zarr_from_new_location: bool = False, ) -> str: @@ -256,8 +297,8 @@ def to_json( dataset_path = str(destination / "dataset.json") new_zarr_root_path = str(destination / "data.zarr") - # Lu: Avoid serilizing and sending None to hub app. - serialized = self.model_dump(exclude={"cache_dir"}, exclude_none=True) + # Lu: Avoid serializing and sending None to hub app. + serialized = self.model_dump(exclude_none=True) serialized["table"] = table_path # Copy over Zarr data to the destination @@ -288,7 +329,7 @@ def to_json( return dataset_path - def cache(self, verify_checksum: bool = False): + def cache(self, verify_checksum: bool = False) -> str: """Cache the dataset to the cache directory. Args: diff --git a/polaris/dataset/_factory.py b/polaris/dataset/_factory.py index dfd550b0..a2a13224 100644 --- a/polaris/dataset/_factory.py +++ b/polaris/dataset/_factory.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Literal, Optional +from typing import Literal import datamol as dm import pandas as pd @@ -11,7 +11,7 @@ from polaris.dataset.converters import Converter, PDBConverter, SDFConverter, ZarrConverter -def create_dataset_from_file(path: str, zarr_root_path: Optional[str] = None) -> DatasetV1: +def create_dataset_from_file(path: str, zarr_root_path: str | None = None) -> DatasetV1: """ This function is a convenience function to create a dataset from a file. @@ -28,7 +28,7 @@ def create_dataset_from_file(path: str, zarr_root_path: Optional[str] = None) -> def create_dataset_from_files( - paths: List[str], zarr_root_path: Optional[str] = None, axis: Literal[0, 1, "index", "columns"] = 0 + paths: list[str], zarr_root_path: str | None = None, axis: Literal[0, 1, "index", "columns"] = 0 ) -> DatasetV1: """ This function is a convenience function to create a dataset from multiple files. @@ -54,7 +54,7 @@ class DatasetFactory: """ The `DatasetFactory` makes it easier to create complex datasets. - It is based on the the factory design pattern and allows a user to specify specific handlers + It is based on the factory design pattern and allows a user to specify specific handlers (i.e. [`Converter`][polaris.dataset.converters._base.Converter] objects) for different file types. These converters are used to convert commonly used file types in drug discovery to something that can be used within Polaris while losing as little information as possible. @@ -73,12 +73,12 @@ class DatasetFactory: Question: How to make adding meta-data easier? The `DatasetFactory` is designed to more easily pull together data from different sources. - However, adding meta-data remains a laborous process. How could we make this simpler through + However, adding meta-data remains a laborious process. How could we make this simpler through the Python API? """ def __init__( - self, zarr_root_path: Optional[str] = None, converters: Optional[Dict[str, Converter]] = None + self, zarr_root_path: str | None = None, converters: dict[str, Converter] | None = None ) -> None: """ Create a new factory object. @@ -93,7 +93,7 @@ def __init__( if converters is None: converters = {} - self._converters: Dict[str, Converter] = converters + self._converters: dict[str, Converter] = converters self.reset(zarr_root_path=zarr_root_path) @property @@ -139,8 +139,8 @@ def register_converter(self, ext: str, converter: Converter): def add_column( self, column: pd.Series, - annotation: Optional[ColumnAnnotation] = None, - adapters: Optional[Adapter] = None, + annotation: ColumnAnnotation | None = None, + adapters: Adapter | None = None, ): """ Add a single column to the DataFrame @@ -150,7 +150,7 @@ def add_column( 1. The name attribute of the column to be set. 2. The name attribute of the column to be unique. 3. If the column is a pointer column, the `zarr_root_path` needs to be set. - 4. The length of the column to match the length of the alredy constructed table. + 4. The length of the column to match the length of the already constructed table. Args: column: The column to add to the dataset. @@ -182,9 +182,9 @@ def add_column( def add_columns( self, df: pd.DataFrame, - annotations: Optional[Dict[str, ColumnAnnotation]] = None, - adapters: Optional[Dict[str, Adapter]] = None, - merge_on: Optional[str] = None, + annotations: dict[str, ColumnAnnotation] | None = None, + adapters: dict[str, Adapter] | None = None, + merge_on: str | None = None, ): """ Add multiple columns to the dataset based on another dataframe. @@ -237,7 +237,7 @@ def add_from_file(self, path: str): table, annotations, adapters = converter.convert(path, self) self.add_columns(table, annotations, adapters) - def add_from_files(self, paths: List[str], axis: Literal[0, 1, "index", "columns"]): + def add_from_files(self, paths: list[str], axis: Literal[0, 1, "index", "columns"]): """ Uses the registered converters to parse the data from a specific files and add them to the dataset. If no converter is found for the file extension, it raises an error. @@ -275,7 +275,7 @@ def build(self) -> DatasetV1: zarr_root_path=self.zarr_root_path, ) - def reset(self, zarr_root_path: Optional[str] = None): + def reset(self, zarr_root_path: str | None = None): """ Resets the factory to its initial state to start building the next dataset from scratch. Note that this will not reset the registered converters. diff --git a/polaris/evaluate/_results.py b/polaris/evaluate/_results.py index 252cc597..4de807b6 100644 --- a/polaris/evaluate/_results.py +++ b/polaris/evaluate/_results.py @@ -20,7 +20,7 @@ from polaris.hub.settings import PolarisHubSettings from polaris.utils.dict2html import dict2html from polaris.utils.errors import InvalidResultError -from polaris.utils.misc import sluggify +from polaris.utils.misc import slugify from polaris.utils.types import ( AccessType, CompetitionPredictionsType, @@ -208,13 +208,15 @@ class BenchmarkResults(EvaluationResult): Together with the benchmark name, this uniquely identifies the benchmark on the Hub. """ + _artifact_type = "result" + benchmark_name: SlugCompatibleStringType = Field(..., frozen=True) benchmark_owner: Optional[HubOwner] = Field(None, frozen=True) @computed_field @property def benchmark_artifact_id(self) -> str: - return f"{self.benchmark_owner}/{sluggify(self.benchmark_name)}" + return f"{self.benchmark_owner}/{slugify(self.benchmark_name)}" def upload_to_hub( self, @@ -248,13 +250,15 @@ class CompetitionResults(EvaluationResult): Together with the competition name, this uniquely identifies the competition on the Hub. """ + _artifact_type = "competitionResult" + competition_name: SlugCompatibleStringType = Field(..., frozen=True) competition_owner: Optional[HubOwner] = Field(None, frozen=True) @computed_field @property def competition_artifact_id(self) -> str: - return f"{self.competition_owner}/{sluggify(self.competition_name)}" + return f"{self.competition_owner}/{slugify(self.competition_name)}" class CompetitionPredictions(ResultsMetadata): diff --git a/polaris/experimental/_dataset_v2.py b/polaris/experimental/_dataset_v2.py index e7987ab3..17f6a893 100644 --- a/polaris/experimental/_dataset_v2.py +++ b/polaris/experimental/_dataset_v2.py @@ -1,8 +1,8 @@ import json import re -import uuid from pathlib import Path from typing import Any, ClassVar, Literal +from uuid import uuid4 import fsspec import numpy as np @@ -12,7 +12,7 @@ from typing_extensions import Self from polaris.dataset._adapters import Adapter -from polaris.dataset._base import _CACHE_SUBDIR, BaseDataset +from polaris.dataset._base import BaseDataset, _CACHE_SUBDIR from polaris.dataset.zarr._manifest import calculate_file_md5, generate_zarr_manifest from polaris.utils.constants import DEFAULT_CACHE_DIR from polaris.utils.errors import InvalidDatasetError @@ -42,6 +42,8 @@ class DatasetV2(BaseDataset): InvalidDatasetError: If the dataset does not conform to the Pydantic data-model specification. """ + _artifact_type = "dataset" + version: ClassVar[Literal[2]] = 2 _zarr_manifest_path: str | None = PrivateAttr(None) _zarr_manifest_md5sum: str | None = PrivateAttr(None) @@ -78,11 +80,18 @@ def _validate_v2_dataset_model(self) -> Self: f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}" ) - # Set the default cache dir if none and make sure it exists - if self.cache_dir is None: - dataset_id = self._zarr_manifest_md5sum if self.has_zarr_manifest_md5sum else str(uuid.uuid4()) - self.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id - self.cache_dir.mkdir(parents=True, exist_ok=True) + return self + + @model_validator(mode="after") + def _ensure_cache_dir_exists(self) -> Self: + """ + Set the default cache dir if none and make sure it exists + """ + if self._cache_dir is None: + dataset_id = self._zarr_manifest_md5sum if self.has_zarr_manifest_md5sum else str(uuid4()) + self._cache_dir = str(Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id) + fs, path = fsspec.url_to_fs(self._cache_dir) + fs.mkdirs(path, exist_ok=True) return self @@ -122,7 +131,7 @@ def dtypes(self) -> dict[str, np.dtype]: @property def zarr_manifest_path(self) -> str: if self._zarr_manifest_path is None: - zarr_manifest_path = generate_zarr_manifest(self.zarr_root_path, self.cache_dir) + zarr_manifest_path = generate_zarr_manifest(self.zarr_root_path, self._cache_dir) self._zarr_manifest_path = zarr_manifest_path return self._zarr_manifest_path @@ -202,7 +211,6 @@ def from_json(cls, path: str): """ with fsspec.open(path, "r") as f: data = json.load(f) - data.pop("cache_dir", None) return cls.model_validate(data) def to_json( @@ -231,7 +239,7 @@ def to_json( new_zarr_root_path = str(destination / "data.zarr") # Lu: Avoid serilizing and sending None to hub app. - serialized = self.model_dump(exclude={"cache_dir"}, exclude_none=True) + serialized = self.model_dump(exclude_none=True) serialized["zarrRootPath"] = new_zarr_root_path # Copy over Zarr data to the destination diff --git a/polaris/hub/client.py b/polaris/hub/client.py index f2f9e796..05cdc5b3 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -23,14 +23,13 @@ SingleTaskBenchmarkSpecification, ) from polaris.competition import CompetitionSpecification -from polaris.dataset import CompetitionDataset, DatasetV1 -from polaris.evaluate import BenchmarkResults, CompetitionResults -from polaris.evaluate._results import CompetitionPredictions -from polaris.hub.external_auth_client import ExternalAuthClient +from polaris.dataset import CompetitionDataset, Dataset, DatasetV1 +from polaris.evaluate import BenchmarkResults, CompetitionPredictions, CompetitionResults +from polaris.hub.external_client import ExternalAuthClient from polaris.hub.oauth import CachedTokenAuth -from polaris.hub.polarisfs import PolarisFileSystem from polaris.hub.settings import PolarisHubSettings -from polaris.utils.context import ProgressIndicator, tmp_attribute_change +from polaris.hub.storage import StorageSession +from polaris.utils.context import ProgressIndicator from polaris.utils.errors import ( InvalidDatasetError, PolarisCreateArtifactError, @@ -38,13 +37,11 @@ PolarisRetrieveArtifactError, PolarisUnauthorizedError, ) -from polaris.utils.misc import should_verify_checksum from polaris.utils.types import ( AccessType, ArtifactSubtype, ChecksumStrategy, HubOwner, - IOMode, SupportedLicenseType, TimeoutTypes, ZarrConflictResolution, @@ -140,7 +137,7 @@ def _prepare_token_endpoint_body(self, body, grant_type, **kwargs): ) return super()._prepare_token_endpoint_body(body, grant_type, **kwargs) - def ensure_active_token(self, token: OAuth2Token) -> bool: + def ensure_active_token(self, token: OAuth2Token | None = None) -> bool: """ Override the active check to trigger a refetch of the token if it is not active. """ @@ -149,7 +146,7 @@ def ensure_active_token(self, token: OAuth2Token) -> bool: return True # Check if external token is still valid - if not self.external_client.ensure_active_token(self.external_client.token): + if not self.external_client.ensure_active_token(): return False # If so, use it to get a new Hub token @@ -256,7 +253,7 @@ def login(self, overwrite: bool = False, auto_open_browser: bool = True): overwrite: Whether to overwrite the current token if the user is already logged in. auto_open_browser: Whether to automatically open the browser to visit the authorization URL. """ - if overwrite or self.token is None or not self.ensure_active_token(self.token): + if overwrite or self.token is None or not self.ensure_active_token(): self.external_client.interactive_login(overwrite=overwrite, auto_open_browser=auto_open_browser) self.token = self.fetch_token() @@ -309,10 +306,10 @@ def get_dataset( def _get_dataset( self, - owner: Union[str, HubOwner], + owner: str | HubOwner, name: str, artifact_type: ArtifactSubtype, - verify_checksum: bool = True, + verify_checksum: ChecksumStrategy = "verify_unless_zarr", ) -> DatasetV1: """Loads either a standard or competition dataset from Polaris Hub @@ -336,75 +333,29 @@ def _get_dataset( else f"/v2/competition/dataset/{owner}/{name}" ) response = self._base_request_to_hub(url=url, method="GET") - storage_response = self.get(response["tableContent"]["url"]) - - # This should be a 307 redirect with the signed URL - if storage_response.status_code != 307: - try: - storage_response.raise_for_status() - except HTTPStatusError as error: - raise PolarisHubError( - message="Could not get signed URL from Polaris Hub.", response=storage_response - ) from error - storage_response = storage_response.json() - url = storage_response["url"] - headers = storage_response["headers"] + # Disregard the Zarr root in the response. We'll get it from the storage token instead. + response.pop("zarrRootPath", None) - response["table"] = self._load_from_signed_url(url=url, headers=headers, load_fn=pd.read_parquet) + # Load the dataset table and optional Zarr archive + with StorageSession(self, "read", Dataset.urn_for(owner, name)) as storage: + table = pd.read_parquet(storage.get_root()) + zarr_root_path = str(storage.paths.extension) if artifact_type == ArtifactSubtype.COMPETITION: - dataset = CompetitionDataset(**response) - md5Sum = response["maskedMd5Sum"] + dataset = CompetitionDataset(table=table, zarr_root_path=zarr_root_path, **response) + md5sum = response["maskedMd5Sum"] else: - dataset = DatasetV1(**response) - md5Sum = response["md5Sum"] + dataset = DatasetV1(table=table, zarr_root_path=zarr_root_path, **response) + md5sum = response["md5Sum"] - if should_verify_checksum(verify_checksum, dataset): - dataset.verify_checksum(md5Sum) + if dataset.should_verify_checksum(verify_checksum): + dataset.verify_checksum(md5sum) else: - dataset.md5sum = md5Sum + dataset.md5sum = md5sum return dataset - def open_zarr_file( - self, owner: str | HubOwner, name: str, path: str, mode: IOMode, as_consolidated: bool = True - ) -> zarr.hierarchy.Group: - """Open a Zarr file from a Polaris dataset - - Args: - owner: Which Hub user or organization owns the artifact. - name: Name of the dataset. - path: Path to the Zarr file within the dataset. - mode: The mode in which the file is opened. - as_consolidated: Whether to open the store with consolidated metadata for optimized reading. - This is only applicable in 'r' and 'r+' modes. - - Returns: - The Zarr object representing the dataset. - """ - if as_consolidated and mode not in ["r", "r+"]: - raise ValueError("Consolidated archives can only be used with 'r' or 'r+' mode.") - - polaris_fs = PolarisFileSystem( - polaris_client=self, - dataset_owner=owner, - dataset_name=name, - ) - - try: - store = zarr.storage.FSStore(path, fs=polaris_fs) - if mode in ["r", "r+"] and as_consolidated: - return zarr.open_consolidated(store, mode=mode) - return zarr.open(store, mode=mode) - - except HTTPStatusError as error: - # In this case, we can pass the response to provide more information - raise PolarisHubError(message="Error opening Zarr store", response=error.response) from error - # This catches all other types of exceptions - except Exception as error: - raise PolarisHubError(message="Error opening Zarr store") from error - def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]: """List all available benchmarks on the Polaris Hub. @@ -469,7 +420,7 @@ def get_benchmark( benchmark = benchmark_cls(**response) - if should_verify_checksum(verify_checksum, benchmark.dataset): + if benchmark.dataset.should_verify_checksum(verify_checksum): benchmark.verify_checksum() else: benchmark.md5sum = response["md5Sum"] @@ -541,16 +492,16 @@ def upload_dataset( ): """Wrapper method for uploading standard datasets to Polaris Hub""" return self._upload_dataset( - dataset, ArtifactSubtype.STANDARD.value, access, timeout, owner, if_exists + dataset, ArtifactSubtype.STANDARD.value, owner, access, timeout, if_exists ) def _upload_dataset( self, dataset: DatasetV1, artifact_type: ArtifactSubtype, + owner: HubOwner | str, access: AccessType = "private", timeout: TimeoutTypes = (10, 200), - owner: Union[HubOwner, str, None] = None, if_exists: ZarrConflictResolution = "replace", ): """Upload the dataset to the Polaris Hub. @@ -595,25 +546,22 @@ def _upload_dataset( # Get the serialized data-model # We exclude the table as it handled separately and we exclude the cache_dir as it is user-specific dataset.owner = HubOwner.normalize(owner or dataset.owner) - dataset_json = dataset.model_dump( - exclude={"cache_dir", "table"}, exclude_none=True, by_alias=True - ) + dataset_json = dataset.model_dump(exclude={"table"}, exclude_none=True, by_alias=True) # If the dataset uses Zarr, we will save the Zarr archive to the Hub as well if dataset.uses_zarr: - dataset_json["zarrRootPath"] = f"{PolarisFileSystem.protocol}://data.zarr" + dataset_json["zarrRootPath"] = f"{StorageSession.polaris_protocol}://data.zarr" # Uploading a dataset is a three-step process. # 1. Upload the dataset meta data to the hub and prepare the hub to receive the data # 2. Upload the parquet file to the hub # 3. Upload the associated Zarr archive - # TODO: Revert step 1 in case step 2 fails - Is this needed? Or should this be taken care of by the hub? # Prepare the parquet file - buffer = BytesIO() - dataset.table.to_parquet(buffer, engine="auto") - parquet_size = len(buffer.getbuffer()) - parquet_md5 = md5(buffer.getbuffer()).hexdigest() + in_memory_parquet = BytesIO() + dataset.table.to_parquet(in_memory_parquet) + parquet_size = len(in_memory_parquet.getbuffer()) + parquet_md5 = md5(in_memory_parquet.getbuffer()).hexdigest() # Step 1: Upload meta-data # Instead of directly uploading the data, we announce to the hub that we intend to upload it. @@ -639,65 +587,32 @@ def _upload_dataset( timeout=timeout, ) - # Step 2: Upload the parquet file - # create an empty PUT request to get the table content URL from cloudflare - hub_response = self.request( - url=response["tableContent"]["url"], - method="PUT", - headers={ - "Content-type": "application/vnd.apache.parquet", - }, - timeout=timeout, - json={"artifactType": artifact_type}, - ) - - if hub_response.status_code == 307: - # If the hub returns a 307 redirect, we need to follow it to get the signed URL - hub_response_body = hub_response.json() - - # Upload the data to the cloudflare url - bucket_response = self.request( - url=hub_response_body["url"], - method=hub_response_body["method"], - headers={ - "Content-type": "application/vnd.apache.parquet", - **hub_response_body["headers"], - }, - content=buffer.getvalue(), - auth=None, - timeout=timeout, # required for large size dataset - ) - bucket_response.raise_for_status() + with StorageSession(self, "write", dataset.urn) as storage: + # Step 2: Upload the parquet file + logger.info("Copying Parquet file to the Hub. This may take a while.") + storage.set_root(in_memory_parquet.getvalue()) - else: - hub_response.raise_for_status() + # Step 3: Upload any associated Zarr archive + if dataset.uses_zarr: + logger.info("Copying Zarr archive to the Hub. This may take a while.") - # Step 3: Upload any associated Zarr archive - if dataset.uses_zarr: - with tmp_attribute_change(self.settings, "default_timeout", timeout): - # Copy the Zarr archive to the hub - dest = self.open_zarr_file( - owner=dataset.owner, - name=dataset.name, - path=dataset_json["zarrRootPath"], - mode="w", - as_consolidated=False, - ) + destination = storage.extension_store # Locally consolidate Zarr archive metadata. Future updates on handling consolidated # metadata based on Zarr developers' recommendations can be tracked at: # https://github.com/zarr-developers/zarr-python/issues/1731 zarr.consolidate_metadata(dataset.zarr_root.store.store) zmetadata_content = dataset.zarr_root.store.store[".zmetadata"] - dest.store[".zmetadata"] = zmetadata_content + destination[".zmetadata"] = zmetadata_content - logger.info("Copying Zarr archive to the Hub. This may take a while.") + # Copy the Zarr archive to the hub zarr.copy_store( source=dataset.zarr_root.store.store, - dest=dest.store, + dest=destination, log=logger.debug, if_exists=if_exists, ) + base_artifact_url = ( "datasets" if artifact_type == ArtifactSubtype.STANDARD.value else "/competition/datasets" ) @@ -791,7 +706,10 @@ def _upload_benchmark( return response def get_competition( - self, owner: Union[str, HubOwner], name: str, verify_checksum: bool = True + self, + owner: str | HubOwner, + name: str, + verify_checksum: ChecksumStrategy = "verify_unless_zarr", ) -> CompetitionSpecification: """Load a competition from the Polaris Hub. @@ -861,7 +779,7 @@ def evaluate_competition( success_msg="Evaluated competition predictions.", error_msg="Failed to evaluate competition predictions.", ) as progress_indicator: - competition.owner = HubOwner(**competition.owner) + competition.owner = HubOwner.normalize(competition.owner) response = self._base_request_to_hub( url=f"/v2/competition/{competition.owner}/{competition.name}/evaluate", diff --git a/polaris/hub/external_auth_client.py b/polaris/hub/external_client.py similarity index 88% rename from polaris/hub/external_auth_client.py rename to polaris/hub/external_client.py index ed388446..35829186 100644 --- a/polaris/hub/external_auth_client.py +++ b/polaris/hub/external_client.py @@ -1,15 +1,18 @@ import webbrowser -from typing import Optional +from typing import Literal, Optional, TypeAlias from authlib.common.security import generate_token from authlib.integrations.base_client import OAuthError from authlib.integrations.httpx_client import OAuth2Client -from authlib.oauth2 import TokenAuth +from authlib.oauth2 import OAuth2Error, TokenAuth +from authlib.oauth2.rfc6749 import OAuth2Token from loguru import logger from polaris.hub.oauth import ExternalCachedTokenAuth from polaris.hub.settings import PolarisHubSettings -from polaris.utils.errors import PolarisUnauthorizedError +from polaris.utils.errors import PolarisHubError, PolarisUnauthorizedError + +Scope: TypeAlias = Literal["read", "write"] class ExternalAuthClient(OAuth2Client): @@ -65,11 +68,16 @@ def create_authorization_url(self, **kwargs) -> tuple[str, Optional[str]]: def fetch_token(self, **kwargs) -> dict: """Light wrapper to automatically pass in the right URL""" - return super().fetch_token( - url=self.settings.token_fetch_url, code_verifier=self.code_verifier, **kwargs - ) + try: + return super().fetch_token( + url=self.settings.token_fetch_url, code_verifier=self.code_verifier, **kwargs + ) + except OAuth2Error as error: + raise PolarisHubError( + message=f"Could not obtain a token from the external OAuth2 server. Error was: {error.error} - {error.description}" + ) from error - def ensure_active_token(self, token) -> bool: + def ensure_active_token(self, token: OAuth2Token | None = None) -> bool: try: return super().ensure_active_token(token) or False except OAuthError: diff --git a/polaris/hub/oauth.py b/polaris/hub/oauth.py index 8be61595..a39e0d22 100644 --- a/polaris/hub/oauth.py +++ b/polaris/hub/oauth.py @@ -1,9 +1,16 @@ import json +import re +from datetime import datetime, timedelta, timezone from pathlib import Path +from time import time +from typing import Any, Literal from authlib.integrations.httpx_client import OAuth2Auth +from pydantic import BaseModel, PositiveInt, computed_field, model_validator +from typing_extensions import Self from polaris.utils.constants import DEFAULT_CACHE_DIR +from polaris.utils.types import AnyUrlString, HttpUrlString class CachedTokenAuth(OAuth2Auth): @@ -50,3 +57,70 @@ def __init__( filename="external_auth_token.json", ): super().__init__(token, token_placement, client, cache_dir, filename) + + +class StoragePaths(BaseModel): + root: AnyUrlString + extension: AnyUrlString | None = None + + @computed_field + @property + def relative_root(self) -> str: + return re.sub(r"^\w+://", "", self.root) + + @computed_field + @property + def relative_extension(self) -> str | None: + if self.extension: + return re.sub(r"^\w+://", "", self.extension) + return None + + +class StorageTokenData(BaseModel): + key: str + secret: str + endpoint: HttpUrlString + paths: StoragePaths + + +class HubOAuth2Token(BaseModel): + """ + Model to parse and validate tokens obtained from the Polaris Hub. + """ + + issued_token_type: Literal["urn:ietf:params:oauth:token-type:jwt"] = ( + "urn:ietf:params:oauth:token-type:jwt" + ) + token_type: Literal["Bearer"] = "Bearer" + expires_in: PositiveInt | None = None + expires_at: datetime | None = None + access_token: str + extra_data: None + + @model_validator(mode="after") + def set_expires_at(self) -> Self: + if self.expires_at is None and self.expires_in is not None: + self.expires_at = datetime.fromtimestamp(time() + self.expires_in, timezone.utc) + return self + + def is_expired(self, leeway=60) -> bool | None: + if not self.expires_at: + return None + # Small timedelta to consider token as expired before it actually expires + expiration_threshold = self.expires_at - timedelta(seconds=leeway) + return datetime.now(timezone.utc) >= expiration_threshold + + def __getitem__(self, item) -> Any | None: + """ + Compatibility with authlib's expectation that this is a dict + """ + return getattr(self, item) + + +class HubStorageOAuth2Token(HubOAuth2Token): + """ + Specialized model for storage tokens. + """ + + token_type: Literal["Storage"] = "Storage" + extra_data: StorageTokenData diff --git a/polaris/hub/polarisfs.py b/polaris/hub/polarisfs.py index 365a4329..9bf2cb5a 100644 --- a/polaris/hub/polarisfs.py +++ b/polaris/hub/polarisfs.py @@ -1,11 +1,11 @@ from hashlib import md5 -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union import fsspec from loguru import logger from polaris.utils.errors import PolarisChecksumError, PolarisHubError -from polaris.utils.misc import sluggify +from polaris.utils.misc import slugify from polaris.utils.types import TimeoutTypes if TYPE_CHECKING: @@ -53,7 +53,7 @@ def __init__( self.default_timeout = self.polaris_client.settings.default_timeout # Prefix to remove from ls entries - self.prefix = f"dataset/{dataset_owner}/{sluggify(dataset_name)}/" + self.prefix = f"dataset/{dataset_owner}/{slugify(dataset_name)}/" # Base path for uploading. Please pay attention on path version update. self.base_path = f"/v1/storage/{self.prefix.rstrip('/')}" diff --git a/polaris/hub/storage.py b/polaris/hub/storage.py new file mode 100644 index 00000000..57dde15d --- /dev/null +++ b/polaris/hub/storage.py @@ -0,0 +1,432 @@ +from base64 import b64encode +from concurrent.futures import ThreadPoolExecutor, as_completed +from contextlib import contextmanager +from hashlib import md5 +from io import BytesIO +from pathlib import Path +from typing import Any, Generator, Literal, Mapping, Sequence, TypeAlias + +import boto3 +from authlib.integrations.httpx_client import OAuth2Client +from authlib.oauth2 import OAuth2Error +from authlib.oauth2.rfc6749 import OAuth2Token +from botocore.exceptions import BotoCoreError, ClientError +from typing_extensions import Self +from zarr.context import Context +from zarr.storage import Store + +from polaris.hub.oauth import HubStorageOAuth2Token, StoragePaths +from polaris.utils.errors import PolarisHubError +from polaris.utils.types import ArtifactUrn + +Scope: TypeAlias = Literal["read", "write"] + + +class S3StoreException(Exception): + """ + Base exception for S3Store. + """ + + +class S3StoreCredentialsExpiredException(S3StoreException): + """ + Exception raised when the S3 credentials have expired. + """ + + +@contextmanager +def handle_s3_errors(): + """ + Standardize error handling for S3 operations. + """ + try: + yield + except ClientError as e: + error_code = e.response["Error"]["Code"] + if error_code == "ExpiredToken": + raise S3StoreCredentialsExpiredException(f"Error in S3Store: Credentials expired: {e}") from e + else: + raise S3StoreException(f"Error in S3Store: {e}") from e + except BotoCoreError as e: + raise S3StoreException(f"Error in S3Store: {e}") from e + + +class S3Store(Store): + """ + A Zarr store implementation using a S3 bucket as the backend storage. + + It supports multipart uploads for large objects and handles S3-specific exceptions. + """ + + _erasable = False + + def __init__( + self, + path: str, + access_key: str, + secret_key: str, + token: str, + endpoint_url: str, + part_size: int = 10 * 1024 * 1024, # 10MB + content_type: str = "application/octet-stream", + ) -> None: + bucket_name, prefix = path.split("/", 1) + + self.s3_client = boto3.client( + "s3", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + aws_session_token=token, + endpoint_url=endpoint_url, + ) + self.bucket_name = bucket_name + self.prefix = prefix + self.part_size = part_size + self.content_type = content_type + + def _full_key(self, key: str) -> str: + """ + Converts a relative key to the full bucket key. + """ + return f"{self.prefix}/{key}" + + def _multipart_upload(self, key: str, value: bytes) -> None: + """ + For large files, use multipart to split the work. + """ + full_key = self._full_key(key) + md5_hash = md5(value) + with handle_s3_errors(): + upload = self.s3_client.create_multipart_upload( + Bucket=self.bucket_name, + Key=full_key, + ContentType=self.content_type, + Metadata={ + "md5sum": md5_hash.hexdigest(), + }, + ) + upload_id = upload["UploadId"] + + parts = [] + for i in range(0, len(value), self.part_size): + part_number = i // self.part_size + 1 + part = value[i : i + self.part_size] + response = self.s3_client.upload_part( + Bucket=self.bucket_name, + Key=full_key, + PartNumber=part_number, + UploadId=upload_id, + Body=part, + ContentMD5=b64encode(md5(part).digest()).decode(), + ) + parts.append({"ETag": response["ETag"], "PartNumber": part_number}) + + self.s3_client.complete_multipart_upload( + Bucket=self.bucket_name, Key=full_key, UploadId=upload_id, MultipartUpload={"Parts": parts} + ) + + def listdir(self, path: str = "") -> Generator[str, None, None]: + """ + For a given path, list all the "subdirectories" and "files" for that path. + The returned paths are relative to the input path. + + Uses pagination and return a generator to handle very large number of keys. + Note: This might not help with some Zarr operations that materialize the whole sequence. + """ + prefix = self._full_key(path) + + # Ensure a trailing slash to avoid the path looking like one specific key + if prefix and not prefix.endswith("/"): + prefix += "/" + + with handle_s3_errors(): + # `list_objects_v2` returns a max of 1000 keys per request, so paginate requests + paginator = self.s3_client.get_paginator("list_objects_v2") + page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix, Delimiter="/") + + for page in page_iterator: + # Contents are "files" + for obj in page.get("Contents", []): + key = obj["Key"][len(prefix) :] + if key: + yield key.split("/")[0] + + # CommonPrefixes are "subdirectories" + for common_prefix in page.get("CommonPrefixes", []): + yield common_prefix["Prefix"][len(prefix) :].strip("/") + + def getitems(self, keys: Sequence[str], *, contexts: Mapping[str, Context]) -> dict[str, Any]: + """ + More efficient implementation of getitems using concurrent fetching through multiple threads. + + The default implementation uses __contains__ to check existence before fetching the value, + which doubles the number of requests. + """ + + def fetch_key(key: str) -> tuple[str, Any]: + """ + Wrapper that looks up a key's value in the store + """ + try: + return key, self[key] + except KeyError: + return key, None + + results = {} + with ThreadPoolExecutor() as executor: + # Create a future for each key to look up + future_to_key = {executor.submit(fetch_key, key): key for key in keys} + + # As each future completes, collect the results, if any + for future in as_completed(future_to_key): + key, value = future.result() + if value is not None: + results[key] = value + + return results + + def getsize(self, key: str) -> int: + """ + Return the size (in bytes) of the object at the given key. + """ + with handle_s3_errors(): + response = self.s3_client.head_object(Bucket=self.bucket_name, Key=self._full_key(key)) + return response["ContentLength"] + + ## MutableMapping implementation, expected by Zarr + + def __getitem__(self, key: str) -> bytes: + """ + Retrieves the value for the given key from the store. + + Makes no provision to handle overly large values returned. + """ + with handle_s3_errors(): + try: + full_key = self._full_key(key) + response = self.s3_client.get_object(Bucket=self.bucket_name, Key=full_key) + return response["Body"].read() + except self.s3_client.exceptions.NoSuchKey: + raise KeyError(key) + + def __setitem__(self, key: str, value: bytes | bytearray | memoryview) -> None: + """ + Persists the given value in the store. + + Based on value size, will use multipart upload for large files, + or a single put_object call. + """ + if isinstance(value, memoryview): + value = value.tobytes() + + if len(value) > self.part_size: + self._multipart_upload(key, value) + else: + with handle_s3_errors(): + md5_hash = md5(value) + self.s3_client.put_object( + Bucket=self.bucket_name, + Key=self._full_key(key), + Body=value, + ContentType=self.content_type, + ContentMD5=b64encode(md5_hash.digest()).decode(), + Metadata={ + "md5sum": md5_hash.hexdigest(), + }, + ) + + def __delitem__(self, key: str) -> None: + """ + Removing a key from the store is not supported. + """ + raise NotImplementedError(f'{type(self)} is not erasable, cannot call "del store[key]"') + + def __contains__(self, key: str) -> bool: + """ + Checks the existence of a key in the store. + + If the intent is to download the value after this check, it is more efficient to + attempt tp retrieve it and handle the KeyError from a non-existent key. + """ + with handle_s3_errors(): + try: + self.s3_client.head_object(Bucket=self.bucket_name, Key=self._full_key(key)) + return True + except self.s3_client.exceptions.NoSuchKey: + return False + except ClientError as e: + if e.response["Error"]["Code"] == "404": + return False + raise e + + def __iter__(self) -> Generator[str, None, None]: + """ + Iterate through all the keys in the store. + """ + with handle_s3_errors(): + # `list_objects_v2` returns a max of 1000 keys per request, so paginate requests + paginator = self.s3_client.get_paginator("list_objects_v2") + page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix) + + for page in page_iterator: + for obj in page.get("Contents", []): + yield obj["Key"][len(self.prefix) + 1 :] + + def __len__(self) -> int: + """ + Number of keys in the store. + """ + with handle_s3_errors(): + # `list_objects_v2` returns a max of 1000 keys per request, so paginate requests + paginator = self.s3_client.get_paginator("list_objects_v2") + page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix) + + return sum((page["KeyCount"] for page in page_iterator)) + + +class StorageTokenAuth: + token: HubStorageOAuth2Token | None + + def __init__(self, token: dict[str, Any] | None, *args) -> None: + self.token = None + if token: + self.set_token(token) + + def set_token(self, token: dict[str, Any] | HubStorageOAuth2Token) -> None: + self.token = HubStorageOAuth2Token(**token) if isinstance(token, dict) else token + + +class StorageSession(OAuth2Client): + """ + A context manager for managing a storage session, with token exchange and token refresh capabilities. + Each session is associated with a specific scope and resource. + """ + + polaris_protocol = "polarisfs" + + token_auth_class = StorageTokenAuth + + def __init__(self, hub_client, scope: Scope, resource: ArtifactUrn): + self.hub_client = hub_client + self.resource = resource + + super().__init__( + # OAuth2Client + token_endpoint=hub_client.settings.hub_token_url, + token_endpoint_auth_method="none", + grant_type="urn:ietf:params:oauth:grant-type:token-exchange", + scope=scope, + # httpx.Client + cert=hub_client.settings.ca_bundle, + ) + + def __enter__(self) -> Self: + self.ensure_active_token() + return self + + def _prepare_token_endpoint_body(self, body, grant_type, **kwargs) -> str: + """ + Override to support required fields for the token exchange grant type. + See https://datatracker.ietf.org/doc/html/rfc8693#name-request + """ + if grant_type == "urn:ietf:params:oauth:grant-type:token-exchange": + kwargs.update( + { + "subject_token": self.hub_client.token.get("access_token"), + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + "requested_token_type": "urn:ietf:params:oauth:token-type:jwt", + "resource": self.resource, + } + ) + return super()._prepare_token_endpoint_body(body, grant_type, **kwargs) + + def fetch_token(self, **kwargs) -> dict[str, Any]: + """ + Error handling for token fetching. + """ + try: + return super().fetch_token() + except OAuth2Error as error: + raise PolarisHubError( + message=f"Could not obtain a token to access the storage backend. Error was: {error.error} - {error.description}" + ) from error + + def ensure_active_token(self, token: OAuth2Token | None = None) -> bool: + """ + Override the active check to trigger a re-fetch of the token if it is not active. + """ + if token is None: + token = self.token + + if token and super().ensure_active_token(token): + return True + + # Check if external token is still valid + if not self.hub_client.ensure_active_token(): + return False + + # If so, use it to get a new Hub token + self.token = self.fetch_token() + return True + + @property + def paths(self) -> StoragePaths: + return self.token.extra_data.paths + + def set_root(self, value: bytes | bytearray) -> None: + """ + Set a value at the root path. + """ + storage_data = self.token.extra_data + path = Path(self.paths.relative_root) + + # Try to be smart about the content type + match path.suffix: + case ".parquet": + content_type = "application/vnd.apache.parquet" + case _: + content_type = "application/octet-stream" + + store = S3Store( + path=str(path.parent), + access_key=storage_data.key, + secret_key=storage_data.secret, + token=f"jwt/{self.token.access_token}", + endpoint_url=storage_data.endpoint, + content_type=content_type, + ) + store[path.name] = value + + def get_root(self) -> BytesIO: + """ + Get the value at the root path. + """ + storage_data = self.token.extra_data + path = Path(self.paths.relative_root) + + store = S3Store( + path=str(path.parent), + access_key=storage_data.key, + secret_key=storage_data.secret, + token=f"jwt/{self.token.access_token}", + endpoint_url=storage_data.endpoint, + ) + return BytesIO(store[path.name]) + + @property + def extension_store(self) -> S3Store | None: + """ + Returns a Zarr store for the extension path, if available, backed by a S3 compatible bucket. + """ + storage_data = self.token.extra_data + return ( + S3Store( + path=self.paths.relative_extension, + access_key=storage_data.key, + secret_key=storage_data.secret, + token=f"jwt/{self.token.access_token}", + endpoint_url=storage_data.endpoint, + ) + if self.paths.relative_extension + else None + ) diff --git a/polaris/loader/load.py b/polaris/loader/load.py index 797f7b78..0cb125cc 100644 --- a/polaris/loader/load.py +++ b/polaris/loader/load.py @@ -9,7 +9,6 @@ ) from polaris.dataset import DatasetV1, create_dataset_from_file from polaris.hub.client import PolarisHubClient -from polaris.utils.misc import should_verify_checksum from polaris.utils.types import ChecksumStrategy @@ -46,7 +45,7 @@ def load_dataset(path: str, verify_checksum: ChecksumStrategy = "verify_unless_z dataset = create_dataset_from_file(path) # Verify checksum if requested - if should_verify_checksum(verify_checksum, dataset): + if dataset.should_verify_checksum(verify_checksum): dataset.verify_checksum() return dataset @@ -89,7 +88,7 @@ def load_benchmark(path: str, verify_checksum: ChecksumStrategy = "verify_unless benchmark = cls.from_json(path) # Verify checksum if requested - if should_verify_checksum(verify_checksum, benchmark.dataset): + if benchmark.dataset.should_verify_checksum(verify_checksum): benchmark.verify_checksum() return benchmark diff --git a/polaris/utils/errors.py b/polaris/utils/errors.py index ad7726bf..62b2f575 100644 --- a/polaris/utils/errors.py +++ b/polaris/utils/errors.py @@ -33,6 +33,10 @@ class InvalidZarrChecksum(Exception): class PolarisHubError(Exception): + BOLD = "\033[1m" + YELLOW = "\033[93m" + _END_CODE = "\033[0m" + def __init__(self, message: str = "", response: Response | None = None): prefix = "The request to the Polaris Hub failed." @@ -41,6 +45,12 @@ def __init__(self, message: str = "", response: Response | None = None): super().__init__("\n".join([prefix, message])) + def format(self, text: str, codes: str | list[str]): + if not isinstance(codes, list): + codes = [codes] + + return "".join(codes) + text + self._END_CODE + class PolarisUnauthorizedError(PolarisHubError): def __init__(self, response: Response | None = None): diff --git a/polaris/utils/misc.py b/polaris/utils/misc.py index 2622fde4..b9156ea5 100644 --- a/polaris/utils/misc.py +++ b/polaris/utils/misc.py @@ -1,9 +1,6 @@ -from typing import TYPE_CHECKING, Any +from typing import Any -from polaris.utils.types import ChecksumStrategy, SlugCompatibleStringType - -if TYPE_CHECKING: - from polaris.dataset import DatasetV1 +from polaris.utils.types import SlugCompatibleStringType, SlugStringType def listit(t: Any): @@ -14,20 +11,8 @@ def listit(t: Any): return list(map(listit, t)) if isinstance(t, (list, tuple)) else t -def sluggify(sluggable: SlugCompatibleStringType): - """ - Converts a string to a slug-compatible string. - """ - return sluggable.lower().replace("_", "-") - - -def should_verify_checksum(strategy: ChecksumStrategy, dataset: "DatasetV1") -> bool: +def slugify(sluggable: SlugCompatibleStringType) -> SlugStringType: """ - Determines whether a checksum should be verified. + Converts a slug-compatible string to a slug. """ - if strategy == "ignore": - return False - elif strategy == "verify": - return True - else: - return not dataset.uses_zarr + return sluggable.lower().replace("_", "-").strip("-") diff --git a/polaris/utils/types.py b/polaris/utils/types.py index 9bac47d5..27c94992 100644 --- a/polaris/utils/types.py +++ b/polaris/utils/types.py @@ -3,6 +3,7 @@ import numpy as np from pydantic import ( + AnyUrl, BaseModel, BeforeValidator, ConfigDict, @@ -11,7 +12,7 @@ TypeAdapter, ) from pydantic.alias_generators import to_camel -from typing_extensions import TypeAlias +from typing_extensions import Self, TypeAlias SplitIndicesType: TypeAlias = list[int] """ @@ -65,7 +66,6 @@ The string must be at least 4 and at most 64 characters long. """ - HubUser: TypeAlias = SlugCompatibleStringType """ A user on the Polaris Hub is identified by a username, @@ -74,12 +74,18 @@ HttpUrlAdapter = TypeAdapter(HttpUrl) HttpUrlString: TypeAlias = Annotated[str, BeforeValidator(lambda v: HttpUrlAdapter.validate_python(v) and v)] - """ -A validated URL that will be turned into a string. +A validated HTTP URL that will be turned into a string. This is useful for interactions with httpx and authlib, who have their own URL types. """ +AnyUrlAdapter = TypeAdapter(AnyUrl) +AnyUrlString: TypeAlias = Annotated[str, BeforeValidator(lambda v: AnyUrlAdapter.validate_python(v) and v)] +""" +A validated generic URL that will be turned into a string. +This is useful for interactions with other libraries that expect a string. +""" + DirectionType: TypeAlias = float | Literal["min", "max"] """ The direction of any variable to be sorted. @@ -118,6 +124,11 @@ Type to specify which action to take to verify the data integrity of an artifact through a checksum. """ +ArtifactUrn: TypeAlias = Annotated[str, StringConstraints(pattern=r"^urn:polaris:\w+:\w+:\w+$")] +""" +A Uniform Resource Name (URN) for an artifact on the Polaris Hub. +""" + RowIndex: TypeAlias = int | str ColumnIndex: TypeAlias = str DatasetIndex: TypeAlias = RowIndex | tuple[RowIndex, ColumnIndex] @@ -150,7 +161,7 @@ def __str__(self): return self.slug @staticmethod - def normalize(owner: Union[str, "HubOwner"]) -> "HubOwner": + def normalize(owner: str | Self) -> Self: """ Normalize a string or `HubOwner` instance to a `HubOwner` instance. """ diff --git a/pyproject.toml b/pyproject.toml index c55ff7fd..2b477d57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,8 @@ dependencies = [ "httpx", "tenacity", "filelock", - "numpy < 2", # We need to pin numpy to avoid issues with fastpdb/biotite. + "numpy < 2", + # We need to pin numpy to avoid issues with fastpdb/biotite. "pandas", "scipy", "scikit-learn", @@ -52,9 +53,11 @@ dependencies = [ "datamol >=0.12.1", "fastpdb", "zarr", - "pyarrow", + "pyarrow < 18", "fsspec[http]", - "yaspin" + "yaspin", + "typing-extensions>=4.12.0", + "boto3>=1.35.0", ] [project.optional-dependencies] @@ -64,7 +67,8 @@ dev = [ "pytest-cov", "ruff", "jupyterlab", - "ipywidgets" + "ipywidgets", + "moto[s3]>=5.0.14", ] doc = [ "mkdocs", @@ -133,10 +137,10 @@ output = "coverage.xml" lint.ignore = [ "E501", # Never enforce `E501` (line length violations). ] -line-length = 110 -target-version = "py310" lint.per-file-ignores."__init__.py" = [ "F401", # imported but unused "E402", # Module level import not at top of file ] +line-length = 110 +target-version = "py310" diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 5319bb7e..d84f50d0 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -30,7 +30,7 @@ def test_load_data(tmp_path, with_slice, with_caching): dataset = DatasetV1(table=table, annotations={"A": {"is_pointer": True}}, zarr_root_path=zarr_path) if with_caching: - dataset.cache_dir = fs.join(tmpdir, "cache") + dataset._cache_dir = fs.join(tmpdir, "cache") dataset.cache() data = dataset.get_data(row=0, col="A") @@ -132,7 +132,7 @@ def test_dataset_caching(zarr_archive, tmpdir): cached_dataset = create_dataset_from_file(zarr_archive, tmpdir.join("original2")) assert original_dataset == cached_dataset - cached_dataset.cache_dir = tmpdir.join("cached").strpath + cached_dataset._cache_dir = tmpdir.join("cached").strpath cache_dir = cached_dataset.cache(verify_checksum=True) assert cached_dataset.zarr_root_path.startswith(cache_dir) diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index d0141152..7113004b 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -13,7 +13,7 @@ from polaris.dataset._factory import DatasetFactory from polaris.dataset.converters._pdb import PDBConverter from polaris.dataset.zarr._manifest import generate_zarr_manifest -from polaris.experimental._dataset_v2 import _INDEX_ARRAY_KEY, DatasetV2 +from polaris.experimental._dataset_v2 import DatasetV2, _INDEX_ARRAY_KEY def test_dataset_v2_get_columns(test_dataset_v2): @@ -76,7 +76,7 @@ def test_dataset_v2_serialization(test_dataset_v2, tmpdir): def test_dataset_v2_caching(test_dataset_v2, tmpdir): cache_dir = tmpdir.join("cache").strpath - test_dataset_v2.cache_dir = cache_dir + test_dataset_v2._cache_dir = cache_dir test_dataset_v2.cache() assert str(test_dataset_v2.zarr_root_path).startswith(cache_dir) @@ -243,7 +243,7 @@ def test_zarr_manifest(test_dataset_v2): root = zarr.open(test_dataset_v2.zarr_root_path, "a") root.array("C", data=np.random.random((100, 2048)), chunks=(1, None)) - generate_zarr_manifest(test_dataset_v2.zarr_root_path, test_dataset_v2.cache_dir) + generate_zarr_manifest(test_dataset_v2.zarr_root_path, test_dataset_v2._cache_dir) # Get the length of the updated manifest file post_change_manifest_length = len(pd.read_parquet(test_dataset_v2.zarr_manifest_path)) diff --git a/tests/test_factory.py b/tests/test_factory.py index 44c09fd5..92d4d5d3 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -1,10 +1,11 @@ import datamol as dm import pandas as pd -from zarr.errors import ContainsArrayError import pytest from fastpdb import struc +from zarr.errors import ContainsArrayError -from polaris.dataset import DatasetFactory, create_dataset_from_file, create_dataset_from_files +from polaris.dataset import DatasetFactory, create_dataset_from_file +from polaris.dataset._factory import create_dataset_from_files from polaris.dataset.converters import PDBConverter, SDFConverter, ZarrConverter diff --git a/tests/test_storage.py b/tests/test_storage.py new file mode 100644 index 00000000..e67e371c --- /dev/null +++ b/tests/test_storage.py @@ -0,0 +1,128 @@ +import os + +import boto3 +import pytest +from moto import mock_aws + +from polaris.hub.storage import S3Store + + +@pytest.fixture(scope="function") +def aws_credentials(): + """ + Mocked AWS Credentials for moto. + """ + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + os.environ["AWS_SECURITY_TOKEN"] = "testing" + os.environ["AWS_SESSION_TOKEN"] = "testing" + + +@pytest.fixture(scope="function") +def mocked_aws(aws_credentials): + """ + Mock all AWS interactions + Requires you to create your own boto3 clients + """ + with mock_aws(): + yield + + +@pytest.fixture +def s3_store(mocked_aws): + # Setup mock S3 environment + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "test-bucket" + s3.create_bucket(Bucket=bucket_name) + + # Create an instance of your S3Store + store = S3Store( + path=f"{bucket_name}/prefix", + access_key="fake-access-key", + secret_key="fake-secret-key", + token="fake-token", + endpoint_url="https://s3.amazonaws.com", + ) + + yield store + + +def test_set_and_get_item(s3_store): + key = "test-key" + value = b"test-value" + s3_store[key] = value + + retrieved_value = s3_store[key] + assert retrieved_value == value + + +def test_get_nonexistent_item(s3_store): + with pytest.raises(KeyError): + _ = s3_store["nonexistent-key"] + + +def test_contains_item(s3_store): + key = "test-key" + value = b"test-value" + s3_store[key] = value + + assert key in s3_store + assert "nonexistent-key" not in s3_store + + +def test_store_iterator_empty(s3_store): + stored_keys = list(s3_store) + assert stored_keys == [] + + +def test_store_iterator(s3_store): + keys = ["dir1/subdir1", "dir1/subdir2", "dir1/file1.ext", "dir2/file2.ext"] + for key in keys: + s3_store[key] = b"test" + + stored_keys = list(s3_store) + assert sorted(stored_keys) == sorted(keys) + + +def test_store_length(s3_store): + keys = ["dir1/subdir1", "dir1/subdir2", "dir1/file1.ext", "dir2/file2.ext"] + for key in keys: + s3_store[key] = b"test" + + assert len(s3_store) == len(keys) + + +def test_listdir(s3_store): + keys = ["dir1/subdir1", "dir1/subdir2", "dir1/file1.ext", "dir2/file2.ext"] + for key in keys: + s3_store[key] = b"test" + + dir1_contents = list(s3_store.listdir("dir1")) + assert set(dir1_contents) == {"file1.ext", "subdir1", "subdir2"} + + dir1_contents = list(s3_store.listdir()) + assert set(dir1_contents) == {"dir1", "dir2"} + + +def test_getsize(s3_store): + key = "test-key" + value = b"test-value" + s3_store[key] = value + + size = s3_store.getsize(key) + assert size == len(value) + + +def test_getitems(s3_store): + keys = ["dir1/subdir1", "dir1/subdir2", "dir1/file1.ext", "dir2/file2.ext"] + for key in keys: + s3_store[key] = b"test" + + items = s3_store.getitems(keys, contexts={}) + assert len(items) == len(keys) + assert all(key in items for key in keys) + + +def test_delete_item_not_supported(s3_store): + with pytest.raises(NotImplementedError): + del s3_store["some-key"]