Skip to content

Commit

Permalink
Misc changes (#230)
Browse files Browse the repository at this point in the history
* start jogging with issue #212

* Replace all occurences of tmpdir with tmp_path

* add comment and fix small typo

* Update list_datasets and list_benchmarks to include support for v2 artifacts.

Co-authored-by: Julien St-Laurent <[email protected]>

* Apply review feedback

Co-authored-by: Cas Wognum <[email protected]>

* more changes to cover edge cases

* Refactoring

Co-authored-by: Cas Wognum <[email protected]>

---------

Co-authored-by: Julien St-Laurent <[email protected]>
Co-authored-by: Cas Wognum <[email protected]>
  • Loading branch information
3 people authored Dec 19, 2024
1 parent abb9b29 commit f72ff53
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 57 deletions.
100 changes: 71 additions & 29 deletions polaris/hub/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ def login(self, overwrite: bool = False, auto_open_browser: bool = True):
# =========================

def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]:
"""List all available datasets on the Polaris Hub.
"""List all available datasets (v1 and v2) on the Polaris Hub.
We prioritize v2 datasets over v1 datasets.
Args:
limit: The maximum number of datasets to return.
Expand All @@ -288,17 +289,40 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]:
A list of dataset names in the format `owner/dataset_name`.
"""
with ProgressIndicator(
start_msg="Fetching artifacts...",
success_msg="Fetched artifacts.",
start_msg="Fetching datasets...",
success_msg="Fetched datasets.",
error_msg="Failed to fetch datasets.",
):
response = self._base_request_to_hub(
url="/v1/dataset", method="GET", params={"limit": limit, "offset": offset}
)
response_data = response.json()
dataset_list = [bm["artifactId"] for bm in response_data["data"]]
# Step 1: Fetch enough v2 datasets to cover the offset and limit
v2_json_response = self._base_request_to_hub(
url="/v2/dataset", method="GET", params={"limit": limit, "offset": offset}
).json()
v2_data = v2_json_response["data"]
v2_datasets = [dataset["artifactId"] for dataset in v2_data]

# If v2 datasets satisfy the limit, return them
if len(v2_datasets) == limit:
return v2_datasets

# Step 2: Calculate the remaining limit and fetch v1 datasets
remaining_limit = max(0, limit - len(v2_datasets))

v1_json_response = self._base_request_to_hub(
url="/v1/dataset",
method="GET",
params={
"limit": remaining_limit,
"offset": max(0, offset - v2_json_response["metadata"]["total"]),
},
).json()
v1_data = v1_json_response["data"]
v1_datasets = [dataset["artifactId"] for dataset in v1_data]

# Combine the v2 and v1 datasets
combined_datasets = v2_datasets + v1_datasets

return dataset_list
# Ensure the final combined list respects the limit
return combined_datasets

def get_dataset(
self,
Expand All @@ -323,7 +347,7 @@ def get_dataset(
error_msg="Failed to fetch dataset.",
):
try:
return self._get_v1_dataset(owner, name, ArtifactSubtype.STANDARD.value, verify_checksum)
return self._get_v1_dataset(owner, name, ArtifactSubtype.STANDARD, verify_checksum)
except PolarisRetrieveArtifactError:
# If the v1 dataset is not found, try to load a v2 dataset
return self._get_v2_dataset(owner, name)
Expand All @@ -348,7 +372,7 @@ def _get_v1_dataset(
"""
url = (
f"/v1/dataset/{owner}/{name}"
if artifact_type == ArtifactSubtype.STANDARD.value
if artifact_type == ArtifactSubtype.STANDARD
else f"/v2/competition/dataset/{owner}/{name}"
)
response = self._base_request_to_hub(url=url, method="GET")
Expand Down Expand Up @@ -408,18 +432,40 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]:
A list of benchmark names in the format `owner/benchmark_name`.
"""
with ProgressIndicator(
start_msg="Fetching artifacts...",
success_msg="Fetched artifacts.",
start_msg="Fetching benchmarks...",
success_msg="Fetched benchmarks.",
error_msg="Failed to fetch benchmarks.",
):
# TODO (cwognum): What to do with pagination, i.e. limit and offset?
response = self._base_request_to_hub(
url="/v1/benchmark", method="GET", params={"limit": limit, "offset": offset}
)
response_data = response.json()
benchmarks_list = [f"{HubOwner(**bm['owner'])}/{bm['name']}" for bm in response_data["data"]]
# Step 1: Fetch enough v2 benchmarks to cover the offset and limit
v2_json_response = self._base_request_to_hub(
url="/v2/benchmark", method="GET", params={"limit": limit, "offset": offset}
).json()
v2_data = v2_json_response["data"]
v2_benchmarks = [f"{HubOwner(**benchmark['owner'])}/{benchmark['name']}" for benchmark in v2_data]

# If v2 benchmarks satisfy the limit, return them
if len(v2_benchmarks) == limit:
return v2_benchmarks

# Step 2: Calculate the remaining limit and fetch v1 benchmarks
remaining_limit = max(0, limit - len(v2_benchmarks))

v1_json_response = self._base_request_to_hub(
url="/v1/benchmark",
method="GET",
params={
"limit": remaining_limit,
"offset": max(0, offset - v2_json_response["metadata"]["total"]),
},
).json()
v1_data = v1_json_response["data"]
v1_benchmarks = [f"{HubOwner(**benchmark['owner'])}/{benchmark['name']}" for benchmark in v1_data]

# Combine the v2 and v1 benchmarks
combined_benchmarks = v2_benchmarks + v1_benchmarks

return benchmarks_list
# Ensure the final combined list respects the limit
return combined_benchmarks

def get_benchmark(
self,
Expand Down Expand Up @@ -582,9 +628,7 @@ def upload_dataset(
)

if isinstance(dataset, DatasetV1):
self._upload_v1_dataset(
dataset, ArtifactSubtype.STANDARD.value, timeout, access, owner, if_exists
)
self._upload_v1_dataset(dataset, ArtifactSubtype.STANDARD, timeout, access, owner, if_exists)
elif isinstance(dataset, DatasetV2):
self._upload_v2_dataset(dataset, timeout, access, owner, if_exists)

Expand Down Expand Up @@ -631,7 +675,7 @@ def _upload_v1_dataset(
# We do so separately for the Zarr archive and Parquet file.
url = (
f"/v1/dataset/{dataset.artifact_id}"
if artifact_type == ArtifactSubtype.STANDARD.value
if artifact_type == ArtifactSubtype.STANDARD
else f"/v2/competition/dataset/{dataset.owner}/{dataset.name}"
)
self._base_request_to_hub(
Expand Down Expand Up @@ -674,7 +718,7 @@ def _upload_v1_dataset(
)

base_artifact_url = (
"datasets" if artifact_type == ArtifactSubtype.STANDARD.value else "/competition/datasets"
"datasets" if artifact_type == ArtifactSubtype.STANDARD else "/competition/datasets"
)
progress_indicator.update_success_msg(
f"Your {artifact_type} dataset has been successfully uploaded to the Hub. "
Expand Down Expand Up @@ -774,7 +818,7 @@ def upload_benchmark(
"""
match benchmark:
case BenchmarkV1Specification():
self._upload_v1_benchmark(benchmark, ArtifactSubtype.STANDARD.value, access, owner)
self._upload_v1_benchmark(benchmark, ArtifactSubtype.STANDARD, access, owner)
case BenchmarkV2Specification():
self._upload_v2_benchmark(benchmark, access, owner)

Expand Down Expand Up @@ -819,9 +863,7 @@ def _upload_v1_benchmark(
benchmark_json["datasetArtifactId"] = benchmark.dataset.artifact_id
benchmark_json["access"] = access

path_params = (
"/v1/benchmark" if artifact_type == ArtifactSubtype.STANDARD.value else "/v2/competition"
)
path_params = "/v1/benchmark" if artifact_type == ArtifactSubtype.STANDARD else "/v2/competition"
url = f"{path_params}/{benchmark.owner}/{benchmark.name}"
self._base_request_to_hub(url=url, method="PUT", json=benchmark_json)

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def test_multi_task_benchmark_multiple_test_sets(test_dataset, regression_metric
@pytest.fixture(scope="function")
def test_docking_dataset(tmp_path, sdf_files, test_org_owner):
# toy docking dataset
factory = DatasetFactory(tmp_path / "ligands.zarr")
factory = DatasetFactory(str(tmp_path / "ligands.zarr"))

converter = SDFConverter(mol_prop_as_cols=True)
factory.register_converter("sdf", converter)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,12 @@ def _check_for_failure(_kwargs):
kwargs.pop("target_types", None) # Reset target types that matches deleted target column
_check_for_failure(kwargs)

# Input columns
kwargs = benchmark.model_dump()
kwargs["input_cols"] = kwargs["input_cols"][1:] + ["iupac"]
_check_for_failure(kwargs)

# --- Don't fail if not checksum is provided ---
# --- Don't fail if no checksum is provided ---
kwargs["md5sum"] = None
dataset = cls(dataset=benchmark.dataset, **kwargs)
assert dataset.md5sum is not None
Expand Down
14 changes: 7 additions & 7 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_dataset_checksum(test_dataset):
def test_dataset_from_zarr(zarr_archive, tmp_path):
"""Test whether loading works when the zarr archive contains a single array or multiple arrays."""
archive = zarr_archive
dataset = create_dataset_from_file(archive, tmp_path / "data")
dataset = create_dataset_from_file(archive, str(tmp_path / "data"))

assert len(dataset.table) == 100
for i in range(100):
Expand All @@ -115,8 +115,8 @@ def test_dataset_from_zarr_to_json_and_back(zarr_archive, tmp_path):
can be saved to and loaded from json.
"""

json_dir = tmp_path / "json"
zarr_dir = tmp_path / "zarr"
json_dir = str(tmp_path / "json")
zarr_dir = str(tmp_path / "zarr")

archive = zarr_archive
dataset = create_dataset_from_file(archive, zarr_dir)
Expand All @@ -132,8 +132,8 @@ def test_dataset_from_zarr_to_json_and_back(zarr_archive, tmp_path):
def test_dataset_caching(zarr_archive, tmp_path):
"""Test whether the dataset remains the same after caching."""

original_dataset = create_dataset_from_file(zarr_archive, tmp_path / "original1")
cached_dataset = create_dataset_from_file(zarr_archive, tmp_path / "original2")
original_dataset = create_dataset_from_file(zarr_archive, str(tmp_path / "original1"))
cached_dataset = create_dataset_from_file(zarr_archive, str(tmp_path / "original2"))
assert original_dataset == cached_dataset

cached_dataset._cache_dir = str(tmp_path / "cached")
Expand All @@ -153,7 +153,7 @@ def test_dataset_index():

def test_dataset_in_memory_optimization(zarr_archive, tmp_path):
"""Check if optimization makes a default Zarr archive faster."""
dataset = create_dataset_from_file(zarr_archive, tmp_path / "dataset")
dataset = create_dataset_from_file(zarr_archive, str(tmp_path / "dataset"))
subset = Subset(dataset=dataset, indices=range(100), input_cols=["A"], target_cols=["B"])

t1 = perf_counter()
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_dataset__get_item__():
def test_dataset__get_item__with_pointer_columns(zarr_archive, tmp_path):
"""Test the __getitem__() interface for a dataset with pointer columns (i.e. part of the data stored in Zarr)."""

dataset = create_dataset_from_file(zarr_archive, tmp_path / "data")
dataset = create_dataset_from_file(zarr_archive, str(tmp_path / "data"))
root = zarr.open(zarr_archive)

# Get a specific cell
Expand Down
6 changes: 3 additions & 3 deletions tests/test_dataset_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_dataset_v2_load_to_memory(test_dataset_v2):


def test_dataset_v2_serialization(test_dataset_v2, tmp_path):
save_dir = tmp_path / "save_dir"
save_dir = str(tmp_path / "save_dir")
path = test_dataset_v2.to_json(save_dir)
new_dataset = DatasetV2.from_json(path)
for i in range(5):
Expand All @@ -86,7 +86,7 @@ def test_dataset_v1_v2_compatibility(test_dataset, tmp_path):
# We can thus also saved these same arrays to a Zarr archive
df = test_dataset.table

path = tmp_path / "data/v1v2.zarr"
path = str(tmp_path / "data" / "v1v2.zarr")

root = zarr.open(path, "w")
root.array("smiles", data=df["smiles"].values, dtype=object, object_codec=numcodecs.VLenUTF8())
Expand All @@ -96,7 +96,7 @@ def test_dataset_v1_v2_compatibility(test_dataset, tmp_path):
zarr.consolidate_metadata(path)

kwargs = test_dataset.model_dump(exclude=["table", "zarr_root_path"])
dataset = DatasetV2(**kwargs, zarr_root_path=str(path))
dataset = DatasetV2(**kwargs, zarr_root_path=path)

subset_1 = Subset(dataset=test_dataset, indices=range(5), input_cols=["smiles"], target_cols=["calc"])
subset_2 = Subset(dataset=dataset, indices=range(5), input_cols=["smiles"], target_cols=["calc"])
Expand Down
6 changes: 3 additions & 3 deletions tests/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
from pathlib import Path

import datamol as dm
import numpy as np
Expand All @@ -17,7 +17,7 @@
from polaris.utils.types import HubOwner


def test_result_to_json(tmp_path: str, test_user_owner: HubOwner):
def test_result_to_json(tmp_path: Path, test_user_owner: HubOwner):
scores = pd.DataFrame(
{
"Test set": ["A", "A", "A", "A", "B", "B", "B", "B"],
Expand All @@ -41,7 +41,7 @@ def test_result_to_json(tmp_path: str, test_user_owner: HubOwner):
contributors=["my-user", "other-user"],
)

path = os.path.join(tmp_path, "result.json")
path = str(tmp_path / "result.json")
result.to_json(path)

BenchmarkResults.from_json(path)
Expand Down
20 changes: 10 additions & 10 deletions tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def _check_dataset(dataset, ground_truth, mol_props_as_col):

def test_sdf_zarr_conversion(sdf_file, caffeine, tmp_path):
"""Test conversion between SDF and Zarr with utility function"""
dataset = create_dataset_from_file(sdf_file, tmp_path / "archive.zarr")
dataset = create_dataset_from_file(sdf_file, str(tmp_path / "archive.zarr"))
_check_dataset(dataset, [caffeine], True)


@pytest.mark.parametrize("mol_props_as_col", [True, False])
def test_factory_sdf_with_prop_as_col(sdf_file, caffeine, tmp_path, mol_props_as_col):
"""Test conversion between SDF and Zarr with factory pattern"""

factory = DatasetFactory(tmp_path / "archive.zarr")
factory = DatasetFactory(str(tmp_path / "archive.zarr"))

converter = SDFConverter(mol_prop_as_cols=mol_props_as_col)
factory.register_converter("sdf", converter)
Expand All @@ -60,7 +60,7 @@ def test_factory_sdf_with_prop_as_col(sdf_file, caffeine, tmp_path, mol_props_as

def test_zarr_to_zarr_conversion(zarr_archive, tmp_path):
"""Test conversion between Zarr and Zarr with utility function"""
dataset = create_dataset_from_file(zarr_archive, tmp_path / "archive.zarr")
dataset = create_dataset_from_file(zarr_archive, str(tmp_path / "archive.zarr"))
assert len(dataset) == 100
assert len(dataset.columns) == 2
assert all(c in dataset.columns for c in ["A", "B"])
Expand All @@ -71,7 +71,7 @@ def test_zarr_to_zarr_conversion(zarr_archive, tmp_path):
def test_zarr_with_factory_pattern(zarr_archive, tmp_path):
"""Test conversion between Zarr and Zarr with factory pattern"""

factory = DatasetFactory(tmp_path / "archive.zarr")
factory = DatasetFactory(str(tmp_path / "archive.zarr"))
converter = ZarrConverter()
factory.register_converter("zarr", converter)
factory.add_from_file(zarr_archive)
Expand All @@ -90,7 +90,7 @@ def test_zarr_with_factory_pattern(zarr_archive, tmp_path):

def test_factory_pdb(pdbs_structs, pdb_paths, tmp_path):
"""Test conversion between PDB file and Zarr with factory pattern"""
factory = DatasetFactory(tmp_path / "pdb.zarr")
factory = DatasetFactory(str(tmp_path / "pdb.zarr"))

converter = PDBConverter()
factory.register_converter("pdb", converter)
Expand All @@ -104,7 +104,7 @@ def test_factory_pdb(pdbs_structs, pdb_paths, tmp_path):
def test_factory_pdbs(pdbs_structs, pdb_paths, tmp_path):
"""Test conversion between PDB files and Zarr with factory pattern"""

factory = DatasetFactory(tmp_path / "pdbs.zarr")
factory = DatasetFactory(str(tmp_path / "pdbs.zarr"))

converter = PDBConverter()
factory.register_converter("pdb", converter)
Expand All @@ -119,7 +119,7 @@ def test_factory_pdbs(pdbs_structs, pdb_paths, tmp_path):
def test_pdbs_zarr_conversion(pdbs_structs, pdb_paths, tmp_path):
"""Test conversion between PDBs and Zarr with utility function"""

dataset = create_dataset_from_files(pdb_paths, tmp_path / "pdbs_2.zarr", axis=0)
dataset = create_dataset_from_files(pdb_paths, str(tmp_path / "pdbs_2.zarr"), axis=0)

assert dataset.table.shape[0] == len(pdb_paths)
_check_pdb_dataset(dataset, pdbs_structs)
Expand All @@ -128,7 +128,7 @@ def test_pdbs_zarr_conversion(pdbs_structs, pdb_paths, tmp_path):
def test_factory_sdfs(sdf_files, caffeine, ibuprofen, tmp_path):
"""Test conversion between SDF and Zarr with factory pattern"""

factory = DatasetFactory(tmp_path / "sdfs.zarr")
factory = DatasetFactory(str(tmp_path / "sdfs.zarr"))

converter = SDFConverter(mol_prop_as_cols=True)
factory.register_converter("sdf", converter)
Expand All @@ -142,7 +142,7 @@ def test_factory_sdfs(sdf_files, caffeine, ibuprofen, tmp_path):
def test_factory_sdf_pdb(sdf_file, pdb_paths, caffeine, pdbs_structs, tmp_path):
"""Test conversion between SDF and PDB from files to Zarr with factory pattern"""

factory = DatasetFactory(tmp_path / "sdf_pdb.zarr")
factory = DatasetFactory(str(tmp_path / "sdf_pdb.zarr"))

sdf_converter = SDFConverter(mol_prop_as_cols=False)
factory.register_converter("sdf", sdf_converter)
Expand All @@ -158,7 +158,7 @@ def test_factory_sdf_pdb(sdf_file, pdb_paths, caffeine, pdbs_structs, tmp_path):


def test_factory_from_files_same_column(sdf_files, pdb_paths, tmp_path):
factory = DatasetFactory(tmp_path / "files.zarr")
factory = DatasetFactory(str(tmp_path / "files.zarr"))

sdf_converter = SDFConverter(mol_prop_as_cols=False)
factory.register_converter("sdf", sdf_converter)
Expand Down
Loading

0 comments on commit f72ff53

Please sign in to comment.