Skip to content

Commit

Permalink
[CELEBORN-1319] Optimize skew partition logic for Reduce Mode to avoi…
Browse files Browse the repository at this point in the history
…d sorting shuffle files

### What changes were proposed in this pull request?
Add logic to support avoid sorting shuffle files for Reduce mode when optimize skew partitions

### Why are the changes needed?
Current logic need sorting shuffle files when read Reduce mode skew partition shuffle files, we found some shuffle sorting timeout and performance issue

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Cluster test and uts

Closes apache#2373 from wangshengjie123/optimize-skew-partition.

Lead-authored-by: wangshengjie <[email protected]>
Co-authored-by: wangshengjie3 <[email protected]>
Co-authored-by: Fu Chen <[email protected]>
Co-authored-by: Shuang <[email protected]>
Co-authored-by: wangshengjie3 <[email protected]>
Co-authored-by: Fei Wang <[email protected]>
Co-authored-by: Wang, Fei <[email protected]>
Signed-off-by: Shuang <[email protected]>
  • Loading branch information
5 people committed Feb 19, 2025
1 parent 7ca69e2 commit d659e06
Show file tree
Hide file tree
Showing 38 changed files with 2,599 additions and 135 deletions.
315 changes: 315 additions & 0 deletions assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch

Large diffs are not rendered by default.

312 changes: 312 additions & 0 deletions assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch

Large diffs are not rendered by default.

312 changes: 312 additions & 0 deletions assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch

Large diffs are not rendered by default.

312 changes: 312 additions & 0 deletions assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.celeborn;

import java.util.*;

import org.apache.commons.lang3.tuple.Pair;

import org.apache.celeborn.common.protocol.PartitionLocation;

public class CelebornPartitionUtil {
/**
* The general idea is to divide each skew partition into smaller partitions:
*
* <p>- Spark driver will calculate the number of sub-partitions: {@code subPartitionSize =
* skewPartitionTotalSize / subPartitionTargetSize}
*
* <p>- In Celeborn, we divide the skew partition into {@code subPartitionSize} small partitions
* by PartitionLocation chunk offsets. This allows them to run in parallel Spark tasks.
*
* <p>For example, one skewed partition has 2 PartitionLocation:
*
* <ul>
* <li>PartitionLocation 0 with chunk offset [0L, 100L, 200L, 300L, 500L, 1000L]
* <li>PartitionLocation 1 with chunk offset [0L, 200L, 500L, 800L, 900L, 1000L]
* </ul>
*
* If we want to divide it into 3 sub-partitions (each sub-partition target size is 2000/3), the
* result will be:
*
* <ul>
* <li>sub-partition 0: uniqueId0 -> (0, 3)
* <li>sub-partition 1: uniqueId0 -> (4, 4), uniqueId1 -> (0, 0)
* <li>sub-partition 2: uniqueId1 -> (1, 4)
* </ul>
*
* Note: (0, 3) means chunks with chunkIndex 0-1-2-3, four chunks.
*
* @param locations PartitionLocation information belonging to the reduce partition
* @param subPartitionSize the number of sub-partitions separated from the reduce partition
* @param subPartitionIndex current sub-partition index
* @return a map of partitionUniqueId to chunkRange pairs for one subtask of skew partitions
*/
public static Map<String, Pair<Integer, Integer>> splitSkewedPartitionLocations(
ArrayList<PartitionLocation> locations, int subPartitionSize, int subPartitionIndex) {
locations.sort(Comparator.comparing((PartitionLocation p) -> p.getUniqueId()));
long totalPartitionSize =
locations.stream().mapToLong((PartitionLocation p) -> p.getStorageInfo().fileSize).sum();
long step = totalPartitionSize / subPartitionSize;
long startOffset = step * subPartitionIndex;
long endOffset =
subPartitionIndex < subPartitionSize - 1
? step * (subPartitionIndex + 1)
: totalPartitionSize + 1; // last subPartition should include all remaining data

long partitionLocationOffset = 0;
Map<String, Pair<Integer, Integer>> chunkRange = new HashMap<>();
for (PartitionLocation p : locations) {
int left = -1;
int right = -1;
Iterator<Long> chunkOffsets = p.getStorageInfo().getChunkOffsets().iterator();
// Start from index 1 since the first chunk offset is always 0.
chunkOffsets.next();
int j = 1;
while (chunkOffsets.hasNext()) {
long currentOffset = partitionLocationOffset + chunkOffsets.next();
if (currentOffset > startOffset && left < 0) {
left = j - 1;
}
if (currentOffset <= endOffset) {
right = j - 1;
}
if (left >= 0 && right >= 0) {
chunkRange.put(p.getUniqueId(), Pair.of(left, right));
}
j++;
}
partitionLocationOffset += p.getStorageInfo().getFileSize();
if (partitionLocationOffset > endOffset) {
break;
}
}
return chunkRange;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
package org.apache.spark.shuffle.celeborn

import java.io.IOException
import java.nio.file.Files
import java.util
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, Set => JSet}
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeoutException, TimeUnit}
import java.util.concurrent.atomic.AtomicReference

import scala.collection.JavaConverters._

import com.google.common.annotations.VisibleForTesting
import org.apache.commons.lang3.tuple.Pair
import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext}
import org.apache.spark.celeborn.ExceptionMakerHelper
import org.apache.spark.internal.Logging
Expand All @@ -35,7 +35,7 @@ import org.apache.spark.shuffle.celeborn.CelebornShuffleReader.streamCreatorPool
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

import org.apache.celeborn.client.{DummyShuffleClient, ShuffleClient}
import org.apache.celeborn.client.{ClientUtils, ShuffleClient}
import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups
import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
import org.apache.celeborn.common.CelebornConf
Expand Down Expand Up @@ -122,15 +122,41 @@ class CelebornShuffleReader[K, C](
}

// host-port -> (TransportClient, PartitionLocation Array, PbOpenStreamList)
val workerRequestMap = new util.HashMap[
val workerRequestMap = new JHashMap[
String,
(TransportClient, util.ArrayList[PartitionLocation], PbOpenStreamList.Builder)]()
(TransportClient, JArrayList[PartitionLocation], PbOpenStreamList.Builder)]()
// partitionId -> (partition uniqueId -> chunkRange pair)
val partitionId2ChunkRange = new JHashMap[Int, JMap[String, Pair[Integer, Integer]]]()

val partitionId2PartitionLocations = new JHashMap[Int, JSet[PartitionLocation]]()

var partCnt = 0

// if startMapIndex > endMapIndex, means partition is skew partition and read by Celeborn implementation.
// locations will split to sub-partitions with startMapIndex size.
val splitSkewPartitionWithoutMapRange =
ClientUtils.readSkewPartitionWithoutMapRange(conf, startMapIndex, endMapIndex)

(startPartition until endPartition).foreach { partitionId =>
if (fileGroups.partitionGroups.containsKey(partitionId)) {
fileGroups.partitionGroups.get(partitionId).asScala.foreach { location =>
var locations = fileGroups.partitionGroups.get(partitionId)
if (splitSkewPartitionWithoutMapRange) {
val partitionLocation2ChunkRange = CelebornPartitionUtil.splitSkewedPartitionLocations(
new JArrayList(locations),
startMapIndex,
endMapIndex)
partitionId2ChunkRange.put(partitionId, partitionLocation2ChunkRange)
// filter locations avoid OPEN_STREAM when split skew partition without map range
val filterLocations = locations.asScala
.filter { location =>
null != partitionLocation2ChunkRange &&
partitionLocation2ChunkRange.containsKey(location.getUniqueId)
}
locations = filterLocations.asJava
partitionId2PartitionLocations.put(partitionId, locations)
}

locations.asScala.foreach { location =>
partCnt += 1
val hostPort = location.hostAndFetchPort
if (!workerRequestMap.containsKey(hostPort)) {
Expand All @@ -142,7 +168,7 @@ class CelebornShuffleReader[K, C](
pbOpenStreamList.setShuffleKey(shuffleKey)
workerRequestMap.put(
hostPort,
(client, new util.ArrayList[PartitionLocation], pbOpenStreamList))
(client, new JArrayList[PartitionLocation], pbOpenStreamList))
} catch {
case ex: Exception =>
shuffleClient.excludeFailedFetchLocation(location.hostAndFetchPort, ex)
Expand Down Expand Up @@ -203,13 +229,22 @@ class CelebornShuffleReader[K, C](

def createInputStream(partitionId: Int): Unit = {
val locations =
if (fileGroups.partitionGroups.containsKey(partitionId)) {
new util.ArrayList(fileGroups.partitionGroups.get(partitionId))
} else new util.ArrayList[PartitionLocation]()
if (splitSkewPartitionWithoutMapRange) {
partitionId2PartitionLocations.get(partitionId)
} else {
fileGroups.partitionGroups.get(partitionId)
}

val locationList =
if (null == locations) {
new JArrayList[PartitionLocation]()
} else {
new JArrayList[PartitionLocation](locations)
}
val streamHandlers =
if (locations != null) {
val streamHandlerArr = new util.ArrayList[PbStreamHandler](locations.size())
locations.asScala.foreach { loc =>
val streamHandlerArr = new JArrayList[PbStreamHandler](locationList.size)
locationList.asScala.foreach { loc =>
streamHandlerArr.add(locationStreamHandlerMap.get(loc))
}
streamHandlerArr
Expand All @@ -226,8 +261,10 @@ class CelebornShuffleReader[K, C](
endMapIndex,
if (throwsFetchFailure) ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER
else null,
locations,
locationList,
streamHandlers,
fileGroups.pushFailedBatches,
partitionId2ChunkRange.get(partitionId),
fileGroups.mapAttempts,
metricsCallback)
streams.put(partitionId, inputStream)
Expand Down Expand Up @@ -414,7 +451,6 @@ class CelebornShuffleReader[K, C](
def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = {
dep.serializer.newInstance()
}

}

object CelebornShuffleReader {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.celeborn;

import java.util.*;

import org.apache.commons.lang3.tuple.Pair;
import org.junit.Assert;
import org.junit.Test;

import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.StorageInfo;

public class CelebornPartitionUtilSuiteJ {
@Test
public void testSkewPartitionSplit() {

ArrayList<PartitionLocation> locations = new ArrayList<>();
for (int i = 0; i < 13; i++) {
locations.add(genPartitionLocation(i, new Long[] {0L, 100L, 200L, 300L, 500L, 1000L}));
}
locations.add(genPartitionLocation(91, new Long[] {0L, 1L}));

int subPartitionSize = 3;

Map<String, Pair<Integer, Integer>> result1 =
CelebornPartitionUtil.splitSkewedPartitionLocations(locations, subPartitionSize, 0);
Map<String, Pair<Integer, Integer>> expectResult1 =
genRanges(
new Object[][] {
{"0-0", 0, 4},
{"0-1", 0, 4},
{"0-10", 0, 4},
{"0-11", 0, 4},
{"0-12", 0, 2}
});
Assert.assertEquals(expectResult1, result1);

Map<String, Pair<Integer, Integer>> result2 =
CelebornPartitionUtil.splitSkewedPartitionLocations(locations, subPartitionSize, 1);
Map<String, Pair<Integer, Integer>> expectResult2 =
genRanges(
new Object[][] {
{"0-12", 3, 4},
{"0-2", 0, 4},
{"0-3", 0, 4},
{"0-4", 0, 4},
{"0-5", 0, 3}
});
Assert.assertEquals(expectResult2, result2);

Map<String, Pair<Integer, Integer>> result3 =
CelebornPartitionUtil.splitSkewedPartitionLocations(locations, subPartitionSize, 2);
Map<String, Pair<Integer, Integer>> expectResult3 =
genRanges(
new Object[][] {
{"0-5", 4, 4},
{"0-6", 0, 4},
{"0-7", 0, 4},
{"0-8", 0, 4},
{"0-9", 0, 4},
{"0-91", 0, 0}
});
Assert.assertEquals(expectResult3, result3);
}

@Test
public void testBoundary() {
ArrayList<PartitionLocation> locations = new ArrayList<>();
locations.add(genPartitionLocation(0, new Long[] {0L, 100L, 200L, 300L, 400L, 500L}));

for (int i = 0; i < 5; i++) {
Map<String, Pair<Integer, Integer>> result =
CelebornPartitionUtil.splitSkewedPartitionLocations(locations, 5, i);
Map<String, Pair<Integer, Integer>> expectResult = genRanges(new Object[][] {{"0-0", i, i}});
Assert.assertEquals(expectResult, result);
}
}

@Test
public void testSplitStable() {
ArrayList<PartitionLocation> locations = new ArrayList<>();
for (int i = 0; i < 13; i++) {
locations.add(genPartitionLocation(i, new Long[] {0L, 100L, 200L, 300L, 500L, 1000L}));
}
locations.add(genPartitionLocation(91, new Long[] {0L, 1L}));

Collections.shuffle(locations);

Map<String, Pair<Integer, Integer>> result =
CelebornPartitionUtil.splitSkewedPartitionLocations(locations, 3, 0);
Map<String, Pair<Integer, Integer>> expectResult =
genRanges(
new Object[][] {
{"0-0", 0, 4},
{"0-1", 0, 4},
{"0-10", 0, 4},
{"0-11", 0, 4},
{"0-12", 0, 2}
});
Assert.assertEquals(expectResult, result);
}

private ArrayList<PartitionLocation> genPartitionLocations(Map<Integer, Long[]> epochToOffsets) {
ArrayList<PartitionLocation> locations = new ArrayList<>();
epochToOffsets.forEach(
(epoch, offsets) -> {
PartitionLocation location =
new PartitionLocation(
0, epoch, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY);
StorageInfo storageInfo =
new StorageInfo(
StorageInfo.Type.HDD,
"mountPoint",
false,
"filePath",
StorageInfo.LOCAL_DISK_MASK,
1,
Arrays.asList(offsets));
location.setStorageInfo(storageInfo);
locations.add(location);
});
return locations;
}

private PartitionLocation genPartitionLocation(int epoch, Long[] offsets) {
PartitionLocation location =
new PartitionLocation(0, epoch, "localhost", 0, 0, 0, 0, PartitionLocation.Mode.PRIMARY);
StorageInfo storageInfo =
new StorageInfo(
StorageInfo.Type.HDD,
"mountPoint",
false,
"filePath",
StorageInfo.LOCAL_DISK_MASK,
offsets[offsets.length - 1],
Arrays.asList(offsets));
location.setStorageInfo(storageInfo);
return location;
}

private Map<String, Pair<Integer, Integer>> genRanges(Object[][] inputs) {
Map<String, Pair<Integer, Integer>> ranges = new HashMap<>();
for (Object[] idToChunkRange : inputs) {
String uid = (String) idToChunkRange[0];
Pair<Integer, Integer> range = Pair.of((int) idToChunkRange[1], (int) idToChunkRange[2]);
ranges.put(uid, range);
}
return ranges;
}
}
Loading

0 comments on commit d659e06

Please sign in to comment.