From c048a969db1940d1dfefdbd6d46f468a326465a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Raul=20Casa=C3=B1a=20Eslava?= Date: Thu, 9 May 2024 16:22:45 +0200 Subject: [PATCH] Added the option to select the number of boosting rounds in KaplanMeyerTree --- xgbse/_kaplan_neighbors.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/xgbse/_kaplan_neighbors.py b/xgbse/_kaplan_neighbors.py index 284eab2..3fddb99 100644 --- a/xgbse/_kaplan_neighbors.py +++ b/xgbse/_kaplan_neighbors.py @@ -386,6 +386,7 @@ def fit( ci_width=0.683, enable_categorical: bool = False, feature_types: Optional[Sequence[str]] = None, + num_boost_round: int = None, **xgb_kwargs ): """ @@ -422,6 +423,8 @@ def fit( feature_types (Sequence[str]): Seq indicating the column type c or q, for categorical or numerical respect. + num_boost_round (int): Number of boosting rounds + Returns: XGBSEKaplanTree: Trained instance of XGBSEKaplanTree """ @@ -435,9 +438,19 @@ def fit( dtrain = convert_data_to_xgb_format( X, y, self.xgb_params["objective"], enable_categorical=enable_categorical, feature_types=feature_types ) - + if num_boost_round is not None and "num_boost_round" in self.xgb_params: + warnings.warn( + f"'num_boost_round' is also in xgb_params. Taking value from main argument: {num_boost_round}" + ) + elif num_boost_round is None and "num_boost_round" in self.xgb_params: + num_boost_round = self.xgb_params.pop("num_boost_round") + elif num_boost_round is None and "num_boost_round" not in self.xgb_params: + warnings.warn( + f"'num_boost_round' is not set, and it is assigned to 1" + ) + num_boost_round = 1 # training XGB - self.bst = xgb.train(self.xgb_params, dtrain, num_boost_round=1, **xgb_kwargs) + self.bst = xgb.train(self.xgb_params, dtrain, num_boost_round=num_boost_round, **xgb_kwargs) self.feature_importances_ = self.bst.get_score() # getting leaves