Skip to content

Commit

Permalink
fix lint typing and import errors
Browse files Browse the repository at this point in the history
  • Loading branch information
yaksh0nti committed Jun 15, 2024
1 parent faee3d3 commit 822873b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 28 deletions.
16 changes: 5 additions & 11 deletions libs/community/langchain_community/tools/zenguard/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
from importlib import metadata

from langchain_zenguard.tools import Detector, ZenGuardInput, ZenGuardTool

try:
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:
# Case where package metadata is not available.
__version__ = ""
del metadata # optional, avoids polluting the results of dir(__package__)
from langchain_community.tools.zenguard.tools import (
Detector,
ZenGuardInput,
ZenGuardTool,
)

__all__ = [
"__version__",
"ZenGuardTool",
"Detector",
"ZenGuardInput",
Expand Down
23 changes: 13 additions & 10 deletions libs/community/langchain_community/tools/zenguard/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Detector(str, Enum):
SECRETS = "secrets"
TOXICITY = "toxicity"


class DetectorAPI(str, Enum):
ALLOWED_TOPICS = "v1/detect/topics/allowed"
BANNED_TOPICS = "v1/detect/topics/banned"
Expand All @@ -25,6 +26,7 @@ class DetectorAPI(str, Enum):
SECRETS = "v1/detect/secrets"
TOXICITY = "v1/detect/toxicity"


class ZenGuardInput(BaseModel):
prompts: List[str] = Field(
...,
Expand All @@ -42,6 +44,7 @@ class ZenGuardInput(BaseModel):
description="Run prompt detection by the detector in parallel or sequentially",
)


class ZenGuardTool(BaseTool):
name = "ZenGuard"
description = (
Expand All @@ -52,11 +55,11 @@ class ZenGuardTool(BaseTool):

zenguard_api_key: Optional[str] = Field(default=None)

_ZENGUARD_API_URL_ROOT = "https://api.zenguard.ai/"
_ZENGUARD_API_URL_ROOT = "https://dummyai-backend-gwlrf6iakq-uc.a.run.app/"
_ZENGUARD_API_KEY_ENV_NAME = "ZENGUARD_API_KEY"

@validator("zenguard_api_key", pre=True, always=True, check_fields=False)
def set_api_key(cls, v):
def set_api_key(cls, v: str) -> str:
if v is None:
v = os.getenv(cls._ZENGUARD_API_KEY_ENV_NAME)
if v is None:
Expand All @@ -66,16 +69,16 @@ def set_api_key(cls, v):
f"the f{cls._ZENGUARD_API_KEY_ENV_NAME} environment variable"
)
return v

def _run(
self,
prompts: List[str],
detectors: List[Detector],
in_parallel: bool = True,
self,
prompts: List[str],
detectors: List[Detector],
in_parallel: bool = True,
) -> Dict[str, Any]:
try:
postfix = None
json = None
json: Optional[Dict[str, Any]] = None
if len(detectors) == 1:
postfix = self._convert_detector_to_api(detectors[0])
json = {"messages": prompts}
Expand All @@ -96,6 +99,6 @@ def _run(
return response.json()
except (requests.HTTPError, requests.Timeout) as e:
return {"error": str(e)}
def _convert_detector_to_api(self, detector: Detector):

def _convert_detector_to_api(self, detector: Detector) -> str:
return DetectorAPI[detector.name].value
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def zenguard_tool():
raise ValueError("ZENGUARD_API_KEY is not set in enviroment varibale")
return ZenGuardTool()


def assert_successful_response_not_detected(response):
assert response is not None
assert "error" not in response, f"API returned an error: {response.get('error')}"
Expand All @@ -20,13 +21,19 @@ def assert_successful_response_not_detected(response):
def assert_detectors_response(response, detectors):
assert response is not None
for detector in detectors:
common_response = next((
resp["common_response"]
for resp in response["responses"]
if resp["detector"] == detector.value
))
assert "err" not in common_response, f"API returned an error: {common_response.get('err')}" # noqa: E501
assert common_response.get("is_detected") is False, f"Prompt was detected: {common_response}" # noqa: E501
common_response = next(
(
resp["common_response"]
for resp in response["responses"]
if resp["detector"] == detector.value
)
)
assert (
"err" not in common_response
), f"API returned an error: {common_response.get('err')}" # noqa: E501
assert (
common_response.get("is_detected") is False
), f"Prompt was detected: {common_response}" # noqa: E501


def test_prompt_injection(zenguard_tool):
Expand Down Expand Up @@ -77,6 +84,7 @@ def test_toxicity(zenguard_tool):
response = zenguard_tool.run({"detectors": detectors, "prompts": [prompt]})
assert_successful_response_not_detected(response)


def test_all_detectors(zenguard_tool):
prompt = "Simple all detectors test"
detectors = [
Expand Down

0 comments on commit 822873b

Please sign in to comment.