diff --git a/doc/api.rst b/doc/api.rst index fa0fed7..b371d46 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -34,6 +34,13 @@ This is the full API documentation of the `sharp` package. qoi.RankScoreQoI qoi.TopKQoI +.. autosummary:: + :toctree: _generated/ + :template: function.rst + + qoi.get_qoi + qoi.get_qoi_names + :mod:`sharp.visualization` -------------------------- diff --git a/sharp/qoi/__init__.py b/sharp/qoi/__init__.py index d8aa086..30dfc60 100644 --- a/sharp/qoi/__init__.py +++ b/sharp/qoi/__init__.py @@ -9,7 +9,8 @@ RankQoI, RankScoreQoI, TopKQoI, - QOI_OBJECTS, + get_qoi, + get_qoi_names, ) __all__ = [ @@ -19,5 +20,6 @@ "RankQoI", "RankScoreQoI", "TopKQoI", - "QOI_OBJECTS", + "get_qoi", + "get_qoi_names", ] diff --git a/sharp/qoi/_qoi.py b/sharp/qoi/_qoi.py index 9443bbb..e8af92d 100644 --- a/sharp/qoi/_qoi.py +++ b/sharp/qoi/_qoi.py @@ -1,3 +1,4 @@ +import copy from .base import BaseQoI, BaseRankQoI @@ -138,7 +139,7 @@ def _calculate(self, rows1, rows2): return (self.estimate(rows1) - self.estimate(rows2)).mean() -QOI_OBJECTS = { +_QOI_OBJECTS = { "diff": DiffQoI, "flip": FlipQoI, "likelihood": LikelihoodQoI, @@ -146,3 +147,65 @@ def _calculate(self, rows1, rows2): "rank_score": RankScoreQoI, "top_k": TopKQoI, } + + +def get_qoi_names(): + """Get the names of all available quantities of interest. + + These names can be passed to :func:`~sharp.qoi.get_qoi` to + retrieve the QoI object. + + Returns + ------- + list of str + Names of all available quantities of interest. + + Examples + -------- + >>> from sharp.qoi import get_qoi_names + >>> all_qois = get_qoi_names() + >>> type(all_qois) + + >>> all_qois[:3] + ['diff', 'flip', 'likelihood'] + >>> "ranking" in all_qois + True + """ + return sorted(_QOI_OBJECTS.keys()) + + +def get_qoi(qoi): + """Get a quantity of interest from string. + + :func:`~sharp.qoi.get_qoi_names` can be used to retrieve the names + of all available quantities of interest. + + Parameters + ---------- + qoi : str, callable or None + Quantity of interest as string. If callable it is returned as is. + If None, returns None. + + Returns + ------- + quantity : callable + The quantity of interest. + + Notes + ----- + When passed a string, this function always returns a copy of the scorer + object. Calling `get_qoi` twice for the same scorer results in two + separate QoI objects. + """ + if isinstance(qoi, str): + try: + quantity = copy.deepcopy(_QOI_OBJECTS[qoi]) + except KeyError: + raise ValueError( + "%r is not a valid scoring value. " + "Use sklearn.metrics.get_scorer_names() " + "to get valid options." % qoi + ) + else: + quantity = qoi + return quantity diff --git a/sharp/tests/test_basic_usage.py b/sharp/tests/test_basic_usage.py index 000ba5b..5af47b8 100644 --- a/sharp/tests/test_basic_usage.py +++ b/sharp/tests/test_basic_usage.py @@ -3,7 +3,7 @@ import numpy as np from sklearn.utils import check_random_state from sharp import ShaRP -from sharp.qoi import QOI_OBJECTS +from sharp.qoi import get_qoi from sharp._measures import MEASURES # Set up some envrionment variables @@ -12,10 +12,10 @@ rng = check_random_state(RNG_SEED) rank_qois_str = ["rank", "rank_score", "top_k"] -rank_qois_obj = [QOI_OBJECTS[qoi] for qoi in rank_qois_str] +rank_qois_obj = [get_qoi(qoi) for qoi in rank_qois_str] clf_qois_str = ["diff", "flip", "likelihood"] -clf_qois_obj = [QOI_OBJECTS[qoi] for qoi in clf_qois_str] +clf_qois_obj = [get_qoi(qoi) for qoi in clf_qois_str] measures = list(MEASURES.keys()) diff --git a/sharp/utils/_checks.py b/sharp/utils/_checks.py index af664f6..da9de4a 100644 --- a/sharp/utils/_checks.py +++ b/sharp/utils/_checks.py @@ -1,7 +1,7 @@ import numpy as np from sklearn.utils.validation import check_array, _get_feature_names -from sharp.qoi import QOI_OBJECTS +from sharp.qoi import get_qoi from sharp._measures import MEASURES @@ -53,7 +53,7 @@ def check_qoi(qoi, target_function=None, X=None): msg = "If `qoi` is of type `str`, `target_function` cannot be None." raise TypeError(msg) - if QOI_OBJECTS[qoi]._qoi_type == "rank": + if get_qoi(qoi)._qoi_type == "rank": # Add dataset to list of parameters if QoI is rank-based params["X"] = X @@ -72,5 +72,5 @@ def check_qoi(qoi, target_function=None, X=None): else: return qoi - qoi = QOI_OBJECTS[qoi](**params) + qoi = get_qoi(qoi)(**params) return qoi