diff --git a/libs/community/langchain_community/tools/zenguard/__init__.py b/libs/community/langchain_community/tools/zenguard/__init__.py index 579f58003f2ff..398d14dbc3869 100644 --- a/libs/community/langchain_community/tools/zenguard/__init__.py +++ b/libs/community/langchain_community/tools/zenguard/__init__.py @@ -1,16 +1,10 @@ -from importlib import metadata - -from langchain_zenguard.tools import Detector, ZenGuardInput, ZenGuardTool - -try: - __version__ = metadata.version(__package__) -except metadata.PackageNotFoundError: - # Case where package metadata is not available. - __version__ = "" -del metadata # optional, avoids polluting the results of dir(__package__) +from langchain_community.tools.zenguard.tools import ( + Detector, + ZenGuardInput, + ZenGuardTool, +) __all__ = [ - "__version__", "ZenGuardTool", "Detector", "ZenGuardInput", diff --git a/libs/community/langchain_community/tools/zenguard/tools.py b/libs/community/langchain_community/tools/zenguard/tools.py index 69c4a62a5acac..05fe91f0c1dd1 100644 --- a/libs/community/langchain_community/tools/zenguard/tools.py +++ b/libs/community/langchain_community/tools/zenguard/tools.py @@ -16,6 +16,7 @@ class Detector(str, Enum): SECRETS = "secrets" TOXICITY = "toxicity" + class DetectorAPI(str, Enum): ALLOWED_TOPICS = "v1/detect/topics/allowed" BANNED_TOPICS = "v1/detect/topics/banned" @@ -25,6 +26,7 @@ class DetectorAPI(str, Enum): SECRETS = "v1/detect/secrets" TOXICITY = "v1/detect/toxicity" + class ZenGuardInput(BaseModel): prompts: List[str] = Field( ..., @@ -42,6 +44,7 @@ class ZenGuardInput(BaseModel): description="Run prompt detection by the detector in parallel or sequentially", ) + class ZenGuardTool(BaseTool): name = "ZenGuard" description = ( @@ -52,11 +55,11 @@ class ZenGuardTool(BaseTool): zenguard_api_key: Optional[str] = Field(default=None) - _ZENGUARD_API_URL_ROOT = "https://api.zenguard.ai/" + _ZENGUARD_API_URL_ROOT = "https://dummyai-backend-gwlrf6iakq-uc.a.run.app/" _ZENGUARD_API_KEY_ENV_NAME = "ZENGUARD_API_KEY" @validator("zenguard_api_key", pre=True, always=True, check_fields=False) - def set_api_key(cls, v): + def set_api_key(cls, v: str) -> str: if v is None: v = os.getenv(cls._ZENGUARD_API_KEY_ENV_NAME) if v is None: @@ -66,16 +69,16 @@ def set_api_key(cls, v): f"the f{cls._ZENGUARD_API_KEY_ENV_NAME} environment variable" ) return v - + def _run( - self, - prompts: List[str], - detectors: List[Detector], - in_parallel: bool = True, + self, + prompts: List[str], + detectors: List[Detector], + in_parallel: bool = True, ) -> Dict[str, Any]: try: postfix = None - json = None + json: Optional[Dict[str, Any]] = None if len(detectors) == 1: postfix = self._convert_detector_to_api(detectors[0]) json = {"messages": prompts} @@ -96,6 +99,6 @@ def _run( return response.json() except (requests.HTTPError, requests.Timeout) as e: return {"error": str(e)} - - def _convert_detector_to_api(self, detector: Detector): + + def _convert_detector_to_api(self, detector: Detector) -> str: return DetectorAPI[detector.name].value diff --git a/libs/community/tests/integration_tests/tools/zenguard/test_zenguard.py b/libs/community/tests/integration_tests/tools/zenguard/test_zenguard.py index 66cc38ca9489d..a46c71c6451c3 100644 --- a/libs/community/tests/integration_tests/tools/zenguard/test_zenguard.py +++ b/libs/community/tests/integration_tests/tools/zenguard/test_zenguard.py @@ -11,6 +11,7 @@ def zenguard_tool(): raise ValueError("ZENGUARD_API_KEY is not set in enviroment varibale") return ZenGuardTool() + def assert_successful_response_not_detected(response): assert response is not None assert "error" not in response, f"API returned an error: {response.get('error')}" @@ -20,13 +21,19 @@ def assert_successful_response_not_detected(response): def assert_detectors_response(response, detectors): assert response is not None for detector in detectors: - common_response = next(( - resp["common_response"] - for resp in response["responses"] - if resp["detector"] == detector.value - )) - assert "err" not in common_response, f"API returned an error: {common_response.get('err')}" # noqa: E501 - assert common_response.get("is_detected") is False, f"Prompt was detected: {common_response}" # noqa: E501 + common_response = next( + ( + resp["common_response"] + for resp in response["responses"] + if resp["detector"] == detector.value + ) + ) + assert ( + "err" not in common_response + ), f"API returned an error: {common_response.get('err')}" # noqa: E501 + assert ( + common_response.get("is_detected") is False + ), f"Prompt was detected: {common_response}" # noqa: E501 def test_prompt_injection(zenguard_tool): @@ -77,6 +84,7 @@ def test_toxicity(zenguard_tool): response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]}) assert_successful_response_not_detected(response) + def test_all_detectors(zenguard_tool): prompt = "Simple all detectors test" detectors = [