diff --git a/lightrag/llm.py b/lightrag/llm.py index 6a191a0f..12a4d5a6 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -29,7 +29,11 @@ from pydantic import BaseModel, Field from typing import List, Dict, Callable, Any from .base import BaseKVStorage -from .utils import compute_args_hash, wrap_embedding_func_with_attrs +from .utils import ( + compute_args_hash, + wrap_embedding_func_with_attrs, + locate_json_string_body_from_string, +) os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -66,9 +70,14 @@ async def openai_complete_if_cache( if if_cache_return is not None: return if_cache_return["return"] - response = await openai_async_client.chat.completions.create( - model=model, messages=messages, **kwargs - ) + if "response_format" in kwargs: + response = await openai_async_client.beta.chat.completions.parse( + model=model, messages=messages, **kwargs + ) + else: + response = await openai_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) content = response.choices[0].message.content if r"\u" in content: content = content.encode("utf-8").decode("unicode_escape") @@ -301,7 +310,7 @@ async def ollama_model_if_cache( model, prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: kwargs.pop("max_tokens", None) - kwargs.pop("response_format", None) + # kwargs.pop("response_format", None) # allow json host = kwargs.pop("host", None) timeout = kwargs.pop("timeout", None) @@ -345,9 +354,9 @@ def initialize_lmdeploy_pipeline( backend_config=TurbomindEngineConfig( tp=tp, model_format=model_format, quant_policy=quant_policy ), - chat_template_config=ChatTemplateConfig(model_name=chat_template) - if chat_template - else None, + chat_template_config=( + ChatTemplateConfig(model_name=chat_template) if chat_template else None + ), log_level="WARNING", ) return lmdeploy_pipe @@ -458,9 +467,16 @@ async def lmdeploy_model_if_cache( return response +class GPTKeywordExtractionFormat(BaseModel): + high_level_keywords: List[str] + low_level_keywords: List[str] + + async def gpt_4o_complete( - prompt, system_prompt=None, history_messages=[], **kwargs + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: + if keyword_extraction: + kwargs["response_format"] = GPTKeywordExtractionFormat return await openai_complete_if_cache( "gpt-4o", prompt, @@ -471,8 +487,10 @@ async def gpt_4o_complete( async def gpt_4o_mini_complete( - prompt, system_prompt=None, history_messages=[], **kwargs + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: + if keyword_extraction: + kwargs["response_format"] = GPTKeywordExtractionFormat return await openai_complete_if_cache( "gpt-4o-mini", prompt, @@ -483,45 +501,56 @@ async def gpt_4o_mini_complete( async def azure_openai_complete( - prompt, system_prompt=None, history_messages=[], **kwargs + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: - return await azure_openai_complete_if_cache( + result = await azure_openai_complete_if_cache( "conversation-4o-mini", prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, ) + if keyword_extraction: # TODO: use JSON API + return locate_json_string_body_from_string(result) + return result async def bedrock_complete( - prompt, system_prompt=None, history_messages=[], **kwargs + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: - return await bedrock_complete_if_cache( + result = await bedrock_complete_if_cache( "anthropic.claude-3-haiku-20240307-v1:0", prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, ) + if keyword_extraction: # TODO: use JSON API + return locate_json_string_body_from_string(result) + return result async def hf_model_complete( - prompt, system_prompt=None, history_messages=[], **kwargs + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: model_name = kwargs["hashing_kv"].global_config["llm_model_name"] - return await hf_model_if_cache( + result = await hf_model_if_cache( model_name, prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, ) + if keyword_extraction: # TODO: use JSON API + return locate_json_string_body_from_string(result) + return result async def ollama_model_complete( - prompt, system_prompt=None, history_messages=[], **kwargs + prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: + if keyword_extraction: + kwargs["format"] = "json" model_name = kwargs["hashing_kv"].global_config["llm_model_name"] return await ollama_model_if_cache( model_name, diff --git a/lightrag/operate.py b/lightrag/operate.py index 94cd412b..74c8e5f6 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -17,7 +17,6 @@ split_string_by_multi_markers, truncate_list_by_token_size, process_combine_contexts, - locate_json_string_body_from_string, ) from .base import ( BaseGraphStorage, @@ -461,12 +460,12 @@ async def kg_query( use_model_func = global_config["llm_model_func"] kw_prompt_temp = PROMPTS["keywords_extraction"] kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language) - result = await use_model_func(kw_prompt) + result = await use_model_func(kw_prompt, keyword_extraction=True) logger.info("kw_prompt result:") print(result) try: - json_text = locate_json_string_body_from_string(result) - keywords_data = json.loads(json_text) + # json_text = locate_json_string_body_from_string(result) # handled in use_model_func + keywords_data = json.loads(result) hl_keywords = keywords_data.get("high_level_keywords", []) ll_keywords = keywords_data.get("low_level_keywords", []) diff --git a/lightrag/utils.py b/lightrag/utils.py index bdd1aa9e..8997b651 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -54,7 +54,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]: maybe_json_str = maybe_json_str.replace("\\n", "") maybe_json_str = maybe_json_str.replace("\n", "") maybe_json_str = maybe_json_str.replace("'", '"') - json.loads(maybe_json_str) + # json.loads(maybe_json_str) # don't check here, cannot validate schema after all return maybe_json_str except Exception: pass