diff --git a/docs/concepts/metrics/faithfulness.md b/docs/concepts/metrics/faithfulness.md index fc67bf9df..df09a1d31 100644 --- a/docs/concepts/metrics/faithfulness.md +++ b/docs/concepts/metrics/faithfulness.md @@ -9,7 +9,6 @@ The generated answer is regarded as faithful if all the claims made in the answe \text{Faithfulness score} = {|\text{Number of claims in the generated answer that can be inferred from given context}| \over |\text{Total number of claims in the generated answer}|} ``` - ```{hint} **Question**: Where and when was Einstein born? @@ -23,47 +22,49 @@ The generated answer is regarded as faithful if all the claims made in the answe ## Example ```{code-block} python -:caption: Faithfulness -from datasets import Dataset +:caption: Faithfulness +from datasets import Dataset from ragas.metrics import faithfulness from ragas import evaluate data_samples = { 'question': ['When was the first super bowl?', 'Who won the most super bowls?'], 'answer': ['The first superbowl was held on Jan 15, 1967', 'The most super bowls have been won by The New England Patriots'], - 'contexts' : [['The First AFL–NFL World Championship Game was an American football game played on January 15, 1967, at the Los Angeles Memorial Coliseum in Los Angeles,'], + 'contexts' : [['The First AFL–NFL World Championship Game was an American football game played on January 15, 1967, at the Los Angeles Memorial Coliseum in Los Angeles,'], ['The Green Bay Packers...Green Bay, Wisconsin.','The Packers compete...Football Conference']], + 'ground_truth': ['The first superbowl was held on January 15, 1967', 'The New England Patriots have won the Super Bowl a record six times'] } dataset = Dataset.from_dict(data_samples) score = evaluate(dataset,metrics=[faithfulness]) score.to_pandas() ``` -## Calculation +## Calculation Let's examine how faithfulness was calculated using the low faithfulness answer: - **Step 1:** Break the generated answer into individual statements. - - Statements: - - Statement 1: "Einstein was born in Germany." - - Statement 2: "Einstein was born on 20th March 1879." + + - Statements: + - Statement 1: "Einstein was born in Germany." + - Statement 2: "Einstein was born on 20th March 1879." - **Step 2:** For each of the generated statements, verify if it can be inferred from the given context. - - Statement 1: Yes - - Statement 2: No -- **Step 3:** Use the formula depicted above to calculate faithfulness. - ```{math} - \text{Faithfulness} = { \text{1} \over \text{2} } = 0.5 - ``` + - Statement 1: Yes + - Statement 2: No +- **Step 3:** Use the formula depicted above to calculate faithfulness. + ```{math} + \text{Faithfulness} = { \text{1} \over \text{2} } = 0.5 + ``` ## Faithfullness with HHEM-2.1-Open [Vectara's HHEM-2.1-Open](https://vectara.com/blog/hhem-2-1-a-better-hallucination-detection-model/) is a classifier model (T5) that is trained to detect hallucinations from LLM generated text. This model can be used in the second step of calculating faithfulness, i.e. when claims are cross-checked with the given context to determine if it can be inferred from the context. The model is free, small, and open-source, making it very efficient in production use cases. You can load the model onto a specified device by setting the `device` argument and adjust the batch size for inference using the `batch_size` parameter. By default, the model is loaded on the CPU with a batch size of 10. To use the model to calculate faithfulness, you can use the following code snippet: ```{code-block} python -from datasets import Dataset +from datasets import Dataset from ragas.metrics import FaithulnesswithHHEM from ragas import evaluate @@ -75,11 +76,45 @@ faithfulness_with_hhem = FaithulnesswithHHEM(device=my_device, batch_size=my_bat data_samples = { 'question': ['When was the first super bowl?', 'Who won the most super bowls?'], 'answer': ['The first superbowl was held on Jan 15, 1967', 'The most super bowls have been won by The New England Patriots'], - 'contexts' : [['The First AFL–NFL World Championship Game was an American football game played on January 15, 1967, at the Los Angeles Memorial Coliseum in Los Angeles,'], + 'contexts' : [['The First AFL–NFL World Championship Game was an American football game played on January 15, 1967, at the Los Angeles Memorial Coliseum in Los Angeles,'], ['The Green Bay Packers...Green Bay, Wisconsin.','The Packers compete...Football Conference']], + 'ground_truth': ['The first superbowl was held on January 15, 1967', 'The New England Patriots have won the Super Bowl a record six times'] } dataset = Dataset.from_dict(data_samples) score = evaluate(dataset,metrics=[faithfulness_with_hhem]) score.to_pandas() -``` \ No newline at end of file +``` + +## Faithfulness with Bespoke-MiniCheck-7B + +[Bespoke Labs's Bespoke-MiniCheck-7B](https://huggingface.co/bespokelabs/Bespoke-MiniCheck-7B) is a fact-checking model developed by [Bespoke Labs](https://bespokelabs.ai). The model takes as input a document and a sentence and determines whether the sentence is supported by the document. + +For faster and more reliable performance, you can call the Bespoke Labs API. You can also run the model locally through HuggingFace. + +```{code-block} python +from datasets import Dataset +from ragas.metrics import FaithfulnesswithMiniCheck +from ragas import evaluate + +device = "cpu" # use "cuda" for NVIDIA GPUs + +# Approach 1: Use the model through the Bespoke Labs API +os.environ['BESPOKE_API_KEY'] = 'bespoke-demo-key' # limited for demos only, please request a key via https://console.bespokelabs.ai. +faithfulness_with_minicheck = FaithfulnesswithMiniCheck(use_api=True) + +# Approach 2: Use the model locally through HuggingFace transformers +faithfulness_with_minicheck = FaithfulnesswithMiniCheck(device=device, batch_size=10) + +data_samples = { + 'question': ['When was the first super bowl?', 'Who won the most super bowls?'], + 'answer': ['The first superbowl was held on Jan 15, 1967', 'The most super bowls have been won by The New England Patriots'], + 'contexts' : [['The First AFL–NFL World Championship Game was an American football game played on January 15, 1967, at the Los Angeles Memorial Coliseum in Los Angeles,'], + ['The Green Bay Packers...Green Bay, Wisconsin.','The Packers compete...Football Conference']], + 'ground_truth': ['The first superbowl was held on January 15, 1967', 'The New England Patriots have won the Super Bowl a record six times'] +} +dataset = Dataset.from_dict(data_samples) +score = evaluate(dataset,metrics=[faithfulness_with_minicheck]) +score.to_pandas() + +``` diff --git a/pyproject.toml b/pyproject.toml index 00f65a491..0e16229dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,9 @@ dynamic = ["version", "readme"] all = [ "sentence-transformers", "transformers", + "torch==2.4.0", + "einops>=0.7.0", + "sentencepiece>=0.2.0", "nltk", "rouge_score", "fuzzywuzzy", diff --git a/requirements/dev.txt b/requirements/dev.txt index df664c6b5..b15f7bc4d 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -14,5 +14,8 @@ fuzzywuzzy rouge_score nltk rapidfuzz +einops>=0.7.0 +torch==2.4.0 +sentencepiece>=0.2.0 pandas -datacompy \ No newline at end of file +datacompy diff --git a/src/ragas/metrics/__init__.py b/src/ragas/metrics/__init__.py index 6b476675d..295f165e8 100644 --- a/src/ragas/metrics/__init__.py +++ b/src/ragas/metrics/__init__.py @@ -16,13 +16,18 @@ context_utilization, ) from ragas.metrics._context_recall import ContextRecall, context_recall +from ragas.metrics._faithfulness import ( + Faithfulness, + FaithfulnesswithMiniCheck, + FaithulnesswithHHEM, + faithfulness, +) from ragas.metrics._domain_specific_rubrics import ( RubricsScoreWithoutReference, RubricsScoreWithReference, rubrics_score_with_reference, rubrics_score_without_reference, ) -from ragas.metrics._faithfulness import Faithfulness, FaithulnesswithHHEM, faithfulness from ragas.metrics._noise_sensitivity import ( NoiseSensitivity, noise_sensitivity_irrelevant, @@ -36,6 +41,7 @@ "Faithfulness", "faithfulness", "FaithulnesswithHHEM", + "FaithfulnesswithMiniCheck", "AnswerSimilarity", "answer_similarity", "ContextPrecision", diff --git a/src/ragas/metrics/_faithfulness.py b/src/ragas/metrics/_faithfulness.py index 489d07c26..75db145ad 100644 --- a/src/ragas/metrics/_faithfulness.py +++ b/src/ragas/metrics/_faithfulness.py @@ -1,11 +1,15 @@ from __future__ import annotations +import asyncio import json import logging +import os import typing as t from dataclasses import dataclass, field +from string import Template import numpy as np +import requests from langchain_core.pydantic_v1 import BaseModel, Field from ragas.dataset_schema import SingleTurnSample @@ -401,4 +405,178 @@ async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float: return sum(scores) / len(scores) +MINICHECK_SYSTEM_PROMPT = ( + "Determine whether the provided claim is consistent with the " + "corresponding document. Consistency in this context implies that all " + "information presented in the claim is substantiated by the document. " + "If not, it should be considered inconsistent. Please assess the " + "claim's consistency with the document by responding with either \"Yes\" " + "or \"No\"." +) +MINICHECK_USER_PROMPT_TEMPLATE = Template("Document: $document\nClaim: $claim") + + +@dataclass +class MiniCheckExample: + document: str = "" + claim: str = "" + + +@dataclass +class FaithfulnesswithMiniCheck(Faithfulness): + name: str = "faithfulness_with_minicheck" # type: ignore + device: str = "cpu" + batch_size: int = 10 + max_sequence_len: int = 10000 # max sequence can be 32768 + use_api: bool = False + bespoke_api_key: str = "" + max_concurrent_requests: int = 10 + + def __post_init__(self): + if self.use_api: + self.bespoke_api_key = (self.bespoke_api_key if self.bespoke_api_key else + os.environ.get("BESPOKE_API_KEY", "")) + if not self.bespoke_api_key: + raise ValueError( + "No API key found for bespokelabs API. Please get your key " + "at https://console.bespokelabs.ai, then provide it " + "by passing the bespoke_api_key parameter to the " + "constructor or set the BESPOKE_API_KEY environment variable.") + self._semaphore = asyncio.Semaphore(self.max_concurrent_requests) + else: + try: + import einops as einops + from transformers import AutoModelForCausalLM, AutoTokenizer + except ImportError: + raise ImportError( + "einops, torch, and transformers must be installed to use this feature, " + " try `pip install .[all]` to install the dependencies.") + self._minicheck = AutoModelForCausalLM.from_pretrained( + "bespokelabs/Bespoke-MiniCheck-7B", trust_remote_code=True + ) + self._tokenizer = AutoTokenizer.from_pretrained( + "bespokelabs/Bespoke-MiniCheck-7B", + trust_remote_code=True) + self._minicheck.to(self.device) + super().__post_init__() + + def _create_examples( + self, row: t.Dict, statements: t.List[str] + ) -> t.List[MiniCheckExample]: + document = "\n".join(row["retrieved_contexts"]) + return [MiniCheckExample(document=document, claim=statement) + for statement in statements] + + def _decode(self, prompts: t.List[str]): + import torch + inputs = self._tokenizer( + prompts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=self.max_sequence_len) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self._minicheck.generate( + **inputs, + max_new_tokens=1, + return_dict_in_generate=True, + output_scores=True) + + return outputs + + def _extract_scores(self, outputs): + import torch + logits = outputs.scores[0] + probs = torch.softmax(logits, dim=-1) + top_5_probs, top_5_indices = torch.topk(probs, 5, dim=-1) + scores = [] + for i in range(logits.shape[0]): + top_5_tokens = [ + self._tokenizer.decode( + [idx]).lower() for idx in top_5_indices[i]] + yes_prob = sum( + prob for token, + prob in zip( + top_5_tokens, + top_5_probs[i]) if token == 'yes') + scores.append(int(yes_prob > 0.5)) + + return scores + + def _score_examples_locally( + self, examples: t.List[MiniCheckExample]) -> t.List[float]: + prompts = [] + for example in examples: + user_prompt = MINICHECK_USER_PROMPT_TEMPLATE.substitute( + document=example.document, claim=example.claim) + message = [ + {"role": "system", "content": MINICHECK_SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ] + prompt = self._tokenizer.apply_chat_template( + message, add_generation_prompt=True, tokenize=False) + prompts.append(prompt) + scores = [] + for i in range(0, len(prompts), self.batch_size): + logits = self._decode(prompts[i:i + self.batch_size]) + scores_batch = self._extract_scores(logits) + scores.extend(scores_batch) + return scores + + async def _score_examples_api( + self, + examples: t.List[MiniCheckExample]) -> t.List[float]: + async def request_minicheck(example: MiniCheckExample) -> float: + def sync_request_minicheck(example: MiniCheckExample) -> float: + try: + response = requests.post( + "https://api.bespokelabs.ai/v0/minicheck/factcheck", + json={ + "context": example.document, + "claim": example.claim + }, + headers={"api_key": self.bespoke_api_key} + ) + response.raise_for_status() + return int(response.json()['support_prob'] > 0.5) + except requests.RequestException as e: + logger.warning(f"Bespoke API request failed: {str(e)}") + return np.nan + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + None, + sync_request_minicheck, + example + ) + return await asyncio.gather(*[ + request_minicheck(example) for example in examples]) + + async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float: + assert self.llm is not None, "LLM is not set" + + p_value = self._create_statements_prompt(row) + statements = await self.llm.generate( + p_value, + callbacks=callbacks, + ) + statements = await _statements_output_parser.aparse( + statements.generations[0][0].text, p_value, self.llm, self.max_retries + ) + + if statements is None: + return np.nan + + statements = [item["simpler_statements"] for item in statements.dicts()] + statements = [item for sublist in statements for item in sublist] + + examples = self._create_examples(row, statements) + if not self.use_api: + scores = self._score_examples_locally(examples) + else: + scores = await self._score_examples_api(examples) + return sum(scores) / len(scores) + + faithfulness = Faithfulness() diff --git a/tests/e2e/test_amnesty_in_ci.py b/tests/e2e/test_amnesty_in_ci.py index 42b44fc20..5d2b4997d 100644 --- a/tests/e2e/test_amnesty_in_ci.py +++ b/tests/e2e/test_amnesty_in_ci.py @@ -35,4 +35,4 @@ def test_amnesty_e2e(): @pytest.mark.ragas_ci def test_assert_in_range(): - assert_in_range(0.5, value=0.1, plus_or_minus=0.1) + assert_in_range(0.5, value=0.1, plus_or_minus=0.1) \ No newline at end of file