-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add ground truth columns to prediction output (#504)
<!-- Thanks for sending a pull request! Here are some tips for you: 1. Run unit tests and ensure that they are passing 2. If your change introduces any API changes, make sure to update the e2e tests 3. Make sure documentation is updated for your PR! --> **What this PR does / why we need it**: <!-- Explain here the context and why you're making the change. What is the problem you're trying to solve. ---> Changes introduced: - Using string enum type as opposed to int, so that the json output for the enum field will be of the string value as opposed to int, which would make it easier to understand the json content. This is important for the use case of providing the enum values as input to flyte tasks. - Instead of having all possible classes as separate attributes within the inference schema class, they are now combined within a single field, model_prediction_output. In addition, the prediction output is now expected to implement the preprocess method, which will be used to preprocess the dataframe before sending the result to Arize. - Due to the existence of models that might have non-float prediction value (such as ranking model), the subclasses of PredictionOutput are expected to provide prediction types as well. - Changes required for observation publisher and batch publisher, due to the changes made to the Merlin SDK, will be on a separate pull request. **Which issue(s) this PR fixes**: <!-- *Automatically closes linked issue when PR is merged. Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`. --> Fixes # **Does this PR introduce a user-facing change?**: <!-- If no, just write "NONE" in the release-note block below. If yes, a release note is required. Enter your extended release note in the block below. If the PR requires additional action from users switching to the new release, include the string "action required". For more information about release notes, see kubernetes' guide here: http://git.k8s.io/community/contributors/guide/release-notes.md --> ```release-note NONE ``` **Checklist** - [x] Added unit test, integration, and/or e2e tests - [x] Tested locally - [x] Updated documentation - [ ] Update Swagger spec if the PR introduce API changes - [ ] Regenerated Golang and Python client if the PR introduce API changes
- Loading branch information
1 parent
8c33cac
commit 6dd5d81
Showing
2 changed files
with
321 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,111 +1,252 @@ | ||
from dataclasses import dataclass | ||
from enum import unique, Enum | ||
from typing import Dict, Optional, List | ||
import abc | ||
from dataclasses import dataclass, field | ||
from enum import Enum, unique | ||
from typing import Dict, List, Optional | ||
|
||
from dataclasses_json import dataclass_json | ||
import numpy as np | ||
import pandas as pd | ||
from dataclasses_json import DataClassJsonMixin, config, dataclass_json | ||
|
||
|
||
class ObservationType(Enum): | ||
""" | ||
Supported observation types. | ||
""" | ||
|
||
FEATURE = "feature" | ||
PREDICTION = "prediction" | ||
GROUND_TRUTH = "ground_truth" | ||
|
||
|
||
@unique | ||
class ValueType(Enum): | ||
FLOAT64 = 1 | ||
INT64 = 2 | ||
BOOLEAN = 3 | ||
STRING = 4 | ||
""" | ||
Supported feature value types. | ||
""" | ||
|
||
FLOAT64 = "float64" | ||
INT64 = "int64" | ||
BOOLEAN = "boolean" | ||
STRING = "string" | ||
|
||
|
||
class PredictionOutput(abc.ABC): | ||
subclass_registry: dict = {} | ||
discriminator_field: str = "output_class" | ||
|
||
""" | ||
Register subclasses of PredictionOutput. | ||
""" | ||
|
||
def __init_subclass__(cls, **kwargs): | ||
super().__init_subclass__(**kwargs) | ||
PredictionOutput.subclass_registry[cls.__name__] = cls | ||
|
||
""" | ||
Given a subclass of PredictionOutput, which is assumed to have dataclass json mix in, encode | ||
the object with a discriminator field used to differentiate the different subclasses. | ||
""" | ||
|
||
@classmethod | ||
def encode_with_discriminator(cls, x): | ||
if type(x) not in cls.subclass_registry.values(): | ||
raise ValueError( | ||
f"Input must be a subclass of {cls.__name__}, got {type(x)}" | ||
) | ||
if not isinstance(x, DataClassJsonMixin): | ||
raise ValueError( | ||
f"Input must be a virtual subclass of DataClassJsonMixin, got {type(x)}" | ||
) | ||
|
||
return { | ||
cls.discriminator_field: type(x).__name__, | ||
**x.to_dict(), | ||
} | ||
|
||
""" | ||
Given a dictionary, encoded using :func:`merlin.observability.inference.PredictionOutput.encode_with_discriminator`, | ||
decode the dictionary back to the correct subclass of PredictionOutput. | ||
""" | ||
|
||
@classmethod | ||
def decode(cls, input: Dict): | ||
return PredictionOutput.subclass_registry[ | ||
input[cls.discriminator_field] | ||
].from_dict(input) | ||
|
||
""" | ||
Given an input dataframe, return a new dataframe with the necessary columns for observability, | ||
along with a schema for the observability backend to parse the dataframe. In place changes might | ||
be made to the input dataframe. | ||
:param df: Input dataframe. | ||
:param observation_types: Types of observations to be included in the output dataframe. | ||
:return: output dataframe | ||
""" | ||
|
||
@abc.abstractmethod | ||
def preprocess( | ||
self, df: pd.DataFrame, observation_types: List[ObservationType] | ||
) -> pd.DataFrame: | ||
raise NotImplementedError | ||
|
||
""" | ||
Return a dictionary mapping the name of the prediction output column to its value type. | ||
""" | ||
|
||
@abc.abstractmethod | ||
def prediction_types(self) -> Dict[str, ValueType]: | ||
raise NotImplementedError | ||
|
||
|
||
@dataclass_json | ||
@dataclass | ||
class RegressionOutput: | ||
class RegressionOutput(PredictionOutput): | ||
""" | ||
Regression model prediction output schema. | ||
Attributes: | ||
prediction_score_column: Name of the column containing the prediction score. | ||
actual_score_column: Name of the column containing the actual score. | ||
""" | ||
|
||
prediction_score_column: str | ||
actual_score_column: str | ||
|
||
@property | ||
def column_types(self) -> Dict[str, ValueType]: | ||
return {self.prediction_score_column: ValueType.FLOAT64} | ||
def preprocess( | ||
self, df: pd.DataFrame, observation_types: List[ObservationType] | ||
) -> pd.DataFrame: | ||
return df | ||
|
||
def prediction_types(self) -> Dict[str, ValueType]: | ||
return { | ||
self.prediction_score_column: ValueType.FLOAT64, | ||
self.actual_score_column: ValueType.FLOAT64, | ||
} | ||
|
||
|
||
@dataclass_json | ||
@dataclass | ||
class BinaryClassificationOutput: | ||
class BinaryClassificationOutput(PredictionOutput): | ||
""" | ||
Binary classification model prediction output schema. | ||
Attributes: | ||
prediction_score_column: Name of the column containing the prediction score. | ||
Prediction score must be between 0.0 and 1.0. | ||
actual_label_column: Name of the column containing the actual class. | ||
positive_class_label: Label for positive class. | ||
negative_class_label: Label for negative class. | ||
score_threshold: Score threshold for prediction to be considered as positive class. | ||
""" | ||
|
||
prediction_score_column: str | ||
prediction_label_column: Optional[str] = None | ||
actual_label_column: str | ||
positive_class_label: str | ||
negative_class_label: str | ||
score_threshold: float = 0.5 | ||
|
||
@property | ||
def column_types(self) -> Dict[str, ValueType]: | ||
column_types_mapping = {self.prediction_score_column: ValueType.FLOAT64} | ||
if self.prediction_label_column is not None: | ||
column_types_mapping[self.prediction_label_column] = ValueType.STRING | ||
return column_types_mapping | ||
|
||
|
||
@dataclass_json | ||
@dataclass | ||
class MulticlassClassificationOutput: | ||
prediction_score_columns: List[str] | ||
prediction_label_columns: Optional[List[str]] = None | ||
def prediction_label_column(self) -> str: | ||
return "_prediction_label" | ||
|
||
@property | ||
def column_types(self) -> Dict[str, ValueType]: | ||
column_types_mapping = { | ||
label_column: ValueType.FLOAT64 | ||
for label_column in self.prediction_score_columns | ||
def actual_score_column(self) -> str: | ||
return "_actual_score" | ||
|
||
def prediction_label(self, prediction_score: float) -> str: | ||
""" | ||
Returns either positive or negative class label based on prediction score. | ||
:param prediction_score: Probability of positive class, between 0.0 and 1.0. | ||
:return: prediction label | ||
""" | ||
return ( | ||
self.positive_class_label | ||
if prediction_score >= self.score_threshold | ||
else self.negative_class_label | ||
) | ||
|
||
def actual_score(self, actual_label: str) -> float: | ||
""" | ||
Derive actual score from actual label. | ||
:param actual_label: Actual label. | ||
:return: actual score. Either 0.0 for negative class or 1.0 for positive class. | ||
""" | ||
if actual_label not in [self.positive_class_label, self.negative_class_label]: | ||
raise ValueError( | ||
f"Actual label must be one of the classes, got {actual_label}" | ||
) | ||
return 1.0 if actual_label == self.positive_class_label else 0.0 | ||
|
||
def preprocess( | ||
self, df: pd.DataFrame, observation_types: List[ObservationType] | ||
) -> pd.DataFrame: | ||
if ObservationType.PREDICTION in observation_types: | ||
df[self.prediction_label_column] = df[self.prediction_score_column].apply( | ||
self.prediction_label | ||
) | ||
if ObservationType.GROUND_TRUTH in observation_types: | ||
df[self.actual_score_column] = df[self.actual_label_column].apply( | ||
self.actual_score | ||
) | ||
return df | ||
|
||
def prediction_types(self) -> Dict[str, ValueType]: | ||
return { | ||
self.prediction_score_column: ValueType.FLOAT64, | ||
self.prediction_label_column: ValueType.STRING, | ||
self.actual_score_column: ValueType.FLOAT64, | ||
self.actual_label_column: ValueType.STRING, | ||
} | ||
if self.prediction_label_columns is not None: | ||
for column_name in self.prediction_label_columns: | ||
column_types_mapping[column_name] = ValueType.STRING | ||
return column_types_mapping | ||
|
||
|
||
@dataclass_json | ||
@dataclass | ||
class RankingOutput: | ||
rank_column: str | ||
class RankingOutput(PredictionOutput): | ||
rank_score_column: str | ||
prediction_group_id_column: str | ||
relevance_score_column: str | ||
""" | ||
Ranking model prediction output schema. | ||
Attributes: | ||
rank_score_column: Name of the column containing the ranking score of the prediction. | ||
prediction_group_id_column: Name of the column containing the prediction group id. | ||
relevance_score_column: Name of the column containing the relevance score of the prediction. | ||
""" | ||
|
||
@property | ||
def column_types(self) -> Dict[str, ValueType]: | ||
def rank_column(self) -> str: | ||
return "_rank" | ||
|
||
def preprocess( | ||
self, df: pd.DataFrame, observation_types: List[ObservationType] | ||
) -> pd.DataFrame: | ||
if ObservationType.PREDICTION in observation_types: | ||
df[self.rank_column] = df.groupby(self.prediction_group_id_column)[ | ||
self.rank_score_column | ||
].rank(method="first", ascending=False).astype(np.int_) | ||
return df | ||
|
||
def prediction_types(self) -> Dict[str, ValueType]: | ||
return { | ||
self.rank_column: ValueType.INT64, | ||
self.prediction_group_id_column: ValueType.STRING, | ||
self.relevance_score_column: ValueType.FLOAT64, | ||
} | ||
|
||
|
||
@unique | ||
class InferenceType(Enum): | ||
BINARY_CLASSIFICATION = 1 | ||
MULTICLASS_CLASSIFICATION = 2 | ||
REGRESSION = 3 | ||
RANKING = 4 | ||
|
||
|
||
@dataclass_json | ||
@dataclass | ||
class InferenceSchema: | ||
feature_types: Dict[str, ValueType] | ||
type: InferenceType | ||
binary_classification: Optional[BinaryClassificationOutput] = None | ||
multiclass_classification: Optional[MulticlassClassificationOutput] = None | ||
regression: Optional[RegressionOutput] = None | ||
ranking: Optional[RankingOutput] = None | ||
prediction_id_column: Optional[str] = "prediction_id" | ||
model_prediction_output: PredictionOutput = field( | ||
metadata=config( | ||
encoder=PredictionOutput.encode_with_discriminator, | ||
decoder=PredictionOutput.decode, | ||
) | ||
) | ||
prediction_id_column: str = "prediction_id" | ||
tag_columns: Optional[List[str]] = None | ||
|
||
@property | ||
def feature_columns(self) -> List[str]: | ||
return list(self.feature_types.keys()) | ||
|
||
@property | ||
def prediction_column_types(self) -> Dict[str, ValueType]: | ||
if self.type == InferenceType.BINARY_CLASSIFICATION: | ||
assert self.binary_classification is not None | ||
return self.binary_classification.column_types | ||
elif self.type == InferenceType.MULTICLASS_CLASSIFICATION: | ||
assert self.multiclass_classification is not None | ||
return self.multiclass_classification.column_types | ||
elif self.type == InferenceType.REGRESSION: | ||
assert self.regression is not None | ||
return self.regression.column_types | ||
elif self.type == InferenceType.RANKING: | ||
assert self.ranking is not None | ||
return self.ranking.column_types | ||
else: | ||
raise ValueError(f"Unknown prediction type: {self.type}") |
Oops, something went wrong.