Skip to content

Commit

Permalink
[pyspark] LTR: distribute the features with same group into same part…
Browse files Browse the repository at this point in the history
…ition (#11047)
  • Loading branch information
wbo4958 authored Dec 4, 2024
1 parent d5693bd commit 544a52e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 32 deletions.
55 changes: 24 additions & 31 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,7 @@ def _validate_params(self) -> None:
)

if self.getOrDefault("early_stopping_rounds") is not None:
if not (
self.isDefined(self.validationIndicatorCol)
and self.getOrDefault(self.validationIndicatorCol) != ""
):
if not self._col_is_defined_not_empty(self.validationIndicatorCol):
raise ValueError(
"If 'early_stopping_rounds' param is set, you need to set "
"'validation_indicator_col' param as well."
Expand Down Expand Up @@ -517,6 +514,9 @@ def _run_on_gpu(self) -> bool:
or self.getOrDefault(self.getParam("tree_method")) == "gpu_hist"
)

def _col_is_defined_not_empty(self, param: "Param[str]") -> bool:
return self.isDefined(param) and self.getOrDefault(param) != ""


def _validate_and_convert_feature_col_as_float_col_list(
dataset: DataFrame, features_col_names: List[str]
Expand Down Expand Up @@ -805,16 +805,13 @@ def _prepare_input_columns_and_feature_prop(
)
select_cols.append(features_array_col)

if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol) != "":
if self._col_is_defined_not_empty(self.weightCol):
select_cols.append(
col(self.getOrDefault(self.weightCol)).alias(alias.weight)
)

has_validation_col = False
if (
self.isDefined(self.validationIndicatorCol)
and self.getOrDefault(self.validationIndicatorCol) != ""
):
if self._col_is_defined_not_empty(self.validationIndicatorCol):
select_cols.append(
col(self.getOrDefault(self.validationIndicatorCol)).alias(alias.valid)
)
Expand All @@ -823,15 +820,12 @@ def _prepare_input_columns_and_feature_prop(
# which will cause exception or hanging issue when creating DMatrix.
has_validation_col = True

if (
self.isDefined(self.base_margin_col)
and self.getOrDefault(self.base_margin_col) != ""
):
if self._col_is_defined_not_empty(self.base_margin_col):
select_cols.append(
col(self.getOrDefault(self.base_margin_col)).alias(alias.margin)
)

if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col) != "":
if self._col_is_defined_not_empty(self.qid_col):
select_cols.append(col(self.getOrDefault(self.qid_col)).alias(alias.qid))

feature_prop = FeatureProp(
Expand Down Expand Up @@ -862,17 +856,22 @@ def _prepare_input(self, dataset: DataFrame) -> Tuple[DataFrame, FeatureProp]:
)

if self._repartition_needed(dataset):
# If validationIndicatorCol defined, and if user unionise train and validation
# dataset, users must set force_repartition to true to force repartition.
# Or else some partitions might contain only train or validation dataset.
if self.getOrDefault(self.repartition_random_shuffle):
# In some cases, spark round-robin repartition might cause data skew
# use random shuffle can address it.
dataset = dataset.repartition(num_workers, rand(1))
if self._col_is_defined_not_empty(self.qid_col):
# For ranking problem, we need to try best the put the instances with
# same group into the same partition
dataset = dataset.repartitionByRange(num_workers, alias.qid)
else:
dataset = dataset.repartition(num_workers)
# If validationIndicatorCol defined, and if user unionise train and validation
# dataset, users must set force_repartition to true to force repartition.
# Or else some partitions might contain only train or validation dataset.
if self.getOrDefault(self.repartition_random_shuffle):
# In some cases, spark round-robin repartition might cause data skew
# use random shuffle can address it.
dataset = dataset.repartition(num_workers, rand(1))
else:
dataset = dataset.repartition(num_workers)

if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col) != "":
if self._col_is_defined_not_empty(self.qid_col):
# XGBoost requires qid to be sorted for each partition
dataset = dataset.sortWithinPartitions(alias.qid, ascending=True)

Expand Down Expand Up @@ -1306,10 +1305,7 @@ def _get_feature_col(
def _get_pred_contrib_col_name(self) -> Optional[str]:
"""Return the pred_contrib_col col name"""
pred_contrib_col_name = None
if (
self.isDefined(self.pred_contrib_col)
and self.getOrDefault(self.pred_contrib_col) != ""
):
if self._col_is_defined_not_empty(self.pred_contrib_col):
pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col)

return pred_contrib_col_name
Expand Down Expand Up @@ -1413,10 +1409,7 @@ def _transform(self, dataset: DataFrame) -> DataFrame:
xgb_sklearn_model = self._xgb_sklearn_model

base_margin_col = None
if (
self.isDefined(self.base_margin_col)
and self.getOrDefault(self.base_margin_col) != ""
):
if self._col_is_defined_not_empty(self.base_margin_col):
base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias(
alias.margin
)
Expand Down
15 changes: 14 additions & 1 deletion tests/test_distributed/test_with_spark/test_spark_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tempfile
import uuid
from collections import namedtuple
from typing import Generator, Sequence
from typing import Generator, Iterable, List, Sequence

import numpy as np
import pytest
Expand Down Expand Up @@ -1794,3 +1794,16 @@ def test_ranker_qid_sorted(self, ltr_data: LTRData) -> None:
assert ranker.getOrDefault(ranker.objective) == "rank:ndcg"
model = ranker.fit(ltr_data.df_train_1)
model.transform(ltr_data.df_test).collect()

def test_ranker_same_qid_in_same_partition(self, ltr_data: LTRData) -> None:
ranker = SparkXGBRanker(qid_col="qid", num_workers=4, force_repartition=True)
df, _ = ranker._prepare_input(ltr_data.df_train_1)

def f(iterator: Iterable) -> List[int]:
yield list(set(iterator))

rows = df.select("qid").rdd.mapPartitions(f).collect()
assert len(rows) == 4
for row in rows:
assert len(row) == 1
assert row[0].qid in [6, 7, 8, 9]

0 comments on commit 544a52e

Please sign in to comment.