From fee8c099eb7b05ed6057535adfd8c447fbaa97d9 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Mon, 18 Nov 2024 14:30:07 +0000 Subject: [PATCH] Another nit, opening an empty image would crash --- python/go_types.py | 11 ++++++++++- python/test_datago_db.py | 10 ++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/python/go_types.py b/python/go_types.py index 7e2df09..11296ee 100644 --- a/python/go_types.py +++ b/python/go_types.py @@ -7,6 +7,9 @@ def uint8_array_to_numpy(go_array): + if len(go_array.Data) == 0: + return None + # By convention, arrays which are already serialized as jpg or png are not reshaped # We export them from Go with a Channels dimension of -1 to mark them as dimensionless. # Anything else is a valid number of channels and will thus lead to a reshape @@ -29,6 +32,9 @@ def uint8_array_to_numpy(go_array): def go_array_to_numpy(go_array) -> Optional[np.ndarray]: + if len(go_array.Data) == 0: + return None + # Generic numpy-serialized array bytes_buffer = bytes(go.Slice_byte(go_array.Data)) try: @@ -39,7 +45,10 @@ def go_array_to_numpy(go_array) -> Optional[np.ndarray]: return None -def go_array_to_pil_image(go_array): +def go_array_to_pil_image(go_array) -> Optional[Image.Image]: + if len(go_array.Data) == 0: + return None + # Zero copy conversion of the image buffer from Go to PIL.Image np_array = uint8_array_to_numpy(go_array) if go_array.Channels <= 0: diff --git a/python/test_datago_db.py b/python/test_datago_db.py index f648f1a..a78a85c 100644 --- a/python/test_datago_db.py +++ b/python/test_datago_db.py @@ -147,6 +147,16 @@ def test_has_tags(): assert "v4_trainset_hq" in sample.Tags, "v4_trainset_hq should be in the tags" +def test_empty_image(): + client_config = get_json_config() + client_config["source_config"]["require_images"] = False + + dataset = get_dataset(client_config) + + # Just check that accessing the sample in python does not crash + _ = next(iter(dataset)) + + def no_test_jpg_compression(): # Check that the images are compressed as expected client_config = get_json_config()