diff --git a/docs/concepts/metrics/index.md b/docs/concepts/metrics/index.md index 34d56a626..afe32d583 100644 --- a/docs/concepts/metrics/index.md +++ b/docs/concepts/metrics/index.md @@ -15,6 +15,7 @@ Just like in any machine learning system, the performance of individual componen - [Context precision](context_precision.md) - [Context utilization](context_utilization.md) - [Context entity recall](context_entities_recall.md) +- [Noise Sensitivity](noise_sensitivity.md) - [Summarization Score](summarization_score.md) ```{toctree} @@ -36,6 +37,7 @@ context_precision context_utilization context_recall context_entities_recall +noise_sensitivity semantic_similarity answer_correctness critique diff --git a/docs/concepts/metrics/noise_sensitivity.md b/docs/concepts/metrics/noise_sensitivity.md new file mode 100644 index 000000000..a07309779 --- /dev/null +++ b/docs/concepts/metrics/noise_sensitivity.md @@ -0,0 +1,101 @@ + + +# Noise Sensitivity + +Noise sensitivity measures how often a system makes errors by providing incorrect responses when utilizing either relevant or irrelevant retrieved documents. The score ranges from 0 to 1, with lower values indicating better performance. Noise sensitivity is computed using the question, ground truth, answer, and the retrieved context. + +To estimate noise sensitivity, each claim in the generated answer is examined to determine whether it is correct based on the ground truth and whether it can be attributed to the relevant (or irrelevant) retrieved context. Ideally, all claims in the answer should be supported by the relevant retrieved context. + + +```{math} +\text{noise sensitivity (relevant)} = {|\text{Number of incorrect claims in answer}| \over |\text{Number of claims in the Answer}|} +``` + +```{Hint} + +Question: What is the Life Insurance Corporation of India (LIC) known for? + +Ground truth: The Life Insurance Corporation of India (LIC) is the largest insurance company in India, established in 1956 through the nationalization of the insurance industry. It is known for managing a large portfolio of investments. + +Relevant Retrieval: + - The Life Insurance Corporation of India (LIC) was established in 1956 following the nationalization of the insurance industry in India. + - LIC is the largest insurance company in India, with a vast network of policyholders and a significant role in the financial sector. + - As the largest institutional investor in India, LIC manages a substantial life fund, contributing to the financial stability of the country. + +Irrelevant Retrieval: + - The Indian economy is one of the fastest-growing major economies in the world, thanks to the secors like finance, technology, manufacturing etc. +``` + + +## Example + +```{code-block} python +:caption: Noise Sensitivity +from datasets import Dataset +from ragas.metrics import noise_sensitivity_relevant, noise_sensitivity_irrelevant +from ragas import evaluate + +data_sample = { + "question": ["What is the Life Insurance Corporation of India (LIC) known for?"], + "ground_truth": ["The Life Insurance Corporation of India (LIC) is the largest insurance company in India, established in 1956 through the nationalization of the insurance industry. It is known for managing a large portfolio of investments."], + "answer": ["The Life Insurance Corporation of India (LIC) is the largest insurance company in India, known for its vast portfolio of investments. LIC contributs to the financial stability of the country."], + "contexts": [[ + "The Life Insurance Corporation of India (LIC) was established in 1956 following the nationalization of the insurance industry in India.", + "LIC is the largest insurance company in India, with a vast network of policyholders and a huge investments.", + "As the largest institutional investor in India, LIC manages a substantial funds, contributing to the financial stability of the country.", + "The Indian economy is one of the fastest-growing major economies in the world, thanks to the secors like finance, technology, manufacturing etc" + ]] +} + +dataset = Dataset.from_dict(data_sample) +metrics = [noise_sensitivity_relevant, noise_sensitivity_irrelevant] +score = evaluate(dataset,metrics=metrics) +score.to_pandas() +``` + +## Calculation + +Let's examine how noise sensitivity in relevant context was calculated: + +- **Step 1:** Identify the relevant contexts from which the ground truth can be inferred. + + - Ground Truth: + The Life Insurance Corporation of India (LIC) is the largest insurance company in India, established in 1956 through the nationalization of the insurance industry. It is known for managing a large portfolio of investments. + + - Contexts: + - Context 1: `The Life Insurance Corporation of India (LIC) was established in 1956` following the nationalization of the insurance industry in India. + - Context 2: `LIC is the largest insurance company in India`, with a vast network of policyholders and a significant role in the financial sector. + - Context 3: `As the largest institutional investor in India, LIC manages a substantial funds`, contributing to the financial stability of the country. + +- **Step 2:** Verify if the claims in the generated answer can be inferred from the relevant context. + + - Answer: + The Life Insurance Corporation of India (LIC) is the largest insurance company in India, known for its vast portfolio of investments. LIC contributs to the financial stability of the country. + + - Contexts: + - Context 1: The Life Insurance Corporation of India (LIC) was established in 1956 following the nationalization of the insurance industry in India. + - Context 2: `LIC is the largest insurance company in India`, with a vast network of policyholders and a significant role in the financial sector. + - Context 3: `As the largest institutional investor in India, LIC manages a substantial funds`, `contributing to the financial stability of the country`. + + +- **Step 3:** Identify any incorrect claims in the answer (i.e., answer statements that are not supported by the ground truth). + + - Ground Truth: + The Life Insurance Corporation of India (LIC) is the largest insurance company in India, established in 1956 through the nationalization of the insurance industry. It is known for managing a large portfolio of investments. + + - Answer: + The Life Insurance Corporation of India (LIC) is the largest insurance company in India, known for its vast portfolio of investments. `LIC contributs to the financial stability of the country`. + + Explanation: The ground truth does not mention anything about LIC contributing to the financial stability of the country. Therefore, this statement in the answer is incorrect. + + Incorrect Statement: 1 + Total claims: 3 + +- **Step 4:** Calculate noise sensitivity using the formula: + ```{math} + \text{noise sensitivity} = { \text{1} \over \text{3} } = 0.333 + ``` +This results in a noise sensitivity score of 0.333, indicating that one out of three claims in the answer was incorrect. + + +Credits: Noise senstivity was introduced in [RAGChecker](https://github.com/amazon-science/RAGChecker/tree/main/ragchecker) \ No newline at end of file diff --git a/src/ragas/async_utils.py b/src/ragas/async_utils.py index c365ac808..6937b4617 100644 --- a/src/ragas/async_utils.py +++ b/src/ragas/async_utils.py @@ -1,4 +1,5 @@ """Async utils.""" + import asyncio from typing import Any, Coroutine, List diff --git a/src/ragas/integrations/langchain.py b/src/ragas/integrations/langchain.py index 44279187f..99f03b2ed 100644 --- a/src/ragas/integrations/langchain.py +++ b/src/ragas/integrations/langchain.py @@ -48,9 +48,9 @@ def __init__(self, metric: Metric, **kwargs: t.Any): t.cast(MetricWithLLM, self.metric).llm = LangchainLLMWrapper(llm) if isinstance(self.metric, MetricWithEmbeddings): embeddings = get_or_init(kwargs, "embeddings", OpenAIEmbeddings) - t.cast( - MetricWithEmbeddings, self.metric - ).embeddings = LangchainEmbeddingsWrapper(embeddings) + t.cast(MetricWithEmbeddings, self.metric).embeddings = ( + LangchainEmbeddingsWrapper(embeddings) + ) self.metric.init(run_config) @property diff --git a/src/ragas/metrics/__init__.py b/src/ragas/metrics/__init__.py index 7236d95e8..cd2f566e2 100644 --- a/src/ragas/metrics/__init__.py +++ b/src/ragas/metrics/__init__.py @@ -13,6 +13,11 @@ ) from ragas.metrics._context_recall import ContextRecall, context_recall from ragas.metrics._faithfulness import Faithfulness, FaithulnesswithHHEM, faithfulness +from ragas.metrics._noise_sensitivity import ( + NoiseSensitivity, + noise_sensitivity_irrelevant, + noise_sensitivity_relevant, +) from ragas.metrics._rubrics_based import ( LabelledRubricsScore, ReferenceFreeRubricsScore, @@ -43,6 +48,9 @@ "context_entity_recall", "SummarizationScore", "summarization_score", + "NoiseSensitivity", + "noise_sensitivity_irrelevant", + "noise_sensitivity_relevant", "labelled_rubrics_score", "reference_free_rubrics_score", "ReferenceFreeRubricsScore", diff --git a/src/ragas/metrics/_noise_sensitivity.py b/src/ragas/metrics/_noise_sensitivity.py new file mode 100644 index 000000000..79b2412e3 --- /dev/null +++ b/src/ragas/metrics/_noise_sensitivity.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import inspect +import json +import logging +import typing as t +from dataclasses import dataclass, field + +import numpy as np + +from ragas.llms.prompt import Prompt +from ragas.metrics._faithfulness import ( + LONG_FORM_ANSWER_PROMPT, + NLI_STATEMENTS_MESSAGE, + HasSegmentMethod, + StatementFaithfulnessAnswers, + _faithfulness_output_parser, + _statements_output_parser, +) +from ragas.metrics.base import EvaluationMode, MetricWithLLM, ensembler, get_segmenter + +if t.TYPE_CHECKING: + from langchain_core.callbacks import Callbacks + + from ragas.llms.prompt import PromptValue + + +logger = logging.getLogger(__name__) + + +@dataclass +class NoiseSensitivity(MetricWithLLM): + name: str = "noise_sensitivity" # type: ignore + focus: str = "relevant" + evaluation_mode: EvaluationMode = EvaluationMode.qga # type: ignore + nli_statements_message: Prompt = field( + default_factory=lambda: NLI_STATEMENTS_MESSAGE + ) + statement_prompt: Prompt = field(default_factory=lambda: LONG_FORM_ANSWER_PROMPT) + sentence_segmenter: t.Optional[HasSegmentMethod] = None + max_retries: int = 1 + _reproducibility: int = 1 + + @property + def reproducibility(self): + return self._reproducibility + + @reproducibility.setter + def reproducibility(self, value): + if value < 1: + logger.warning("reproducibility cannot be less than 1, setting to 1") + value = 1 + elif value % 2 == 0: + logger.warning( + "reproducibility level cannot be set to even number, setting to odd" + ) + value += 1 + self._reproducibility = value + + def __post_init__(self): + if self.sentence_segmenter is None: + language = self.nli_statements_message.language + self.sentence_segmenter = get_segmenter(language=language, clean=False) + if self.focus not in {"relevant", "irrelevant"}: + raise ValueError( + f"Invalid argument passed for 'focus': {self.focus}. Must be 'relevant' or 'irrelevant'." + ) + self.name = f"{self.name}_{self.focus}" # type: ignore + + def _create_nli_prompt(self, contexts: str, statements: t.List[str]) -> PromptValue: + assert self.llm is not None, "llm must be set to compute score" + + statements_str: str = json.dumps(statements) + prompt_value = self.nli_statements_message.format( + context=contexts, statements=statements_str + ) + return prompt_value + + def _create_statements_prompt(self, text: str, question: str) -> PromptValue: + assert self.sentence_segmenter is not None, "sentence_segmenter is not set" + # contexts = row["contexts"] + sentences = self.sentence_segmenter.segment(text) + sentences = [ + sentence for sentence in sentences if sentence.strip().endswith(".") + ] + sentences = "\n".join([f"{i}:{x}" for i, x in enumerate(sentences)]) + prompt_value = self.statement_prompt.format( + question=question, answer=text, sentences=sentences + ) + return prompt_value + + async def _evaluate_statement_faithfulness( + self, statements, context: str, callbacks: Callbacks + ): + assert self.llm is not None, "LLM is not set" + + p_value = self._create_nli_prompt(context, statements) + nli_result = await self.llm.generate( + p_value, + callbacks=callbacks, + n=self._reproducibility, + ) + + nli_result_text = [ + nli_result.generations[0][i].text for i in range(self._reproducibility) + ] + faithfulness_list = [ + await _faithfulness_output_parser.aparse( + text, p_value, self.llm, self.max_retries + ) + for text in nli_result_text + ] + + faithfulness_list = [ + faith.dicts() for faith in faithfulness_list if faith is not None + ] + + if faithfulness_list: + faithfulness_list = ensembler.from_discrete( + faithfulness_list, + "verdict", + ) + + faithfulness_list = StatementFaithfulnessAnswers.parse_obj( + faithfulness_list + ) + + verdict_list = [ + 1 if statement.verdict else 0 + for statement in faithfulness_list.__root__ + ] + return np.array(verdict_list) + else: + return np.nan + + async def _decompose_answer_into_statements( + self, text: str, question: str, callbacks: Callbacks + ): + assert self.llm is not None, "LLM is not set" + + p_value = self._create_statements_prompt(text, question) + + if inspect.iscoroutinefunction(self.llm.generate): + statements_gen = await self.llm.generate( + p_value, + callbacks=callbacks, + ) + else: + statements_gen = self.llm.generate( + p_value, + callbacks=callbacks, + ) + + # Await the aparse method + statements = await _statements_output_parser.aparse( + statements_gen.generations[0][0].text, p_value, self.llm, self.max_retries # type: ignore + ) + + if statements is None: + return np.nan + + # Ensure statements is not a coroutine before calling dicts() + if inspect.iscoroutine(statements): + statements = await statements + + # Add error handling and logging + if not hasattr(statements, "dicts"): + logging.error(f"Unexpected type for statements: {type(statements)}") + logging.error(f"Statements content: {statements}") + raise AttributeError( + f"'statements' object of type {type(statements)} has no attribute 'dicts'" + ) + + statements = [item["simpler_statements"] for item in statements.dicts()] + statements = [item for sublist in statements for item in sublist] + + return statements + + def _compute_score(self, answers: t.Dict) -> float: + # relevant retrievals + relevant_retrieved = np.max( + answers["retrieved2ground_truth"], axis=0, keepdims=True + ) + relevant_faithful = np.max( + relevant_retrieved & answers["retrieved2answer"], axis=1 + ) + + # irrelevant retrievals + irrelevant_retrieved = ~np.max( + answers["retrieved2ground_truth"], axis=0, keepdims=True + ) + irrelevant_faithful = np.max( + irrelevant_retrieved & answers["retrieved2answer"], axis=1 + ) + + # to keep them exclusive + irrelevant_faithful &= ~relevant_faithful + + incorrect = ~answers["ground_truth2answer"] + noise_sensitivity_in_relevant = np.mean(relevant_faithful & incorrect) + noise_sensitivity_in_irrelevant = np.mean(irrelevant_faithful & incorrect) + + if self.focus == "irrelevant": + return noise_sensitivity_in_irrelevant + + return noise_sensitivity_in_relevant + + async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float: + """ + returns the NLI score for each (q, c, a) pair + """ + assert self.llm is not None, "LLM is not set" + + gt_statements = await self._decompose_answer_into_statements( + row["ground_truth"], row["question"], callbacks + ) + ans_statements = await self._decompose_answer_into_statements( + row["answer"], row["question"], callbacks + ) + gt_verdictslist = [] + ans_verdictslist = [] + + for ctx in row["contexts"]: + verdicts = await self._evaluate_statement_faithfulness( + gt_statements, ctx, callbacks + ) + gt_verdictslist.append(verdicts) + + verdicts = await self._evaluate_statement_faithfulness( + ans_statements, ctx, callbacks + ) + ans_verdictslist.append(verdicts) + + answers = {} + answers["retrieved2ground_truth"] = np.array(gt_verdictslist).T + answers["retrieved2answer"] = np.array(ans_verdictslist).T + answers["ground_truth2answer"] = await self._evaluate_statement_faithfulness( + ans_statements, row["ground_truth"], callbacks + ) + answers["ground_truth2answer"] = np.array([answers["ground_truth2answer"]]) + answers = {k: v.astype(bool) for k, v in answers.items()} + return self._compute_score(answers) + + def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None: + assert self.llm is not None, "LLM is not set" + + logger.info(f"Adapting Faithfulness metric to {language}") + + self.nli_statements_message = self.nli_statements_message.adapt( + language, self.llm, cache_dir + ) + self.statement_prompt = self.statement_prompt.adapt( + language, self.llm, cache_dir + ) + + self.sentence_segmenter = get_segmenter(language=language, clean=False) + + def save(self, cache_dir: t.Optional[str] = None) -> None: + self.nli_statements_message.save(cache_dir) + self.statement_prompt.save(cache_dir) + + +noise_sensitivity_relevant = NoiseSensitivity() +noise_sensitivity_irrelevant = NoiseSensitivity(focus="irrelevant")