-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathprecision_recall_curve.py
47 lines (36 loc) · 1.33 KB
/
precision_recall_curve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# %%
import numpy as np
import pandas as pd
import pymatviz as pmv
pmv.set_plotly_template("pymatviz_dark")
# Random classification data
np_rng = np.random.default_rng(seed=0)
rand_clf_size = 100
y_binary = np_rng.choice([0, 1], size=rand_clf_size)
y_proba = np.clip(y_binary - 0.1 * np_rng.normal(scale=5, size=rand_clf_size), 0.2, 0.9)
df_clf = pd.DataFrame({"target": y_binary, "probability": y_proba})
# %% Plotly version - basic usage
fig = pmv.precision_recall_curve_plotly(y_binary, y_proba)
fig.show()
# pmv.io.save_and_compress_svg(fig, "precision-recall-curve-plotly")
# %% Plotly version - with DataFrame
fig = pmv.precision_recall_curve_plotly("target", "probability", df=df_clf)
fig.show()
# %% Multiple PR curves on same plot
# Generate data for multiple classifiers
classifiers = {
"Classifier A": (
np.clip(y_binary - 0.1 * np_rng.normal(scale=5, size=rand_clf_size), 0.2, 0.9)
),
"Classifier B": (
np.clip(y_binary - 0.2 * np_rng.normal(scale=3, size=rand_clf_size), 0.1, 0.95)
),
"Classifier C": (
np.clip(
y_binary - 0.15 * np_rng.normal(scale=4, size=rand_clf_size), 0.15, 0.85
)
),
}
fig = pmv.precision_recall_curve_plotly(targets=y_binary, probs_positive=classifiers)
fig.show()
pmv.io.save_and_compress_svg(fig, "precision-recall-curve-plotly-multiple")