Skip to content

Commit

Permalink
patching some metrics and errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Raul committed Jul 8, 2022
1 parent 495697d commit bbad373
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
10 changes: 7 additions & 3 deletions xgbse/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def approx_brier_score(y_true, survival, aggregate="mean"):
# adding censoring distribution survival at event
event_time_windows = _match_times_to_windows(times, survival.columns)
scoring_df["cens_at_event"] = censoring_dist[event_time_windows].iloc[0].values
# TODO Something is broken when use sklearn.model_selection.cross_val_score

# list of window results
window_results = []

Expand Down Expand Up @@ -177,11 +177,15 @@ def approx_brier_score(y_true, survival, aggregate="mean"):
)

# adding and taking average
result = (first_term + second_term).sum() / scoring_df.shape[0]
# OLD CODE:
# result = (first_term + second_term).sum() / scoring_df.shape[0]

added_terms = (first_term + second_term)
result = np.nanmean(added_terms[np.isfinite(added_terms)])
window_results.append(result)

if aggregate == "mean":
return np.array(window_results).mean()
return np.nanmean(window_results)
elif aggregate is None:
return np.array(window_results)
else:
Expand Down
20 changes: 14 additions & 6 deletions xgbse/non_parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ def get_time_bins(T, E, size=12):
"""
Method to automatically define time bins
"""

lower_bound = max(T[E == 0].min(), T[E == 1].min()) + 1
upper_bound = min(T[E == 0].max(), T[E == 1].max()) - 1

# OLD CODE:
# lower_bound = max(T[E == 0].min(), T[E == 1].min()) + 1
# upper_bound = min(T[E == 0].max(), T[E == 1].max()) - 1
lower_bound = np.nanmin(T[E == 1]) + 1
upper_bound = np.nanmax(T[E == 1]) - 1
return np.linspace(lower_bound, upper_bound, size, dtype=int)


Expand Down Expand Up @@ -74,12 +75,19 @@ def sample_time_bins(surv_array, T_neighs, time_bins):
"""

surv_df = []

for t in time_bins:
survival_at_t = (surv_array + (T_neighs > t)).min(axis=1)
# Old code:
# survival_at_t = (surv_array + (T_neighs > t)).min(axis=1)

mask = np.zeros_like(T_neighs)
mask[T_neighs > t] = np.inf
survival_at_t = (surv_array + mask).min(axis=1)
# survival_at_t = np.where(~np.isfinite(survival_at_t), survival_at_t, 1.0)

surv_df.append(survival_at_t)

surv_df = pd.DataFrame(surv_df, index=time_bins).T
surv_df.where(np.isfinite(surv_df), 1.0, inplace=True)
return surv_df


Expand Down

0 comments on commit bbad373

Please sign in to comment.