diff --git a/polaris/dataset/_dataset_v2.py b/polaris/dataset/_dataset_v2.py index 5144f1f7..4ea99704 100644 --- a/polaris/dataset/_dataset_v2.py +++ b/polaris/dataset/_dataset_v2.py @@ -14,10 +14,91 @@ from polaris.dataset._adapters import Adapter from polaris.dataset._base import BaseDataset from polaris.dataset.zarr._manifest import calculate_file_md5, generate_zarr_manifest +from polaris.dataset.zarr._utils import load_zarr_group_to_memory from polaris.utils.errors import InvalidDatasetError from polaris.utils.types import AccessType, ChecksumStrategy, HubOwner, ZarrConflictResolution _INDEX_ARRAY_KEY = "__index__" +_GROUP_FORMAT_METADATA_KEY = "polaris:group_format" + + +def _get_group_format(group: zarr.Group) -> Literal["subgroups", "arrays"]: + """ + Returns the format of the group. + + A group can either be a collection of subgroups or a collection of arrays, but not both. + """ + group_format = group.attrs.get(_GROUP_FORMAT_METADATA_KEY, "subgroups") + if group_format not in ["subgroups", "arrays"]: + raise ValueError(f"Invalid group format: {group_format}. Should be 'subgroups' or 'arrays'.") + return group_format + + +def _verify_group_with_subgroups(group: zarr.Group): + """ + Verifies the structure of a group with subgroups + + Since the keys for subgroups can not be assumed to be ordered, we have no easy way to index these groups. + Any subgroup should therefore have a special array that defines the index for that group. + + NOTE: We're not checking the length since that's already done in the Dataset model validator. + """ + + if _INDEX_ARRAY_KEY not in group.array_keys(): + raise InvalidDatasetError(f"Group {group.basename} does not have an index array.") + + index_arr = group[_INDEX_ARRAY_KEY] + if len(index_arr) != len(group) - 1: + raise InvalidDatasetError( + f"Length of index array for group {group.basename} does not match the size of the group. Expected {len(group) - 1}, found {len(index_arr)}. {set(group.group_keys())}, {set(group.array_keys())}" + ) + if any(x not in group for x in index_arr): + raise InvalidDatasetError( + f"Keys of index array for group {group.basename} does not match the group members." + ) + + array_keys = list(group.array_keys()) + if len(array_keys) > 1: + raise InvalidDatasetError( + f"Group {group.basename} should only have the special '{_INDEX_ARRAY_KEY}' array, found the following arrays: {array_keys}" + ) + + +def _verify_group_with_arrays(group: zarr.Group): + """ + Verifies the structure of a group with arrays. + + In this case, the nth datapoint can be constructed by taking the nth element from each array. + + NOTE: We're not checking the length since that's already done in the Dataset model validator. + """ + subgroup_keys = list(group.group_keys()) + if len(subgroup_keys) > 0: + raise InvalidDatasetError(f"Group {group.basename} should have no subgroups, found: {subgroup_keys}") + + +def _verify_group(group: zarr.Group): + """ + Verifies the structure of a group in a Zarr archive. + """ + match _get_group_format(group): + case "subgroups": + _verify_group_with_subgroups(group) + case "arrays": + _verify_group_with_arrays(group) + + +def _get_group_length(group: zarr.Group): + match _get_group_format(group): + case "subgroups": + lengths = {len(list(group.group_keys())), len(group[_INDEX_ARRAY_KEY])} + case "arrays": + lengths = {len(group[k]) for k in group.array_keys()} + if len(lengths) > 1: + raise InvalidDatasetError( + f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}" + ) + return lengths class DatasetV2(BaseDataset): @@ -51,26 +132,15 @@ class DatasetV2(BaseDataset): def _validate_v2_dataset_model(self) -> Self: """Verifies some dependencies between properties""" - # Since the keys for subgroups are not ordered, we have no easy way to index these groups. - # Any subgroup should therefore have a special array that defines the index for that group. - for group in self.zarr_root.group_keys(): - if _INDEX_ARRAY_KEY not in self.zarr_root[group].array_keys(): - raise InvalidDatasetError(f"Group {group} does not have an index array.") - - index_arr = self.zarr_root[group][_INDEX_ARRAY_KEY] - if len(index_arr) != len(self.zarr_root[group]) - 1: - raise InvalidDatasetError( - f"Length of index array for group {group} does not match the size of the group." - ) - if any(x not in self.zarr_root[group] for x in index_arr): - raise InvalidDatasetError( - f"Keys of index array for group {group} does not match the group members." - ) + for _, group in self.zarr_root.groups(): + _verify_group(group) # Check the structure of the Zarr archive # All arrays or groups in the root should have the same length. lengths = {len(self.zarr_root[k]) for k in self.zarr_root.array_keys()} - lengths.update({len(self.zarr_root[k][_INDEX_ARRAY_KEY]) for k in self.zarr_root.group_keys()}) + for _, group in self.zarr_root.groups(): + lengths.update(_get_group_length(group)) + if len(lengths) > 1: raise InvalidDatasetError( f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}" @@ -176,19 +246,29 @@ def get_data(self, row: int, col: str, adapters: dict[str, Adapter] | None = Non # Get the data group_or_array = self.zarr_data[col] - # If it is a group, there is no deterministic order for the child keys. - # We therefore use a special array that defines the index. - # If loaded to memory, the group is represented by a dictionary. - if isinstance(group_or_array, zarr.Group) or isinstance(group_or_array, dict): - # Indices in a group should always be strings + # For each column, the Zarr archive can take three formats: + # 1. An array: Each index in the array corresponds to a data point in the dataset. + # 2. A group with subgroups: Each subgroup corresponds to a data point in the dataset. + # The special __index__ array specifies the ordering of the subgroups. + # 3. A group with arrays: The nth datapoint is constructed from indexing the nth element in each array. + + if isinstance(group_or_array, zarr.Array) or isinstance(group_or_array, np.ndarray): + # Option 1: An array + data = group_or_array[row] + elif _get_group_format(self._zarr_root[col]) == "subgroups": + # Option 2: A group with subgroups + # The __index__ array specifies the ordering. An index is always a string. row = str(group_or_array[_INDEX_ARRAY_KEY][row]) - arr = group_or_array[row] + data = load_zarr_group_to_memory(group_or_array[row]) + else: + # Option 3: A group with arrays + data = {k: group_or_array[k][row] for k in group_or_array.keys()} # Adapt the input to the specified format if adapter is not None: - arr = adapter(arr) + data = adapter(data) - return arr + return data def upload_to_hub(self, access: AccessType = "private", owner: HubOwner | str | None = None): """ diff --git a/tests/conftest.py b/tests/conftest.py index 53d16284..03837d45 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ ) from polaris.competition import CompetitionSpecification from polaris.dataset import ColumnAnnotation, DatasetFactory, DatasetV1, DatasetV2 +from polaris.dataset._dataset_v2 import _GROUP_FORMAT_METADATA_KEY, _INDEX_ARRAY_KEY from polaris.dataset.converters import SDFConverter from polaris.utils.types import HubOwner @@ -126,7 +127,7 @@ def test_dataset(test_data, test_org_owner) -> DatasetV1: @pytest.fixture(scope="function") -def test_dataset_v2(zarr_archive, test_org_owner) -> DatasetV2: +def test_dataset_v2(zarr_archive_v2, test_org_owner) -> DatasetV2: dataset = DatasetV2( name="test-dataset-v2", source="https://www.example.com", @@ -136,7 +137,7 @@ def test_dataset_v2(zarr_archive, test_org_owner) -> DatasetV2: owner=test_org_owner, license="CC-BY-4.0", curation_reference="https://www.example.com", - zarr_root_path=zarr_archive, + zarr_root_path=zarr_archive_v2, ) check_version(dataset) return dataset @@ -152,6 +153,40 @@ def zarr_archive(tmp_path): return tmp_path +@pytest.fixture(scope="function") +def zarr_archive_v2(tmp_path): + tmp_path = fs.join(tmp_path, "data.zarr") + root = zarr.open(tmp_path, mode="w") + + # The root can have 3 types of children, each corresponding to a column: + # 1. An array: Each index in the array corresponds to a data point in the dataset. + data = np.arange(256 * 100).reshape((100, 256)) + root.array("A", data=data, chunks=(1, None)) + + # 2. A group with subgroups: Each subgroup corresponds to a data point in the dataset. + # The special __index__ array specifies the ordering of the subgroups. + group = root.create_group("B") + group.attrs[_GROUP_FORMAT_METADATA_KEY] = "subgroups" + for idx in range(100): + subgroup = group.create_group(str(idx)) + subgroup.array("x", data=np.arange(32) + (idx * 32), chunks=None) + subgroup.array("y", data=np.arange(32) + (idx * 32), chunks=None) + subgroup.array("z", data=np.arange(32) + (idx * 32), chunks=None) + group.array(_INDEX_ARRAY_KEY, data=np.arange(100), chunks=None) + + # 3. A group with arrays: The nth datapoint is constructed from indexing the nth element in each array. + group = root.create_group("C") + group.attrs[_GROUP_FORMAT_METADATA_KEY] = "arrays" + + data = np.arange(32 * 100).reshape((100, 32)) + group.array("x", data=data, chunks=(1, None)) + group.array("y", data=data, chunks=(1, None)) + group.array("z", data=data, chunks=(1, None)) + + zarr.consolidate_metadata(root.store) + return tmp_path + + @pytest.fixture() def regression_metrics(): return [ diff --git a/tests/test_benchmark_v2.py b/tests/test_benchmark_v2.py index 732c4b67..5bf29c4d 100644 --- a/tests/test_benchmark_v2.py +++ b/tests/test_benchmark_v2.py @@ -113,7 +113,7 @@ def test_benchmark_v2_specification(valid_split, test_dataset_v2, tmp_path): BenchmarkV2Specification(**config) -def test_benchmark_v2_invalid_indices(valid_split, test_dataset_v2): +def test_benchmark_v2_invalid_indices(test_dataset_v2): """ Test validation of indices against dataset length """ diff --git a/tests/test_dataset_v2.py b/tests/test_dataset_v2.py index c664ded0..19dad3cc 100644 --- a/tests/test_dataset_v2.py +++ b/tests/test_dataset_v2.py @@ -1,5 +1,4 @@ import os -from copy import deepcopy from time import perf_counter import numcodecs @@ -10,36 +9,46 @@ from pydantic import ValidationError from polaris.dataset import DatasetV2, Subset -from polaris.dataset._dataset_v2 import _INDEX_ARRAY_KEY -from polaris.dataset._factory import DatasetFactory -from polaris.dataset.converters._pdb import PDBConverter +from polaris.dataset._dataset_v2 import _GROUP_FORMAT_METADATA_KEY, _INDEX_ARRAY_KEY from polaris.dataset.zarr._manifest import generate_zarr_manifest +# Helper methods + + +def _check_column_b_or_c_data(data, idx): + return all(np.array_equal(data[prop], np.arange(32) + (32 * idx)) for prop in ["x", "y", "z"]) + + +def _check_column_a_data(data, idx): + return np.array_equal(data, np.arange(256) + (256 * idx)) + + +# Test cases + def test_dataset_v2_get_columns(test_dataset_v2): - assert set(test_dataset_v2.columns) == {"A", "B"} + assert set(test_dataset_v2.columns) == {"A", "B", "C"} def test_dataset_v2_get_rows(test_dataset_v2): assert set(test_dataset_v2.rows) == set(range(100)) -def test_dataset_v2_get_data(test_dataset_v2, zarr_archive): - root = zarr.open(zarr_archive, "r") +def test_dataset_v2_get_data(test_dataset_v2): indices = np.random.randint(0, len(test_dataset_v2), 5) for idx in indices: - assert np.array_equal(test_dataset_v2.get_data(row=idx, col="A"), root["A"][idx]) - assert np.array_equal(test_dataset_v2.get_data(row=idx, col="B"), root["B"][idx]) + assert _check_column_a_data(test_dataset_v2.get_data(row=idx, col="A"), idx) + assert _check_column_b_or_c_data(test_dataset_v2.get_data(row=idx, col="B"), idx) + assert _check_column_b_or_c_data(test_dataset_v2.get_data(row=idx, col="C"), idx) -def test_dataset_v2_with_subset(test_dataset_v2, zarr_archive): - root = zarr.open(zarr_archive, "r") +def test_dataset_v2_with_subset(test_dataset_v2): indices = np.random.randint(0, len(test_dataset_v2), 5) - subset = Subset(test_dataset_v2, indices, "A", "B") + subset = Subset(test_dataset_v2, indices, "A", ["B", "C"]) for i, (x, y) in enumerate(subset): - idx = indices[i] - assert np.array_equal(x, root["A"][idx]) - assert np.array_equal(y, root["B"][idx]) + assert _check_column_a_data(x, indices[i]) + assert _check_column_b_or_c_data(y["B"], indices[i]) + assert _check_column_b_or_c_data(y["C"], indices[i]) def test_dataset_v2_load_to_memory(test_dataset_v2): @@ -47,7 +56,7 @@ def test_dataset_v2_load_to_memory(test_dataset_v2): dataset=test_dataset_v2, indices=range(100), input_cols=["A"], - target_cols=["B"], + target_cols=["B", "C"], ) t1 = perf_counter() @@ -70,8 +79,9 @@ def test_dataset_v2_serialization(test_dataset_v2, tmp_path): path = test_dataset_v2.to_json(save_dir) new_dataset = DatasetV2.from_json(path) for i in range(5): - assert np.array_equal(new_dataset.get_data(i, "A"), test_dataset_v2.get_data(i, "A")) - assert np.array_equal(new_dataset.get_data(i, "B"), test_dataset_v2.get_data(i, "B")) + assert _check_column_a_data(new_dataset.get_data(row=i, col="A"), i) + assert _check_column_b_or_c_data(new_dataset.get_data(row=i, col="B"), i) + assert _check_column_b_or_c_data(new_dataset.get_data(row=i, col="C"), i) def test_dataset_v2_caching(test_dataset_v2, tmp_path): @@ -108,122 +118,75 @@ def test_dataset_v1_v2_compatibility(test_dataset, tmp_path): assert y1 == y2 -def test_dataset_v2_with_pdbs(pdb_paths, tmp_path): - # The PDB example is interesting because it creates a more complex Zarr archive - # that includes subgroups - zarr_root_path = str(tmp_path / "pdbs.zarr") - factory = DatasetFactory(zarr_root_path) - - # Build a V1 dataset - converter = PDBConverter() - factory.register_converter("pdb", converter) - factory.add_from_files(pdb_paths, axis=0) - dataset_v1 = factory.build() - - # Build a V2 dataset based on the V1 dataset - - # Add the magic index column to the Zarr subgroup - root = zarr.open(zarr_root_path, "a") - ordered_keys = [v.split("/")[-1] for v in dataset_v1.table["pdb"].values] - root["pdb"].array(_INDEX_ARRAY_KEY, data=ordered_keys, dtype=object, object_codec=numcodecs.VLenUTF8()) - zarr.consolidate_metadata(zarr_root_path) - - # Update annotations to no longer have pointer columns - annotations = deepcopy(dataset_v1.annotations) - for anno in annotations.values(): - anno.is_pointer = False - - # Create the dataset - dataset_v2 = DatasetV2( - zarr_root_path=zarr_root_path, - annotations=annotations, - default_adapters=dataset_v1.default_adapters, - ) - - assert len(dataset_v1) == len(dataset_v2) - for idx in range(len(dataset_v1)): - pdb_1 = dataset_v1.get_data(idx, "pdb") - pdb_2 = dataset_v2.get_data(idx, "pdb") - assert pdb_1 == pdb_2 - - -def test_dataset_v2_indexing(zarr_archive): - # Create a subgroup with 100 arrays - root = zarr.open(zarr_archive, "a") - subgroup = root.create_group("X") - for i in range(100): - subgroup.array(f"{i}", data=np.array([i] * 100)) - - # Index it in reverse (element 0 is the last element in the array) - indices = [f"{idx}" for idx in range(100)][::-1] - subgroup.array(_INDEX_ARRAY_KEY, data=indices, dtype=object, object_codec=numcodecs.VLenUTF8()) - zarr.consolidate_metadata(zarr_archive) - - # Create the dataset - dataset = DatasetV2(zarr_root_path=zarr_archive) - - # Check that the dataset is indexed correctly - for idx in range(100): - assert np.array_equal(dataset.get_data(row=idx, col="X"), np.array([99 - idx] * 100)) +def test_dataset_v2_with_pdbs(pdb_paths, tmp_path): ... -def test_dataset_v2_validation_index_array(zarr_archive): - root = zarr.open(zarr_archive, "a") +def test_dataset_v2_validation_index_array(zarr_archive_v2): + root = zarr.open(zarr_archive_v2, "a") # Create subgroup that lacks the index array - subgroup = root.create_group("X") - zarr.consolidate_metadata(zarr_archive) + subgroup = root.create_group("D") + subgroup.attrs[_GROUP_FORMAT_METADATA_KEY] = "subgroups" + zarr.consolidate_metadata(zarr_archive_v2) with pytest.raises(ValidationError, match="does not have an index array"): - DatasetV2(zarr_root_path=zarr_archive) + DatasetV2(zarr_root_path=zarr_archive_v2) indices = [f"{idx}" for idx in range(100)] indices[-1] = "invalid" # Create index array that doesn't match group length (zero arrays in group, but 100 indices) subgroup.array(_INDEX_ARRAY_KEY, data=indices, dtype=object, object_codec=numcodecs.VLenUTF8()) - zarr.consolidate_metadata(zarr_archive) + zarr.consolidate_metadata(zarr_archive_v2) with pytest.raises(ValidationError, match="Length of index array"): - DatasetV2(zarr_root_path=zarr_archive) + DatasetV2(zarr_root_path=zarr_archive_v2) for i in range(100): - subgroup.array(f"{i}", data=np.random.random(100)) - zarr.consolidate_metadata(zarr_archive) + subgroup.create_group(str(i)) + zarr.consolidate_metadata(zarr_archive_v2) # Create index array that has invalid keys (last keys = 'invalid' rather than '99') with pytest.raises(ValidationError, match="Keys of index array"): - DatasetV2(zarr_root_path=zarr_archive) + DatasetV2(zarr_root_path=zarr_archive_v2) -def test_dataset_v2_validation_consistent_lengths(zarr_archive): - root = zarr.open(zarr_archive, "a") +def test_dataset_v2_validation_consistent_lengths(zarr_archive_v2): + root = zarr.open(zarr_archive_v2, "a") # Change the length of one of the arrays - root["A"].append(np.random.random((1, 2048))) - zarr.consolidate_metadata(zarr_archive) + root["A"].append(np.random.random((1, 256))) + zarr.consolidate_metadata(zarr_archive_v2) # Subgroup has a false number of indices with pytest.raises(ValidationError, match="should have the same length"): - DatasetV2(zarr_root_path=zarr_archive) + DatasetV2(zarr_root_path=zarr_archive_v2) + + # Make the length of the different columns equal again + subgroup = root["B"].create_group("100") + subgroup.array("x", data=np.arange(32)) + subgroup.array("y", data=np.arange(32)) + subgroup.array("z", data=np.arange(32)) - # Make the length of the two arrays equal again - # shouldn't crash - root["B"].append(np.random.random((1, 2048))) - zarr.consolidate_metadata(zarr_archive) - DatasetV2(zarr_root_path=zarr_archive) + # Directly appending a single element fails, likely because a bug in the Zarr Array API + root["B"][_INDEX_ARRAY_KEY] = root["B"][_INDEX_ARRAY_KEY][:].tolist() + [100] + + root["C"]["x"].append(np.arange(32).reshape(1, 32)) + root["C"]["y"].append(np.arange(32).reshape(1, 32)) + root["C"]["z"].append(np.arange(32).reshape(1, 32)) + + zarr.consolidate_metadata(zarr_archive_v2) + DatasetV2(zarr_root_path=zarr_archive_v2) # Create subgroup with inconsistent length - subgroup = root.create_group("X") - for i in range(100): - subgroup.array(f"{i}", data=np.random.random(100)) - indices = [f"{idx}" for idx in range(100)] - subgroup.array(_INDEX_ARRAY_KEY, data=indices, dtype=object, object_codec=numcodecs.VLenUTF8()) - zarr.consolidate_metadata(zarr_archive) + subgroup = root.create_group("D") + subgroup.create_group("0") + subgroup.array(_INDEX_ARRAY_KEY, data=["0"], dtype=object, object_codec=numcodecs.VLenUTF8()) + zarr.consolidate_metadata(zarr_archive_v2) # Subgroup has a false number of indices with pytest.raises(ValidationError, match="should have the same length"): - DatasetV2(zarr_root_path=zarr_archive) + DatasetV2(zarr_root_path=zarr_archive_v2) def test_zarr_manifest(test_dataset_v2): @@ -231,17 +194,20 @@ def test_zarr_manifest(test_dataset_v2): assert test_dataset_v2.zarr_manifest_path is not None assert os.path.isfile(test_dataset_v2.zarr_manifest_path) - # Assert the manifest contains 204 rows (the number "204" is chosen because - # the Zarr archive defined in `conftest.py` contains 204 unique files) + # The root has 2 files (.zmetadata, .zgroup) + # The A array has 1 array with 100 chunks = 100 + 1 = 101 + # The B group has 100 groups with 3 single-chunk arrays + 1 single-chunk array = 100 * (3 * 2 + 1) + 2 + 2 = 704 + # The C group has 3 arrays with 100 chunks = 3 * (100 + 1) + 2 = 305 + # Total number of files: 2 + 101 + 702 + 305 = 1112 df = pd.read_parquet(test_dataset_v2.zarr_manifest_path) - assert len(df) == 204 + assert len(df) == 1112 # Assert the manifest hash is calculated assert test_dataset_v2.zarr_manifest_md5sum is not None # Add array to Zarr archive to change the number of chunks in the dataset root = zarr.open(test_dataset_v2.zarr_root_path, "a") - root.array("C", data=np.random.random((100, 2048)), chunks=(1, None)) + root.array("D", data=np.random.random((100, 256)), chunks=(1, None)) generate_zarr_manifest(test_dataset_v2.zarr_root_path, test_dataset_v2._cache_dir) @@ -249,23 +215,37 @@ def test_zarr_manifest(test_dataset_v2): post_change_manifest_length = len(pd.read_parquet(test_dataset_v2.zarr_manifest_path)) # Ensure Zarr manifest has an additional 100 chunks + 1 array metadata file - assert post_change_manifest_length == 305 + assert post_change_manifest_length == 1213 -def test_dataset_v2__get_item__(test_dataset_v2, zarr_archive): +def test_dataset_v2__get_item__(test_dataset_v2): """Test the __getitem__() interface for the dataset V2.""" - # Ground truth - root = zarr.open(zarr_archive) - # Get a specific cell - assert np.array_equal(test_dataset_v2[0, "A"], root["A"][0, :]) + assert np.array_equal(test_dataset_v2[0, "A"], np.arange(256)) # Get a specific row - def _check_row_equality(d1, d2): + def _check_dict_equality(d1, d2): assert len(d1) == len(d2) - for k in d1: - assert np.array_equal(d1[k], d2[k]) - - _check_row_equality(test_dataset_v2[0], {"A": root["A"][0, :], "B": root["B"][0, :]}) - _check_row_equality(test_dataset_v2[10], {"A": root["A"][10, :], "B": root["B"][10, :]}) + for k, v in d1.items(): + if isinstance(v, dict): + _check_dict_equality(v, d2[k]) + else: + assert np.array_equal(d1[k], d2[k]) + + _check_dict_equality( + test_dataset_v2[0], + { + "A": np.arange(256), + "B": {"x": np.arange(32), "y": np.arange(32), "z": np.arange(32)}, + "C": {"x": np.arange(32), "y": np.arange(32), "z": np.arange(32)}, + }, + ) + _check_dict_equality( + test_dataset_v2[10], + { + "A": np.arange(256) + 2560, + "B": {"x": np.arange(32) + 320, "y": np.arange(32) + 320, "z": np.arange(32) + 320}, + "C": {"x": np.arange(32) + 320, "y": np.arange(32) + 320, "z": np.arange(32) + 320}, + }, + )