Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
henchaves authored Oct 31, 2024
2 parents d1b240b + 4f88d80 commit f9cceda
Show file tree
Hide file tree
Showing 14 changed files with 2,153 additions and 2,078 deletions.
11 changes: 11 additions & 0 deletions giskard/rag/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,17 @@ def get_failure_plot(self, question_evaluation: Sequence[dict] = None):
def get_random_document(self):
return self._rng.choice(self._documents)

def get_random_documents(self, n: int, with_replacement=False):
if with_replacement:
return list(self._rng.choice(self._documents, n, replace=True))

docs = list(self._rng.choice(self._documents, min(n, len(self._documents)), replace=False))

if len(docs) <= n:
docs.extend(self._rng.choice(self._documents, n - len(docs), replace=True))

return docs

def get_neighbors(self, seed_document: Document, n_neighbors: int = 4, similarity_threshold: float = 0.2):
seed_embedding = seed_document.embeddings

Expand Down
2 changes: 1 addition & 1 deletion giskard/rag/metrics/ragas_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
try:
from langchain_core.outputs import LLMResult
from langchain_core.outputs.generation import Generation
from langchain_core.prompt_values import PromptValue
from ragas.embeddings import BaseRagasEmbeddings
from ragas.llms import BaseRagasLLM
from ragas.llms.prompt import PromptValue
from ragas.metrics import answer_relevancy, context_precision, context_recall, faithfulness
from ragas.metrics.base import Metric as BaseRagasMetric
from ragas.run_config import RunConfig
Expand Down
6 changes: 4 additions & 2 deletions giskard/rag/question_generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ class GenerateFromSingleQuestionMixin:
_question_type: str

def generate_questions(self, knowledge_base: KnowledgeBase, num_questions: int, *args, **kwargs) -> Iterator[Dict]:
for _ in range(num_questions):
docs = knowledge_base.get_random_documents(num_questions)

for doc in docs:
try:
yield self.generate_single_question(knowledge_base, *args, **kwargs)
yield self.generate_single_question(knowledge_base, *args, **kwargs, seed_document=doc)
except Exception as e: # @TODO: specify exceptions
logger.error(f"Encountered error in question generation: {e}. Skipping.")
logger.exception(e)
Expand Down
4 changes: 2 additions & 2 deletions giskard/rag/question_generators/double_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ class DoubleQuestionsGenerator(GenerateFromSingleQuestionMixin, _LLMBasedQuestio
_question_type = "double"

def generate_single_question(
self, knowledge_base: KnowledgeBase, agent_description: str, language: str
self, knowledge_base: KnowledgeBase, agent_description: str, language: str, seed_document=None
) -> QuestionSample:
seed_document = knowledge_base.get_random_document()
seed_document = seed_document or knowledge_base.get_random_document()
context_documents = knowledge_base.get_neighbors(
seed_document, self._context_neighbors, self._context_similarity_threshold
)
Expand Down
4 changes: 2 additions & 2 deletions giskard/rag/question_generators/oos_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class OutOfScopeGenerator(GenerateFromSingleQuestionMixin, _LLMBasedQuestionGene
_question_type = "out of scope"

def generate_single_question(
self, knowledge_base: KnowledgeBase, agent_description: str, language: str
self, knowledge_base: KnowledgeBase, agent_description: str, language: str, seed_document=None
) -> QuestionSample:
"""
Generate a question from a list of context documents.
Expand All @@ -87,7 +87,7 @@ def generate_single_question(
Tuple[dict, dict]
The generated question and the metadata of the question.
"""
seed_document = knowledge_base.get_random_document()
seed_document = seed_document or knowledge_base.get_random_document()

context_documents = knowledge_base.get_neighbors(
seed_document, self._context_neighbors, self._context_similarity_threshold
Expand Down
11 changes: 9 additions & 2 deletions giskard/rag/question_generators/simple_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ class SimpleQuestionsGenerator(GenerateFromSingleQuestionMixin, _LLMBasedQuestio

_question_type = "simple"

def generate_single_question(self, knowledge_base: KnowledgeBase, agent_description: str, language: str) -> dict:
def generate_single_question(
self,
knowledge_base: KnowledgeBase,
agent_description: str,
language: str,
seed_document=None,
) -> dict:
"""
Generate a question from a list of context documents.
Expand All @@ -80,7 +86,8 @@ def generate_single_question(self, knowledge_base: KnowledgeBase, agent_descript
QuestionSample
The generated question and the metadata of the question.
"""
seed_document = knowledge_base.get_random_document()
seed_document = seed_document or knowledge_base.get_random_document()

context_documents = knowledge_base.get_neighbors(
seed_document, self._context_neighbors, self._context_similarity_threshold
)
Expand Down
4 changes: 2 additions & 2 deletions giskard/rag/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def save(self, folder_path: str):
path = Path(folder_path)
path.mkdir(exist_ok=True, parents=True)
self.to_html(path / "report.html")
self._testset.save(path / "testset.json")
self._testset.save(path / "testset.jsonl")

report_details = {"recommendation": self._recommendation}
with open(path / "report_details.json", "w", encoding="utf-8") as f:
Expand Down Expand Up @@ -195,7 +195,7 @@ def load(
path = Path(folder_path)
knowledge_base_meta = json.load(open(path / "knowledge_base_meta.json", "r"))
knowledge_base_data = pd.read_json(path / "knowledge_base.jsonl", orient="records", lines=True)
testset = QATestset.load(path / "testset.json")
testset = QATestset.load(path / "testset.jsonl")

answers = json.load(open(path / "agent_answer.json", "r"))
model_outputs = [AgentAnswer(**answer) for answer in answers]
Expand Down
Loading

0 comments on commit f9cceda

Please sign in to comment.