Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyangkang committed Feb 13, 2025
1 parent 167606b commit ba54c1b
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 13 deletions.
3 changes: 2 additions & 1 deletion docs/Examples/01.AdaSTEM_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -740,6 +740,7 @@
" classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n",
" regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n",
" ), # hurdel model for zero-inflated problem (e.g., count)\n",
" task='hurdle',\n",
" save_gridding_plot = True,\n",
" ensemble_fold=10, # data are modeled 10 times, each time with jitter and rotation in Quadtree algo\n",
" min_ensemble_required=7, # Only points covered by > 7 ensembles will be predicted\n",
Expand Down
3 changes: 2 additions & 1 deletion docs/Examples/02.AdaSTEM_learning_curve_analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -195,6 +195,7 @@
" classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n",
" regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n",
" ),\n",
" task='hurdle',\n",
" save_gridding_plot = True,\n",
" ensemble_fold=10, \n",
" min_ensemble_required=7,\n",
Expand Down
3 changes: 2 additions & 1 deletion docs/Examples/04.SphereAdaSTEM_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -731,6 +731,7 @@
" classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n",
" regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n",
" ), # hurdel model for zero-inflated problem (e.g., count)\n",
" task='hurdle',\n",
" save_gridding_plot = True,\n",
" ensemble_fold=10, # data are modeled 10 times, each time with jitter and rotation in Quadtree algo\n",
" min_ensemble_required=7, # Only points covered by > 7 stixels will be predicted\n",
Expand Down
3 changes: 2 additions & 1 deletion docs/Examples/05.Hurdle_in_ada_or_ada_in_hurdle.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -660,6 +660,7 @@
" classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n",
" regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n",
" ),\n",
" task='hurdle',\n",
" save_gridding_plot = True,\n",
" ensemble_fold=10, \n",
" min_ensemble_required=7,\n",
Expand Down
1 change: 1 addition & 0 deletions docs/Examples/06.Base_model_choices.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,7 @@
" classifier=base_model_dict[base_model_name]['classifier'],\n",
" regressor=base_model_dict[base_model_name]['regressor']\n",
" ),\n",
" task='hurdle',\n",
" save_gridding_plot = True,\n",
" ensemble_fold=10, \n",
" min_ensemble_required=7,\n",
Expand Down
3 changes: 2 additions & 1 deletion docs/Examples/07.Optimizing_stixel_size.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1154,6 +1154,7 @@
" classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n",
" regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n",
" ),\n",
" task='hurdle',\n",
" save_gridding_plot = True,\n",
" ensemble_fold=10, \n",
" min_ensemble_required=7,\n",
Expand Down
4 changes: 3 additions & 1 deletion docs/Examples/08.Lazy_loading.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -691,6 +691,7 @@
" classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n",
" regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n",
" ), # hurdel model for zero-inflated problem (e.g., count)\n",
" task='hurdle',\n",
" save_gridding_plot = True,\n",
" ensemble_fold=ensemble_fold, # data are modeled 10 times, each time with jitter and rotation in Quadtree algo\n",
" min_ensemble_required=ensemble_fold-2, # Only points covered by > 7 ensembles will be predicted\n",
Expand Down Expand Up @@ -720,6 +721,7 @@
" classifier=XGBClassifier(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1),\n",
" regressor=XGBRegressor(tree_method='hist',random_state=42, verbosity = 0, n_jobs=1)\n",
" ), # hurdel model for zero-inflated problem (e.g., count)\n",
" task='hurdle',\n",
" save_gridding_plot = True,\n",
" ensemble_fold=ensemble_fold, # data are modeled 10 times, each time with jitter and rotation in Quadtree algo\n",
" min_ensemble_required=ensemble_fold-2, # Only points covered by > 7 ensembles will be predicted\n",
Expand Down
7 changes: 3 additions & 4 deletions stemflow/model/AdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,9 @@ def eval_STEM_res(
elif task == "hurdle":
cls_threshold = 0

if not task == "regression":
if task == "regression":
auc, kappa, f1, precision, recall, average_precision = [np.nan] * 6
else:
a = pd.DataFrame({"y_true": np.array(y_test).flatten(), "pred": np.array(y_pred).flatten()}).dropna()

y_test_b = np.where(a.y_true > cls_threshold, 1, 0)
Expand All @@ -1039,9 +1041,6 @@ def eval_STEM_res(
recall = recall_score(y_test_b, y_pred_b)
average_precision = average_precision_score(y_test_b, y_pred_b)

else:
auc, kappa, f1, precision, recall, average_precision = [np.nan] * 6

if not task == "classification":
a = pd.DataFrame({"y_true": y_test, "pred": y_pred}).dropna()
s_r, _ = spearmanr(np.array(a.y_true), np.array(a.pred))
Expand Down
7 changes: 4 additions & 3 deletions stemflow/model/static_func_AdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,12 @@ def predict_one_stixel(

if pred is None:
# Still haven't found the pred function
if task == "regression":
pred = model_x_names_tuple[0].predict(X_test_stixel[model_x_names_tuple[1]])
else:
if task == "classification":
pred = model_x_names_tuple[0].predict_proba(X_test_stixel[model_x_names_tuple[1]], **base_model_prediction_param)
pred = pred[:,1]
else:
pred = model_x_names_tuple[0].predict(X_test_stixel[model_x_names_tuple[1]], **base_model_prediction_param)


res = pd.DataFrame({"index": list(X_test_stixel.index), "pred": np.array(pred).flatten()}).set_index("index")

Expand Down

0 comments on commit ba54c1b

Please sign in to comment.