Skip to content

Commit

Permalink
fix: Replace Halo with rich for CLI spinner and progress tracking (#256)
Browse files Browse the repository at this point in the history
* First pass, using rich for CLI spinner and logging

* Replace ProgressIndicator class

* Review feedback

Co-authored-by: Honoré Hounwanou <[email protected]>

* Consistency

Co-authored-by: Honoré Hounwanou <[email protected]>

* Tweak progress columns. Replace tqdm.

---------

Co-authored-by: Honoré Hounwanou <[email protected]>
  • Loading branch information
jstlaurent and mercuryseries authored Jan 27, 2025
1 parent 41985ca commit 5520fe8
Show file tree
Hide file tree
Showing 17 changed files with 282 additions and 291 deletions.
4 changes: 1 addition & 3 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@ channels:
dependencies:
- python >=3.10
- pip
- tqdm
- loguru
- typer
- pyyaml
- pydantic >=2
- pydantic-settings >=2
- fsspec
- halo
- typing-extensions >=4.12.0
- boto3 <1.36.0
- pyroaring
- rich

# Hub client
- authlib
Expand Down
18 changes: 11 additions & 7 deletions polaris/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import os
import sys
import logging

from loguru import logger
from rich.logging import RichHandler

from ._version import __version__
from .loader import load_benchmark, load_competition, load_dataset

__all__ = ["load_dataset", "load_benchmark", "load_competition", "__version__"]

# Configure the default logging level
os.environ["LOGURU_LEVEL"] = os.environ.get("LOGURU_LEVEL", "INFO")
logger.remove()
logger.add(sys.stderr, level=os.environ["LOGURU_LEVEL"])
# Polaris specific logger
logger = logging.getLogger(__name__)

# Only add handler if the logger has not already been configured externally
if not logger.handlers:
handler = RichHandler(rich_tracebacks=True)
handler.setFormatter(logging.Formatter("%(message)s", datefmt="[%Y-%m-%d %X]"))
logger.addHandler(handler)
logger.setLevel(logging.INFO)
4 changes: 3 additions & 1 deletion polaris/_artifact.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import logging
from typing import ClassVar, Literal

import fsspec
from loguru import logger
from packaging.version import Version
from pydantic import (
BaseModel,
Expand All @@ -19,6 +19,8 @@
from polaris.utils.misc import slugify
from polaris.utils.types import ArtifactUrn, HubOwner, SlugCompatibleStringType, SlugStringType

logger = logging.getLogger(__name__)


class BaseArtifactModel(BaseModel):
"""
Expand Down
4 changes: 3 additions & 1 deletion polaris/benchmark/_split.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging
from itertools import chain

from loguru import logger
from pydantic import BaseModel, computed_field, field_serializer, model_validator
from typing_extensions import Self

from polaris.utils.errors import InvalidBenchmarkError
from polaris.utils.misc import listit
from polaris.utils.types import SplitType

logger = logging.getLogger(__name__)


class SplitSpecificationV1Mixin(BaseModel):
"""
Expand Down
42 changes: 25 additions & 17 deletions polaris/dataset/_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import json
import logging
from os import PathLike
from pathlib import Path, PurePath
from typing import Any, Iterable, MutableMapping
Expand All @@ -8,7 +9,6 @@
import fsspec
import numpy as np
import zarr
from loguru import logger
from pydantic import (
Field,
PrivateAttr,
Expand All @@ -25,6 +25,7 @@
from polaris.dataset.zarr import MemoryMappedDirectoryStore
from polaris.dataset.zarr._utils import check_zarr_codecs, load_zarr_group_to_memory
from polaris.utils.constants import DEFAULT_CACHE_DIR
from polaris.utils.context import track_progress
from polaris.utils.dict2html import dict2html
from polaris.utils.errors import InvalidDatasetError
from polaris.utils.types import (
Expand All @@ -37,6 +38,8 @@
ZarrConflictResolution,
)

logger = logging.getLogger(__name__)

# Constants
_CACHE_SUBDIR = "datasets"

Expand Down Expand Up @@ -371,19 +374,24 @@ def _cache_zarr(self, destination: str | PathLike, if_exists: ZarrConflictResolu
# Copy over Zarr data to the destination
self._warn_about_remote_zarr = False

logger.info(f"Copying Zarr archive to {destination_zarr_root}. This may take a while.")
destination_store = zarr.open(str(destination_zarr_root), "w").store
source_store = self.zarr_root.store.store

if isinstance(source_store, S3Store):
source_store.copy_to_destination(destination_store, if_exists, logger.info)
else:
zarr.copy_store(
source=source_store,
dest=destination_store,
log=logger.info,
if_exists=if_exists,
)
self.zarr_root_path = str(destination_zarr_root)
self._zarr_root = None
self._zarr_data = None
with track_progress(description="Copying Zarr archive", total=1) as (
progress,
task,
):
progress.log(f"[green]Copying to destination {destination_zarr_root}")
progress.log("[yellow]For large Zarr archives, this may take a while.")
destination_store = zarr.open(str(destination_zarr_root), "w").store
source_store = self.zarr_root.store.store

if isinstance(source_store, S3Store):
source_store.copy_to_destination(destination_store, if_exists, logger.info)
else:
zarr.copy_store(
source=source_store,
dest=destination_store,
log=logger.info,
if_exists=if_exists,
)
self.zarr_root_path = str(destination_zarr_root)
self._zarr_root = None
self._zarr_data = None
2 changes: 2 additions & 0 deletions polaris/dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ def cache(

if verify_checksum:
self.verify_checksum()
else:
self._md5sum = None

return str(destination)

Expand Down
4 changes: 3 additions & 1 deletion polaris/dataset/_dataset_v2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
import re
from os import PathLike
from pathlib import Path
Expand All @@ -7,7 +8,6 @@
import fsspec
import numpy as np
import zarr
from loguru import logger
from pydantic import PrivateAttr, computed_field, model_validator
from typing_extensions import Self

Expand All @@ -17,6 +17,8 @@
from polaris.utils.errors import InvalidDatasetError
from polaris.utils.types import AccessType, ChecksumStrategy, HubOwner, ZarrConflictResolution

logger = logging.getLogger(__name__)

_INDEX_ARRAY_KEY = "__index__"


Expand Down
4 changes: 3 additions & 1 deletion polaris/dataset/_factory.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import logging
import os
from typing import Literal

import datamol as dm
import pandas as pd
import zarr
from loguru import logger

from polaris.dataset import ColumnAnnotation, DatasetV1
from polaris.dataset._adapters import Adapter
from polaris.dataset.converters import Converter, PDBConverter, SDFConverter, ZarrConverter

logger = logging.getLogger(__name__)


def create_dataset_from_file(path: str, zarr_root_path: str | None = None) -> DatasetV1:
"""
Expand Down
71 changes: 39 additions & 32 deletions polaris/dataset/zarr/_checksum.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
import zarr.errors
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel
from tqdm import tqdm

from polaris.utils.context import track_progress
from polaris.utils.errors import InvalidZarrChecksum

ZARR_DIGEST_PATTERN = "([0-9a-f]{32})-([0-9]+)-([0-9]+)"
Expand All @@ -56,8 +56,8 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple["_ZarrDirectoryDigest",
r"""
Implements an algorithm to compute the Zarr checksum.
Warning: This checksum is sensitive to Zarr configuration.
This checksum is sensitive to change in the Zarr structure. For example, if you change the chunk size,
Warning: This checksum is sensitive to Zarr configuration.
This checksum is sensitive to change in the Zarr structure. For example, if you change the chunk size,
the checksum will also change.
To understand how this works, consider the following directory structure:
Expand All @@ -67,17 +67,17 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple["_ZarrDirectoryDigest",
a c
/
b
Within zarr, this would for example be:
- `root`: A Zarr Group with a single Array.
- `a`: A Zarr Array
- `b`: A single chunk of the Zarr Array
- `c`: A metadata file (i.e. .zarray, .zattrs or .zgroup)
- `c`: A metadata file (i.e. .zarray, .zattrs or .zgroup)
To compute the checksum, we first find all the trees in the node, in this case b and c.
To compute the checksum, we first find all the trees in the node, in this case b and c.
We compute the hash of the content (the raw bytes) for each of these files.
We then work our way up the tree. For any node (directory), we find all children of that node.
In an sorted order, we then serialize a list with - for each of the children - the checksum, size, and number of children.
The hash of the directory is then equal to the hash of the serialized JSON.
Expand Down Expand Up @@ -116,33 +116,40 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple["_ZarrDirectoryDigest",
leaves = fs.find(zarr_root_path, detail=True)
zarr_md5sum_manifest = []

for file in tqdm(leaves.values(), desc="Finding all files in the Zarr archive"):
path = file["name"]

relpath = path.removeprefix(zarr_root_path)
relpath = relpath.lstrip("/")
relpath = Path(relpath)

size = file["size"]

# Compute md5sum of file
md5sum = hashlib.md5()
with fs.open(path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
md5sum.update(chunk)
digest = md5sum.hexdigest()
files = leaves.values()
with track_progress(description="Finding all files in the Zarr archive", total=len(files)) as (
progress,
task,
):
for file in files:
path = file["name"]

relpath = path.removeprefix(zarr_root_path)
relpath = relpath.lstrip("/")
relpath = Path(relpath)

size = file["size"]

# Compute md5sum of file
md5sum = hashlib.md5()
with fs.open(path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
md5sum.update(chunk)
digest = md5sum.hexdigest()

# Add a leaf to the tree
# (This actually adds the file's checksum to the parent directory's manifest)
tree.add_leaf(
path=relpath,
size=size,
digest=digest,
)

# Add a leaf to the tree
# (This actually adds the file's checksum to the parent directory's manifest)
tree.add_leaf(
path=relpath,
size=size,
digest=digest,
)
# We persist the checksums for leaf nodes separately,
# because this is what the Hub needs to verify data integrity.
zarr_md5sum_manifest.append(ZarrFileChecksum(path=str(relpath), md5sum=digest, size=size))

# We persist the checksums for leaf nodes separately,
# because this is what the Hub needs to verify data integrity.
zarr_md5sum_manifest.append(ZarrFileChecksum(path=str(relpath), md5sum=digest, size=size))
progress.update(task, advance=1, refresh=True)

# Compute digest
return tree.process(), zarr_md5sum_manifest
Expand Down
4 changes: 3 additions & 1 deletion polaris/experimental/_split_v2.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import logging
from functools import cached_property
from hashlib import md5
from typing import Generator, Sequence

from loguru import logger
from pydantic import BaseModel, ConfigDict, Field, computed_field, field_validator, model_validator
from pydantic.alias_generators import to_camel
from pyroaring import BitMap
from typing_extensions import Self

from polaris.utils.errors import InvalidBenchmarkError

logger = logging.getLogger(__name__)


class IndexSet(BaseModel):
"""
Expand Down
Loading

0 comments on commit 5520fe8

Please sign in to comment.