Skip to content

Commit

Permalink
add json retry
Browse files Browse the repository at this point in the history
  • Loading branch information
shahules786 committed Dec 7, 2023
1 parent 2487430 commit f3e6068
Showing 1 changed file with 123 additions and 0 deletions.
123 changes: 123 additions & 0 deletions src/ragas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,17 @@

import json
import os
import typing as t
import warnings
from dataclasses import dataclass
from functools import lru_cache

from langchain.callbacks.manager import CallbackManager, trace_as_chain_group
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate

if t.TYPE_CHECKING:
from ragas.llms import RagasLLM

DEBUG_ENV_VAR = "RAGAS_DEBUG"
# constant to tell us that there is no key passed to the llm/embeddings
NO_KEY = "no-key"
Expand All @@ -29,3 +37,118 @@ def load_as_json(text):
warnings.warn(f"Invalid json: {e}")

return {}


JSON_PROMPT = HumanMessagePromptTemplate.from_template(
"""
Rewrite the input into valid json
Input:
{{
"name": "John Doe",
"age": 30,
"isStudent": false
"address": {{
"street": "123 Main St",
"city": "Anytown",
"state": "CA",
}}
"hobbies": ["reading", "swimming", "cycling"]
}}
Output:
{{
"name": "John Doe",
"age": 30,
"isStudent": false,
"address": {{
"street": "123 Main St",
"city": "Anytown",
"state": "CA"
}},
"hobbies": ["reading", "swimming", "cycling"]
}}
Input:
{{
"statement": "The Earth is also known as "Terra" "
}}
Output:
{{
"statement": "The Earth is also known as 'Terra'"
}}
Input:
{input}
Output:
"""
)


@dataclass
class JsonLoader:
max_retries: int = 2

def safe_load(self, text: str, llm: RagasLLM):
retry = 0
while retry <= self.max_retries:
try:
start, end = self._find_outermost_json(text)
return json.loads(text[start:end])
except ValueError:
text = self._fix_to_json(text, llm)
retry -= 1

return {}

def _fix_to_json(
self,
text,
llm,
callbacks: t.Optional[CallbackManager] = None,
callback_group_name: str = "batch",
):
with trace_as_chain_group(
callback_group_name, callback_manager=callbacks
) as batch_group:
human_prompt = ChatPromptTemplate.from_messages(
[JSON_PROMPT.format(input=text)]
)
results = llm.generate(
[human_prompt],
n=1,
callbacks=batch_group,
)
return results.generations[0][0].text

def _find_outermost_json(self, text):
stack = []
start_index = -1

for i, char in enumerate(text):
if char in "{[":
if len(stack) == 0:
start_index = i
stack.append(char)

elif char in "}]":
if len(stack) > 0:
last = stack.pop()
if (char == "}" and last != "{") or (char == "]" and last != "["):
# Mismatched closing brace/bracket, invalid JSON
break

if len(stack) == 0 and start_index != -1:
# Found a valid outermost JSON
return (
start_index,
i + 1,
) # Add 1 to include the closing brace/bracket in the range

return -1, -1 # No valid JSON found


json_loader = JsonLoader()

0 comments on commit f3e6068

Please sign in to comment.