From 3a0d99a5261a7c1cf7cd0da602d0a7de6ea571b7 Mon Sep 17 00:00:00 2001 From: Elron Bandel Date: Thu, 23 Jan 2025 13:51:18 +0200 Subject: [PATCH] Add mtrag benchmark (#1548) * Add mtrag benchmark Signed-off-by: elronbandel * Add multi_type_serializer for references and prediction fields in various JSON metrics Signed-off-by: elronbandel * Remove unused TempOperator class and delete obsolete multi_turn.json task file Signed-off-by: elronbandel --------- Signed-off-by: elronbandel --- prepare/cards/mtrag.py | 131 ++++++++++++++++++ prepare/metrics/rag.py | 3 + prepare/metrics/rag_answer_relevance.py | 3 + prepare/metrics/rag_context_relevance.py | 3 + prepare/metrics/rag_metrics_deprecated.py | 2 + prepare/tasks/rag/rag_end_to_end.py | 5 +- prepare/templates/rag/end_to_end.py | 11 +- src/unitxt/catalog/cards/rag/mtrag.json | 64 +++++++++ .../cards/rag/mtrag/documents/clapnq.json | 43 ++++++ .../cards/rag/mtrag/documents/cloud.json | 43 ++++++ .../cards/rag/mtrag/documents/fiqa.json | 43 ++++++ .../cards/rag/mtrag/documents/govt.json | 37 +++++ .../rag/answer_relevance/token_recall.json | 4 + .../catalog/metrics/rag/answer_reward.json | 4 + .../answer_relevance/answer_reward.json | 5 + .../answer_relevance/token_recall.json | 5 + .../metrics/rag/end_to_end/answer_reward.json | 5 + .../rag/end_to_end/context_relevance.json | 4 + .../perplexity_flan_t5_small.json | 4 + .../context_relevance/sentence_bert_bge.json | 4 + .../sentence_bert_mini_lm.json | 4 + .../context_relevance/token_precision.json | 4 + .../rag/external_rag/context_relevance.json | 4 + .../perplexity_flan_t5_small.json | 4 + .../context_relevance/sentence_bert_bge.json | 4 + .../sentence_bert_mini_lm.json | 4 + .../context_relevance/token_precision.json | 4 + .../answer_relevance/answer_reward.json | 5 + .../answer_relevance/token_recall.json | 5 + src/unitxt/catalog/tasks/rag/end_to_end.json | 2 +- .../rag/end_to_end/json_predictions.json | 12 +- src/unitxt/loaders.py | 45 ++++-- src/unitxt/serializers.py | 1 + src/unitxt/templates.py | 27 ++++ utils/.secrets.baseline | 4 +- 35 files changed, 534 insertions(+), 18 deletions(-) create mode 100644 prepare/cards/mtrag.py create mode 100644 src/unitxt/catalog/cards/rag/mtrag.json create mode 100644 src/unitxt/catalog/cards/rag/mtrag/documents/clapnq.json create mode 100644 src/unitxt/catalog/cards/rag/mtrag/documents/cloud.json create mode 100644 src/unitxt/catalog/cards/rag/mtrag/documents/fiqa.json create mode 100644 src/unitxt/catalog/cards/rag/mtrag/documents/govt.json diff --git a/prepare/cards/mtrag.py b/prepare/cards/mtrag.py new file mode 100644 index 0000000000..ee9887e6eb --- /dev/null +++ b/prepare/cards/mtrag.py @@ -0,0 +1,131 @@ +import json + +from unitxt import add_to_catalog +from unitxt.blocks import ( + TaskCard, +) +from unitxt.collections_operators import Dictify, Wrap +from unitxt.loaders import LoadCSV +from unitxt.operators import ( + Cast, + Copy, + MapInstanceValues, + Set, + ZipFieldValues, +) +from unitxt.templates import InputOutputTemplate +from unitxt.test_utils.card import test_card + +card = TaskCard( + loader=LoadCSV( + files={ + "test": "https://raw.githubusercontent.com/IBM/mt-rag-benchmark/refs/heads/main/human/generation_tasks/reference+RAG.jsonl" + }, + file_type="json", + lines=True, + data_classification_policy=["public"], + ), + preprocess_steps=[ + MapInstanceValues( + { + "Answerability": { + "['UNANSWERABLE']": False, + "['ANSWERABLE']": True, + "['PARTIAL']": True, + }, + } + ), + Copy( + field_to_field={ + "targets/*/text": "reference_answers", + "Answerability": "is_answerable_label", + "task_id": "question_id", + "contexts/*/document_id": "reference_context_ids", + "contexts/*/text": "reference_contexts", + "input/*/speaker": "roles", + "input/*/text": "contents", + }, + ), + ZipFieldValues( + fields=["roles", "contents"], + to_field="conversation", + ), + Dictify( + field="conversation", + with_keys=["role", "content"], + to_field="question", + process_every_value=True, + ), + ], + task="tasks.rag.end_to_end", + templates={"default": "templates.rag.end_to_end.json_predictions"}, + __tags__={"license": "apache-2.0"}, + __description__="""MTRAG: a comprehensive and diverse human-generated multi-turn RAG dataset, accompanied by four document corpora. To the best of our knowledge, MTRAG is the first end-to-end human-generated multi-turn RAG benchmark that reflects real-world properties of multi-turn conversations. +""", +) +wrong_answer = { + "contexts": ["hi"], + "is_answerable": True, + "answer": "Don't know", + "context_ids": ["id0"], +} + +test_card( + card, + strict=False, + full_mismatch_prediction_values=[json.dumps(wrong_answer)], + debug=False, + demos_taken_from="test", + demos_pool_size=5, +) + +add_to_catalog(card, "cards.rag.mtrag", overwrite=True) + + +for subset in ["clapnq", "cloud", "fiqa", "govt"]: + subset_operators = [] + if subset in ["fiqa", "clapnq"]: + subset_operators.append( + Cast( + field="_id", + to="str", + to_field="document_id", + ) + ) + if subset in ["cloud"]: + subset_operators.append(Set(fields={"title": ""})) + + card = TaskCard( + loader=LoadCSV( + files={ + "test": f"https://github.com/IBM/mt-rag-benchmark/raw/refs/heads/main/corpora/{subset}.jsonl.zip" + }, + compression="zip", + file_type="json", + lines=True, + data_classification_policy=["public"], + ), + preprocess_steps=[ + *subset_operators, + Wrap(field="text", inside="list", to_field="passages"), + Set( + fields={ + "metadata_field": "", + } + ), + ], + task="tasks.rag.corpora", + templates={ + "empty": InputOutputTemplate( + input_format="", + output_format="", + ), + }, + ) + test_card( + card, + strict=False, + demos_taken_from="test", + ) + + add_to_catalog(card, f"cards.rag.mtrag.documents.{subset}", overwrite=True) diff --git a/prepare/metrics/rag.py b/prepare/metrics/rag.py index 1cc3fc2c98..7eea4d28f2 100644 --- a/prepare/metrics/rag.py +++ b/prepare/metrics/rag.py @@ -7,6 +7,7 @@ TokenOverlap, ) from unitxt.operators import Copy, ListFieldValues +from unitxt.serializers import MultiTypeSerializer from unitxt.test_utils.metrics import test_metric metrics = { @@ -494,6 +495,7 @@ "metrics.rag.end_to_end.answer_reward": [ copy_field_prediction_answer_to_prediction, copy_field_question_to_references_in_a_list, + MultiTypeSerializer(field="references", process_every_value=True), ], "metrics.rag.end_to_end.answer_faithfulness": [ copy_field_prediction_contexts_to_references, @@ -506,6 +508,7 @@ "metrics.rag.end_to_end.context_relevance": [ copy_field_prediction_contexts_to_references, copy_field_question_to_prediction, + MultiTypeSerializer(field="prediction"), ], } diff --git a/prepare/metrics/rag_answer_relevance.py b/prepare/metrics/rag_answer_relevance.py index d947898334..a092ac04ce 100644 --- a/prepare/metrics/rag_answer_relevance.py +++ b/prepare/metrics/rag_answer_relevance.py @@ -3,6 +3,7 @@ MetricPipeline, ) from unitxt.operators import Copy, ListFieldValues +from unitxt.serializers import MultiTypeSerializer task_names = ["external_rag", "response_generation", "end_to_end"] base = "metrics.rag" @@ -30,6 +31,7 @@ def get_preprocess_steps(task): "task_data/question": "references", } ), + MultiTypeSerializer(field="references", process_every_value=True), last_step, ] if task == "end_to_end": @@ -40,6 +42,7 @@ def get_preprocess_steps(task): "prediction/answer": "prediction", } ), + MultiTypeSerializer(field="references", process_every_value=True), last_step, ] raise ValueError(f"Unsupported rag task {task}") diff --git a/prepare/metrics/rag_context_relevance.py b/prepare/metrics/rag_context_relevance.py index 6833d8ce65..fb9afe0669 100644 --- a/prepare/metrics/rag_context_relevance.py +++ b/prepare/metrics/rag_context_relevance.py @@ -3,6 +3,7 @@ MetricPipeline, ) from unitxt.operators import Copy +from unitxt.serializers import MultiTypeSerializer base = "metrics.rag" tasks = ["external_rag", "end_to_end"] @@ -15,11 +16,13 @@ def get_preprocess_steps(task): return [ Copy(field="contexts", to_field="references"), Copy(field="question", to_field="prediction"), + MultiTypeSerializer(field="prediction"), ] if task == "end_to_end": return [ Copy(field="prediction/contexts", to_field="references"), Copy(field="task_data/question", to_field="prediction"), + MultiTypeSerializer(field="prediction"), ] raise ValueError(f"Unsupported rag task for {dimension}:{task}") diff --git a/prepare/metrics/rag_metrics_deprecated.py b/prepare/metrics/rag_metrics_deprecated.py index 7b769df48a..56d54af439 100644 --- a/prepare/metrics/rag_metrics_deprecated.py +++ b/prepare/metrics/rag_metrics_deprecated.py @@ -2,6 +2,7 @@ from unitxt.collections_operators import Wrap from unitxt.metrics import MetricPipeline from unitxt.operators import Copy, ListFieldValues +from unitxt.serializers import MultiTypeSerializer base = "metrics.rag" new_base = "metrics.rag.external_rag" @@ -100,6 +101,7 @@ def get_replacing_metric(depr_metric): # This metric compares the answer (as the prediction) to the question (as the reference). # We have to wrap the question by a list (otherwise it will be a string), # because references are expected to be lists + MultiTypeSerializer(field="references"), ListFieldValues(fields=["references"], to_field="references"), ] add_metric_pipeline_to_catalog( diff --git a/prepare/tasks/rag/rag_end_to_end.py b/prepare/tasks/rag/rag_end_to_end.py index 839698b21e..7a6a2b0d6b 100644 --- a/prepare/tasks/rag/rag_end_to_end.py +++ b/prepare/tasks/rag/rag_end_to_end.py @@ -2,7 +2,7 @@ from unitxt import add_to_catalog from unitxt.blocks import Task -from unitxt.types import RagResponse +from unitxt.types import Dialog, RagResponse add_to_catalog( Task( @@ -11,7 +11,7 @@ For details of RAG see: https://www.unitxt.ai/en/latest/docs/rag_support.html. """, input_fields={ - "question": str, + "question": Union[str, Dialog], "question_id": Any, "metadata_field": str, }, @@ -44,6 +44,7 @@ overwrite=True, ) + add_to_catalog( Task( input_fields={ diff --git a/prepare/templates/rag/end_to_end.py b/prepare/templates/rag/end_to_end.py index 1d90220b22..abbaa45363 100644 --- a/prepare/templates/rag/end_to_end.py +++ b/prepare/templates/rag/end_to_end.py @@ -1,7 +1,7 @@ from unitxt import add_to_catalog from unitxt.operator import SequentialOperator from unitxt.struct_data_operators import LoadJson -from unitxt.templates import InputOutputTemplate +from unitxt.templates import JsonOutputTemplate add_to_catalog( SequentialOperator( @@ -18,9 +18,14 @@ add_to_catalog( # For rag end-to-end tasks - InputOutputTemplate( + JsonOutputTemplate( input_format="", - output_format='{{"answer": "{reference_answers}", "contexts" : ["{reference_contexts}"], "context_ids" : ["{reference_context_ids}"]}}', + output_fields={ + "reference_answers": "answer", + "reference_contexts": "contexts", + "reference_context_ids": "context_ids", + }, + wrap_with_list_fields=["reference_contexts", "reference_context_ids"], postprocessors=["processors.load_json_predictions"], ), "templates.rag.end_to_end.json_predictions", diff --git a/src/unitxt/catalog/cards/rag/mtrag.json b/src/unitxt/catalog/cards/rag/mtrag.json new file mode 100644 index 0000000000..6c3f42934b --- /dev/null +++ b/src/unitxt/catalog/cards/rag/mtrag.json @@ -0,0 +1,64 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_csv", + "files": { + "test": "https://raw.githubusercontent.com/IBM/mt-rag-benchmark/refs/heads/main/human/generation_tasks/reference+RAG.jsonl" + }, + "file_type": "json", + "lines": true, + "data_classification_policy": [ + "public" + ] + }, + "preprocess_steps": [ + { + "__type__": "map_instance_values", + "mappers": { + "Answerability": { + "['UNANSWERABLE']": false, + "['ANSWERABLE']": true, + "['PARTIAL']": true + } + } + }, + { + "__type__": "copy", + "field_to_field": { + "targets/*/text": "reference_answers", + "Answerability": "is_answerable_label", + "task_id": "question_id", + "contexts/*/document_id": "reference_context_ids", + "contexts/*/text": "reference_contexts", + "input/*/speaker": "roles", + "input/*/text": "contents" + } + }, + { + "__type__": "zip_field_values", + "fields": [ + "roles", + "contents" + ], + "to_field": "conversation" + }, + { + "__type__": "dictify", + "field": "conversation", + "with_keys": [ + "role", + "content" + ], + "to_field": "question", + "process_every_value": true + } + ], + "task": "tasks.rag.end_to_end", + "templates": { + "default": "templates.rag.end_to_end.json_predictions" + }, + "__tags__": { + "license": "apache-2.0" + }, + "__description__": "MTRAG: a comprehensive and diverse human-generated multi-turn RAG dataset, accompanied by four document corpora. To the best of our knowledge, MTRAG is the first end-to-end human-generated multi-turn RAG benchmark that reflects real-world properties of multi-turn conversations.\n" +} diff --git a/src/unitxt/catalog/cards/rag/mtrag/documents/clapnq.json b/src/unitxt/catalog/cards/rag/mtrag/documents/clapnq.json new file mode 100644 index 0000000000..bafb046a5b --- /dev/null +++ b/src/unitxt/catalog/cards/rag/mtrag/documents/clapnq.json @@ -0,0 +1,43 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_csv", + "files": { + "test": "https://github.com/IBM/mt-rag-benchmark/raw/refs/heads/main/corpora/clapnq.jsonl.zip" + }, + "compression": "zip", + "file_type": "json", + "lines": true, + "data_classification_policy": [ + "public" + ] + }, + "preprocess_steps": [ + { + "__type__": "cast", + "field": "_id", + "to": "str", + "to_field": "document_id" + }, + { + "__type__": "wrap", + "field": "text", + "inside": "list", + "to_field": "passages" + }, + { + "__type__": "set", + "fields": { + "metadata_field": "" + } + } + ], + "task": "tasks.rag.corpora", + "templates": { + "empty": { + "__type__": "input_output_template", + "input_format": "", + "output_format": "" + } + } +} diff --git a/src/unitxt/catalog/cards/rag/mtrag/documents/cloud.json b/src/unitxt/catalog/cards/rag/mtrag/documents/cloud.json new file mode 100644 index 0000000000..a7f21df562 --- /dev/null +++ b/src/unitxt/catalog/cards/rag/mtrag/documents/cloud.json @@ -0,0 +1,43 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_csv", + "files": { + "test": "https://github.com/IBM/mt-rag-benchmark/raw/refs/heads/main/corpora/cloud.jsonl.zip" + }, + "compression": "zip", + "file_type": "json", + "lines": true, + "data_classification_policy": [ + "public" + ] + }, + "preprocess_steps": [ + { + "__type__": "set", + "fields": { + "title": "" + } + }, + { + "__type__": "wrap", + "field": "text", + "inside": "list", + "to_field": "passages" + }, + { + "__type__": "set", + "fields": { + "metadata_field": "" + } + } + ], + "task": "tasks.rag.corpora", + "templates": { + "empty": { + "__type__": "input_output_template", + "input_format": "", + "output_format": "" + } + } +} diff --git a/src/unitxt/catalog/cards/rag/mtrag/documents/fiqa.json b/src/unitxt/catalog/cards/rag/mtrag/documents/fiqa.json new file mode 100644 index 0000000000..47871fc58a --- /dev/null +++ b/src/unitxt/catalog/cards/rag/mtrag/documents/fiqa.json @@ -0,0 +1,43 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_csv", + "files": { + "test": "https://github.com/IBM/mt-rag-benchmark/raw/refs/heads/main/corpora/fiqa.jsonl.zip" + }, + "compression": "zip", + "file_type": "json", + "lines": true, + "data_classification_policy": [ + "public" + ] + }, + "preprocess_steps": [ + { + "__type__": "cast", + "field": "_id", + "to": "str", + "to_field": "document_id" + }, + { + "__type__": "wrap", + "field": "text", + "inside": "list", + "to_field": "passages" + }, + { + "__type__": "set", + "fields": { + "metadata_field": "" + } + } + ], + "task": "tasks.rag.corpora", + "templates": { + "empty": { + "__type__": "input_output_template", + "input_format": "", + "output_format": "" + } + } +} diff --git a/src/unitxt/catalog/cards/rag/mtrag/documents/govt.json b/src/unitxt/catalog/cards/rag/mtrag/documents/govt.json new file mode 100644 index 0000000000..e72403aca6 --- /dev/null +++ b/src/unitxt/catalog/cards/rag/mtrag/documents/govt.json @@ -0,0 +1,37 @@ +{ + "__type__": "task_card", + "loader": { + "__type__": "load_csv", + "files": { + "test": "https://github.com/IBM/mt-rag-benchmark/raw/refs/heads/main/corpora/govt.jsonl.zip" + }, + "compression": "zip", + "file_type": "json", + "lines": true, + "data_classification_policy": [ + "public" + ] + }, + "preprocess_steps": [ + { + "__type__": "wrap", + "field": "text", + "inside": "list", + "to_field": "passages" + }, + { + "__type__": "set", + "fields": { + "metadata_field": "" + } + } + ], + "task": "tasks.rag.corpora", + "templates": { + "empty": { + "__type__": "input_output_template", + "input_format": "", + "output_format": "" + } + } +} diff --git a/src/unitxt/catalog/metrics/rag/answer_relevance/token_recall.json b/src/unitxt/catalog/metrics/rag/answer_relevance/token_recall.json index d5fe717938..3b31375293 100644 --- a/src/unitxt/catalog/metrics/rag/answer_relevance/token_recall.json +++ b/src/unitxt/catalog/metrics/rag/answer_relevance/token_recall.json @@ -17,6 +17,10 @@ }, "not_exist_do_nothing": true }, + { + "__type__": "multi_type_serializer", + "field": "references" + }, { "__type__": "list_field_values", "fields": [ diff --git a/src/unitxt/catalog/metrics/rag/answer_reward.json b/src/unitxt/catalog/metrics/rag/answer_reward.json index a10971373a..a8b7ee6e14 100644 --- a/src/unitxt/catalog/metrics/rag/answer_reward.json +++ b/src/unitxt/catalog/metrics/rag/answer_reward.json @@ -17,6 +17,10 @@ }, "not_exist_do_nothing": true }, + { + "__type__": "multi_type_serializer", + "field": "references" + }, { "__type__": "list_field_values", "fields": [ diff --git a/src/unitxt/catalog/metrics/rag/end_to_end/answer_relevance/answer_reward.json b/src/unitxt/catalog/metrics/rag/end_to_end/answer_relevance/answer_reward.json index 055dc249e9..7dca8b021a 100644 --- a/src/unitxt/catalog/metrics/rag/end_to_end/answer_relevance/answer_reward.json +++ b/src/unitxt/catalog/metrics/rag/end_to_end/answer_relevance/answer_reward.json @@ -9,6 +9,11 @@ "prediction/answer": "prediction" } }, + { + "__type__": "multi_type_serializer", + "field": "references", + "process_every_value": true + }, { "__type__": "list_field_values", "fields": [ diff --git a/src/unitxt/catalog/metrics/rag/end_to_end/answer_relevance/token_recall.json b/src/unitxt/catalog/metrics/rag/end_to_end/answer_relevance/token_recall.json index a8be78860a..cf8a9aca76 100644 --- a/src/unitxt/catalog/metrics/rag/end_to_end/answer_relevance/token_recall.json +++ b/src/unitxt/catalog/metrics/rag/end_to_end/answer_relevance/token_recall.json @@ -9,6 +9,11 @@ "prediction/answer": "prediction" } }, + { + "__type__": "multi_type_serializer", + "field": "references", + "process_every_value": true + }, { "__type__": "list_field_values", "fields": [ diff --git a/src/unitxt/catalog/metrics/rag/end_to_end/answer_reward.json b/src/unitxt/catalog/metrics/rag/end_to_end/answer_reward.json index 52a336c819..118931556e 100644 --- a/src/unitxt/catalog/metrics/rag/end_to_end/answer_reward.json +++ b/src/unitxt/catalog/metrics/rag/end_to_end/answer_reward.json @@ -17,6 +17,11 @@ "task_data/question" ], "to_field": "references" + }, + { + "__type__": "multi_type_serializer", + "field": "references", + "process_every_value": true } ], "metric": "metrics.reward.deberta_v3_large_v2[score_prefix=answer_reward_]" diff --git a/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance.json b/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance.json index cb281b2f74..6932803e89 100644 --- a/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance.json +++ b/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance.json @@ -19,6 +19,10 @@ "prediction" ] ] + }, + { + "__type__": "multi_type_serializer", + "field": "prediction" } ], "metric": "metrics.perplexity_q.flan_t5_small[score_prefix=context_relevance_]" diff --git a/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance/perplexity_flan_t5_small.json b/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance/perplexity_flan_t5_small.json index 4786ef6690..1ad62720c3 100644 --- a/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance/perplexity_flan_t5_small.json +++ b/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance/perplexity_flan_t5_small.json @@ -11,6 +11,10 @@ "__type__": "copy", "field": "task_data/question", "to_field": "prediction" + }, + { + "__type__": "multi_type_serializer", + "field": "prediction" } ], "metric": "metrics.perplexity_q.flan_t5_small", diff --git a/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance/sentence_bert_bge.json b/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance/sentence_bert_bge.json index a7f42e6b20..780fca1ea1 100644 --- a/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance/sentence_bert_bge.json +++ b/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance/sentence_bert_bge.json @@ -11,6 +11,10 @@ "__type__": "copy", "field": "task_data/question", "to_field": "prediction" + }, + { + "__type__": "multi_type_serializer", + "field": "prediction" } ], "metric": "metrics.sentence_bert.bge_large_en_1_5", diff --git a/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance/sentence_bert_mini_lm.json b/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance/sentence_bert_mini_lm.json index 2d2a9e92e7..0db0bece9d 100644 --- a/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance/sentence_bert_mini_lm.json +++ b/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance/sentence_bert_mini_lm.json @@ -11,6 +11,10 @@ "__type__": "copy", "field": "task_data/question", "to_field": "prediction" + }, + { + "__type__": "multi_type_serializer", + "field": "prediction" } ], "metric": "metrics.sentence_bert.minilm_l12_v2", diff --git a/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance/token_precision.json b/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance/token_precision.json index 7eec4bb7ea..c9cbe2d8ea 100644 --- a/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance/token_precision.json +++ b/src/unitxt/catalog/metrics/rag/end_to_end/context_relevance/token_precision.json @@ -11,6 +11,10 @@ "__type__": "copy", "field": "task_data/question", "to_field": "prediction" + }, + { + "__type__": "multi_type_serializer", + "field": "prediction" } ], "metric": "metrics.token_overlap", diff --git a/src/unitxt/catalog/metrics/rag/external_rag/context_relevance.json b/src/unitxt/catalog/metrics/rag/external_rag/context_relevance.json index f640ae4323..87560fc6ff 100644 --- a/src/unitxt/catalog/metrics/rag/external_rag/context_relevance.json +++ b/src/unitxt/catalog/metrics/rag/external_rag/context_relevance.json @@ -11,6 +11,10 @@ "__type__": "copy", "field": "question", "to_field": "prediction" + }, + { + "__type__": "multi_type_serializer", + "field": "prediction" } ], "metric": "metrics.perplexity_q.flan_t5_small", diff --git a/src/unitxt/catalog/metrics/rag/external_rag/context_relevance/perplexity_flan_t5_small.json b/src/unitxt/catalog/metrics/rag/external_rag/context_relevance/perplexity_flan_t5_small.json index e1dd5609b1..527b1579f7 100644 --- a/src/unitxt/catalog/metrics/rag/external_rag/context_relevance/perplexity_flan_t5_small.json +++ b/src/unitxt/catalog/metrics/rag/external_rag/context_relevance/perplexity_flan_t5_small.json @@ -11,6 +11,10 @@ "__type__": "copy", "field": "question", "to_field": "prediction" + }, + { + "__type__": "multi_type_serializer", + "field": "prediction" } ], "metric": "metrics.perplexity_q.flan_t5_small", diff --git a/src/unitxt/catalog/metrics/rag/external_rag/context_relevance/sentence_bert_bge.json b/src/unitxt/catalog/metrics/rag/external_rag/context_relevance/sentence_bert_bge.json index 39238a9dfd..08f3edbc53 100644 --- a/src/unitxt/catalog/metrics/rag/external_rag/context_relevance/sentence_bert_bge.json +++ b/src/unitxt/catalog/metrics/rag/external_rag/context_relevance/sentence_bert_bge.json @@ -11,6 +11,10 @@ "__type__": "copy", "field": "question", "to_field": "prediction" + }, + { + "__type__": "multi_type_serializer", + "field": "prediction" } ], "metric": "metrics.sentence_bert.bge_large_en_1_5", diff --git a/src/unitxt/catalog/metrics/rag/external_rag/context_relevance/sentence_bert_mini_lm.json b/src/unitxt/catalog/metrics/rag/external_rag/context_relevance/sentence_bert_mini_lm.json index 30cf8600d7..0ff5ba0544 100644 --- a/src/unitxt/catalog/metrics/rag/external_rag/context_relevance/sentence_bert_mini_lm.json +++ b/src/unitxt/catalog/metrics/rag/external_rag/context_relevance/sentence_bert_mini_lm.json @@ -11,6 +11,10 @@ "__type__": "copy", "field": "question", "to_field": "prediction" + }, + { + "__type__": "multi_type_serializer", + "field": "prediction" } ], "metric": "metrics.sentence_bert.minilm_l12_v2", diff --git a/src/unitxt/catalog/metrics/rag/external_rag/context_relevance/token_precision.json b/src/unitxt/catalog/metrics/rag/external_rag/context_relevance/token_precision.json index d9dd52098d..73aac29f5d 100644 --- a/src/unitxt/catalog/metrics/rag/external_rag/context_relevance/token_precision.json +++ b/src/unitxt/catalog/metrics/rag/external_rag/context_relevance/token_precision.json @@ -11,6 +11,10 @@ "__type__": "copy", "field": "question", "to_field": "prediction" + }, + { + "__type__": "multi_type_serializer", + "field": "prediction" } ], "metric": "metrics.token_overlap", diff --git a/src/unitxt/catalog/metrics/rag/response_generation/answer_relevance/answer_reward.json b/src/unitxt/catalog/metrics/rag/response_generation/answer_relevance/answer_reward.json index 537188c896..1151a2b406 100644 --- a/src/unitxt/catalog/metrics/rag/response_generation/answer_relevance/answer_reward.json +++ b/src/unitxt/catalog/metrics/rag/response_generation/answer_relevance/answer_reward.json @@ -8,6 +8,11 @@ "task_data/question": "references" } }, + { + "__type__": "multi_type_serializer", + "field": "references", + "process_every_value": true + }, { "__type__": "list_field_values", "fields": [ diff --git a/src/unitxt/catalog/metrics/rag/response_generation/answer_relevance/token_recall.json b/src/unitxt/catalog/metrics/rag/response_generation/answer_relevance/token_recall.json index c80c5a2d0c..9ea1d2bfc4 100644 --- a/src/unitxt/catalog/metrics/rag/response_generation/answer_relevance/token_recall.json +++ b/src/unitxt/catalog/metrics/rag/response_generation/answer_relevance/token_recall.json @@ -8,6 +8,11 @@ "task_data/question": "references" } }, + { + "__type__": "multi_type_serializer", + "field": "references", + "process_every_value": true + }, { "__type__": "list_field_values", "fields": [ diff --git a/src/unitxt/catalog/tasks/rag/end_to_end.json b/src/unitxt/catalog/tasks/rag/end_to_end.json index 86993c64d5..1f966412ab 100644 --- a/src/unitxt/catalog/tasks/rag/end_to_end.json +++ b/src/unitxt/catalog/tasks/rag/end_to_end.json @@ -2,7 +2,7 @@ "__type__": "task", "__description__": "This is a task corresponding to an end to end RAG evaluation. It assumes the user provides a question, and\n the RAG system returns an answer and a set of retrieved contexts (documents or passages).\n For details of RAG see: https://www.unitxt.ai/en/latest/docs/rag_support.html.\n", "input_fields": { - "question": "str", + "question": "Union[str, Dialog]", "question_id": "Any", "metadata_field": "str" }, diff --git a/src/unitxt/catalog/templates/rag/end_to_end/json_predictions.json b/src/unitxt/catalog/templates/rag/end_to_end/json_predictions.json index 29c61217fb..87e7b5459d 100644 --- a/src/unitxt/catalog/templates/rag/end_to_end/json_predictions.json +++ b/src/unitxt/catalog/templates/rag/end_to_end/json_predictions.json @@ -1,7 +1,15 @@ { - "__type__": "input_output_template", + "__type__": "json_output_template", "input_format": "", - "output_format": "{{\"answer\": \"{reference_answers}\", \"contexts\" : [\"{reference_contexts}\"], \"context_ids\" : [\"{reference_context_ids}\"]}}", + "output_fields": { + "reference_answers": "answer", + "reference_contexts": "contexts", + "reference_context_ids": "context_ids" + }, + "wrap_with_list_fields": [ + "reference_contexts", + "reference_context_ids" + ], "postprocessors": [ "processors.load_json_predictions" ] diff --git a/src/unitxt/loaders.py b/src/unitxt/loaders.py index 8f04e6df5b..5fa83d9bb8 100644 --- a/src/unitxt/loaders.py +++ b/src/unitxt/loaders.py @@ -39,7 +39,17 @@ from abc import abstractmethod from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Union +from typing import ( + Any, + Dict, + Iterable, + List, + Literal, + Mapping, + Optional, + Sequence, + Union, +) import pandas as pd import requests @@ -349,24 +359,43 @@ class LoadCSV(Loader): loader_limit: Optional[int] = None streaming: bool = True sep: str = "," + compression: Optional[str] = None + lines: Optional[bool] = None + file_type: Literal["csv", "json"] = "csv" def _maybe_set_classification_policy(self): self.set_default_data_classification( ["proprietary"], "when loading from local files" ) + def get_reader(self): + if self.file_type == "csv": + return pd.read_csv + if self.file_type == "json": + return pd.read_json + raise ValueError() + + def get_args(self): + args = {} + if self.file_type == "csv": + args["sep"] = self.sep + if self.compression is not None: + args["compression"] = self.compression + if self.lines is not None: + args["lines"] = self.lines + if self.get_limit() is not None: + args["nrows"] = self.get_limit() + return args + def load_iterables(self): iterables = {} for split_name, file_path in self.files.items(): + reader = self.get_reader() if self.get_limit() is not None: self.log_limited_loading() - iterables[split_name] = pd.read_csv( - file_path, nrows=self.get_limit(), sep=self.sep - ).to_dict("records") - else: - iterables[split_name] = pd.read_csv(file_path, sep=self.sep).to_dict( - "records" - ) + iterables[split_name] = reader(file_path, **self.get_args()).to_dict( + "records" + ) return iterables diff --git a/src/unitxt/serializers.py b/src/unitxt/serializers.py index b45e96827b..ebc4d34a9c 100644 --- a/src/unitxt/serializers.py +++ b/src/unitxt/serializers.py @@ -158,6 +158,7 @@ class MultiTypeSerializer(Serializer): serializers: List[SingleTypeSerializer] = Field( default_factory=lambda: [ DocumentSerializer(), + DialogSerializer(), MultiDocumentSerializer(), ImageSerializer(), VideoSerializer(), diff --git a/src/unitxt/templates.py b/src/unitxt/templates.py index d23819eec1..05fe25a9d0 100644 --- a/src/unitxt/templates.py +++ b/src/unitxt/templates.py @@ -272,6 +272,24 @@ def reference_fields_to_target_and_references( return target, references +class JsonOutputFormatTemplate(Template): + output_fields: Dict[str, str] + wrap_with_list_fields: List[str] + + def reference_fields_to_target_and_references( + self, reference_fields: Dict[str, object] + ) -> str: + data = {} + for field, target_field in self.output_fields.items(): + value = reference_fields[field] + if field in self.wrap_with_list_fields: + value = [value] + data[target_field] = value + target = json.dumps(data, ensure_ascii=False) + references = [target] + return target, references + + class InputOutputTemplate(InputFormatTemplate, OutputFormatTemplate): """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance. @@ -281,6 +299,15 @@ class InputOutputTemplate(InputFormatTemplate, OutputFormatTemplate): pass +class JsonOutputTemplate(InputFormatTemplate, JsonOutputFormatTemplate): + """Generate field 'source' from fields designated as input, and fields 'target' and 'references' from fields designated as output, of the processed instance. + + Args specify the formatting strings with which to glue together the input and reference fields of the processed instance into one string ('source' and 'target'), and into a list of strings ('references'). + """ + + pass + + class InputOutputTemplateWithCustomTarget(InputOutputTemplate): reference: str diff --git a/utils/.secrets.baseline b/utils/.secrets.baseline index 837bb18aaf..fe45e1690c 100644 --- a/utils/.secrets.baseline +++ b/utils/.secrets.baseline @@ -151,7 +151,7 @@ "filename": "src/unitxt/loaders.py", "hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742", "is_verified": false, - "line_number": 502, + "line_number": 531, "is_secret": false } ], @@ -184,5 +184,5 @@ } ] }, - "generated_at": "2025-01-22T09:13:29Z" + "generated_at": "2025-01-23T10:05:40Z" }