Skip to content

Commit

Permalink
add class to configure a decorator that treats Codex as a backup
Browse files Browse the repository at this point in the history
  • Loading branch information
elisno committed Jan 24, 2025
1 parent e215c21 commit 3acb048
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/cleanlab_codex/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: MIT
from cleanlab_codex.codex import Codex
from cleanlab_codex.codex_tool import CodexTool
from cleanlab_codex.codex_backup import CodexBackup

__all__ = ["Codex", "CodexTool"]
__all__ = ["Codex", "CodexTool", "CodexBackup"]
105 changes: 105 additions & 0 deletions src/cleanlab_codex/codex_backup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from __future__ import annotations

from typing import Any, Callable, Optional
from functools import wraps

from cleanlab_codex.codex import Codex
from cleanlab_codex.utils.response_validators import is_bad_response

def handle_backup_default(backup_response: str, decorated_instance: Any) -> None:
"""Default implementation is a no-op."""
return None


class CodexBackup:
"""A backup decorator that connects to a Codex project to answer questions that
cannot be adequately answered by the existing agent.
"""
DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question."

def __init__(
self,
codex_client: Codex,
*,
project_id: Optional[str] = None,
fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER,
backup_handler: Callable[[str, Any], None] = handle_backup_default,
):
self._codex_client = codex_client
self._project_id = project_id
self._fallback_answer = fallback_answer
self._backup_handler = backup_handler

@classmethod
def from_access_key(
cls,
access_key: str,
*,
project_id: Optional[str] = None,
fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER,
backup_handler: Callable[[str, Any], None] = handle_backup_default,
) -> CodexBackup:
"""Creates a CodexBackup from an access key. The project ID that the CodexBackup will use is the one that is associated with the access key."""
return cls(
codex_client=Codex(key=access_key),
project_id=project_id,
fallback_answer=fallback_answer,
backup_handler=backup_handler,
)

@classmethod
def from_client(
cls,
codex_client: Codex,
*,
project_id: Optional[str] = None,
fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER,
backup_handler: Callable[[str, Any], None] = handle_backup_default,
) -> CodexBackup:
"""Creates a CodexBackup from a Codex client.
If the Codex client is initialized with a project access key, the CodexBackup will use the project ID that is associated with the access key.
If the Codex client is initialized with a user API key, a project ID must be provided.
"""
return cls(
codex_client=codex_client,
project_id=project_id,
fallback_answer=fallback_answer,
backup_handler=backup_handler,
)

def to_decorator(self):
"""Factory that creates a backup decorator using the provided Codex client"""
def decorator(chat_method):
"""
Decorator for RAG chat methods that adds backup response handling.
If the original chat method returns an inadequate response, attempts to get
a backup response from Codex. Returns the backup response if available,
otherwise returns the original response.
Args:
chat_method: Method with signature (self, user_message: str) -> str
where 'self' refers to the instance being decorated, not an instance of CodexBackup.
"""
@wraps(chat_method)
def wrapper(decorated_instance, user_message):
# Call the original chat method
assistant_response = chat_method(decorated_instance, user_message)

# Return original response if it's adequate
if not is_bad_response(assistant_response):
return assistant_response

# Query Codex for a backup response
cache_result = self._codex_client.query(user_message)[0]
if not cache_result:
return assistant_response

# Handle backup response if handler exists
self._backup_handler(
backup_response=cache_result,
decorated_instance=decorated_instance,
)
return cache_result
return wrapper
return decorator
80 changes: 80 additions & 0 deletions src/cleanlab_codex/utils/response_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
This module provides validation functions for checking if an LLM response is inadequate/unhelpful.
The default implementation checks for common fallback phrases, but alternative implementations
are provided below as examples that can be adapted for specific needs.
"""


def is_bad_response(response: str) -> bool:
"""
Default implementation that checks for common fallback phrases from LLM assistants.
NOTE: YOU SHOULD MODIFY THIS METHOD YOURSELF.
"""
return basic_validator(response)

def basic_validator(response: str) -> bool:
"""Basic implementation that checks for common fallback phrases from LLM assistants.
Args:
response: The response from the assistant
Returns:
bool: True if the response appears to be a fallback/inadequate response
"""
partial_fallback_responses = [
"Based on the available information",
"I cannot provide a complete answer to this question",
# Add more substrings here to improve the recall of the check
]
return any(
partial_fallback_response.lower() in response.lower()
for partial_fallback_response in partial_fallback_responses
)

# Alternative Implementations
# ---------------------------
# The following implementations are provided as examples and inspiration.
# They should be adapted to your specific needs.


# Fuzzy String Matching
"""
from thefuzz import fuzz
def fuzzy_match_validator(response: str, fallback_answer: str, threshold: int = 70) -> bool:
partial_ratio = fuzz.partial_ratio(fallback_answer.lower(), response.lower())
return partial_ratio >= threshold
"""

# TLM Score Thresholding
"""
from cleanlab_studio import Studio
studio = Studio("<API_KEY>")
tlm = studio.TLM()
def tlm_score_validator(response: str, context: str, query: str, tlm: TLM, threshold: float = 0.5) -> bool:
prompt = f"Context: {context}\n\n Query: {query}\n\n Query: {query}"
resp = tlm.get_trustworthiness_score(prompt, response)
score = resp['trustworthiness_score']
return score < threshold
"""

# TLM Binary Classification
"""
from typing import Optional
from cleanlab_studio import Studio
studio = Studio("<API_KEY>")
tlm = studio.TLM()
def tlm_binary_validator(response: str, tlm: TLM, query: Optional[str] = None) -> bool:
if query is None:
prompt = f"Here is a response from an AI assistant: {response}\n\n Is it helpful? Answer Yes/No only."
else:
prompt = f"Here is a response from an AI assistant: {response}\n\n Considering the following query: {query}\n\n Is the response helpful? Answer Yes/No only."
output = tlm.prompt(prompt)
return output["response"].lower() == "no"
"""
65 changes: 65 additions & 0 deletions tests/test_codex_backup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from unittest.mock import MagicMock

from cleanlab_codex.codex_backup import CodexBackup

MOCK_BACKUP_RESPONSE = "This is a test response"
FALLBACK_MESSAGE = "Based on the available information, I cannot provide a complete answer to this question."
TEST_MESSAGE = "Hello, world!"


def test_codex_backup(mock_client: MagicMock): # noqa: ARG001
mock_response = MagicMock()
mock_response.answer = MOCK_BACKUP_RESPONSE
mock_client.projects.entries.query.return_value = mock_response

codex_backup = CodexBackup.from_access_key("")

class MockApp:
@codex_backup.to_decorator()
def chat(self, user_message: str) -> str:
# Just echo the user message
return user_message

app = MockApp()

# Echo works well
response = app.chat(TEST_MESSAGE)
assert response == TEST_MESSAGE

# Backup works well for fallback responses
response = app.chat(FALLBACK_MESSAGE)
assert response == MOCK_BACKUP_RESPONSE

def test_backup_handler(mock_client: MagicMock):
mock_response = MagicMock()
mock_response.answer = MOCK_BACKUP_RESPONSE
mock_client.projects.entries.query.return_value = mock_response

mock_handler = MagicMock()
mock_handler.return_value = None
codex_backup = CodexBackup.from_access_key("", backup_handler=mock_handler)

class MockApp:
@codex_backup.to_decorator()
def chat(self, user_message: str) -> str:
# Just echo the user message
return user_message

app = MockApp()

response = app.chat(TEST_MESSAGE)
assert response == TEST_MESSAGE

# Handler should not be called for good responses
assert mock_handler.call_count == 0

response = app.chat(FALLBACK_MESSAGE)
assert response == MOCK_BACKUP_RESPONSE

# Handler should be called for bad responses
assert mock_handler.call_count == 1
# The MockApp is the second argument to the handler, i.e. it has the necessary context
# to handle the new response
assert mock_handler.call_args.kwargs["decorated_instance"] == app


0 comments on commit 3acb048

Please sign in to comment.