diff --git a/env.yml b/env.yml index d036303c..df6dc398 100644 --- a/env.yml +++ b/env.yml @@ -4,17 +4,15 @@ channels: dependencies: - python >=3.10 - pip - - tqdm - - loguru - typer - pyyaml - pydantic >=2 - pydantic-settings >=2 - fsspec - - halo - typing-extensions >=4.12.0 - boto3 <1.36.0 - pyroaring + - rich # Hub client - authlib diff --git a/polaris/__init__.py b/polaris/__init__.py index 273793ba..0246f97b 100644 --- a/polaris/__init__.py +++ b/polaris/__init__.py @@ -1,14 +1,18 @@ -import os -import sys +import logging -from loguru import logger +from rich.logging import RichHandler from ._version import __version__ from .loader import load_benchmark, load_competition, load_dataset __all__ = ["load_dataset", "load_benchmark", "load_competition", "__version__"] -# Configure the default logging level -os.environ["LOGURU_LEVEL"] = os.environ.get("LOGURU_LEVEL", "INFO") -logger.remove() -logger.add(sys.stderr, level=os.environ["LOGURU_LEVEL"]) +# Polaris specific logger +logger = logging.getLogger(__name__) + +# Only add handler if the logger has not already been configured externally +if not logger.handlers: + handler = RichHandler(rich_tracebacks=True) + handler.setFormatter(logging.Formatter("%(message)s", datefmt="[%Y-%m-%d %X]")) + logger.addHandler(handler) + logger.setLevel(logging.INFO) diff --git a/polaris/_artifact.py b/polaris/_artifact.py index 85a6c143..ec878ec3 100644 --- a/polaris/_artifact.py +++ b/polaris/_artifact.py @@ -1,8 +1,8 @@ import json +import logging from typing import ClassVar, Literal import fsspec -from loguru import logger from packaging.version import Version from pydantic import ( BaseModel, @@ -19,6 +19,8 @@ from polaris.utils.misc import slugify from polaris.utils.types import ArtifactUrn, HubOwner, SlugCompatibleStringType, SlugStringType +logger = logging.getLogger(__name__) + class BaseArtifactModel(BaseModel): """ diff --git a/polaris/benchmark/_split.py b/polaris/benchmark/_split.py index cdc40707..3971b4c1 100644 --- a/polaris/benchmark/_split.py +++ b/polaris/benchmark/_split.py @@ -1,6 +1,6 @@ +import logging from itertools import chain -from loguru import logger from pydantic import BaseModel, computed_field, field_serializer, model_validator from typing_extensions import Self @@ -8,6 +8,8 @@ from polaris.utils.misc import listit from polaris.utils.types import SplitType +logger = logging.getLogger(__name__) + class SplitSpecificationV1Mixin(BaseModel): """ diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index 875de6f4..92f17761 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -1,5 +1,6 @@ import abc import json +import logging from os import PathLike from pathlib import Path, PurePath from typing import Any, Iterable, MutableMapping @@ -8,7 +9,6 @@ import fsspec import numpy as np import zarr -from loguru import logger from pydantic import ( Field, PrivateAttr, @@ -25,6 +25,7 @@ from polaris.dataset.zarr import MemoryMappedDirectoryStore from polaris.dataset.zarr._utils import check_zarr_codecs, load_zarr_group_to_memory from polaris.utils.constants import DEFAULT_CACHE_DIR +from polaris.utils.context import track_progress from polaris.utils.dict2html import dict2html from polaris.utils.errors import InvalidDatasetError from polaris.utils.types import ( @@ -37,6 +38,8 @@ ZarrConflictResolution, ) +logger = logging.getLogger(__name__) + # Constants _CACHE_SUBDIR = "datasets" @@ -371,19 +374,24 @@ def _cache_zarr(self, destination: str | PathLike, if_exists: ZarrConflictResolu # Copy over Zarr data to the destination self._warn_about_remote_zarr = False - logger.info(f"Copying Zarr archive to {destination_zarr_root}. This may take a while.") - destination_store = zarr.open(str(destination_zarr_root), "w").store - source_store = self.zarr_root.store.store - - if isinstance(source_store, S3Store): - source_store.copy_to_destination(destination_store, if_exists, logger.info) - else: - zarr.copy_store( - source=source_store, - dest=destination_store, - log=logger.info, - if_exists=if_exists, - ) - self.zarr_root_path = str(destination_zarr_root) - self._zarr_root = None - self._zarr_data = None + with track_progress(description="Copying Zarr archive", total=1) as ( + progress, + task, + ): + progress.log(f"[green]Copying to destination {destination_zarr_root}") + progress.log("[yellow]For large Zarr archives, this may take a while.") + destination_store = zarr.open(str(destination_zarr_root), "w").store + source_store = self.zarr_root.store.store + + if isinstance(source_store, S3Store): + source_store.copy_to_destination(destination_store, if_exists, logger.info) + else: + zarr.copy_store( + source=source_store, + dest=destination_store, + log=logger.info, + if_exists=if_exists, + ) + self.zarr_root_path = str(destination_zarr_root) + self._zarr_root = None + self._zarr_data = None diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 9aaa5deb..5829b837 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -332,6 +332,8 @@ def cache( if verify_checksum: self.verify_checksum() + else: + self._md5sum = None return str(destination) diff --git a/polaris/dataset/_dataset_v2.py b/polaris/dataset/_dataset_v2.py index 4e35f062..5b5e26c2 100644 --- a/polaris/dataset/_dataset_v2.py +++ b/polaris/dataset/_dataset_v2.py @@ -1,4 +1,5 @@ import json +import logging import re from os import PathLike from pathlib import Path @@ -7,7 +8,6 @@ import fsspec import numpy as np import zarr -from loguru import logger from pydantic import PrivateAttr, computed_field, model_validator from typing_extensions import Self @@ -17,6 +17,8 @@ from polaris.utils.errors import InvalidDatasetError from polaris.utils.types import AccessType, ChecksumStrategy, HubOwner, ZarrConflictResolution +logger = logging.getLogger(__name__) + _INDEX_ARRAY_KEY = "__index__" diff --git a/polaris/dataset/_factory.py b/polaris/dataset/_factory.py index a2a13224..83597c9a 100644 --- a/polaris/dataset/_factory.py +++ b/polaris/dataset/_factory.py @@ -1,15 +1,17 @@ +import logging import os from typing import Literal import datamol as dm import pandas as pd import zarr -from loguru import logger from polaris.dataset import ColumnAnnotation, DatasetV1 from polaris.dataset._adapters import Adapter from polaris.dataset.converters import Converter, PDBConverter, SDFConverter, ZarrConverter +logger = logging.getLogger(__name__) + def create_dataset_from_file(path: str, zarr_root_path: str | None = None) -> DatasetV1: """ diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py index 8a1a8348..ccf5d9fd 100644 --- a/polaris/dataset/zarr/_checksum.py +++ b/polaris/dataset/zarr/_checksum.py @@ -45,8 +45,8 @@ import zarr.errors from pydantic import BaseModel, ConfigDict from pydantic.alias_generators import to_camel -from tqdm import tqdm +from polaris.utils.context import track_progress from polaris.utils.errors import InvalidZarrChecksum ZARR_DIGEST_PATTERN = "([0-9a-f]{32})-([0-9]+)-([0-9]+)" @@ -56,8 +56,8 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple["_ZarrDirectoryDigest", r""" Implements an algorithm to compute the Zarr checksum. - Warning: This checksum is sensitive to Zarr configuration. - This checksum is sensitive to change in the Zarr structure. For example, if you change the chunk size, + Warning: This checksum is sensitive to Zarr configuration. + This checksum is sensitive to change in the Zarr structure. For example, if you change the chunk size, the checksum will also change. To understand how this works, consider the following directory structure: @@ -67,17 +67,17 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple["_ZarrDirectoryDigest", a c / b - + Within zarr, this would for example be: - `root`: A Zarr Group with a single Array. - `a`: A Zarr Array - `b`: A single chunk of the Zarr Array - - `c`: A metadata file (i.e. .zarray, .zattrs or .zgroup) + - `c`: A metadata file (i.e. .zarray, .zattrs or .zgroup) - To compute the checksum, we first find all the trees in the node, in this case b and c. + To compute the checksum, we first find all the trees in the node, in this case b and c. We compute the hash of the content (the raw bytes) for each of these files. - + We then work our way up the tree. For any node (directory), we find all children of that node. In an sorted order, we then serialize a list with - for each of the children - the checksum, size, and number of children. The hash of the directory is then equal to the hash of the serialized JSON. @@ -116,33 +116,40 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple["_ZarrDirectoryDigest", leaves = fs.find(zarr_root_path, detail=True) zarr_md5sum_manifest = [] - for file in tqdm(leaves.values(), desc="Finding all files in the Zarr archive"): - path = file["name"] - - relpath = path.removeprefix(zarr_root_path) - relpath = relpath.lstrip("/") - relpath = Path(relpath) - - size = file["size"] - - # Compute md5sum of file - md5sum = hashlib.md5() - with fs.open(path, "rb") as f: - for chunk in iter(lambda: f.read(8192), b""): - md5sum.update(chunk) - digest = md5sum.hexdigest() + files = leaves.values() + with track_progress(description="Finding all files in the Zarr archive", total=len(files)) as ( + progress, + task, + ): + for file in files: + path = file["name"] + + relpath = path.removeprefix(zarr_root_path) + relpath = relpath.lstrip("/") + relpath = Path(relpath) + + size = file["size"] + + # Compute md5sum of file + md5sum = hashlib.md5() + with fs.open(path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + md5sum.update(chunk) + digest = md5sum.hexdigest() + + # Add a leaf to the tree + # (This actually adds the file's checksum to the parent directory's manifest) + tree.add_leaf( + path=relpath, + size=size, + digest=digest, + ) - # Add a leaf to the tree - # (This actually adds the file's checksum to the parent directory's manifest) - tree.add_leaf( - path=relpath, - size=size, - digest=digest, - ) + # We persist the checksums for leaf nodes separately, + # because this is what the Hub needs to verify data integrity. + zarr_md5sum_manifest.append(ZarrFileChecksum(path=str(relpath), md5sum=digest, size=size)) - # We persist the checksums for leaf nodes separately, - # because this is what the Hub needs to verify data integrity. - zarr_md5sum_manifest.append(ZarrFileChecksum(path=str(relpath), md5sum=digest, size=size)) + progress.update(task, advance=1, refresh=True) # Compute digest return tree.process(), zarr_md5sum_manifest diff --git a/polaris/experimental/_split_v2.py b/polaris/experimental/_split_v2.py index ef0b11f0..e209a0f9 100644 --- a/polaris/experimental/_split_v2.py +++ b/polaris/experimental/_split_v2.py @@ -1,8 +1,8 @@ +import logging from functools import cached_property from hashlib import md5 from typing import Generator, Sequence -from loguru import logger from pydantic import BaseModel, ConfigDict, Field, computed_field, field_validator, model_validator from pydantic.alias_generators import to_camel from pyroaring import BitMap @@ -10,6 +10,8 @@ from polaris.utils.errors import InvalidBenchmarkError +logger = logging.getLogger(__name__) + class IndexSet(BaseModel): """ diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 55a5b144..52f2e6a5 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -1,4 +1,5 @@ import json +import logging from hashlib import md5 from io import BytesIO from typing import get_args @@ -12,7 +13,6 @@ from authlib.oauth2 import OAuth2Error, TokenAuth from authlib.oauth2.rfc6749 import OAuth2Token from httpx import ConnectError, HTTPStatusError, Response -from loguru import logger from typing_extensions import Self from polaris.benchmark import ( @@ -28,7 +28,7 @@ from polaris.hub.oauth import CachedTokenAuth from polaris.hub.settings import PolarisHubSettings from polaris.hub.storage import StorageSession -from polaris.utils.context import ProgressIndicator +from polaris.utils.context import track_progress from polaris.utils.errors import ( InvalidDatasetError, PolarisCreateArtifactError, @@ -46,6 +46,8 @@ ZarrConflictResolution, ) +logger = logging.getLogger(__name__) + _HTTPX_SSL_ERROR_CODE = "[SSL: CERTIFICATE_VERIFY_FAILED]" @@ -247,7 +249,7 @@ def login(self, overwrite: bool = False, auto_open_browser: bool = True): self.external_client.interactive_login(overwrite=overwrite, auto_open_browser=auto_open_browser) self.token = self.fetch_token() - logger.success("You are successfully logged in to the Polaris Hub.") + logger.info("You are successfully logged in to the Polaris Hub.") # ========================= # API Endpoints @@ -264,11 +266,7 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]: Returns: A list of dataset names in the format `owner/dataset_name`. """ - with ProgressIndicator( - start_msg="Fetching datasets...", - success_msg="Fetched datasets.", - error_msg="Failed to fetch datasets.", - ): + with track_progress(description="Fetching datasets", total=1): # Step 1: Fetch enough v2 datasets to cover the offset and limit v2_json_response = self._base_request_to_hub( url="/v2/dataset", method="GET", params={"limit": limit, "offset": offset} @@ -317,11 +315,7 @@ def get_dataset( Returns: A `Dataset` instance, if it exists. """ - with ProgressIndicator( - start_msg="Fetching dataset...", - success_msg="Fetched dataset.", - error_msg="Failed to fetch dataset.", - ): + with track_progress(description="Fetching dataset", total=1): try: return self._get_v1_dataset(owner, name, verify_checksum) except PolarisRetrieveArtifactError: @@ -397,11 +391,7 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]: Returns: A list of benchmark names in the format `owner/benchmark_name`. """ - with ProgressIndicator( - start_msg="Fetching benchmarks...", - success_msg="Fetched benchmarks.", - error_msg="Failed to fetch benchmarks.", - ): + with track_progress(description="Fetching benchmarks", total=1): # Step 1: Fetch enough v2 benchmarks to cover the offset and limit v2_json_response = self._base_request_to_hub( url="/v2/benchmark", method="GET", params={"limit": limit, "offset": offset} @@ -449,11 +439,7 @@ def get_benchmark( Returns: A `BenchmarkSpecification` instance, if it exists. """ - with ProgressIndicator( - start_msg="Fetching benchmark...", - success_msg="Fetched benchmark.", - error_msg="Failed to fetch benchmark.", - ): + with track_progress(description="Fetching benchmark", total=1): try: return self._get_v1_benchmark(owner, name, verify_checksum) except PolarisRetrieveArtifactError: @@ -530,11 +516,7 @@ def upload_results( access: Grant public or private access to result owner: Which Hub user or organization owns the artifact. Takes precedence over `results.owner`. """ - with ProgressIndicator( - start_msg="Uploading artifact...", - success_msg="Uploaded artifact.", - error_msg="Failed to upload result.", - ) as progress_indicator: + with track_progress(description="Uploading results", total=1) as (progress, task): # Get the serialized model data-structure results.owner = HubOwner.normalize(owner or results.owner) result_json = results.model_dump(by_alias=True, exclude_none=True) @@ -547,8 +529,8 @@ def upload_results( # Inform the user about where to find their newly created artifact. result_url = urljoin(self.settings.hub_url, response.headers.get("Content-Location")) - progress_indicator.update_success_msg( - f"Your result has been successfully uploaded to the Hub. View it here: {result_url}" + progress.log( + f"[green]Your result has been successfully uploaded to the Hub. View it here: {result_url}" ) def upload_dataset( @@ -610,11 +592,7 @@ def _upload_v1_dataset( Upload a V1 dataset to the Polaris Hub. """ - with ProgressIndicator( - start_msg="Uploading artifact...", - success_msg="Uploaded artifact.", - error_msg="Failed to upload dataset.", - ) as progress_indicator: + with track_progress(description="Uploading dataset", total=1) as (progress, task): # Get the serialized data-model # We exclude the table as it handled separately dataset.owner = HubOwner.normalize(owner or dataset.owner) @@ -656,30 +634,30 @@ def _upload_v1_dataset( ) with StorageSession(self, "write", dataset.urn) as storage: - # Step 2: Upload the parquet file - logger.info("Copying Parquet file to the Hub. This may take a while.") - storage.set_file("root", in_memory_parquet.getvalue()) + with track_progress(description="Copying Parquet file", total=1) as (progress, task): + # Step 2: Upload the parquet file + progress.log("[yellow]This may take a while.") + storage.set_file("root", in_memory_parquet.getvalue()) - # Step 3: Upload any associated Zarr archive + # Step 3: Upload any associated Zarr archive if dataset.uses_zarr: - logger.info("Copying Zarr archive to the Hub. This may take a while.") - - destination = storage.store("extension") - - # Locally consolidate Zarr archive metadata. Future updates on handling consolidated - # metadata based on Zarr developers' recommendations can be tracked at: - # https://github.com/zarr-developers/zarr-python/issues/1731 - zarr.consolidate_metadata(dataset.zarr_root.store.store) - zmetadata_content = dataset.zarr_root.store.store[".zmetadata"] - destination[".zmetadata"] = zmetadata_content - - # Copy the Zarr archive to the hub - destination.copy_from_source( - dataset.zarr_root.store.store, if_exists=if_exists, log=logger.info - ) - - progress_indicator.update_success_msg( - f"Your dataset has been successfully uploaded to the Hub. " + with track_progress(description="Copying Zarr archive", total=1): + destination = storage.store("extension") + + # Locally consolidate Zarr archive metadata. Future updates on handling consolidated + # metadata based on Zarr developers' recommendations can be tracked at: + # https://github.com/zarr-developers/zarr-python/issues/1731 + zarr.consolidate_metadata(dataset.zarr_root.store.store) + zmetadata_content = dataset.zarr_root.store.store[".zmetadata"] + destination[".zmetadata"] = zmetadata_content + + # Copy the Zarr archive to the hub + destination.copy_from_source( + dataset.zarr_root.store.store, if_exists=if_exists, log=logger.info + ) + + progress.log( + f"[green]Your dataset has been successfully uploaded to the Hub. " f"View it here: {urljoin(self.settings.hub_url, f'datasets/{dataset.owner}/{dataset.name}')}" ) @@ -695,11 +673,7 @@ def _upload_v2_dataset( Upload a V2 dataset to the Polaris Hub. """ - with ProgressIndicator( - start_msg="Uploading artifact...", - success_msg="Uploaded artifact.", - error_msg="Failed to upload dataset.", - ) as progress_indicator: + with track_progress(description="Uploading dataset", total=1) as (progress, task): # Get the serialized data-model dataset.owner = HubOwner.normalize(owner or dataset.owner) dataset_json = dataset.model_dump(exclude_none=True, by_alias=True) @@ -721,30 +695,34 @@ def _upload_v2_dataset( with StorageSession(self, "write", dataset.urn) as storage: # Step 2: Upload the manifest file - logger.info("Copying the dataset manifest file to the Hub.") - with open(dataset.zarr_manifest_path, "rb") as manifest_file: - storage.set_file("manifest", manifest_file.read()) + with track_progress(description="Copying manifest file", total=1): + with open(dataset.zarr_manifest_path, "rb") as manifest_file: + storage.set_file("manifest", manifest_file.read()) # Step 3: Upload the Zarr archive - logger.info("Copying Zarr archive to the Hub. This may take a while.") + with track_progress(description="Copying Zarr archive", total=1) as ( + progress_zarr, + task_zarr, + ): + progress_zarr.log("[yellow]This may take a while.") - destination = storage.store("root") + destination = storage.store("root") - # Locally consolidate Zarr archive metadata. Future updates on handling consolidated - # metadata based on Zarr developers' recommendations can be tracked at: - # https://github.com/zarr-developers/zarr-python/issues/1731 - zarr.consolidate_metadata(dataset.zarr_root.store.store) - zmetadata_content = dataset.zarr_root.store.store[".zmetadata"] - destination[".zmetadata"] = zmetadata_content + # Locally consolidate Zarr archive metadata. Future updates on handling consolidated + # metadata based on Zarr developers' recommendations can be tracked at: + # https://github.com/zarr-developers/zarr-python/issues/1731 + zarr.consolidate_metadata(dataset.zarr_root.store.store) + zmetadata_content = dataset.zarr_root.store.store[".zmetadata"] + destination[".zmetadata"] = zmetadata_content - # Copy the Zarr archive to the hub - destination.copy_from_source( - dataset.zarr_root.store.store, if_exists=if_exists, log=logger.info - ) + # Copy the Zarr archive to the hub + destination.copy_from_source( + dataset.zarr_root.store.store, if_exists=if_exists, log=logger.info + ) dataset_url = urljoin(self.settings.hub_url, response.headers.get("Content-Location")) - progress_indicator.update_success_msg( - f"Your V2 dataset has been successfully uploaded to the Hub. View it here: {dataset_url}" + progress.log( + f"[green]Your V2 dataset has been successfully uploaded to the Hub. View it here: {dataset_url}" ) def upload_benchmark( @@ -807,11 +785,7 @@ def _upload_v1_benchmark( access: Grant public or private access to result owner: Which Hub user or organization owns the artifact. Takes precedence over `benchmark.owner`. """ - with ProgressIndicator( - start_msg="Uploading artifact...", - success_msg="Uploaded artifact.", - error_msg="Failed to upload benchmark.", - ) as progress_indicator: + with track_progress(description="Uploading benchmark", total=1) as (progress, task): # Get the serialized data-model # We exclude the dataset as we expect it to exist on the hub already. benchmark.owner = HubOwner.normalize(owner or benchmark.owner) @@ -823,9 +797,8 @@ def _upload_v1_benchmark( url = f"{path_params}/{benchmark.owner}/{benchmark.name}" self._base_request_to_hub(url=url, method="PUT", json=benchmark_json) - progress_indicator.update_success_msg( - f"Your benchmark has been successfully uploaded to the Hub. " - f"View it here: {urljoin(self.settings.hub_url, url)}" + progress.log( + f"[green]Your benchmark has been successfully uploaded to the Hub. View it here: {urljoin(self.settings.hub_url, url)}" ) def _upload_v2_benchmark( @@ -834,11 +807,7 @@ def _upload_v2_benchmark( access: AccessType = "private", owner: HubOwner | str | None = None, ): - with ProgressIndicator( - start_msg="Uploading artifact...", - success_msg="Uploaded artifact.", - error_msg="Failed to upload benchmark.", - ) as progress_indicator: + with track_progress(description="Uploading benchmark", total=1) as (progress, task): # Get the serialized data-model # We exclude the dataset as we expect it to exist on the hub already. benchmark.owner = HubOwner.normalize(owner or benchmark.owner) @@ -860,13 +829,17 @@ def _upload_v2_benchmark( logger.info("Copying the benchmark split to the Hub. This may take a while.") # 2. Upload each index set bitmap - for label, index_set in benchmark.split: - logger.info(f"Copying index set {label} to the Hub.") - storage.set_file(label, index_set.serialize()) + with track_progress( + description="Copying index sets", total=benchmark.split.n_test_sets + 1 + ) as (progress_index_sets, task_index_sets): + for label, index_set in benchmark.split: + logger.info(f"Copying index set {label} to the Hub.") + storage.set_file(label, index_set.serialize()) + progress_index_sets.update(task_index_sets, advance=1, refresh=True) benchmark_url = urljoin(self.settings.hub_url, response.headers.get("Content-Location")) - progress_indicator.update_success_msg( - f"Your benchmark has been successfully uploaded to the Hub. View it here: {benchmark_url}" + progress.log( + f"[green]Your benchmark has been successfully uploaded to the Hub. View it here: {benchmark_url}" ) def get_competition(self, artifact_id: str) -> CompetitionSpecification: @@ -901,12 +874,7 @@ def submit_competition_predictions( competition: The competition to evaluate the predictions for. competition_predictions: The predictions and associated metadata to be submitted to the Hub. """ - with ProgressIndicator( - start_msg="Submitting competition predictions...", - success_msg="Submitted competition predictions.", - error_msg="Failed to submit competition predictions.", - ) as progress_indicator: - # + with track_progress(description="Submitting competition predictions", total=1): # Prepare prediction payload for submission prediction_json = competition_predictions.model_dump(by_alias=True, exclude_none=True) prediction_payload = { @@ -920,9 +888,4 @@ def submit_competition_predictions( method="POST", json=prediction_payload, ) - - # Log success and return submission response - progress_indicator.update_success_msg( - "Your competition predictions have been successfully uploaded to the Hub for evaluation." - ) return response diff --git a/polaris/hub/external_client.py b/polaris/hub/external_client.py index 278a4f55..30f191ff 100644 --- a/polaris/hub/external_client.py +++ b/polaris/hub/external_client.py @@ -1,3 +1,4 @@ +import logging import webbrowser from typing import Literal, Optional, TypeAlias @@ -6,12 +7,13 @@ from authlib.integrations.httpx_client import OAuth2Client from authlib.oauth2 import OAuth2Error, TokenAuth from authlib.oauth2.rfc6749 import OAuth2Token -from loguru import logger from polaris.hub.oauth import ExternalCachedTokenAuth from polaris.hub.settings import PolarisHubSettings from polaris.utils.errors import PolarisHubError, PolarisUnauthorizedError +logger = logging.getLogger(__name__) + Scope: TypeAlias = Literal["read", "write"] @@ -155,4 +157,4 @@ def interactive_login(self, overwrite: bool = False, auto_open_browser: bool = T # Step 3: Exchange authorization code for an access token self.fetch_token(code=authorization_code, grant_type="authorization_code") - logger.success(f"Successfully authenticated to the Polaris Hub as `{self.user_info['email']}`! 🎉") + logger.info(f"Successfully authenticated to the Polaris Hub as `{self.user_info['email']}`! 🎉") diff --git a/polaris/hub/storage.py b/polaris/hub/storage.py index 2704875a..4e225e52 100644 --- a/polaris/hub/storage.py +++ b/polaris/hub/storage.py @@ -20,6 +20,7 @@ from zarr.util import buffer_size from polaris.hub.oauth import BenchmarkV2Paths, DatasetV1Paths, DatasetV2Paths, HubStorageOAuth2Token +from polaris.utils.context import track_progress from polaris.utils.errors import PolarisHubError from polaris.utils.types import ArtifactUrn, ZarrConflictResolution @@ -168,25 +169,33 @@ def copy_to_destination( number_source_keys = len(self) - batch_iter = iter(self) - while batch := tuple(islice(batch_iter, self._batch_size)): - to_put = batch if if_exists == "replace" else filter(lambda key: key not in destination, batch) - skipped = len(batch) - len(to_put) + with track_progress(description="Copying Zarr keys", total=number_source_keys) as ( + progress_keys, + task_keys, + ): + batch_iter = iter(self) + while batch := tuple(islice(batch_iter, self._batch_size)): + to_put = ( + batch if if_exists == "replace" else filter(lambda key: key not in destination, batch) + ) + skipped = len(batch) - len(to_put) - if skipped > 0 and if_exists == "raise": - raise CopyError(f"keys {to_put} exist in destination") + if skipped > 0: + if if_exists == "raise": + raise CopyError(f"keys {to_put} exist in destination") + else: + progress_keys.log(f"Skipped {skipped} keys that already exists") - items = self.getitems(to_put, contexts={}) - for key, content in items.items(): - destination[key] = content - total_bytes_copied += buffer_size(content) + items = self.getitems(to_put, contexts={}) + for key, content in items.items(): + destination[key] = content - total_copied += len(to_put) - total_skipped += skipped + size_copied = buffer_size(content) + total_bytes_copied += size_copied + total_copied += 1 - log( - f"Copied {total_copied} ({total_bytes_copied / (1024**2):.2f} MiB), skipped {total_skipped}, of {number_source_keys} keys. {(total_copied + total_skipped) / number_source_keys * 100:.2f}% completed." - ) + total_skipped += skipped + progress_keys.update(task_keys, advance=len(batch), refresh=True) return total_copied, total_skipped, total_bytes_copied @@ -259,24 +268,29 @@ def copy_key(key: str, source: Store, if_exists: ZarrConflictResolution) -> tupl number_source_keys = len(source) - # Batch the keys, otherwise we end up with too many files open at the same time - batch_iter = iter(source.keys()) - while batch := tuple(islice(batch_iter, self._batch_size)): - # Create a future for each key to copy - future_to_key = [ - executor.submit(copy_key, source_key, source, if_exists) for source_key in batch - ] - - # As each future completes, collect the results - for future in as_completed(future_to_key): - result_copied, result_skipped, result_bytes_copied = future.result() - total_copied += result_copied - total_skipped += result_skipped - total_bytes_copied += result_bytes_copied - - log( - f"Copied {total_copied} ({total_bytes_copied / (1024**2):.2f} MiB), skipped {total_skipped}, of {number_source_keys} keys. {(total_copied + total_skipped) / number_source_keys * 100:.2f}% completed." - ) + with track_progress(description="Copying Zarr keys", total=number_source_keys) as ( + progress_keys, + task_keys, + ): + # Batch the keys, otherwise we end up with too many files open at the same time + batch_iter = iter(source.keys()) + while batch := tuple(islice(batch_iter, self._batch_size)): + # Create a future for each key to copy + future_to_key = [ + executor.submit(copy_key, source_key, source, if_exists) for source_key in batch + ] + + # As each future completes, collect the results + for future in as_completed(future_to_key): + result_copied, result_skipped, result_bytes_copied = future.result() + total_copied += result_copied + progress_keys.update(task_keys, advance=result_copied, refresh=True) + + total_skipped += result_skipped + if result_skipped > 0: + progress_keys.log(f"Skipped {result_skipped} keys that already exists") + + total_bytes_copied += result_bytes_copied return total_copied, total_skipped, total_bytes_copied diff --git a/polaris/mixins/_checksum.py b/polaris/mixins/_checksum.py index 8fccac35..9e05d8a6 100644 --- a/polaris/mixins/_checksum.py +++ b/polaris/mixins/_checksum.py @@ -1,11 +1,13 @@ import abc +import logging import re -from loguru import logger from pydantic import BaseModel, PrivateAttr, computed_field from polaris.utils.errors import PolarisChecksumError +logger = logging.getLogger(__name__) + class ChecksumMixin(BaseModel, abc.ABC): """ diff --git a/polaris/utils/context.py b/polaris/utils/context.py index 582032dd..edabd0cb 100644 --- a/polaris/utils/context.py +++ b/polaris/utils/context.py @@ -1,31 +1,61 @@ -from halo import Halo - -from polaris.mixins import FormattingMixin - - -class ProgressIndicator(FormattingMixin): - def __init__(self, success_msg: str, error_msg: str, start_msg: str = "In progress..."): - self._start_msg = start_msg - self._success_msg = success_msg - self._error_msg = error_msg - - self._spinner = Halo(text=self._start_msg, spinner="dots") - - def __enter__(self): - self._spinner.start() - return self - - def __exit__(self, exc_type, exc_value, traceback): - self._spinner.text = "" - - if exc_type: - self._spinner.text_color = "red" - self._spinner.fail(f"ERROR: {self._error_msg}") - else: - self._spinner.text_color = "green" - self._spinner.succeed(f"SUCCESS: {self.format(self._success_msg, self.BOLD)}\n") - - self._spinner.stop() - - def update_success_msg(self, msg: str): - self._success_msg = msg +from contextlib import contextmanager +from contextvars import ContextVar +from itertools import cycle + +from rich.progress import ( + BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TextColumn, + TimeElapsedColumn, +) + +# Singleton Progress instance to be used for all calls to `track_progress` +progress_instance = ContextVar( + "progress", + default=Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + MofNCompleteColumn(), + TimeElapsedColumn(), + ), +) + +colors = cycle( + { + "green", + "cyan", + "magenta", + } +) + + +@contextmanager +def track_progress(description: str, total: float | None = 1.0): + """ + Use the Progress instance to track a task's progress + """ + progress = progress_instance.get() + + # Make sure the Progress is started + progress.start() + + task = progress.add_task(f"[{next(colors)}]{description}", total=total) + + try: + # Yield the task and Progress instance, for more granular control + yield progress, task + + # Mark the task as completed + progress.update(task, completed=total, refresh=True) + progress.log(f"[green] Success: {description}") + except Exception: + progress.log(f"[red] Error: {description}") + raise + finally: + # Remove the task from the UI, and stop the progress bar if all tasks are completed + progress.remove_task(task) + if progress.finished: + progress.stop() diff --git a/pyproject.toml b/pyproject.toml index ec11e58f..4d5ec5ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,9 +39,7 @@ dependencies = [ "datamol >=0.12.1", "fastpdb", "fsspec[http]", - "halo", "httpx", - "loguru", "numcodecs[msgpack]>=0.13.1", "numpy < 2", # We need to pin numpy to avoid issues with fastpdb/biotite. "pandas", @@ -50,10 +48,10 @@ dependencies = [ "pydantic-settings >=2", "pyroaring", "pyyaml", + "rich>=13.9.4", "scikit-learn", "scipy", "seaborn", - "tqdm", "typer", "typing-extensions>=4.12.0", "zarr >=2,<3", diff --git a/uv.lock b/uv.lock index 53bce2a7..13664469 100644 --- a/uv.lock +++ b/uv.lock @@ -890,19 +890,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259 }, ] -[[package]] -name = "halo" -version = "0.0.31" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama" }, - { name = "log-symbols" }, - { name = "six" }, - { name = "spinners" }, - { name = "termcolor" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ee/48/d53580d30b1fabf25d0d1fcc3f5b26d08d2ac75a1890ff6d262f9f027436/halo-0.0.31.tar.gz", hash = "sha256:7b67a3521ee91d53b7152d4ee3452811e1d2a6321975137762eb3d70063cc9d6", size = 11666 } - [[package]] name = "httpcore" version = "1.0.7" @@ -1414,18 +1401,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/1d/50ad811d1c5dae091e4cf046beba925bcae0a610e79ae4c538f996f63ed5/kiwisolver-1.4.8-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:65ea09a5a3faadd59c2ce96dc7bf0f364986a315949dc6374f04396b0d60e09b", size = 71762 }, ] -[[package]] -name = "log-symbols" -version = "0.0.14" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/45/87/e86645d758a4401c8c81914b6a88470634d1785c9ad09823fa4a1bd89250/log_symbols-0.0.14.tar.gz", hash = "sha256:cf0bbc6fe1a8e53f0d174a716bc625c4f87043cc21eb55dd8a740cfe22680556", size = 3211 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/28/5d/d710c38be68b0fb54e645048fe359c3904cc3cb64b2de9d40e1712bf110c/log_symbols-0.0.14-py3-none-any.whl", hash = "sha256:4952106ff8b605ab7d5081dd2c7e6ca7374584eff7086f499c06edd1ce56dcca", size = 3081 }, -] - [[package]] name = "loguru" version = "0.7.3" @@ -2244,9 +2219,7 @@ dependencies = [ { name = "datamol" }, { name = "fastpdb" }, { name = "fsspec", extra = ["http"] }, - { name = "halo" }, { name = "httpx" }, - { name = "loguru" }, { name = "numcodecs", version = "0.13.1", source = { registry = "https://pypi.org/simple" }, extra = ["msgpack"], marker = "python_full_version < '3.11'" }, { name = "numcodecs", version = "0.14.1", source = { registry = "https://pypi.org/simple" }, extra = ["msgpack"], marker = "python_full_version >= '3.11'" }, { name = "numpy" }, @@ -2256,10 +2229,10 @@ dependencies = [ { name = "pydantic-settings" }, { name = "pyroaring" }, { name = "pyyaml" }, + { name = "rich" }, { name = "scikit-learn" }, { name = "scipy" }, { name = "seaborn" }, - { name = "tqdm" }, { name = "typer" }, { name = "typing-extensions" }, { name = "zarr" }, @@ -2297,9 +2270,7 @@ requires-dist = [ { name = "datamol", specifier = ">=0.12.1" }, { name = "fastpdb" }, { name = "fsspec", extras = ["http"] }, - { name = "halo" }, { name = "httpx" }, - { name = "loguru" }, { name = "numcodecs", extras = ["msgpack"], specifier = ">=0.13.1" }, { name = "numpy", specifier = "<2" }, { name = "pandas" }, @@ -2308,10 +2279,10 @@ requires-dist = [ { name = "pydantic-settings", specifier = ">=2" }, { name = "pyroaring" }, { name = "pyyaml" }, + { name = "rich", specifier = ">=13.9.4" }, { name = "scikit-learn" }, { name = "scipy" }, { name = "seaborn" }, - { name = "tqdm" }, { name = "typer" }, { name = "typing-extensions", specifier = ">=4.12.0" }, { name = "zarr", specifier = ">=2,<3" }, @@ -3257,15 +3228,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/c2/fe97d779f3ef3b15f05c94a2f1e3d21732574ed441687474db9d342a7315/soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9", size = 36186 }, ] -[[package]] -name = "spinners" -version = "0.0.24" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d3/91/bb331f0a43e04d950a710f402a0986a54147a35818df0e1658551c8d12e1/spinners-0.0.24.tar.gz", hash = "sha256:1eb6aeb4781d72ab42ed8a01dcf20f3002bf50740d7154d12fb8c9769bf9e27f", size = 5308 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9f/8e/3310207a68118000ca27ac878b8386123628b335ecb3d4bec4743357f0d1/spinners-0.0.24-py3-none-any.whl", hash = "sha256:2fa30d0b72c9650ad12bbe031c9943b8d441e41b4f5602b0ec977a19f3290e98", size = 5499 }, -] - [[package]] name = "stack-data" version = "0.6.3" @@ -3280,15 +3242,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521 }, ] -[[package]] -name = "termcolor" -version = "2.3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b8/85/147a0529b4e80b6b9d021ca8db3a820fcac53ec7374b87073d004aaf444c/termcolor-2.3.0.tar.gz", hash = "sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a", size = 12163 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/67/e1/434566ffce04448192369c1a282931cf4ae593e91907558eaecd2e9f2801/termcolor-2.3.0-py3-none-any.whl", hash = "sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475", size = 6872 }, -] - [[package]] name = "terminado" version = "0.18.1"