-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add class to configure a decorator that treats Codex as a backup
- Loading branch information
Showing
4 changed files
with
252 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# SPDX-License-Identifier: MIT | ||
from cleanlab_codex.codex import Codex | ||
from cleanlab_codex.codex_tool import CodexTool | ||
from cleanlab_codex.codex_backup import CodexBackup | ||
|
||
__all__ = ["Codex", "CodexTool"] | ||
__all__ = ["Codex", "CodexTool", "CodexBackup"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any, Callable, Optional | ||
from functools import wraps | ||
|
||
from cleanlab_codex.codex import Codex | ||
from cleanlab_codex.utils.response_validators import is_bad_response | ||
|
||
def handle_backup_default(backup_response: str, decorated_instance: Any) -> None: | ||
"""Default implementation is a no-op.""" | ||
return None | ||
|
||
|
||
class CodexBackup: | ||
"""A backup decorator that connects to a Codex project to answer questions that | ||
cannot be adequately answered by the existing agent. | ||
""" | ||
DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question." | ||
|
||
def __init__( | ||
self, | ||
codex_client: Codex, | ||
*, | ||
project_id: Optional[str] = None, | ||
fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER, | ||
backup_handler: Callable[[str, Any], None] = handle_backup_default, | ||
): | ||
self._codex_client = codex_client | ||
self._project_id = project_id | ||
self._fallback_answer = fallback_answer | ||
self._backup_handler = backup_handler | ||
|
||
@classmethod | ||
def from_access_key( | ||
cls, | ||
access_key: str, | ||
*, | ||
project_id: Optional[str] = None, | ||
fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER, | ||
backup_handler: Callable[[str, Any], None] = handle_backup_default, | ||
) -> CodexBackup: | ||
"""Creates a CodexBackup from an access key. The project ID that the CodexBackup will use is the one that is associated with the access key.""" | ||
return cls( | ||
codex_client=Codex(key=access_key), | ||
project_id=project_id, | ||
fallback_answer=fallback_answer, | ||
backup_handler=backup_handler, | ||
) | ||
|
||
@classmethod | ||
def from_client( | ||
cls, | ||
codex_client: Codex, | ||
*, | ||
project_id: Optional[str] = None, | ||
fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER, | ||
backup_handler: Callable[[str, Any], None] = handle_backup_default, | ||
) -> CodexBackup: | ||
"""Creates a CodexBackup from a Codex client. | ||
If the Codex client is initialized with a project access key, the CodexBackup will use the project ID that is associated with the access key. | ||
If the Codex client is initialized with a user API key, a project ID must be provided. | ||
""" | ||
return cls( | ||
codex_client=codex_client, | ||
project_id=project_id, | ||
fallback_answer=fallback_answer, | ||
backup_handler=backup_handler, | ||
) | ||
|
||
def to_decorator(self): | ||
"""Factory that creates a backup decorator using the provided Codex client""" | ||
def decorator(chat_method): | ||
""" | ||
Decorator for RAG chat methods that adds backup response handling. | ||
If the original chat method returns an inadequate response, attempts to get | ||
a backup response from Codex. Returns the backup response if available, | ||
otherwise returns the original response. | ||
Args: | ||
chat_method: Method with signature (self, user_message: str) -> str | ||
where 'self' refers to the instance being decorated, not an instance of CodexBackup. | ||
""" | ||
@wraps(chat_method) | ||
def wrapper(decorated_instance, user_message): | ||
# Call the original chat method | ||
assistant_response = chat_method(decorated_instance, user_message) | ||
|
||
# Return original response if it's adequate | ||
if not is_bad_response(assistant_response): | ||
return assistant_response | ||
|
||
# Query Codex for a backup response | ||
cache_result = self._codex_client.query(user_message)[0] | ||
if not cache_result: | ||
return assistant_response | ||
|
||
# Handle backup response if handler exists | ||
self._backup_handler( | ||
backup_response=cache_result, | ||
decorated_instance=decorated_instance, | ||
) | ||
return cache_result | ||
return wrapper | ||
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
""" | ||
This module provides validation functions for checking if an LLM response is inadequate/unhelpful. | ||
The default implementation checks for common fallback phrases, but alternative implementations | ||
are provided below as examples that can be adapted for specific needs. | ||
""" | ||
|
||
|
||
def is_bad_response(response: str) -> bool: | ||
""" | ||
Default implementation that checks for common fallback phrases from LLM assistants. | ||
NOTE: YOU SHOULD MODIFY THIS METHOD YOURSELF. | ||
""" | ||
return basic_validator(response) | ||
|
||
def basic_validator(response: str) -> bool: | ||
"""Basic implementation that checks for common fallback phrases from LLM assistants. | ||
Args: | ||
response: The response from the assistant | ||
Returns: | ||
bool: True if the response appears to be a fallback/inadequate response | ||
""" | ||
partial_fallback_responses = [ | ||
"Based on the available information", | ||
"I cannot provide a complete answer to this question", | ||
# Add more substrings here to improve the recall of the check | ||
] | ||
return any( | ||
partial_fallback_response.lower() in response.lower() | ||
for partial_fallback_response in partial_fallback_responses | ||
) | ||
|
||
# Alternative Implementations | ||
# --------------------------- | ||
# The following implementations are provided as examples and inspiration. | ||
# They should be adapted to your specific needs. | ||
|
||
|
||
# Fuzzy String Matching | ||
""" | ||
from thefuzz import fuzz | ||
def fuzzy_match_validator(response: str, fallback_answer: str, threshold: int = 70) -> bool: | ||
partial_ratio = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) | ||
return partial_ratio >= threshold | ||
""" | ||
|
||
# TLM Score Thresholding | ||
""" | ||
from cleanlab_studio import Studio | ||
studio = Studio("<API_KEY>") | ||
tlm = studio.TLM() | ||
def tlm_score_validator(response: str, context: str, query: str, tlm: TLM, threshold: float = 0.5) -> bool: | ||
prompt = f"Context: {context}\n\n Query: {query}\n\n Query: {query}" | ||
resp = tlm.get_trustworthiness_score(prompt, response) | ||
score = resp['trustworthiness_score'] | ||
return score < threshold | ||
""" | ||
|
||
# TLM Binary Classification | ||
""" | ||
from typing import Optional | ||
from cleanlab_studio import Studio | ||
studio = Studio("<API_KEY>") | ||
tlm = studio.TLM() | ||
def tlm_binary_validator(response: str, tlm: TLM, query: Optional[str] = None) -> bool: | ||
if query is None: | ||
prompt = f"Here is a response from an AI assistant: {response}\n\n Is it helpful? Answer Yes/No only." | ||
else: | ||
prompt = f"Here is a response from an AI assistant: {response}\n\n Considering the following query: {query}\n\n Is the response helpful? Answer Yes/No only." | ||
output = tlm.prompt(prompt) | ||
return output["response"].lower() == "no" | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from unittest.mock import MagicMock | ||
|
||
from cleanlab_codex.codex_backup import CodexBackup | ||
|
||
MOCK_BACKUP_RESPONSE = "This is a test response" | ||
FALLBACK_MESSAGE = "Based on the available information, I cannot provide a complete answer to this question." | ||
TEST_MESSAGE = "Hello, world!" | ||
|
||
|
||
def test_codex_backup(mock_client: MagicMock): # noqa: ARG001 | ||
mock_response = MagicMock() | ||
mock_response.answer = MOCK_BACKUP_RESPONSE | ||
mock_client.projects.entries.query.return_value = mock_response | ||
|
||
codex_backup = CodexBackup.from_access_key("") | ||
|
||
class MockApp: | ||
@codex_backup.to_decorator() | ||
def chat(self, user_message: str) -> str: | ||
# Just echo the user message | ||
return user_message | ||
|
||
app = MockApp() | ||
|
||
# Echo works well | ||
response = app.chat(TEST_MESSAGE) | ||
assert response == TEST_MESSAGE | ||
|
||
# Backup works well for fallback responses | ||
response = app.chat(FALLBACK_MESSAGE) | ||
assert response == MOCK_BACKUP_RESPONSE | ||
|
||
def test_backup_handler(mock_client: MagicMock): | ||
mock_response = MagicMock() | ||
mock_response.answer = MOCK_BACKUP_RESPONSE | ||
mock_client.projects.entries.query.return_value = mock_response | ||
|
||
mock_handler = MagicMock() | ||
mock_handler.return_value = None | ||
codex_backup = CodexBackup.from_access_key("", backup_handler=mock_handler) | ||
|
||
class MockApp: | ||
@codex_backup.to_decorator() | ||
def chat(self, user_message: str) -> str: | ||
# Just echo the user message | ||
return user_message | ||
|
||
app = MockApp() | ||
|
||
response = app.chat(TEST_MESSAGE) | ||
assert response == TEST_MESSAGE | ||
|
||
# Handler should not be called for good responses | ||
assert mock_handler.call_count == 0 | ||
|
||
response = app.chat(FALLBACK_MESSAGE) | ||
assert response == MOCK_BACKUP_RESPONSE | ||
|
||
# Handler should be called for bad responses | ||
assert mock_handler.call_count == 1 | ||
# The MockApp is the second argument to the handler, i.e. it has the necessary context | ||
# to handle the new response | ||
assert mock_handler.call_args.kwargs["decorated_instance"] == app | ||
|
||
|