From c5843c945a8c51ab76da692c7c0b64be9b2df59d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?El=C3=ADas=20Snorrason?= Date: Tue, 11 Feb 2025 05:01:03 +0000 Subject: [PATCH] formatting --- src/cleanlab_codex/codex_backup.py | 13 ++++++++----- src/cleanlab_codex/response_validation.py | 21 +++++++++++---------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/cleanlab_codex/codex_backup.py b/src/cleanlab_codex/codex_backup.py index d0f95d5..70cc1f0 100644 --- a/src/cleanlab_codex/codex_backup.py +++ b/src/cleanlab_codex/codex_backup.py @@ -119,11 +119,14 @@ def run( response, query=query, context=context, - config=cast(BadResponseDetectionConfig, { - "tlm": self._tlm, - "fallback_answer": self._fallback_answer, - **_is_bad_response_kwargs, - }), + config=cast( + BadResponseDetectionConfig, + { + "tlm": self._tlm, + "fallback_answer": self._fallback_answer, + **_is_bad_response_kwargs, + }, + ), ): return response diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py index c1c2eb9..f239397 100644 --- a/src/cleanlab_codex/response_validation.py +++ b/src/cleanlab_codex/response_validation.py @@ -21,21 +21,18 @@ def get_trustworthiness_score( prompt: Union[str, Sequence[str]], response: Union[str, Sequence[str]], **kwargs: Any, - ) -> Dict[str, Any]: - ... + ) -> Dict[str, Any]: ... def prompt( self, prompt: Union[str, Sequence[str]], /, **kwargs: Any, - ) -> Dict[str, Any]: - ... + ) -> Dict[str, Any]: ... TLM = _TLMProtocol - DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question." DEFAULT_PARTIAL_RATIO_THRESHOLD = 70 DEFAULT_TRUSTWORTHINESS_THRESHOLD = 0.5 @@ -53,6 +50,7 @@ class BadResponseDetectionConfig(TypedDict, total=False): unhelpfulness_confidence_threshold: Optional confidence threshold (0.0-1.0) for unhelpful classification tlm: TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks) """ + # Fallback check config fallback_answer: str partial_ratio_threshold: int @@ -67,6 +65,7 @@ class BadResponseDetectionConfig(TypedDict, total=False): # Shared config (for untrustworthiness and unhelpfulness checks) tlm: Optional[TLM] + def get_bad_response_config() -> BadResponseDetectionConfig: """Get the default configuration for bad response detection functions. @@ -119,11 +118,13 @@ def is_bad_response( validation_checks: list[Callable[[], bool]] = [] # All required inputs are available for checking fallback responses - validation_checks.append(lambda: is_fallback_response( - response, - cfg["fallback_answer"], - threshold=cfg["partial_ratio_threshold"], - )) + validation_checks.append( + lambda: is_fallback_response( + response, + cfg["fallback_answer"], + threshold=cfg["partial_ratio_threshold"], + ) + ) can_run_untrustworthy_check = query is not None and context is not None and cfg["tlm"] is not None if can_run_untrustworthy_check: