From 544a52e6ae438871024123b5de92b2314ddb3f78 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Wed, 4 Dec 2024 23:53:58 +0800 Subject: [PATCH] [pyspark] LTR: distribute the features with same group into same partition (#11047) --- python-package/xgboost/spark/core.py | 55 ++++++++----------- .../test_with_spark/test_spark_local.py | 15 ++++- 2 files changed, 38 insertions(+), 32 deletions(-) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 32d7c1e490c8..689e747e8a5c 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -475,10 +475,7 @@ def _validate_params(self) -> None: ) if self.getOrDefault("early_stopping_rounds") is not None: - if not ( - self.isDefined(self.validationIndicatorCol) - and self.getOrDefault(self.validationIndicatorCol) != "" - ): + if not self._col_is_defined_not_empty(self.validationIndicatorCol): raise ValueError( "If 'early_stopping_rounds' param is set, you need to set " "'validation_indicator_col' param as well." @@ -517,6 +514,9 @@ def _run_on_gpu(self) -> bool: or self.getOrDefault(self.getParam("tree_method")) == "gpu_hist" ) + def _col_is_defined_not_empty(self, param: "Param[str]") -> bool: + return self.isDefined(param) and self.getOrDefault(param) != "" + def _validate_and_convert_feature_col_as_float_col_list( dataset: DataFrame, features_col_names: List[str] @@ -805,16 +805,13 @@ def _prepare_input_columns_and_feature_prop( ) select_cols.append(features_array_col) - if self.isDefined(self.weightCol) and self.getOrDefault(self.weightCol) != "": + if self._col_is_defined_not_empty(self.weightCol): select_cols.append( col(self.getOrDefault(self.weightCol)).alias(alias.weight) ) has_validation_col = False - if ( - self.isDefined(self.validationIndicatorCol) - and self.getOrDefault(self.validationIndicatorCol) != "" - ): + if self._col_is_defined_not_empty(self.validationIndicatorCol): select_cols.append( col(self.getOrDefault(self.validationIndicatorCol)).alias(alias.valid) ) @@ -823,15 +820,12 @@ def _prepare_input_columns_and_feature_prop( # which will cause exception or hanging issue when creating DMatrix. has_validation_col = True - if ( - self.isDefined(self.base_margin_col) - and self.getOrDefault(self.base_margin_col) != "" - ): + if self._col_is_defined_not_empty(self.base_margin_col): select_cols.append( col(self.getOrDefault(self.base_margin_col)).alias(alias.margin) ) - if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col) != "": + if self._col_is_defined_not_empty(self.qid_col): select_cols.append(col(self.getOrDefault(self.qid_col)).alias(alias.qid)) feature_prop = FeatureProp( @@ -862,17 +856,22 @@ def _prepare_input(self, dataset: DataFrame) -> Tuple[DataFrame, FeatureProp]: ) if self._repartition_needed(dataset): - # If validationIndicatorCol defined, and if user unionise train and validation - # dataset, users must set force_repartition to true to force repartition. - # Or else some partitions might contain only train or validation dataset. - if self.getOrDefault(self.repartition_random_shuffle): - # In some cases, spark round-robin repartition might cause data skew - # use random shuffle can address it. - dataset = dataset.repartition(num_workers, rand(1)) + if self._col_is_defined_not_empty(self.qid_col): + # For ranking problem, we need to try best the put the instances with + # same group into the same partition + dataset = dataset.repartitionByRange(num_workers, alias.qid) else: - dataset = dataset.repartition(num_workers) + # If validationIndicatorCol defined, and if user unionise train and validation + # dataset, users must set force_repartition to true to force repartition. + # Or else some partitions might contain only train or validation dataset. + if self.getOrDefault(self.repartition_random_shuffle): + # In some cases, spark round-robin repartition might cause data skew + # use random shuffle can address it. + dataset = dataset.repartition(num_workers, rand(1)) + else: + dataset = dataset.repartition(num_workers) - if self.isDefined(self.qid_col) and self.getOrDefault(self.qid_col) != "": + if self._col_is_defined_not_empty(self.qid_col): # XGBoost requires qid to be sorted for each partition dataset = dataset.sortWithinPartitions(alias.qid, ascending=True) @@ -1306,10 +1305,7 @@ def _get_feature_col( def _get_pred_contrib_col_name(self) -> Optional[str]: """Return the pred_contrib_col col name""" pred_contrib_col_name = None - if ( - self.isDefined(self.pred_contrib_col) - and self.getOrDefault(self.pred_contrib_col) != "" - ): + if self._col_is_defined_not_empty(self.pred_contrib_col): pred_contrib_col_name = self.getOrDefault(self.pred_contrib_col) return pred_contrib_col_name @@ -1413,10 +1409,7 @@ def _transform(self, dataset: DataFrame) -> DataFrame: xgb_sklearn_model = self._xgb_sklearn_model base_margin_col = None - if ( - self.isDefined(self.base_margin_col) - and self.getOrDefault(self.base_margin_col) != "" - ): + if self._col_is_defined_not_empty(self.base_margin_col): base_margin_col = col(self.getOrDefault(self.base_margin_col)).alias( alias.margin ) diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index 1f8374e06d11..79569c7fd373 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -4,7 +4,7 @@ import tempfile import uuid from collections import namedtuple -from typing import Generator, Sequence +from typing import Generator, Iterable, List, Sequence import numpy as np import pytest @@ -1794,3 +1794,16 @@ def test_ranker_qid_sorted(self, ltr_data: LTRData) -> None: assert ranker.getOrDefault(ranker.objective) == "rank:ndcg" model = ranker.fit(ltr_data.df_train_1) model.transform(ltr_data.df_test).collect() + + def test_ranker_same_qid_in_same_partition(self, ltr_data: LTRData) -> None: + ranker = SparkXGBRanker(qid_col="qid", num_workers=4, force_repartition=True) + df, _ = ranker._prepare_input(ltr_data.df_train_1) + + def f(iterator: Iterable) -> List[int]: + yield list(set(iterator)) + + rows = df.select("qid").rdd.mapPartitions(f).collect() + assert len(rows) == 4 + for row in rows: + assert len(row) == 1 + assert row[0].qid in [6, 7, 8, 9]