Skip to content

Commit

Permalink
fix: improved answer relevancy (#346)
Browse files Browse the repository at this point in the history
  • Loading branch information
shahules786 authored Dec 8, 2023
1 parent af55f18 commit 0430e8f
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 35 deletions.
1 change: 1 addition & 0 deletions src/ragas/llms/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def isOpenAI(llm: BaseLLM | BaseChatModel) -> bool:
def isBedrock(llm: BaseLLM | BaseChatModel) -> bool:
return isinstance(llm, Bedrock) or isinstance(llm, BedrockChat)


def isAmazonAPIGateway(llm: BaseLLM | BaseChatModel) -> bool:
return isinstance(llm, AmazonAPIGateway)

Expand Down
1 change: 1 addition & 0 deletions src/ragas/llms/llamaindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
if t.TYPE_CHECKING:
from langchain.callbacks.base import Callbacks
from langchain.prompts import ChatPromptTemplate

try:
from llama_index.llms.base import LLM as LiLLM
except ImportError:
Expand Down
72 changes: 54 additions & 18 deletions src/ragas/metrics/_answer_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ragas.embeddings.base import embedding_factory
from ragas.exceptions import OpenAIKeyNotFound
from ragas.metrics.base import EvaluationMode, MetricWithLLM
from ragas.utils import load_as_json

if t.TYPE_CHECKING:
from langchain.callbacks.manager import CallbackManager
Expand All @@ -21,13 +22,46 @@

QUESTION_GEN = HumanMessagePromptTemplate.from_template(
"""
Generate question for the given answer.
Answer:\nThe PSLV-C56 mission is scheduled to be launched on Sunday, 30 July 2023 at 06:30 IST / 01:00 UTC. It will be launched from the Satish Dhawan Space Centre, Sriharikota, Andhra Pradesh, India
Question: When is the scheduled launch date and time for the PSLV-C56 mission, and where will it be launched from?
Generate a question for the given answer and Identify if answer is noncommittal
Answer:{answer}
Question:
""" # noqa: E501
Answer:
Albert Einstein was born in Germany.
Context:
Albert Einstein was a German-born theoretical physicist who is widely held to be one of the greatest and most influential scientists of all time
Output:
{{"question":"Where was Albert Einstein born?","noncommittal":false}}
Answer:
It can change its skin color based on the temperature of its environment.
Context:
A recent scientific study has discovered a new species of frog in the Amazon rainforest that has the unique ability to change its skin color based on the temperature of its environment.
Output:
{{"question":"What unique ability does the newly discovered species of frog have?","noncommittal":false}}
Answer:
Everest
Context:
The tallest mountain on Earth, measured from sea level, is a renowned peak located in the Himalayas.
Output:
{{"question":"What is the tallest mountain on Earth?","noncommittal":false}}
Answer:
I don't know about the groundbreaking feature of the smartphone invented in 2023 as am unware of information beyong 2022.
Context:
In 2023, a groundbreaking invention was announced: a smartphone with a battery life of one month, revolutionizing the way people use mobile technology.
Output:
{{"question":"What was the groundbreaking feature of the smartphone invented in 2023?", "noncommittal":true}}
Answer:
{answer}
Context:
{context}
Output:""" # noqa: E501
)


Expand All @@ -53,7 +87,7 @@ class AnswerRelevancy(MetricWithLLM):
"""

name: str = "answer_relevancy"
evaluation_mode: EvaluationMode = EvaluationMode.qa
evaluation_mode: EvaluationMode = EvaluationMode.qac
batch_size: int = 15
strictness: int = 3
embeddings: RagasEmbeddings = field(default_factory=embedding_factory)
Expand All @@ -71,29 +105,31 @@ def _score_batch(
callbacks: t.Optional[CallbackManager] = None,
callback_group_name: str = "batch",
) -> list[float]:
questions, answers = dataset["question"], dataset["answer"]
questions, answers, contexts = (
dataset["question"],
dataset["answer"],
dataset["contexts"],
)
with trace_as_chain_group(
callback_group_name, callback_manager=callbacks
) as batch_group:
prompts = []
for ans in answers:
human_prompt = QUESTION_GEN.format(answer=ans)
for ans, ctx in zip(answers, contexts):
human_prompt = QUESTION_GEN.format(answer=ans, context="\n".join(ctx))
prompts.append(ChatPromptTemplate.from_messages([human_prompt]))

results = self.llm.generate(
prompts,
n=self.strictness,
callbacks=batch_group,
)
results = [[i.text for i in r] for r in results.generations]

results = [[load_as_json(i.text) for i in r] for r in results.generations]
scores = []
for question, gen_questions in zip(questions, results):
if question is not None and question != "" and len(gen_questions) > 0:
cosine_sim = self.calculate_similarity(question, gen_questions)
scores.append(cosine_sim.mean())
else:
scores.append(0.0)
for question, result in zip(questions, results):
gen_questions = [item.get("question", "") for item in result]
committal = np.any([item.get("noncommittal", False) for item in result])
cosine_sim = self.calculate_similarity(question, gen_questions)
scores.append(cosine_sim.mean() * int(not committal))

return scores

Expand Down
31 changes: 21 additions & 10 deletions src/ragas/testset/testset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,16 @@
"conditional": "_condition_question",
}

DataRow = namedtuple("DataRow", ["question", "ground_truth_context", "ground_truth", "question_type", "episode_done"])
DataRow = namedtuple(
"DataRow",
[
"question",
"ground_truth_context",
"ground_truth",
"question_type",
"episode_done",
],
)


@dataclass
Expand All @@ -73,11 +82,11 @@ def to_pandas(self) -> pd.DataFrame:
data_samples = []
for data in self.test_data:
data = {
"question": data.question,
"ground_truth_context": data.ground_truth_context,
"ground_truth": data.ground_truth,
"question_type": data.question_type,
"episode_done": data.episode_done,
"question": data.question,
"ground_truth_context": data.ground_truth_context,
"ground_truth": data.ground_truth,
"question_type": data.question_type,
"episode_done": data.episode_done,
}
data_samples.append(data)

Expand Down Expand Up @@ -394,11 +403,13 @@ def generate(
context = self._generate_context(question, text_chunk)
is_conv = len(context) > 1
answer = self._generate_answer(question, context)
for i, (qstn, ctx, ans) in enumerate(zip(question.split("\n"), context, answer)):
episode_done = False if is_conv and i==0 else True
for i, (qstn, ctx, ans) in enumerate(
zip(question.split("\n"), context, answer)
):
episode_done = False if is_conv and i == 0 else True
samples.append(
DataRow(qstn, [ctx], [ans], evolve_type, episode_done)
)
DataRow(qstn, [ctx], [ans], evolve_type, episode_done)
)
count += 1
pbar.update(count)

Expand Down
6 changes: 2 additions & 4 deletions src/ragas/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ def remap_column_names(dataset: Dataset, column_map: dict[str, str]) -> Dataset:
"""
Remap the column names in case dataset uses different column names
"""

inverse_column_map = {v: k for k, v in column_map.items()}
return dataset.rename_columns(
inverse_column_map
)
return dataset.rename_columns(inverse_column_map)


def validate_column_dtypes(ds: Dataset):
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ def test_column_remap(column_map):
}
)
remapped_dataset = remap_column_names(TEST_DATASET, column_map)
assert all(
col in remapped_dataset.column_names for col in column_map.keys()
)
assert all(col in remapped_dataset.column_names for col in column_map.keys())


def test_column_remap_omit():
Expand Down

0 comments on commit 0430e8f

Please sign in to comment.