Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support a new group structure for columns #241

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 104 additions & 24 deletions polaris/dataset/_dataset_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,91 @@
from polaris.dataset._adapters import Adapter
from polaris.dataset._base import BaseDataset
from polaris.dataset.zarr._manifest import calculate_file_md5, generate_zarr_manifest
from polaris.dataset.zarr._utils import load_zarr_group_to_memory
from polaris.utils.errors import InvalidDatasetError
from polaris.utils.types import AccessType, ChecksumStrategy, HubOwner, ZarrConflictResolution

_INDEX_ARRAY_KEY = "__index__"
_GROUP_FORMAT_METADATA_KEY = "polaris:group_format"


def _get_group_format(group: zarr.Group) -> Literal["subgroups", "arrays"]:
"""
Returns the format of the group.

A group can either be a collection of subgroups or a collection of arrays, but not both.
"""
group_format = group.attrs.get(_GROUP_FORMAT_METADATA_KEY, "subgroups")
if group_format not in ["subgroups", "arrays"]:
raise ValueError(f"Invalid group format: {group_format}. Should be 'subgroups' or 'arrays'.")
return group_format


def _verify_group_with_subgroups(group: zarr.Group):
"""
Verifies the structure of a group with subgroups

Since the keys for subgroups can not be assumed to be ordered, we have no easy way to index these groups.
Any subgroup should therefore have a special array that defines the index for that group.

NOTE: We're not checking the length since that's already done in the Dataset model validator.
"""

if _INDEX_ARRAY_KEY not in group.array_keys():
raise InvalidDatasetError(f"Group {group.basename} does not have an index array.")

index_arr = group[_INDEX_ARRAY_KEY]
if len(index_arr) != len(group) - 1:
raise InvalidDatasetError(
f"Length of index array for group {group.basename} does not match the size of the group. Expected {len(group) - 1}, found {len(index_arr)}. {set(group.group_keys())}, {set(group.array_keys())}"
)
if any(x not in group for x in index_arr):
raise InvalidDatasetError(
f"Keys of index array for group {group.basename} does not match the group members."
)

array_keys = list(group.array_keys())
if len(array_keys) > 1:
raise InvalidDatasetError(
f"Group {group.basename} should only have the special '{_INDEX_ARRAY_KEY}' array, found the following arrays: {array_keys}"
)


def _verify_group_with_arrays(group: zarr.Group):
"""
Verifies the structure of a group with arrays.

In this case, the nth datapoint can be constructed by taking the nth element from each array.

NOTE: We're not checking the length since that's already done in the Dataset model validator.
"""
subgroup_keys = list(group.group_keys())
if len(subgroup_keys) > 0:
raise InvalidDatasetError(f"Group {group.basename} should have no subgroups, found: {subgroup_keys}")


def _verify_group(group: zarr.Group):
"""
Verifies the structure of a group in a Zarr archive.
"""
match _get_group_format(group):
case "subgroups":
_verify_group_with_subgroups(group)
case "arrays":
_verify_group_with_arrays(group)


def _get_group_length(group: zarr.Group):
match _get_group_format(group):
case "subgroups":
lengths = {len(list(group.group_keys())), len(group[_INDEX_ARRAY_KEY])}
case "arrays":
lengths = {len(group[k]) for k in group.array_keys()}
if len(lengths) > 1:
raise InvalidDatasetError(
f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}"
)
return lengths


class DatasetV2(BaseDataset):
Expand Down Expand Up @@ -51,26 +132,15 @@ class DatasetV2(BaseDataset):
def _validate_v2_dataset_model(self) -> Self:
"""Verifies some dependencies between properties"""

# Since the keys for subgroups are not ordered, we have no easy way to index these groups.
# Any subgroup should therefore have a special array that defines the index for that group.
for group in self.zarr_root.group_keys():
if _INDEX_ARRAY_KEY not in self.zarr_root[group].array_keys():
raise InvalidDatasetError(f"Group {group} does not have an index array.")

index_arr = self.zarr_root[group][_INDEX_ARRAY_KEY]
if len(index_arr) != len(self.zarr_root[group]) - 1:
raise InvalidDatasetError(
f"Length of index array for group {group} does not match the size of the group."
)
if any(x not in self.zarr_root[group] for x in index_arr):
raise InvalidDatasetError(
f"Keys of index array for group {group} does not match the group members."
)
for _, group in self.zarr_root.groups():
_verify_group(group)

# Check the structure of the Zarr archive
# All arrays or groups in the root should have the same length.
lengths = {len(self.zarr_root[k]) for k in self.zarr_root.array_keys()}
lengths.update({len(self.zarr_root[k][_INDEX_ARRAY_KEY]) for k in self.zarr_root.group_keys()})
for _, group in self.zarr_root.groups():
lengths.update(_get_group_length(group))

if len(lengths) > 1:
raise InvalidDatasetError(
f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}"
Expand Down Expand Up @@ -176,19 +246,29 @@ def get_data(self, row: int, col: str, adapters: dict[str, Adapter] | None = Non
# Get the data
group_or_array = self.zarr_data[col]

# If it is a group, there is no deterministic order for the child keys.
# We therefore use a special array that defines the index.
# If loaded to memory, the group is represented by a dictionary.
if isinstance(group_or_array, zarr.Group) or isinstance(group_or_array, dict):
# Indices in a group should always be strings
# For each column, the Zarr archive can take three formats:
# 1. An array: Each index in the array corresponds to a data point in the dataset.
# 2. A group with subgroups: Each subgroup corresponds to a data point in the dataset.
# The special __index__ array specifies the ordering of the subgroups.
# 3. A group with arrays: The nth datapoint is constructed from indexing the nth element in each array.

if isinstance(group_or_array, zarr.Array) or isinstance(group_or_array, np.ndarray):
# Option 1: An array
data = group_or_array[row]
elif _get_group_format(self._zarr_root[col]) == "subgroups":
# Option 2: A group with subgroups
# The __index__ array specifies the ordering. An index is always a string.
row = str(group_or_array[_INDEX_ARRAY_KEY][row])
arr = group_or_array[row]
data = load_zarr_group_to_memory(group_or_array[row])
else:
# Option 3: A group with arrays
data = {k: group_or_array[k][row] for k in group_or_array.keys()}

# Adapt the input to the specified format
if adapter is not None:
arr = adapter(arr)
data = adapter(data)

return arr
return data

def upload_to_hub(self, access: AccessType = "private", owner: HubOwner | str | None = None):
"""
Expand Down
39 changes: 37 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from polaris.competition import CompetitionSpecification
from polaris.dataset import ColumnAnnotation, DatasetFactory, DatasetV1, DatasetV2
from polaris.dataset._dataset_v2 import _GROUP_FORMAT_METADATA_KEY, _INDEX_ARRAY_KEY
from polaris.dataset.converters import SDFConverter
from polaris.utils.types import HubOwner

Expand Down Expand Up @@ -126,7 +127,7 @@ def test_dataset(test_data, test_org_owner) -> DatasetV1:


@pytest.fixture(scope="function")
def test_dataset_v2(zarr_archive, test_org_owner) -> DatasetV2:
def test_dataset_v2(zarr_archive_v2, test_org_owner) -> DatasetV2:
dataset = DatasetV2(
name="test-dataset-v2",
source="https://www.example.com",
Expand All @@ -136,7 +137,7 @@ def test_dataset_v2(zarr_archive, test_org_owner) -> DatasetV2:
owner=test_org_owner,
license="CC-BY-4.0",
curation_reference="https://www.example.com",
zarr_root_path=zarr_archive,
zarr_root_path=zarr_archive_v2,
)
check_version(dataset)
return dataset
Expand All @@ -152,6 +153,40 @@ def zarr_archive(tmp_path):
return tmp_path


@pytest.fixture(scope="function")
def zarr_archive_v2(tmp_path):
tmp_path = fs.join(tmp_path, "data.zarr")
root = zarr.open(tmp_path, mode="w")

# The root can have 3 types of children, each corresponding to a column:
# 1. An array: Each index in the array corresponds to a data point in the dataset.
data = np.arange(256 * 100).reshape((100, 256))
root.array("A", data=data, chunks=(1, None))

# 2. A group with subgroups: Each subgroup corresponds to a data point in the dataset.
# The special __index__ array specifies the ordering of the subgroups.
group = root.create_group("B")
group.attrs[_GROUP_FORMAT_METADATA_KEY] = "subgroups"
for idx in range(100):
subgroup = group.create_group(str(idx))
subgroup.array("x", data=np.arange(32) + (idx * 32), chunks=None)
subgroup.array("y", data=np.arange(32) + (idx * 32), chunks=None)
subgroup.array("z", data=np.arange(32) + (idx * 32), chunks=None)
group.array(_INDEX_ARRAY_KEY, data=np.arange(100), chunks=None)

# 3. A group with arrays: The nth datapoint is constructed from indexing the nth element in each array.
group = root.create_group("C")
group.attrs[_GROUP_FORMAT_METADATA_KEY] = "arrays"

data = np.arange(32 * 100).reshape((100, 32))
group.array("x", data=data, chunks=(1, None))
group.array("y", data=data, chunks=(1, None))
group.array("z", data=data, chunks=(1, None))

zarr.consolidate_metadata(root.store)
return tmp_path


@pytest.fixture()
def regression_metrics():
return [
Expand Down
2 changes: 1 addition & 1 deletion tests/test_benchmark_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_benchmark_v2_specification(valid_split, test_dataset_v2, tmp_path):
BenchmarkV2Specification(**config)


def test_benchmark_v2_invalid_indices(valid_split, test_dataset_v2):
def test_benchmark_v2_invalid_indices(test_dataset_v2):
"""
Test validation of indices against dataset length
"""
Expand Down
Loading