diff --git a/pyproject.toml b/pyproject.toml index c7fa840..a2b53a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,8 @@ extra-dependencies = [ "pytest", "llama-index-core", "smolagents", + "cleanlab-studio", + "thefuzz", "langchain-core", ] [tool.hatch.envs.types.scripts] @@ -54,6 +56,8 @@ allow-direct-references = true extra-dependencies = [ "llama-index-core", "smolagents; python_version >= '3.10'", + "cleanlab-studio", + "thefuzz", "langchain-core", ] diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py new file mode 100644 index 0000000..dcc15d5 --- /dev/null +++ b/src/cleanlab_codex/response_validation.py @@ -0,0 +1,299 @@ +""" +This module provides validation functions for evaluating LLM responses and determining if they should be replaced with Codex-generated alternatives. +""" + +from __future__ import annotations + +from typing import ( + Any, + Callable, + Dict, + Optional, + Protocol, + Sequence, + Union, + cast, + runtime_checkable, +) + +from pydantic import BaseModel, ConfigDict, Field + +from cleanlab_codex.utils.errors import MissingDependencyError +from cleanlab_codex.utils.prompt import default_format_prompt + + +@runtime_checkable +class TLM(Protocol): + def get_trustworthiness_score( + self, + prompt: Union[str, Sequence[str]], + response: Union[str, Sequence[str]], + **kwargs: Any, + ) -> Dict[str, Any]: ... + + def prompt( + self, + prompt: Union[str, Sequence[str]], + /, + **kwargs: Any, + ) -> Dict[str, Any]: ... + + +DEFAULT_FALLBACK_ANSWER: str = ( + "Based on the available information, I cannot provide a complete answer to this question." +) +DEFAULT_FALLBACK_SIMILARITY_THRESHOLD: int = 70 +DEFAULT_TRUSTWORTHINESS_THRESHOLD: float = 0.5 + +Query = str +Context = str +Prompt = str + + +class BadResponseDetectionConfig(BaseModel): + """Configuration for bad response detection functions.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + # Fallback check config + fallback_answer: str = Field( + default=DEFAULT_FALLBACK_ANSWER, description="Known unhelpful response to compare against" + ) + fallback_similarity_threshold: int = Field( + default=DEFAULT_FALLBACK_SIMILARITY_THRESHOLD, + description="Fuzzy matching similarity threshold (0-100). Higher values mean responses must be more similar to fallback_answer to be considered bad.", + ) + + # Untrustworthy check config + trustworthiness_threshold: float = Field( + default=DEFAULT_TRUSTWORTHINESS_THRESHOLD, + description="Score threshold (0.0-1.0). Lower values allow less trustworthy responses.", + ) + format_prompt: Callable[[Query, Context], Prompt] = Field( + default=default_format_prompt, + description="Function to format (query, context) into a prompt string.", + ) + + # Unhelpful check config + unhelpfulness_confidence_threshold: Optional[float] = Field( + default=None, + description="Optional confidence threshold (0.0-1.0) for unhelpful classification.", + ) + + # Shared config (for untrustworthiness and unhelpfulness checks) + tlm: Optional[TLM] = Field( + default=None, + description="TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks).", + ) + + +DEFAULT_CONFIG = BadResponseDetectionConfig() + + +def is_bad_response( + response: str, + *, + context: Optional[str] = None, + query: Optional[str] = None, + config: Union[BadResponseDetectionConfig, Dict[str, Any]] = DEFAULT_CONFIG, +) -> bool: + """Run a series of checks to determine if a response is bad. + + If any check detects an issue (i.e. fails), the function returns True, indicating the response is bad. + + This function runs three possible validation checks: + 1. **Fallback check**: Detects if response is too similar to a known fallback answer. + 2. **Untrustworthy check**: Assesses response trustworthiness based on the given context and query. + 3. **Unhelpful check**: Predicts if the response adequately answers the query or not, in a useful way. + + Note: + Each validation check runs conditionally based on whether the required arguments are provided. + As soon as any validation check fails, the function returns True. + + Args: + response: The response to check. + context: Optional context/documents used for answering. Required for untrustworthy check. + query: Optional user question. Required for untrustworthy and unhelpful checks. + config: Optional, typed dictionary of configuration parameters. See <_BadReponseConfig> for details. + + Returns: + bool: True if any validation check fails, False if all pass. + """ + config = BadResponseDetectionConfig.model_validate(config) + + validation_checks: list[Callable[[], bool]] = [] + + # All required inputs are available for checking fallback responses + validation_checks.append( + lambda: is_fallback_response( + response, + config.fallback_answer, + threshold=config.fallback_similarity_threshold, + ) + ) + + can_run_untrustworthy_check = query is not None and context is not None and config.tlm is not None + if can_run_untrustworthy_check: + # The if condition guarantees these are not None + validation_checks.append( + lambda: is_untrustworthy_response( + response=response, + context=cast(str, context), + query=cast(str, query), + tlm=cast(TLM, config.tlm), + trustworthiness_threshold=config.trustworthiness_threshold, + format_prompt=config.format_prompt, + ) + ) + + can_run_unhelpful_check = query is not None and config.tlm is not None + if can_run_unhelpful_check: + validation_checks.append( + lambda: is_unhelpful_response( + response=response, + query=cast(str, query), + tlm=cast(TLM, config.tlm), + trustworthiness_score_threshold=cast(float, config.unhelpfulness_confidence_threshold), + ) + ) + + return any(check() for check in validation_checks) + + +def is_fallback_response( + response: str, + fallback_answer: str = DEFAULT_FALLBACK_ANSWER, + threshold: int = DEFAULT_FALLBACK_SIMILARITY_THRESHOLD, +) -> bool: + """Check if a response is too similar to a known fallback answer. + + Uses fuzzy string matching to compare the response against a known fallback answer. + Returns True if the response is similar enough to be considered unhelpful. + + Args: + response: The response to check. + fallback_answer: A known unhelpful/fallback response to compare against. + threshold: Similarity threshold (0-100). Higher values require more similarity. + Default 70 means responses that are 70% or more similar are considered bad. + + Returns: + bool: True if the response is too similar to the fallback answer, False otherwise + """ + try: + from thefuzz import fuzz # type: ignore + except ImportError as e: + raise MissingDependencyError( + import_name=e.name or "thefuzz", + package_url="https://github.com/seatgeek/thefuzz", + ) from e + + partial_ratio: int = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) + return bool(partial_ratio >= threshold) + + +def is_untrustworthy_response( + response: str, + context: str, + query: str, + tlm: TLM, + trustworthiness_threshold: float = DEFAULT_TRUSTWORTHINESS_THRESHOLD, + format_prompt: Callable[[str, str], str] = default_format_prompt, +) -> bool: + """Check if a response is untrustworthy. + + Uses TLM to evaluate whether a response is trustworthy given the context and query. + Returns True if TLM's trustworthiness score falls below the threshold, indicating + the response may be incorrect or unreliable. + + Args: + response: The response to check from the assistant + context: The context information available for answering the query + query: The user's question or request + tlm: The TLM model to use for evaluation + trustworthiness_threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses. + Default 0.5, meaning responses with scores less than 0.5 are considered untrustworthy. + format_prompt: Function that takes (query, context) and returns a formatted prompt string. + Users should provide their RAG app's own prompt formatting function here + to match how their LLM is prompted. + + Returns: + bool: True if the response is deemed untrustworthy by TLM, False otherwise + """ + try: + from cleanlab_studio import Studio # type: ignore[import-untyped] # noqa: F401 + except ImportError as e: + raise MissingDependencyError( + import_name=e.name or "cleanlab_studio", + package_name="cleanlab-studio", + package_url="https://github.com/cleanlab/cleanlab-studio", + ) from e + + prompt = format_prompt(query, context) + result = tlm.get_trustworthiness_score(prompt, response) + score: float = result["trustworthiness_score"] + return score < trustworthiness_threshold + + +def is_unhelpful_response( + response: str, + query: str, + tlm: TLM, + trustworthiness_score_threshold: Optional[float] = None, +) -> bool: + """Check if a response is unhelpful by asking TLM to evaluate it. + + Uses TLM to evaluate whether a response is helpful by asking it to make a Yes/No judgment. + The evaluation considers both the TLM's binary classification of helpfulness and its + confidence score. Returns True only if TLM classifies the response as unhelpful AND + is sufficiently confident in that assessment (if a threshold is provided). + + Args: + response: The response to check + query: User query that will be used to evaluate if the response is helpful + tlm: The TLM model to use for evaluation + trustworthiness_score_threshold: Optional confidence threshold (0.0-1.0) + If provided and the response is marked as unhelpful, + the confidence score must exceed this threshold for + the response to be considered truly unhelpful. + + Returns: + bool: True if TLM determines the response is unhelpful with sufficient confidence, + False otherwise + """ + try: + from cleanlab_studio import Studio # noqa: F401 + except ImportError as e: + raise MissingDependencyError( + import_name=e.name or "cleanlab_studio", + package_name="cleanlab-studio", + package_url="https://github.com/cleanlab/cleanlab-studio", + ) from e + + # If editing `question`, make sure `expected_unhelpful_response` is still correct: + # - When asking "is helpful?" -> "no" means unhelpful + # - When asking "is unhelpful?" -> "yes" means unhelpful + question = ( + "Does the AI Assistant Response seem unhelpful? " + "Things that are not helpful include answers that:\n" + "- Are not useful, incomplete, incorrect, uncertain or unclear.\n" + "- Abstain or refuse to answer the question\n" + "- Statements which are similar to 'I don't know', 'Sorry', or 'No information available'.\n" + "- Leave the original question unresolved\n" + "- Are irrelevant to the question\n" + "Answer Yes/No only." + ) + expected_unhelpful_response = "yes" + + prompt = ( + "Consider the following User Query and AI Assistant Response.\n\n" + f"User Query: {query}\n\n" + f"AI Assistant Response: {response}\n\n" + f"{question}" + ) + + output = tlm.prompt(prompt, constrain_outputs=["Yes", "No"]) + response_marked_unhelpful = output["response"].lower() == expected_unhelpful_response + is_trustworthy = trustworthiness_score_threshold is None or ( + output["trustworthiness_score"] > trustworthiness_score_threshold + ) + return response_marked_unhelpful and is_trustworthy diff --git a/src/cleanlab_codex/utils/prompt.py b/src/cleanlab_codex/utils/prompt.py new file mode 100644 index 0000000..2717ef5 --- /dev/null +++ b/src/cleanlab_codex/utils/prompt.py @@ -0,0 +1,21 @@ +""" +Helper functions for processing prompts in RAG applications. +""" + + +def default_format_prompt(query: str, context: str) -> str: + """Default function for formatting RAG prompts. + + Args: + query: The user's question + context: The context/documents to use for answering + + Returns: + str: A formatted prompt combining the query and context + """ + template = ( + "Using only information from the following Context, answer the following Query.\n\n" + "Context:\n{context}\n\n" + "Query: {query}" + ) + return template.format(context=context, query=query) diff --git a/tests/test_response_validation.py b/tests/test_response_validation.py new file mode 100644 index 0000000..cbc1d29 --- /dev/null +++ b/tests/test_response_validation.py @@ -0,0 +1,210 @@ +"""Unit tests for validation module functions.""" + +from __future__ import annotations + +from typing import Any, Dict, Sequence, Union +from unittest.mock import Mock, patch + +import pytest + +from cleanlab_codex.response_validation import ( + is_bad_response, + is_fallback_response, + is_unhelpful_response, + is_untrustworthy_response, +) + +# Mock responses for testing +GOOD_RESPONSE = "This is a helpful and specific response that answers the question completely." +BAD_RESPONSE = "Based on the available information, I cannot provide a complete answer." +QUERY = "What is the capital of France?" +CONTEXT = "Paris is the capital and largest city of France." + + +class MockTLM(Mock): + _trustworthiness_score: float = 0.8 + _response: str = "No" + + @property + def trustworthiness_score(self) -> float: + return self._trustworthiness_score + + @trustworthiness_score.setter + def trustworthiness_score(self, value: float) -> None: + self._trustworthiness_score = value + + @property + def response(self) -> str: + return self._response + + @response.setter + def response(self, value: str) -> None: + self._response = value + + def get_trustworthiness_score( + self, + prompt: Union[str, Sequence[str]], # noqa: ARG002 + response: Union[str, Sequence[str]], # noqa: ARG002 + **kwargs: Any, # noqa: ARG002 + ) -> Dict[str, Any]: + return {"trustworthiness_score": self._trustworthiness_score} + + def prompt( + self, + prompt: Union[str, Sequence[str]], # noqa: ARG002 + /, + **kwargs: Any, # noqa: ARG002 + ) -> Dict[str, Any]: + return {"response": self._response, "trustworthiness_score": self._trustworthiness_score} + + +@pytest.fixture +def mock_tlm() -> MockTLM: + return MockTLM() + + +@pytest.mark.parametrize( + ("response", "threshold", "fallback_answer", "expected"), + [ + # Test threshold variations + (GOOD_RESPONSE, 30, None, True), + (GOOD_RESPONSE, 55, None, False), + # Test default behavior (BAD_RESPONSE should be flagged) + (BAD_RESPONSE, None, None, True), + # Test default behavior for different response (GOOD_RESPONSE should not be flagged) + (GOOD_RESPONSE, None, None, False), + # Test custom fallback answer + (GOOD_RESPONSE, 80, "This is an unhelpful response", False), + ], +) +def test_is_fallback_response( + response: str, + threshold: float | None, + fallback_answer: str | None, + *, + expected: bool, +) -> None: + """Test fallback response detection.""" + kwargs: dict[str, float | str] = {} + if threshold is not None: + kwargs["threshold"] = threshold + if fallback_answer is not None: + kwargs["fallback_answer"] = fallback_answer + + assert is_fallback_response(response, **kwargs) is expected # type: ignore + + +def test_is_untrustworthy_response(mock_tlm: Mock) -> None: + """Test untrustworthy response detection.""" + # Test trustworthy response + mock_tlm.trustworthiness_score = 0.8 + assert is_untrustworthy_response(GOOD_RESPONSE, CONTEXT, QUERY, mock_tlm, trustworthiness_threshold=0.5) is False + + # Test untrustworthy response + mock_tlm.trustworthiness_score = 0.3 + assert is_untrustworthy_response(BAD_RESPONSE, CONTEXT, QUERY, mock_tlm, trustworthiness_threshold=0.5) is True + + +@pytest.mark.parametrize( + ("response", "tlm_response", "tlm_score", "threshold", "expected"), + [ + # Test helpful response + (GOOD_RESPONSE, "No", 0.9, 0.5, False), + # Test unhelpful response + (BAD_RESPONSE, "Yes", 0.9, 0.5, True), + # Test unhelpful response but low trustworthiness score + (BAD_RESPONSE, "Yes", 0.3, 0.5, False), + # Test without threshold - Yes prediction + (BAD_RESPONSE, "Yes", 0.3, None, True), + (GOOD_RESPONSE, "Yes", 0.3, None, True), + # Test without threshold - No prediction + (BAD_RESPONSE, "No", 0.3, None, False), + (GOOD_RESPONSE, "No", 0.3, None, False), + ], +) +def test_is_unhelpful_response( + mock_tlm: Mock, + response: str, + tlm_response: str, + tlm_score: float, + threshold: float | None, + *, + expected: bool, +) -> None: + """Test unhelpful response detection.""" + mock_tlm.response = tlm_response + mock_tlm.trustworthiness_score = tlm_score + assert is_unhelpful_response(response, QUERY, mock_tlm, trustworthiness_score_threshold=threshold) is expected + + +@pytest.mark.parametrize( + ("response", "trustworthiness_score", "prompt_response", "prompt_score", "expected"), + [ + # Good response passes all checks + (GOOD_RESPONSE, 0.8, "No", 0.9, False), + # Bad response fails at least one check + (BAD_RESPONSE, 0.3, "Yes", 0.9, True), + ], +) +def test_is_bad_response( + mock_tlm: Mock, + response: str, + trustworthiness_score: float, + prompt_response: str, + prompt_score: float, + *, + expected: bool, +) -> None: + """Test the main is_bad_response function.""" + mock_tlm.trustworthiness_score = trustworthiness_score + mock_tlm.response = prompt_response + mock_tlm.trustworthiness_score = prompt_score + + assert ( + is_bad_response( + response, + context=CONTEXT, + query=QUERY, + config={"tlm": mock_tlm}, + ) + is expected + ) + + +@pytest.mark.parametrize( + ("response", "fuzz_ratio", "prompt_response", "prompt_score", "query", "tlm", "expected"), + [ + # Test with only fallback check (no context/query/tlm) + (BAD_RESPONSE, 90, None, None, None, None, True), + # Test with fallback and unhelpful checks (no context) + (GOOD_RESPONSE, 30, "No", 0.9, QUERY, "mock_tlm", False), + ], +) +def test_is_bad_response_partial_inputs( + mock_tlm: Mock, + response: str, + fuzz_ratio: int, + prompt_response: str, + prompt_score: float, + query: str, + tlm: Mock, + *, + expected: bool, +) -> None: + """Test is_bad_response with partial inputs (some checks disabled).""" + mock_fuzz = Mock() + mock_fuzz.partial_ratio.return_value = fuzz_ratio + with patch.dict("sys.modules", {"thefuzz": Mock(fuzz=mock_fuzz)}): + if prompt_response is not None: + mock_tlm.response = prompt_response + mock_tlm.trustworthiness_score = prompt_score + tlm = mock_tlm + + assert ( + is_bad_response( + response, + query=query, + config={"tlm": tlm}, + ) + is expected + )