-
Notifications
You must be signed in to change notification settings - Fork 828
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: adding an implementation of abstractQA (#1359)
Once you have a kg with #1352 you can run this like ```py from ragas.experimental.testset.generators.abstract import AbstractGenerator abstract_qa = AbstractGenerator() dist = await abstract_qa.generate_distributions(n=10, knowledge_graph=kg) q = await abstract_qa.generate_user_input(dist[0]) ``` merge after merging #1352
- Loading branch information
Showing
21 changed files
with
590 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .abstract import AbstractGenerator | ||
|
||
__all__ = ["AbstractGenerator"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import logging | ||
import math | ||
import random | ||
import typing as t | ||
from dataclasses import dataclass, field | ||
|
||
from ragas.executor import run_async_batch | ||
from ragas.experimental.prompt import PydanticPrompt, StringIO | ||
from ragas.experimental.testset.generators.base import ( | ||
BaseSimulator, | ||
BasicScenario, | ||
UserInputLength, | ||
UserInputStyle, | ||
) | ||
from ragas.experimental.testset.generators.prompts import ( | ||
AbstractQuestionFromTheme, | ||
CommonThemeFromSummaries, | ||
CriticUserInput, | ||
GenerateReference, | ||
ModifyUserInput, | ||
Summaries, | ||
ThemeAndContext, | ||
Themes, | ||
UserInputAndContext, | ||
UserInputWithStyleAndLength, | ||
extend_modify_input_prompt, | ||
) | ||
from ragas.experimental.testset.graph import KnowledgeGraph, Node | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AbstractQAScenario(BasicScenario): | ||
theme: str | ||
|
||
|
||
@dataclass | ||
class AbstractGenerator(BaseSimulator): | ||
generate_user_input_prompt: PydanticPrompt = field( | ||
default_factory=AbstractQuestionFromTheme | ||
) | ||
critic_user_input_prompt: PydanticPrompt = field(default_factory=CriticUserInput) | ||
user_input_modification_prompt: PydanticPrompt = field( | ||
default_factory=ModifyUserInput | ||
) | ||
generate_reference_prompt: PydanticPrompt = field(default_factory=GenerateReference) | ||
|
||
def __post_init__(self): | ||
self.common_theme_prompt = CommonThemeFromSummaries() | ||
|
||
async def generate_scenarios( | ||
self, n: int, knowledge_graph: KnowledgeGraph | ||
) -> t.List[AbstractQAScenario]: | ||
node_clusters = knowledge_graph.find_clusters( | ||
relationship_condition=lambda rel: ( | ||
True if rel.get_property("cosine_similarity") else False | ||
) | ||
) | ||
logger.info("found %d clusters", len(node_clusters)) | ||
|
||
# filter out nodes that are not chunks | ||
node_clusters = [ | ||
cluster | ||
for cluster in node_clusters | ||
if all(node.type == "chunk" for node in cluster) | ||
] | ||
|
||
# find the number of themes to generation for given n and the num of clusters | ||
# will generate more themes just in case | ||
num_clusters = len(node_clusters) | ||
num_themes = math.ceil(n / num_clusters) | ||
logger.info("generating %d themes", num_themes) | ||
|
||
kw_list = [] | ||
for cluster in node_clusters: | ||
summaries = [] | ||
for node in cluster: | ||
summary = node.get_property("summary") | ||
if summary is not None: | ||
summaries.append(summary) | ||
|
||
summaries = Summaries( | ||
summaries=summaries, | ||
num_themes=num_themes, | ||
) | ||
kw_list.append({"data": summaries, "llm": self.llm}) | ||
|
||
themes: t.List[Themes] = run_async_batch( | ||
desc="Generating common themes", | ||
func=self.common_theme_prompt.generate, | ||
kwargs_list=kw_list, | ||
) | ||
|
||
# sample clusters and themes to get num_clusters * num_themes | ||
clusters_sampled = [] | ||
themes_sampled = [] | ||
themes_list = [theme.themes for theme in themes] | ||
for cluster, ts in zip(node_clusters, themes_list): | ||
for theme in ts: | ||
themes_sampled.append(theme) | ||
clusters_sampled.append(cluster) | ||
|
||
# sample question styles and question lengths | ||
question_styles = random.choices( | ||
list(UserInputStyle), k=num_clusters * num_themes | ||
) | ||
question_lengths = random.choices( | ||
list(UserInputLength), k=num_clusters * num_themes | ||
) | ||
|
||
# create distributions | ||
distributions = [] | ||
for cluster, theme, style, length in zip( | ||
clusters_sampled, themes_sampled, question_styles, question_lengths | ||
): | ||
distributions.append( | ||
AbstractQAScenario( | ||
theme=theme.theme, | ||
nodes=cluster, | ||
style=style, | ||
length=length, | ||
) | ||
) | ||
return distributions | ||
|
||
async def generate_user_input(self, scenario: AbstractQAScenario) -> str: | ||
question = await self.generate_user_input_prompt.generate( | ||
data=ThemeAndContext( | ||
theme=scenario.theme, | ||
context=self.make_source_text(scenario), | ||
), | ||
llm=self.llm, | ||
) | ||
return question.text | ||
|
||
async def critic_user_input(self, user_input: str) -> bool: | ||
critic = await self.critic_user_input_prompt.generate( | ||
data=StringIO(text=user_input), llm=self.llm | ||
) | ||
return critic.independence > 1 and critic.clear_intent > 1 | ||
|
||
async def modify_user_input( | ||
self, user_input: str, scenario: AbstractQAScenario | ||
) -> str: | ||
prompt = extend_modify_input_prompt( | ||
question_modification_prompt=self.user_input_modification_prompt, | ||
style=scenario.style, | ||
length=scenario.length, | ||
) | ||
modified_question = await prompt.generate( | ||
data=UserInputWithStyleAndLength( | ||
user_input=user_input, | ||
style=scenario.style, | ||
length=scenario.length, | ||
), | ||
llm=self.llm, | ||
) | ||
return modified_question.text | ||
|
||
async def generate_reference(self, user_input: str, chunks: t.List[Node]) -> str: | ||
reference = await self.generate_reference_prompt.generate( | ||
data=UserInputAndContext( | ||
user_input=user_input, | ||
context=self.make_source_text(chunks), | ||
), | ||
llm=self.llm, | ||
) | ||
return reference.text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import typing as t | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass, field | ||
from enum import Enum | ||
|
||
from pydantic import BaseModel | ||
|
||
from ragas.experimental.testset.graph import KnowledgeGraph, Node | ||
from ragas.llms import BaseRagasLLM, llm_factory | ||
|
||
|
||
class UserInputLength(str, Enum): | ||
LONG = "long" | ||
MEDIUM = "medium" | ||
SHORT = "short" | ||
|
||
|
||
class UserInputStyle(str, Enum): | ||
MISSPELLED = "Misspelled queries" | ||
PERFECT_GRAMMAR = "Perfect grammar" | ||
POOR_GRAMMAR = "Poor grammar" | ||
WEB_SEARCH_LIKE = "Web search like queries" | ||
|
||
|
||
class BasicScenario(BaseModel): | ||
nodes: t.List[Node] | ||
style: UserInputStyle | ||
length: UserInputLength | ||
|
||
|
||
Scenario = t.TypeVar("Scenario", bound=BasicScenario) | ||
|
||
|
||
@dataclass | ||
class BaseSimulator(ABC, t.Generic[Scenario]): | ||
llm: BaseRagasLLM = field(default_factory=llm_factory) | ||
|
||
@abstractmethod | ||
async def generate_user_input( | ||
self, | ||
scenario: Scenario, | ||
) -> str: | ||
pass | ||
|
||
@abstractmethod | ||
async def generate_reference(self, user_input: str, chunks: t.List[Node]) -> str: | ||
pass | ||
|
||
@abstractmethod | ||
async def critic_user_input(self, user_input: str) -> bool: | ||
pass | ||
|
||
@abstractmethod | ||
async def modify_user_input(self, user_input: str, scenario: Scenario) -> str: | ||
pass | ||
|
||
@abstractmethod | ||
async def generate_scenarios( | ||
self, n: int, knowledge_graph: KnowledgeGraph | ||
) -> t.List[Scenario]: | ||
pass | ||
|
||
@staticmethod | ||
def make_source_text(scenario: Scenario) -> str: | ||
page_contents = [] | ||
for node in scenario.nodes: | ||
page_contents.append(node.get_property("page_content")) | ||
return "\n\n".join(page_contents) |
Oops, something went wrong.