Skip to content

Commit

Permalink
Added the option to select the number of boosting rounds in KaplanMey…
Browse files Browse the repository at this point in the history
…erTree
  • Loading branch information
Raul Casaña Eslava committed May 9, 2024
1 parent 436585c commit c048a96
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions xgbse/_kaplan_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def fit(
ci_width=0.683,
enable_categorical: bool = False,
feature_types: Optional[Sequence[str]] = None,
num_boost_round: int = None,
**xgb_kwargs
):
"""
Expand Down Expand Up @@ -422,6 +423,8 @@ def fit(
feature_types (Sequence[str]): Seq indicating the column type c or q, for categorical or numerical respect.
num_boost_round (int): Number of boosting rounds
Returns:
XGBSEKaplanTree: Trained instance of XGBSEKaplanTree
"""
Expand All @@ -435,9 +438,19 @@ def fit(
dtrain = convert_data_to_xgb_format(
X, y, self.xgb_params["objective"], enable_categorical=enable_categorical, feature_types=feature_types
)

if num_boost_round is not None and "num_boost_round" in self.xgb_params:
warnings.warn(
f"'num_boost_round' is also in xgb_params. Taking value from main argument: {num_boost_round}"
)
elif num_boost_round is None and "num_boost_round" in self.xgb_params:
num_boost_round = self.xgb_params.pop("num_boost_round")
elif num_boost_round is None and "num_boost_round" not in self.xgb_params:
warnings.warn(
f"'num_boost_round' is not set, and it is assigned to 1"
)
num_boost_round = 1
# training XGB
self.bst = xgb.train(self.xgb_params, dtrain, num_boost_round=1, **xgb_kwargs)
self.bst = xgb.train(self.xgb_params, dtrain, num_boost_round=num_boost_round, **xgb_kwargs)
self.feature_importances_ = self.bst.get_score()

# getting leaves
Expand Down

0 comments on commit c048a96

Please sign in to comment.