diff --git a/docs/Examples/01.AdaSTEM_demo.ipynb b/docs/Examples/01.AdaSTEM_demo.ipynb index 4e32294..adff63a 100644 --- a/docs/Examples/01.AdaSTEM_demo.ipynb +++ b/docs/Examples/01.AdaSTEM_demo.ipynb @@ -730,7 +730,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -740,6 +740,7 @@ " classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n", " regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n", " ), # hurdel model for zero-inflated problem (e.g., count)\n", + " task='hurdle',\n", " save_gridding_plot = True,\n", " ensemble_fold=10, # data are modeled 10 times, each time with jitter and rotation in Quadtree algo\n", " min_ensemble_required=7, # Only points covered by > 7 ensembles will be predicted\n", diff --git a/docs/Examples/02.AdaSTEM_learning_curve_analysis.ipynb b/docs/Examples/02.AdaSTEM_learning_curve_analysis.ipynb index c9c1c07..a0f93f6 100644 --- a/docs/Examples/02.AdaSTEM_learning_curve_analysis.ipynb +++ b/docs/Examples/02.AdaSTEM_learning_curve_analysis.ipynb @@ -136,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -195,6 +195,7 @@ " classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n", " regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n", " ),\n", + " task='hurdle',\n", " save_gridding_plot = True,\n", " ensemble_fold=10, \n", " min_ensemble_required=7,\n", diff --git a/docs/Examples/04.SphereAdaSTEM_demo.ipynb b/docs/Examples/04.SphereAdaSTEM_demo.ipynb index 0be187d..be22d58 100644 --- a/docs/Examples/04.SphereAdaSTEM_demo.ipynb +++ b/docs/Examples/04.SphereAdaSTEM_demo.ipynb @@ -721,7 +721,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -731,6 +731,7 @@ " classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n", " regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n", " ), # hurdel model for zero-inflated problem (e.g., count)\n", + " task='hurdle',\n", " save_gridding_plot = True,\n", " ensemble_fold=10, # data are modeled 10 times, each time with jitter and rotation in Quadtree algo\n", " min_ensemble_required=7, # Only points covered by > 7 stixels will be predicted\n", diff --git a/docs/Examples/05.Hurdle_in_ada_or_ada_in_hurdle.ipynb b/docs/Examples/05.Hurdle_in_ada_or_ada_in_hurdle.ipynb index f10c20f..af01a64 100644 --- a/docs/Examples/05.Hurdle_in_ada_or_ada_in_hurdle.ipynb +++ b/docs/Examples/05.Hurdle_in_ada_or_ada_in_hurdle.ipynb @@ -651,7 +651,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -660,6 +660,7 @@ " classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n", " regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n", " ),\n", + " task='hurdle',\n", " save_gridding_plot = True,\n", " ensemble_fold=10, \n", " min_ensemble_required=7,\n", diff --git a/docs/Examples/06.Base_model_choices.ipynb b/docs/Examples/06.Base_model_choices.ipynb index ba6a533..44ecd79 100644 --- a/docs/Examples/06.Base_model_choices.ipynb +++ b/docs/Examples/06.Base_model_choices.ipynb @@ -710,6 +710,7 @@ " classifier=base_model_dict[base_model_name]['classifier'],\n", " regressor=base_model_dict[base_model_name]['regressor']\n", " ),\n", + " task='hurdle',\n", " save_gridding_plot = True,\n", " ensemble_fold=10, \n", " min_ensemble_required=7,\n", diff --git a/docs/Examples/07.Optimizing_stixel_size.ipynb b/docs/Examples/07.Optimizing_stixel_size.ipynb index 6889824..eee7168 100644 --- a/docs/Examples/07.Optimizing_stixel_size.ipynb +++ b/docs/Examples/07.Optimizing_stixel_size.ipynb @@ -581,7 +581,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1154,6 +1154,7 @@ " classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n", " regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n", " ),\n", + " task='hurdle',\n", " save_gridding_plot = True,\n", " ensemble_fold=10, \n", " min_ensemble_required=7,\n", diff --git a/docs/Examples/08.Lazy_loading.ipynb b/docs/Examples/08.Lazy_loading.ipynb index af060a3..3e0f363 100644 --- a/docs/Examples/08.Lazy_loading.ipynb +++ b/docs/Examples/08.Lazy_loading.ipynb @@ -678,7 +678,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -691,6 +691,7 @@ " classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n", " regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n", " ), # hurdel model for zero-inflated problem (e.g., count)\n", + " task='hurdle',\n", " save_gridding_plot = True,\n", " ensemble_fold=ensemble_fold, # data are modeled 10 times, each time with jitter and rotation in Quadtree algo\n", " min_ensemble_required=ensemble_fold-2, # Only points covered by > 7 ensembles will be predicted\n", @@ -720,6 +721,7 @@ " classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n", " regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n", " ), # hurdel model for zero-inflated problem (e.g., count)\n", + " task='hurdle',\n", " save_gridding_plot = True,\n", " ensemble_fold=ensemble_fold, # data are modeled 10 times, each time with jitter and rotation in Quadtree algo\n", " min_ensemble_required=ensemble_fold-2, # Only points covered by > 7 ensembles will be predicted\n", diff --git a/stemflow/model/AdaSTEM.py b/stemflow/model/AdaSTEM.py index 011d3db..13841be 100644 --- a/stemflow/model/AdaSTEM.py +++ b/stemflow/model/AdaSTEM.py @@ -1022,7 +1022,9 @@ def eval_STEM_res( elif task == "hurdle": cls_threshold = 0 - if not task == "regression": + if task == "regression": + auc, kappa, f1, precision, recall, average_precision = [np.nan] * 6 + else: a = pd.DataFrame({"y_true": np.array(y_test).flatten(), "pred": np.array(y_pred).flatten()}).dropna() y_test_b = np.where(a.y_true > cls_threshold, 1, 0) @@ -1039,9 +1041,6 @@ def eval_STEM_res( recall = recall_score(y_test_b, y_pred_b) average_precision = average_precision_score(y_test_b, y_pred_b) - else: - auc, kappa, f1, precision, recall, average_precision = [np.nan] * 6 - if not task == "classification": a = pd.DataFrame({"y_true": y_test, "pred": y_pred}).dropna() s_r, _ = spearmanr(np.array(a.y_true), np.array(a.pred)) diff --git a/stemflow/model/static_func_AdaSTEM.py b/stemflow/model/static_func_AdaSTEM.py index 55ed3c8..349bffe 100644 --- a/stemflow/model/static_func_AdaSTEM.py +++ b/stemflow/model/static_func_AdaSTEM.py @@ -467,11 +467,12 @@ def predict_one_stixel( if pred is None: # Still haven't found the pred function - if task == "regression": - pred = model_x_names_tuple[0].predict(X_test_stixel[model_x_names_tuple[1]]) - else: + if task == "classification": pred = model_x_names_tuple[0].predict_proba(X_test_stixel[model_x_names_tuple[1]], **base_model_prediction_param) pred = pred[:,1] + else: + pred = model_x_names_tuple[0].predict(X_test_stixel[model_x_names_tuple[1]], **base_model_prediction_param) + res = pd.DataFrame({"index": list(X_test_stixel.index), "pred": np.array(pred).flatten()}).set_index("index")