-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support text2gql search for GraphRAG (#2227)
Co-authored-by: aries_ckt <[email protected]>
- Loading branch information
1 parent
9336e80
commit e0081e6
Showing
23 changed files
with
791 additions
and
165 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""Agentic ntentTranslator class.""" | ||
import logging | ||
|
||
from dbgpt.rag.transformer.base import TranslatorBase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AgenticIntentTranslator(TranslatorBase): | ||
"""Agentic ntentTranslator class.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
"""AwelIntentTranslator class.""" | ||
import logging | ||
|
||
from dbgpt.rag.transformer.base import TranslatorBase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AwelIntentTranslator(TranslatorBase): | ||
"""AwelIntentTranslator class.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
"""LLMTranslator class.""" | ||
|
||
import logging | ||
from abc import ABC, abstractmethod | ||
from typing import Dict, List | ||
|
||
from dbgpt.core import BaseMessage, LLMClient, ModelMessage, ModelRequest | ||
from dbgpt.rag.transformer.base import TranslatorBase | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LLMTranslator(TranslatorBase, ABC): | ||
"""LLMTranslator class.""" | ||
|
||
def __init__(self, llm_client: LLMClient, model_name: str, prompt_template: str): | ||
"""Initialize the LLMExtractor.""" | ||
self._llm_client = llm_client | ||
self._model_name = model_name | ||
self._prompt_template = prompt_template | ||
|
||
async def translate(self, text: str) -> Dict: | ||
"""Translate by LLM.""" | ||
messages = self._format_messages(text) | ||
return await self._translate(messages) | ||
|
||
async def _translate(self, messages: List[BaseMessage]) -> Dict: | ||
"""Inner translate by LLM.""" | ||
# use default model if needed | ||
if not self._model_name: | ||
models = await self._llm_client.models() | ||
if not models: | ||
raise Exception("No models available") | ||
self._model_name = models[0].model | ||
logger.info(f"Using model {self._model_name} to extract") | ||
|
||
model_messages = ModelMessage.from_base_messages(messages) | ||
request = ModelRequest(model=self._model_name, messages=model_messages) | ||
response = await self._llm_client.generate(request=request) | ||
|
||
if not response.success: | ||
code = str(response.error_code) | ||
reason = response.text | ||
logger.error(f"request llm failed ({code}) {reason}") | ||
return {} | ||
|
||
return self._parse_response(response.text) | ||
|
||
def truncate(self): | ||
"""Do nothing by default.""" | ||
|
||
def drop(self): | ||
"""Do nothing by default.""" | ||
|
||
@abstractmethod | ||
def _format_messages(self, text: str, history: str = None) -> List[BaseMessage]: | ||
"""Parse llm response.""" | ||
|
||
@abstractmethod | ||
def _parse_response(self, text: str) -> Dict: | ||
"""Parse llm response.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
"""SimpleIntentTranslator class.""" | ||
import json | ||
import logging | ||
import re | ||
from typing import Dict, List, Union | ||
|
||
from dbgpt.core import BaseMessage, HumanPromptTemplate, LLMClient | ||
from dbgpt.rag.transformer.llm_translator import LLMTranslator | ||
|
||
INTENT_INTERPRET_PT = """ | ||
A question is provided below. Given the question, analyze and classify it into one of the following categories: | ||
1. Single Entity Search: search for the detail of the given entity. | ||
2. One Hop Entity Search: given one entity and one relation, search for all entities that have the relation with the given entity. | ||
3. One Hop Relation Search: given two entities, serach for the relation between them. | ||
4. Two Hop Entity Search: given one entity and one relation, break that relation into two consecutive relation, then search all entities that have the two hop relation with the given entity. | ||
5. Freestyle Question: questions that are not in above four categories. Search all related entities and two-hop subgraphs centered on them. | ||
After classfied the given question, rewrite the question in a graph query language style, return the category of the given question, the rewrited question in json format. | ||
Also return entities and relations that might be used for query generation in json format. Here are some examples to guide your classification: | ||
--------------------- | ||
Example: | ||
Question: Introduce TuGraph. | ||
Return: | ||
{{"category": "Single Entity Search", rewritten_question": "Query the entity named TuGraph then return the entity.", entities": ["TuGraph"], "relations": []}} | ||
Question: Who commits code to TuGraph. | ||
Return: | ||
{{"category": "One Hop Entity Search", "rewritten_question": "Query all one hop paths that has a entity named TuGraph and a relation named commit, then return them.", "entities": ["TuGraph"], "relations": ["commit"]}} | ||
Question: What is the relation between Alex and TuGraph? | ||
Return: | ||
{{"category": "One Hop Relation Search", "rewritten_question": "Query all one hop paths between the entity named Alex and the entity named TuGraph, then return them.", "entities": ["Alex", "TuGraph"], "relations": []}} | ||
Question: Who is the colleague of Bob? | ||
Return: | ||
{{"category": "Two Hop Entity Search", "rewritten_question": "Query all entities that have a two hop path between them and the entity named Bob, both entities should have a work for relation with the middle entity.", "entities": ["Bob"], "relations": ["work for"]}} | ||
Question: Introduce TuGraph and DBGPT seperately. | ||
Return: | ||
{{"category": "Freestyle Question", "rewritten_question": "Query the entity named TuGraph and the entity named DBGPT, then return two-hop subgraphs centered on them.", "entities": ["TuGraph", "DBGPT"], "relations": []}} | ||
--------------------- | ||
Text: {text} | ||
Return: | ||
""" # noqa: E501 | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SimpleIntentTranslator(LLMTranslator): | ||
"""SimpleIntentTranslator class.""" | ||
|
||
def __init__(self, llm_client: LLMClient, model_name: str): | ||
"""Initialize the SimpleIntentTranslator.""" | ||
super().__init__(llm_client, model_name, INTENT_INTERPRET_PT) | ||
|
||
def _format_messages(self, text: str, history: str = None) -> List[BaseMessage]: | ||
# interprete intention with single prompt only. | ||
template = HumanPromptTemplate.from_template(self._prompt_template) | ||
|
||
messages: List[BaseMessage] = ( | ||
template.format_messages(text=text, history=history) | ||
if history is not None | ||
else template.format_messages(text=text) | ||
) | ||
|
||
return messages | ||
|
||
def truncate(self): | ||
"""Do nothing by default.""" | ||
|
||
def drop(self): | ||
"""Do nothing by default.""" | ||
|
||
def _parse_response(self, text: str) -> Dict: | ||
""" | ||
Parse llm response. | ||
The returned diction should contain the following content. | ||
{ | ||
"category": "Type of the given question.", | ||
"original_question: "The original question provided by user.", | ||
"rewritten_question": "Rewritten question in graph query language style." | ||
"entities": ["entities", "that", "might", "be", "used", "in", "query"], | ||
"relations" ["relations", "that", "might", "be", "used", "in", "query"] | ||
} | ||
""" | ||
code_block_pattern = re.compile(r"```json(.*?)```", re.S) | ||
json_pattern = re.compile(r"{.*?}", re.S) | ||
|
||
match_result = re.findall(code_block_pattern, text) | ||
if match_result: | ||
text = match_result[0] | ||
match_result = re.findall(json_pattern, text) | ||
if match_result: | ||
text = match_result[0] | ||
else: | ||
text = "" | ||
|
||
intention: Dict[str, Union[str, List[str]]] = {} | ||
intention = json.loads(text) | ||
if "category" not in intention: | ||
intention["category"] = "" | ||
if "original_question" not in intention: | ||
intention["original_question"] = "" | ||
if "rewritten_question" not in intention: | ||
intention["rewritten_question"] = "" | ||
if "entities" not in intention: | ||
intention["entities"] = [] | ||
if "relations" not in intention: | ||
intention["relations"] = [] | ||
|
||
return intention |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,106 @@ | ||
"""Text2GQL class.""" | ||
import json | ||
import logging | ||
import re | ||
from typing import Dict, List, Union | ||
|
||
from dbgpt.rag.transformer.base import TranslatorBase | ||
from dbgpt.core import BaseMessage, HumanPromptTemplate, LLMClient | ||
from dbgpt.rag.transformer.llm_translator import LLMTranslator | ||
|
||
TEXT_TO_GQL_PT = """ | ||
A question written in graph query language style is provided below. The category of this question, entities and relations that might be used in the cypher query are also provided. | ||
Given the question, translate the question into a cypher query that can be executed on the given knowledge graph. Make sure the syntax of the translated cypher query is correct. | ||
To help query generation, the schema of the knowledge graph is: | ||
{schema} | ||
--------------------- | ||
Example: | ||
Question: Query the entity named TuGraph then return the entity. | ||
Category: Single Entity Search | ||
entities: ["TuGraph"] | ||
relations: [] | ||
Query: | ||
Match (n) WHERE n.id="TuGraph" RETURN n | ||
Question: Query all one hop paths between the entity named Alex and the entity named TuGraph, then return them. | ||
Category: One Hop Entity Search | ||
entities: ["Alex", "TuGraph"] | ||
relations: [] | ||
Query: | ||
MATCH p=(n)-[r]-(m) WHERE n.id="Alex" AND m.id="TuGraph" RETURN p | ||
Question: Query all entities that have a two hop path between them and the entity named Bob, both entities should have a work for relation with the middle entity. | ||
Category: Two Hop Entity Search | ||
entities: ["Bob"] | ||
relations: ["work for"] | ||
Query: | ||
MATCH p=(n)-[r1]-(m)-[r2]-(l) WHERE n.id="Bob" AND r1.id="work for" AND r2.id="work for" RETURN p | ||
Question: Introduce TuGraph and DBGPT seperately. | ||
Category: Freestyle Question | ||
relations: [] | ||
Query: | ||
MATCH p=(n)-[r:relation*2]-(m) WHERE n.id IN ["TuGraph", "DB-GPT"] RETURN p | ||
--------------------- | ||
Question: {question} | ||
Category: {category} | ||
entities: {entities} | ||
relations: {relations} | ||
Query: | ||
""" # noqa: E501 | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Text2GQL(TranslatorBase): | ||
class Text2GQL(LLMTranslator): | ||
"""Text2GQL class.""" | ||
|
||
def __init__(self, llm_client: LLMClient, model_name: str): | ||
"""Initialize the Text2GQL.""" | ||
super().__init__(llm_client, model_name, TEXT_TO_GQL_PT) | ||
|
||
def _format_messages(self, text: str, history: str = None) -> List[BaseMessage]: | ||
# translate intention to gql with single prompt only. | ||
intention: Dict[str, Union[str, List[str]]] = json.loads(text) | ||
question = intention.get("rewritten_question", "") | ||
category = intention.get("category", "") | ||
entities = intention.get("entities", "") | ||
relations = intention.get("relations", "") | ||
schema = intention.get("schema", "") | ||
|
||
template = HumanPromptTemplate.from_template(self._prompt_template) | ||
|
||
messages = ( | ||
template.format_messages( | ||
schema=schema, | ||
question=question, | ||
category=category, | ||
entities=entities, | ||
relations=relations, | ||
history=history, | ||
) | ||
if history is not None | ||
else template.format_messages( | ||
schema=schema, | ||
question=question, | ||
category=category, | ||
entities=entities, | ||
relations=relations, | ||
) | ||
) | ||
|
||
return messages | ||
|
||
def _parse_response(self, text: str) -> Dict: | ||
"""Parse llm response.""" | ||
translation: Dict[str, str] = {} | ||
query = "" | ||
|
||
code_block_pattern = re.compile(r"```cypher(.*?)```", re.S) | ||
|
||
result = re.findall(code_block_pattern, text) | ||
if result: | ||
query = result[0] | ||
else: | ||
query = text | ||
|
||
translation["query"] = query.strip() | ||
|
||
return translation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.