Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CodexBackup class #32

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/cleanlab_codex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from cleanlab_codex.codex_tool import CodexTool
from cleanlab_codex.project import Project

__all__ = ["Client", "CodexTool", "Project"]
__all__ = ["Client", "CodexTool", "CodexBackup", "Project"]
114 changes: 114 additions & 0 deletions src/cleanlab_codex/codex_backup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Enables connecting RAG applications to Codex as a Backup system.

This module provides functionality to use Codex as a fallback when a primary
RAG (Retrieval-Augmented Generation) system fails to provide adequate responses.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional

from cleanlab_codex.response_validation import BadResponseDetectionConfig, is_bad_response

if TYPE_CHECKING:
from cleanlab_codex.project import Project
from cleanlab_codex.types.backup import BackupHandler
from cleanlab_codex.types.tlm import TLM


def handle_backup_default(codex_response: str, primary_system: Any) -> None: # noqa: ARG001
"""Default implementation is a no-op."""
return None


class CodexBackup:
"""A backup decorator that connects to a Codex project to answer questions that
cannot be adequately answered by the existing agent.

Args:
project: The Codex project to use for backup responses
fallback_answer: The fallback answer to use if the primary system fails to provide an adequate response
backup_handler: A callback function that processes Codex's response and updates the primary RAG system. This handler is called whenever Codex provides a backup response after the primary system fails. By default, the backup handler is a no-op.
primary_system: The existing RAG system that needs to be backed up by Codex
tlm: The client for the Trustworthy Language Model, which evaluates the quality of responses from the primary system
is_bad_response_kwargs: Additional keyword arguments to pass to the is_bad_response function, for detecting inadequate responses from the primary system
"""

DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question."

def __init__(
self,
*,
project: Project,
fallback_answer: str = DEFAULT_FALLBACK_ANSWER,
backup_handler: BackupHandler = handle_backup_default,
primary_system: Optional[Any] = None,
tlm: Optional[TLM] = None,
is_bad_response_kwargs: Optional[dict[str, Any]] = None,
Comment on lines +46 to +47
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reasoning for tlm being a separate argument but the rest of the arguments for is_bad_response being grouped into is_bad_response_kwargs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should now be is_bad_response_config: BadResponseDetectionConfig. You're right, the tlm argument (and fallback_answer) should be fetched from the config instead.

):
self._project = project
self._fallback_answer = fallback_answer
self._backup_handler = backup_handler
self._primary_system: Optional[Any] = primary_system
self._tlm = tlm
self._is_bad_response_kwargs = is_bad_response_kwargs or {}

@classmethod
def from_project(cls, project: Project, **kwargs: Any) -> CodexBackup:
return cls(project=project, **kwargs)

@property
def primary_system(self) -> Any:
if self._primary_system is None:
error_message = "Primary system not set. Please set a primary system using the `add_primary_system` method."
raise ValueError(error_message)
return self._primary_system

@primary_system.setter
def primary_system(self, primary_system: Any) -> None:
"""Set the primary RAG system that will be used to generate responses."""
self._primary_system = primary_system

def run(
self,
response: str,
query: str,
context: Optional[str] = None,
) -> str:
"""Check if a response is adequate and provide a backup from Codex if needed.

Args:
primary_system: The system that generated the original response
response: The response to evaluate
query: The original query that generated the response
context: Optional context used to generate the response

Returns:
str: Either the original response if adequate, or a backup response from Codex
"""

is_bad = is_bad_response(
response,
query=query,
context=context,
config=BadResponseDetectionConfig.model_validate(
{
"tlm": self._tlm,
"fallback_answer": self._fallback_answer,
**self._is_bad_response_kwargs,
},
),
)
if not is_bad:
return response

codex_response = self._project.query(query, fallback_answer=self._fallback_answer)[0]
if not codex_response:
return response

if self._primary_system is not None:
self._backup_handler(
codex_response=codex_response,
primary_system=self._primary_system,
)
Comment on lines +109 to +113
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I was creating a RAG system for a production app, I think I'd be much more likely to implement handling replacement of original response either just within the chat method or as a separate instance method of the class. Without this CodexBackup class, I'd probably do something along the lines of:

class RAGChatWithCodexBackup(RAGChat):
  def __init__(self, client: OpenAI, assistant_id: str, codex_access_key: str):
    super().__init__(client, assistant_id)
    self._codex_project = Project.from_access_key(access_key)

  def replace_latest_message(self, new_message: str) -> None:
    <code from your handle_backup_for_openai_assistants method>
   ...

  def chat(self, user_message: str) -> str:
    response = super().chat(user_message)
    codex_response: str | None = None
    if is_bad_response(response=response, query=user_message):
      codex_response = self._codex_project.query(user_message)

    if codex_response is None:
      return response

    self.replace_latest_message(codex_response)
    return codex_response

Very unlikely that I'd define a method that fits the expected signature for _backup_handler

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the CodexBackup class you've defined without providing backup_handler and primary_system, I'd probably end up with the following:

class RAGChatWithCodexBackup(RAGChat):
  def __init__(self, client: OpenAI, assistant_id: str, codex_access_key: str):
    super().__init__(client, assistant_id)
    self._codex_backup = CodexBackup.from_project(Project.from_access_key(codex_access_key))
  
  def replace_latest_message(self, new_message: str) -> None:
    <code from your handle_backup_for_openai_assistants method>
   ...
  
  def chat(self, user_message: str) -> str:
    response = super().chat(user_message)
    backup_response: str | None = self._codex_backup.run(response=response, query=user_message)
    if backup_response is not None and backup_response != response:
      self.replace_latest_message(backup_response)
      return backup_response
    return response

which does save a couple lines of code, but not much.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering:
a) Whether it's worth trying to do the backup_handler stuff within this class (maybe could modify the expected function signature to allow for class instance methods).
b) Whether it's worth having this class right now (doesn't really save much code). But could potentially make the second example a little cleaner by modifying CodexBackup.run() to return a pair of backup_response, codex_used and then could just do if codex_used: self.replace_latest_message(backup_response)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed on slack. Will leave this implementation for now to avoid extra work of updating tutorials before soft launch deadline.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm putting this PR on hold then.
Moving the code into the tutorial now.

return codex_response
22 changes: 1 addition & 21 deletions src/cleanlab_codex/response_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,16 @@
Callable,
Dict,
Optional,
Protocol,
Sequence,
Union,
cast,
runtime_checkable,
)

from pydantic import BaseModel, ConfigDict, Field

from cleanlab_codex.types.tlm import TLM
from cleanlab_codex.utils.errors import MissingDependencyError
from cleanlab_codex.utils.prompt import default_format_prompt


@runtime_checkable
class TLM(Protocol):
def get_trustworthiness_score(
self,
prompt: Union[str, Sequence[str]],
response: Union[str, Sequence[str]],
**kwargs: Any,
) -> Dict[str, Any]: ...

def prompt(
self,
prompt: Union[str, Sequence[str]],
/,
**kwargs: Any,
) -> Dict[str, Any]: ...


DEFAULT_FALLBACK_ANSWER: str = (
"Based on the available information, I cannot provide a complete answer to this question."
)
Expand Down
30 changes: 30 additions & 0 deletions src/cleanlab_codex/types/backup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Types for Codex Backup."""

from __future__ import annotations

from typing import Any, Protocol


class BackupHandler(Protocol):
"""Protocol defining how to handle backup responses from Codex.

This protocol defines a callable interface for processing Codex responses that are
retrieved when the primary response system (e.g., a RAG system) fails to provide
an adequate answer. Implementations of this protocol can be used to:

- Update the primary system's context or knowledge base
- Log Codex responses for analysis
- Trigger system improvements or retraining
- Perform any other necessary side effects

Args:
codex_response (str): The response received from Codex
primary_system (Any): The instance of the primary RAG system that
generated the inadequate response. This allows the handler to
update or modify the primary system if needed.

Returns:
None: The handler performs side effects but doesn't return a value
"""

def __call__(self, codex_response: str, primary_system: Any) -> None: ...
22 changes: 22 additions & 0 deletions src/cleanlab_codex/types/tlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Protocol for a Trustworthy Language Model."""

from __future__ import annotations

from typing import Any, Dict, Protocol, Sequence, Union, runtime_checkable


@runtime_checkable
class TLM(Protocol):
def get_trustworthiness_score(
self,
prompt: Union[str, Sequence[str]],
response: Union[str, Sequence[str]],
**kwargs: Any,
) -> Dict[str, Any]: ...

def prompt(
self,
prompt: Union[str, Sequence[str]],
/,
**kwargs: Any,
) -> Dict[str, Any]: ...
71 changes: 71 additions & 0 deletions tests/test_codex_backup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from unittest.mock import MagicMock

import pytest

from cleanlab_codex.codex_backup import CodexBackup

MOCK_BACKUP_RESPONSE = "This is a test response"
FALLBACK_MESSAGE = "Based on the available information, I cannot provide a complete answer to this question."
TEST_MESSAGE = "Hello, world!"


class MockApp:
def chat(self, user_message: str) -> str:
# Just echo the user message
return user_message


@pytest.fixture
def mock_app() -> MockApp:
return MockApp()


def test_codex_backup(mock_app: MockApp) -> None:
# Create a mock project directly
mock_project = MagicMock()
mock_project.query.return_value = (MOCK_BACKUP_RESPONSE,)

# Echo works well
query = TEST_MESSAGE
response = mock_app.chat(query)
assert response == query

# Backup works well for fallback responses
codex_backup = CodexBackup.from_project(mock_project)
query = FALLBACK_MESSAGE
response = mock_app.chat(query)
assert response == query
response = codex_backup.run(response, query=query)
assert response == MOCK_BACKUP_RESPONSE, f"Response was {response}"


def test_backup_handler(mock_app: MockApp) -> None:
mock_project = MagicMock()
mock_project.query.return_value = (MOCK_BACKUP_RESPONSE,)

mock_handler = MagicMock()
mock_handler.return_value = None

codex_backup = CodexBackup.from_project(mock_project, primary_system=mock_app, backup_handler=mock_handler)

query = TEST_MESSAGE
response = mock_app.chat(query)
assert response == query

response = codex_backup.run(response, query=query)
assert response == query, f"Response was {response}"

# Handler should not be called for good responses
assert mock_handler.call_count == 0

query = FALLBACK_MESSAGE
response = mock_app.chat(query)
assert response == query
response = codex_backup.run(response, query=query)
assert response == MOCK_BACKUP_RESPONSE, f"Response was {response}"

# Handler should be called for bad responses
assert mock_handler.call_count == 1
# The MockApp is the second argument to the handler, i.e. it has the necessary context
# to handle the new response
assert mock_handler.call_args.kwargs["primary_system"] == mock_app