Skip to content

Commit

Permalink
chores: cleanup metrics (#1348)
Browse files Browse the repository at this point in the history
  • Loading branch information
shahules786 authored Sep 24, 2024
1 parent 495bbd1 commit b5514c6
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 46 deletions.
69 changes: 54 additions & 15 deletions src/ragas/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,56 @@
import sys

from ragas.metrics._answer_correctness import AnswerCorrectness, answer_correctness
from ragas.metrics._answer_relevance import AnswerRelevancy, answer_relevancy
from ragas.metrics._answer_similarity import AnswerSimilarity, answer_similarity
from ragas.metrics._answer_relevance import (
AnswerRelevancy,
ResponseRelevancy,
answer_relevancy,
)
from ragas.metrics._answer_similarity import (
AnswerSimilarity,
SemanticSimilarity,
answer_similarity,
)
from ragas.metrics._aspect_critic import AspectCritic
from ragas.metrics._bleu_score import BleuScore
from ragas.metrics._context_entities_recall import (
ContextEntityRecall,
context_entity_recall,
)
from ragas.metrics._context_precision import (
ContextPrecision,
ContextUtilization,
LLMContextPrecisionWithoutReference,
NonLLMContextPrecisionWithReference,
context_precision,
context_utilization,
)
from ragas.metrics._context_recall import ContextRecall, context_recall
from ragas.metrics._context_recall import (
ContextRecall,
LLMContextRecall,
NonLLMContextRecall,
context_recall,
)
from ragas.metrics._datacompy_score import DataCompyScore
from ragas.metrics._domain_specific_rubrics import (
RubricsScoreWithoutReference,
RubricsScoreWithReference,
rubrics_score_with_reference,
rubrics_score_without_reference,
)
from ragas.metrics._factual_correctness import FactualCorrectness
from ragas.metrics._faithfulness import Faithfulness, FaithulnesswithHHEM, faithfulness
from ragas.metrics._noise_sensitivity import (
NoiseSensitivity,
noise_sensitivity_irrelevant,
noise_sensitivity_relevant,
from ragas.metrics._goal_accuracy import (
AgentGoalAccuracyWithoutReference,
AgentGoalAccuracyWithReference,
)
from ragas.metrics._instance_specific_rubrics import (
InstanceRubricsScoreWithoutReference,
InstanceRubricsWithReference,
)
from ragas.metrics._noise_sensitivity import NoiseSensitivity
from ragas.metrics._rogue_score import RougeScore
from ragas.metrics._sql_semantic_equivalence import LLMSQLEquivalence
from ragas.metrics._string import ExactMatch, NonLLMStringSimilarity, StringPresence, DistanceMeasure
from ragas.metrics._summarization import SummarizationScore, summarization_score
from ragas.metrics._tool_call_accuracy import ToolCallAccuracy

__all__ = [
"AnswerCorrectness",
Expand All @@ -41,7 +64,6 @@
"ContextPrecision",
"context_precision",
"ContextUtilization",
"context_utilization",
"ContextRecall",
"context_recall",
"AspectCritic",
Expand All @@ -52,12 +74,29 @@
"SummarizationScore",
"summarization_score",
"NoiseSensitivity",
"noise_sensitivity_irrelevant",
"noise_sensitivity_relevant",
"rubrics_score_with_reference",
"rubrics_score_without_reference",
"RubricsScoreWithoutReference",
"RubricsScoreWithReference",
"LLMContextPrecisionWithoutReference",
"NonLLMContextPrecisionWithReference",
"LLMContextPrecisionWithoutReference",
"LLMContextRecall",
"NonLLMContextRecall",
"FactualCorrectness",
"InstanceRubricsScoreWithoutReference",
"InstanceRubricsWithReference",
"NonLLMStringSimilarity",
"ExactMatch",
"StringPresence",
"BleuScore",
"RougeScore",
"DataCompyScore",
"LLMSQLEquivalence",
"AgentGoalAccuracyWithoutReference",
"AgentGoalAccuracyWithReference",
"ToolCallAccuracy",
"ResponseRelevancy",
"SemanticSimilarity",
"DistanceMeasure",
]

current_module = sys.modules[__name__]
Expand Down
23 changes: 14 additions & 9 deletions src/ragas/metrics/_answer_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@
from ragas.llms.prompt import PromptValue


class AnswerRelevanceClassification(BaseModel):
class ResponseRelevanceClassification(BaseModel):
question: str
noncommittal: int


_output_instructions = get_json_format_instructions(
pydantic_object=AnswerRelevanceClassification
pydantic_object=ResponseRelevanceClassification
)
_output_parser = RagasoutputParser(pydantic_object=AnswerRelevanceClassification)
_output_parser = RagasoutputParser(pydantic_object=ResponseRelevanceClassification)


QUESTION_GEN = Prompt(
Expand All @@ -44,7 +44,7 @@ class AnswerRelevanceClassification(BaseModel):
{
"answer": """Albert Einstein was born in Germany.""",
"context": """Albert Einstein was a German-born theoretical physicist who is widely held to be one of the greatest and most influential scientists of all time""",
"output": AnswerRelevanceClassification.parse_obj(
"output": ResponseRelevanceClassification.parse_obj(
{
"question": "Where was Albert Einstein born?",
"noncommittal": 0,
Expand All @@ -54,7 +54,7 @@ class AnswerRelevanceClassification(BaseModel):
{
"answer": """It can change its skin color based on the temperature of its environment.""",
"context": """A recent scientific study has discovered a new species of frog in the Amazon rainforest that has the unique ability to change its skin color based on the temperature of its environment.""",
"output": AnswerRelevanceClassification.parse_obj(
"output": ResponseRelevanceClassification.parse_obj(
{
"question": "What unique ability does the newly discovered species of frog have?",
"noncommittal": 0,
Expand All @@ -64,7 +64,7 @@ class AnswerRelevanceClassification(BaseModel):
{
"answer": """Everest""",
"context": """The tallest mountain on Earth, measured from sea level, is a renowned peak located in the Himalayas.""",
"output": AnswerRelevanceClassification.parse_obj(
"output": ResponseRelevanceClassification.parse_obj(
{
"question": "What is the tallest mountain on Earth?",
"noncommittal": 0,
Expand All @@ -74,7 +74,7 @@ class AnswerRelevanceClassification(BaseModel):
{
"answer": """I don't know about the groundbreaking feature of the smartphone invented in 2023 as am unaware of information beyond 2022. """,
"context": """In 2023, a groundbreaking invention was announced: a smartphone with a battery life of one month, revolutionizing the way people use mobile technology.""",
"output": AnswerRelevanceClassification.parse_obj(
"output": ResponseRelevanceClassification.parse_obj(
{
"question": "What was the groundbreaking feature of the smartphone invented in 2023?",
"noncommittal": 1,
Expand All @@ -89,7 +89,7 @@ class AnswerRelevanceClassification(BaseModel):


@dataclass
class AnswerRelevancy(MetricWithLLM, MetricWithEmbeddings, SingleTurnMetric):
class ResponseRelevancy(MetricWithLLM, MetricWithEmbeddings, SingleTurnMetric):
"""
Scores the relevancy of the answer according to the given question.
Answers with incomplete, redundant or unnecessary information is penalized.
Expand Down Expand Up @@ -139,7 +139,7 @@ def calculate_similarity(
)

def _calculate_score(
self, answers: t.Sequence[AnswerRelevanceClassification], row: t.Dict
self, answers: t.Sequence[ResponseRelevanceClassification], row: t.Dict
) -> float:
question = row["user_input"]
gen_questions = [answer.question for answer in answers]
Expand Down Expand Up @@ -197,4 +197,9 @@ def save(self, cache_dir: str | None = None) -> None:
self.question_generation.save(cache_dir)


class AnswerRelevancy(ResponseRelevancy):
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
return await super()._ascore(row, callbacks)


answer_relevancy = AnswerRelevancy()
7 changes: 6 additions & 1 deletion src/ragas/metrics/_answer_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


@dataclass
class AnswerSimilarity(MetricWithLLM, MetricWithEmbeddings, SingleTurnMetric):
class SemanticSimilarity(MetricWithLLM, MetricWithEmbeddings, SingleTurnMetric):
"""
Scores the semantic similarity of ground truth with generated answer.
cross encoder score is used to quantify semantic similarity.
Expand Down Expand Up @@ -91,4 +91,9 @@ async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
return score.tolist()[0]


class AnswerSimilarity(SemanticSimilarity):
async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
return await super()._ascore(row, callbacks)


answer_similarity = AnswerSimilarity()
3 changes: 1 addition & 2 deletions src/ragas/metrics/_context_entities_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ class ContextEntityRecall(MetricWithLLM, SingleTurnMetric):
context_entity_recall_prompt: Prompt = field(
default_factory=lambda: TEXT_ENTITY_EXTRACTION
)
batch_size: int = 15
max_retries: int = 1

def _compute_score(
Expand Down Expand Up @@ -195,4 +194,4 @@ def save(self, cache_dir: str | None = None) -> None:
return self.context_entity_recall_prompt.save(cache_dir)


context_entity_recall = ContextEntityRecall(batch_size=15)
context_entity_recall = ContextEntityRecall()
4 changes: 0 additions & 4 deletions src/ragas/metrics/_domain_specific_rubrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,3 @@ def _create_single_turn_prompt(self, row: t.Dict) -> SingleTurnWithReferenceInpu
reference=ground_truth,
rubrics=self.rubrics,
)


rubrics_score_with_reference = RubricsScoreWithReference()
rubrics_score_without_reference = RubricsScoreWithoutReference()
2 changes: 1 addition & 1 deletion src/ragas/metrics/_factual_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from numpy.typing import NDArray
from pydantic import BaseModel, Field

from ragas.experimental.prompt import PydanticPrompt
from ragas.experimental.metrics._faithfulness import (
NLIStatementInput,
NLIStatementPrompt,
)
from ragas.experimental.prompt import PydanticPrompt
from ragas.metrics.base import (
MetricType,
MetricWithLLM,
Expand Down
8 changes: 1 addition & 7 deletions src/ragas/metrics/_noise_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
@dataclass
class NoiseSensitivity(MetricWithLLM, SingleTurnMetric):
name: str = "noise_sensitivity" # type: ignore
focus: str = "relevant"
focus: t.Literal["relevant", "irrelevant"] = "relevant"
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
default_factory=lambda: {
MetricType.SINGLE_TURN: {
Expand Down Expand Up @@ -266,8 +266,6 @@ async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None:
assert self.llm is not None, "LLM is not set"

logger.info(f"Adapting Faithfulness metric to {language}")

self.nli_statements_message = self.nli_statements_message.adapt(
language, self.llm, cache_dir
)
Expand All @@ -280,7 +278,3 @@ def adapt(self, language: str, cache_dir: t.Optional[str] = None) -> None:
def save(self, cache_dir: t.Optional[str] = None) -> None:
self.nli_statements_message.save(cache_dir)
self.statement_prompt.save(cache_dir)


noise_sensitivity_relevant = NoiseSensitivity()
noise_sensitivity_irrelevant = NoiseSensitivity(focus="irrelevant")
15 changes: 10 additions & 5 deletions src/ragas/metrics/_rogue_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dataclasses import dataclass, field

from langchain_core.callbacks import Callbacks
from rouge_score import rouge_scorer

from ragas.dataset_schema import SingleTurnSample
from ragas.metrics.base import MetricType, SingleTurnMetric
Expand All @@ -18,6 +17,15 @@ class RougeScore(SingleTurnMetric):
rogue_type: t.Literal["rouge1", "rougeL"] = "rougeL"
measure_type: t.Literal["fmeasure", "precision", "recall"] = "fmeasure"

def __post_init__(self):
try:
from rouge_score import rouge_scorer
except ImportError as e:
raise ImportError(
f"{e.name} is required for rouge score. Please install it using `pip install {e.name}"
)
self.rouge_scorer = rouge_scorer

def init(self, run_config: RunConfig):
pass

Expand All @@ -26,12 +34,9 @@ async def _single_turn_ascore(
) -> float:
assert isinstance(sample.reference, str), "Sample reference must be a string"
assert isinstance(sample.response, str), "Sample response must be a string"
scorer = rouge_scorer.RougeScorer([self.rogue_type], use_stemmer=True)
scorer = self.rouge_scorer.RougeScorer([self.rogue_type], use_stemmer=True)
scores = scorer.score(sample.reference, sample.response)
return getattr(scores[self.rogue_type], self.measure_type)

async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
return await self._single_turn_ascore(SingleTurnSample(**row), callbacks)


rouge_score = RougeScore()
2 changes: 1 addition & 1 deletion src/ragas/metrics/_sql_semantic_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class EquivalencePrompt(PydanticPrompt[EquivalenceInput, EquivalenceOutput]):


@dataclass
class LLMSqlEquivalenceWithReference(MetricWithLLM, SingleTurnMetric):
class LLMSQLEquivalence(MetricWithLLM, SingleTurnMetric):
name: str = "llm_sql_equivalence_with_reference" # type: ignore
_required_columns: t.Dict[MetricType, t.Set[str]] = field(
default_factory=lambda: {
Expand Down
4 changes: 3 additions & 1 deletion src/ragas/metrics/_tool_call_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ class ToolCallAccuracy(MultiTurnMetric):
}
)

arg_comparison_metric: SingleTurnMetric = ExactMatch()
arg_comparison_metric: SingleTurnMetric = field(
default_factory=lambda: ExactMatch()
)

def init(self, run_config):
pass
Expand Down

0 comments on commit b5514c6

Please sign in to comment.