From c49305dbb0a5205c1e01468c5418c0f78fa59358 Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Tue, 21 Jan 2025 11:38:49 +0100 Subject: [PATCH] [BUGFIX] Prevent index error with empty values (#5787) # Description This PR fixes error when indexing records with missing values for chat fields. Noticed in [this discord thread](https://discord.com/channels/879548962464493619/1325817229954252841) **Type of change** - Bug fix (non-breaking change which fixes an issue) **How Has This Been Tested** **Checklist** - I added relevant documentation - I followed the style guidelines of this project - I did a self-review of my code - I made corresponding changes to the documentation - I confirm My changes generate no new warnings - I have added tests that prove my fix is effective or that my feature works - I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --- argilla-server/CHANGELOG.md | 5 ++ .../argilla_server/search_engine/commons.py | 12 +-- .../tests/unit/search_engine/test_commons.py | 74 +++++++++++++++++++ argilla/tests/integration/test_add_records.py | 46 ++++++++++-- 4 files changed, 125 insertions(+), 12 deletions(-) diff --git a/argilla-server/CHANGELOG.md b/argilla-server/CHANGELOG.md index 0b0c212cd2..ff5685d606 100644 --- a/argilla-server/CHANGELOG.md +++ b/argilla-server/CHANGELOG.md @@ -21,6 +21,11 @@ These are the section headers that we use: - Added support to create users with predefined ids. ([#5786](https://github.com/argilla-io/argilla/pull/5786)) - Added support to create workspaces with predefined ids. ([#5786](https://github.com/argilla-io/argilla/pull/5786)) +### Fixed + +- Fixed error when indexing records with missing chat fields. ([#5787](https://github.com/argilla-io/argilla/pull/5787)) +- Prevent store empty custom fields as `None`. ([#5787](https://github.com/argilla-io/argilla/pull/5787)) + ## [2.6.0](https://github.com/argilla-io/argilla/compare/v2.5.0...v2.6.0) ### Added diff --git a/argilla-server/src/argilla_server/search_engine/commons.py b/argilla-server/src/argilla_server/search_engine/commons.py index 73f2d902c1..aa37f88739 100644 --- a/argilla-server/src/argilla_server/search_engine/commons.py +++ b/argilla-server/src/argilla_server/search_engine/commons.py @@ -873,11 +873,13 @@ def _map_record_response_to_es(response: Response) -> Dict[str, Any]: def _map_record_fields_to_es(cls, fields: dict, dataset_fields: List[Field]) -> dict: for field in dataset_fields: if field.is_image: - fields[field.name] = None - elif field.is_custom: - fields[field.name] = str(fields.get(field.name, "")) - else: - fields[field.name] = fields.get(field.name, "") + continue + + value = fields.get(field.name) + if field.is_custom and value is not None: + value = str(value) + + fields[field.name] = value return fields diff --git a/argilla-server/tests/unit/search_engine/test_commons.py b/argilla-server/tests/unit/search_engine/test_commons.py index 9ecd1cf42d..d7b6f99099 100644 --- a/argilla-server/tests/unit/search_engine/test_commons.py +++ b/argilla-server/tests/unit/search_engine/test_commons.py @@ -983,6 +983,80 @@ async def test_index_records(self, search_engine: BaseElasticAndOpenSearchEngine for record in records ] + async def test_index_records_with_none_field_values( + self, search_engine: BaseElasticAndOpenSearchEngine, opensearch: OpenSearch + ): + text_field = await TextFieldFactory.create(name="text", required=False) + image_field = await ImageFieldFactory.create(name="image", required=False) + chat_field = await ChatFieldFactory.create(name="chat", required=False) + custom_field = await CustomFieldFactory.create(name="custom", required=False) + + dataset = await DatasetFactory.create( + fields=[text_field, image_field, chat_field, custom_field], + questions=[], + ) + + record = await RecordFactory.create( + dataset=dataset, + fields={ + text_field.name: None, + image_field.name: None, + chat_field.name: None, + custom_field.name: None, + }, + responses=[], + ) + + other_record = await RecordFactory.create( + dataset=dataset, + fields={ + text_field.name: "This is the value for text", + image_field.name: "https://random.url/image", + chat_field.name: [{"role": "user", "content": "Hello world"}, {"role": "bot", "content": "Hi"}], + custom_field.name: {"a": "This is a value", "b": 100}, + }, + responses=[], + ) + + records = [record, other_record] + + await refresh_dataset(dataset) + await refresh_records(records) + + await search_engine.create_index(dataset) + await search_engine.index_records(dataset, records) + + index_name = es_index_name_for_dataset(dataset) + + es_docs = [hit["_source"] for hit in opensearch.search(index=index_name)["hits"]["hits"]] + assert es_docs == [ + { + "id": str(record.id), + "fields": { + text_field.name: None, + # image_field.name: None, # image fields are not indexed + chat_field.name: None, + custom_field.name: None, + }, + "external_id": record.external_id, + "status": RecordStatus.pending, + "inserted_at": record.inserted_at.isoformat(), + "updated_at": record.updated_at.isoformat(), + }, + { + "id": str(other_record.id), + "fields": { + text_field.name: other_record.fields[text_field.name], + chat_field.name: other_record.fields[chat_field.name], + custom_field.name: other_record.fields[custom_field.name], + }, + "external_id": other_record.external_id, + "status": RecordStatus.pending, + "inserted_at": other_record.inserted_at.isoformat(), + "updated_at": other_record.updated_at.isoformat(), + }, + ] + async def test_configure_metadata_property( self, search_engine: BaseElasticAndOpenSearchEngine, opensearch: OpenSearch ): diff --git a/argilla/tests/integration/test_add_records.py b/argilla/tests/integration/test_add_records.py index 11b9652125..4a360e2050 100644 --- a/argilla/tests/integration/test_add_records.py +++ b/argilla/tests/integration/test_add_records.py @@ -569,9 +569,7 @@ def test_add_record_resources(client): assert dataset_records[2].suggestions["topics"].score == [0.9, 0.8] -def test_add_record_with_chat_field(client): - user_id = client.users[0].id - mock_dataset_name = f"test_add_record_with_chat_field{datetime.now().strftime('%Y%m%d%H%M%S')}" +def test_add_record_with_chat_field(client: rg.Argilla, dataset_name: str): mock_resources = [ rg.Record( fields={ @@ -604,23 +602,57 @@ def test_add_record_with_chat_field(client): ] settings = rg.Settings( fields=[ - rg.ChatField(name="chat", required=True), + rg.ChatField(name="chat", required=False), ], questions=[ rg.TextQuestion(name="comment", use_markdown=False), ], ) dataset = rg.Dataset( - name=mock_dataset_name, + name=dataset_name, settings=settings, client=client, ) dataset.create() dataset.records.log(records=mock_resources) + list(dataset.records) - dataset_records = list(dataset.records) + assert dataset.name == dataset_name - assert dataset.name == mock_dataset_name + +def test_add_records_with_optional_chat_field(client: rg.Argilla, dataset_name: str): + mock_resources = [ + rg.Record( + fields={ + "text": "This a text", + "chat": None, + }, + ), + rg.Record( + fields={ + "text": "This a text", + }, + ), + ] + settings = rg.Settings( + fields=[ + rg.TextField(name="text", required=True), + rg.ChatField(name="chat", required=False), + ], + questions=[ + rg.TextQuestion(name="comment", use_markdown=False), + ], + ) + dataset = rg.Dataset( + name=dataset_name, + settings=settings, + client=client, + ) + dataset.create() + dataset.records.log(records=mock_resources) + assert len(list(dataset.records(query="this"))) == 2 # Forcing search to check the records indexation + + assert dataset.name == dataset_name def test_add_records_with_responses_and_same_schema_name(client: Argilla, username: str):