From 98765a6d5faa7d1dcaf3d6838956f03577f77168 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Wed, 27 Nov 2024 10:27:58 +0800 Subject: [PATCH] [jvm-packages] LTR: distribute the features with same group into same partition --- .../scala/spark/GpuXGBoostPluginSuite.scala | 49 +++++++++++++++++++ .../xgboost4j/scala/spark/XGBoostRanker.scala | 17 +++++++ .../scala/spark/XGBoostRankerSuite.scala | 48 ++++++++++++++++++ 3 files changed, 114 insertions(+) diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala index 6559d90c7887..a5ff2ba0f589 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/spark/GpuXGBoostPluginSuite.scala @@ -542,6 +542,55 @@ class GpuXGBoostPluginSuite extends GpuTestSuite { } } + test("Same group must be in the same partition") { + val num_workers = 3 + withGpuSparkSession() { spark => + import spark.implicits._ + val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq( + (0.1, 1, 0), + (0.1, 1, 0), + (0.1, 1, 0), + (0.1, 1, 1), + (0.1, 1, 1), + (0.1, 1, 1), + (0.1, 1, 2), + (0.1, 1, 2), + (0.1, 1, 2)), 1)).toDF("label", "f1", "group") + + // The original pattern will repartition df in a RoundRobin manner + val oriRows = df.repartition(num_workers) + .sortWithinPartitions(df.col("group")) + .select("group") + .mapPartitions { case iter => + val tmp: ArrayBuffer[Int] = ArrayBuffer.empty + while (iter.hasNext) { + val r = iter.next() + tmp.append(r.getInt(0)) + } + Iterator.single(tmp.mkString(",")) + }.collect() + assert(oriRows.length == 3) + assert(oriRows.contains("0,1,2")) + + // The fix has replaced repartition with repartitionByRange which will put the + // instances with same group into the same partition + val ranker = new XGBoostRanker().setGroupCol("group").setNumWorkers(num_workers) + val processedDf = ranker.getPlugin.get.asInstanceOf[GpuXGBoostPlugin].preprocess(ranker, df) + val rows = processedDf + .select("group") + .mapPartitions { case iter => + val tmp: ArrayBuffer[Int] = ArrayBuffer.empty + while (iter.hasNext) { + val r = iter.next() + tmp.append(r.getInt(0)) + } + Iterator.single(tmp.mkString(",")) + }.collect() + + rows.forall(Seq("0,0,0", "1,1,1", "2,2,2").contains) + } + } + test("Ranker: XGBoost-Spark should match xgboost4j") { withGpuSparkSession() { spark => import spark.implicits._ diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRanker.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRanker.scala index 14d13e34ff61..0265eac55979 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRanker.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRanker.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.{DefaultParamsReadable, Identifiable, MLReadable, MLReader} import org.apache.spark.ml.xgboost.SparkUtils import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.{DataType, DoubleType, StructType} import ml.dmlc.xgboost4j.scala.Booster @@ -62,6 +63,22 @@ class XGBoostRanker(override val uid: String, } } + /** + * Repartition the dataset to the numWorkers if needed. + * + * @param dataset to be repartition + * @return the repartitioned dataset + */ + override private[spark] def repartitionIfNeeded(dataset: Dataset[_]) = { + val numPartitions = dataset.rdd.getNumPartitions + if (getForceRepartition || getNumWorkers != numPartitions) { + // Please note that the output of repartitionByRange is not deterministic + dataset.repartitionByRange(getNumWorkers, col(getGroupCol)) + } else { + dataset + } + } + /** * Sort partition for Ranker issue. * diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRankerSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRankerSuite.scala index 81a770bfe327..063836538931 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRankerSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRankerSuite.scala @@ -151,6 +151,54 @@ class XGBoostRankerSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite }} } + test("Same group must be in the same partition") { + val spark = ss + import spark.implicits._ + val num_workers = 3 + val df = ss.createDataFrame(sc.parallelize(Seq( + (0.1, Vectors.dense(1.0, 2.0, 3.0), 0), + (0.1, Vectors.dense(0.0, 0.0, 0.0), 0), + (0.1, Vectors.dense(0.0, 3.0, 0.0), 0), + (0.1, Vectors.dense(2.0, 0.0, 4.0), 1), + (0.1, Vectors.dense(0.2, 1.2, 2.0), 1), + (0.1, Vectors.dense(0.5, 2.2, 1.7), 1), + (0.1, Vectors.dense(0.5, 2.2, 1.7), 2), + (0.1, Vectors.dense(0.5, 2.2, 1.7), 2), + (0.1, Vectors.dense(0.5, 2.2, 1.7), 2)), 1)).toDF("label", "features", "group") + + // The original pattern will repartition df in a RoundRobin manner + val oriRows = df.repartition(num_workers) + .sortWithinPartitions(df.col("group")) + .select("group") + .mapPartitions { case iter => + val tmp: ArrayBuffer[Int] = ArrayBuffer.empty + while (iter.hasNext) { + val r = iter.next() + tmp.append(r.getInt(0)) + } + Iterator.single(tmp.mkString(",")) + }.collect() + assert(oriRows.length == 3) + assert(oriRows.contains("0,1,2")) + + // The fix has replaced repartition with repartitionByRange which will put the + // instances with same group into the same partition + val ranker = new XGBoostRanker().setGroupCol("group").setNumWorkers(num_workers) + val (processedDf, _) = ranker.preprocess(df) + val rows = processedDf + .select("group") + .mapPartitions { case iter => + val tmp: ArrayBuffer[Int] = ArrayBuffer.empty + while (iter.hasNext) { + val r = iter.next() + tmp.append(r.getInt(0)) + } + Iterator.single(tmp.mkString(",")) + }.collect() + + rows.forall(Seq("0,0,0", "1,1,1", "2,2,2").contains) + } + private def runLengthEncode(input: Seq[Int]): Seq[Int] = { if (input.isEmpty) return Seq(0)