Skip to content

Commit

Permalink
Make var names clearer as to whether holds multi or single value
Browse files Browse the repository at this point in the history
  • Loading branch information
hermidalc committed Aug 18, 2021
1 parent b030e9d commit 86694cf
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
34 changes: 17 additions & 17 deletions generate_roc_pr_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,20 @@
fig, ax = plt.subplots(figsize=(fig_dim, fig_dim), dpi=fig_dpi)
for ridx, _ in enumerate(split_results):
tprs, roc_scores = [], []
mean_fpr = np.linspace(0, 1, 1000)
mean_fprs = np.linspace(0, 1, 1000)
for split_result in split_results[ridx]:
if split_result is None:
continue
tprs.append(np.interp(
mean_fpr, split_result['scores']['te']['fpr'],
mean_fprs, split_result['scores']['te']['fpr'],
split_result['scores']['te']['tpr']))
tprs[-1][0] = 0.0
roc_scores.append(split_result['scores']['te']['roc_auc'])
mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
std_tpr = np.std(tprs, axis=0)
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
mean_tprs = np.mean(tprs, axis=0)
mean_tprs[-1] = 1.0
std_tprs = np.std(tprs, axis=0)
tprs_upper = np.minimum(mean_tprs + std_tprs, 1)
tprs_lower = np.maximum(mean_tprs - std_tprs, 0)
if data_type == 'combo':
dtype_label = ('Combo' if ridx == 0 else
'Expression' if ridx == 1 else 'Microbiome')
Expand All @@ -128,10 +128,10 @@
label = 'Clinical'
color = colors[-1]
zorder = 2
ax.plot(mean_fpr, mean_tpr, alpha=0.8, color=color, lw=2,
ax.plot(mean_fprs, mean_tprs, alpha=0.8, color=color, lw=2,
label=r'{} AUROC = $\bf{{{:.2f}}}$'.format(
label, np.mean(roc_scores)), zorder=zorder)
ax.fill_between(mean_fpr, tprs_lower, tprs_upper, alpha=0.1,
ax.fill_between(mean_fprs, tprs_lower, tprs_upper, alpha=0.1,
color=color, zorder=zorder)
ax.plot([0, 1], [0, 1], alpha=0.2, color='darkgrey',
linestyle='--', lw=1.5, zorder=1)
Expand Down Expand Up @@ -179,18 +179,18 @@
fig, ax = plt.subplots(figsize=(fig_dim, fig_dim), dpi=fig_dpi)
for ridx, _ in enumerate(split_results):
pres, pr_scores = [], []
mean_rec = np.linspace(0, 1, 1000)
mean_recs = np.linspace(0, 1, 1000)
for split_result in split_results[ridx]:
if split_result is None:
continue
pres.append(np.interp(
mean_rec, split_result['scores']['te']['rec'][::-1],
mean_recs, split_result['scores']['te']['rec'][::-1],
split_result['scores']['te']['pre'][::-1]))
pr_scores.append(split_result['scores']['te']['pr_auc'])
mean_pre = np.mean(pres, axis=0)
std_pre = np.std(pres, axis=0)
pres_upper = np.minimum(mean_pre + std_pre, 1)
pres_lower = np.maximum(mean_pre - std_pre, 0)
mean_pres = np.mean(pres, axis=0)
std_pres = np.std(pres, axis=0)
pres_upper = np.minimum(mean_pres + std_pres, 1)
pres_lower = np.maximum(mean_pres - std_pres, 0)
if data_type == 'combo':
dtype_label = ('Combo' if ridx == 0 else
'Expression' if ridx == 1 else 'Microbiome')
Expand All @@ -206,11 +206,11 @@
label = 'Clinical'
color = colors[-1]
zorder = 2
ax.step(mean_rec, mean_pre, alpha=0.8, color=color, lw=2,
ax.step(mean_recs, mean_pres, alpha=0.8, color=color, lw=2,
label=r'{} AUPRC = $\bf{{{:.2f}}}$'.format(
label, np.mean(pr_scores)), where='post',
zorder=zorder)
ax.fill_between(mean_rec, pres_lower, pres_upper, alpha=0.1,
ax.fill_between(mean_recs, pres_lower, pres_upper, alpha=0.1,
color=color, zorder=zorder)
ax.set_xlabel('Recall', fontsize=axis_fontsize, labelpad=5)
ax.set_ylabel('Precision', fontsize=axis_fontsize, labelpad=5)
Expand Down
14 changes: 7 additions & 7 deletions generate_td_auc_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,10 @@ def get_cv_split_idxs(X, y, groups, group_weights):
for time, auc in zip(times, aucs):
interp_aucs.append(np.interp(mean_times, time, auc))
mean_times = mean_times / days_per_year
mean_auc = np.mean(interp_aucs, axis=0)
std_auc = np.std(interp_aucs, axis=0)
aucs_upper = np.minimum(mean_auc + std_auc, 1)
aucs_lower = np.maximum(mean_auc - std_auc, 0)
mean_aucs = np.mean(interp_aucs, axis=0)
std_aucs = np.std(interp_aucs, axis=0)
aucs_upper = np.minimum(mean_aucs + std_aucs, 1)
aucs_lower = np.maximum(mean_aucs - std_aucs, 0)
if data_type == 'combo':
dtype_label = ('Combo' if ridx == 0 else
'Expression' if ridx == 1 else 'Microbiome')
Expand All @@ -255,12 +255,12 @@ def get_cv_split_idxs(X, y, groups, group_weights):
label = 'Clinical'
color = colors[-1]
zorder = 2
ax.plot(mean_times, mean_auc, alpha=0.8, color=color, lw=2,
ax.plot(mean_times, mean_aucs, alpha=0.8, color=color, lw=2,
label=r'{} AUC = $\bf{{{:.2f}}}$'.format(
label, np.mean(mean_auc)), zorder=zorder)
label, np.mean(mean_aucs)), zorder=zorder)
ax.fill_between(mean_times, aucs_lower, aucs_upper, alpha=0.1,
color=color, zorder=zorder)
ax.axhline(np.mean(mean_auc), alpha=0.5, color=color,
ax.axhline(np.mean(mean_aucs), alpha=0.5, color=color,
linestyle='--', lw=1.5, zorder=1)
xaxis_tick_base = (3 if max(mean_times) > 20 else
2 if max(mean_times) > 10 else 1)
Expand Down

0 comments on commit 86694cf

Please sign in to comment.