diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/ChunkStreamManager.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/ChunkStreamManager.java index 7ea390168f8..9586f519f3e 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/ChunkStreamManager.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/ChunkStreamManager.java @@ -22,8 +22,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; -import scala.Tuple2; - import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -194,11 +192,6 @@ public StreamState getStreamState(long streamId) { return streams.get(streamId); } - public Tuple2 getShuffleKeyAndFileName(long streamId) { - StreamState state = streams.get(streamId); - return new Tuple2<>(state.shuffleKey, state.fileName); - } - public int getStreamsCount() { return streams.size(); } diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala index 42009835e9b..ae693e93297 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala @@ -459,9 +459,11 @@ class FetchHandler( streamType match { case StreamType.ChunkStream => val streamState = chunkStreamManager.getStreamState(streamId) - val (shuffleKey, fileName) = (streamState.shuffleKey, streamState.fileName) - workerSource.recordAppActiveConnection(client, shuffleKey) - getRawFileInfo(shuffleKey, fileName).closeStream(streamId) + if (streamState != null) { + val (shuffleKey, fileName) = (streamState.shuffleKey, streamState.fileName) + workerSource.recordAppActiveConnection(client, shuffleKey) + getRawFileInfo(shuffleKey, fileName).closeStream(streamId) + } case StreamType.CreditStream => val shuffleKey = creditStreamManager.getStreamShuffleKey(streamId) if (shuffleKey != null) { @@ -501,9 +503,15 @@ class FetchHandler( logDebug(s"Received req from ${NettyUtils.getRemoteAddress(client.getChannel)}" + s" to fetch block $streamChunkSlice") - workerSource.recordAppActiveConnection( - client, - chunkStreamManager.getShuffleKeyAndFileName(streamChunkSlice.streamId)._1) + val streamState = chunkStreamManager.getStreamState(streamChunkSlice.streamId) + if (streamState == null) { + val message = s"Stream ${streamChunkSlice.streamId} is not registered with worker. " + + "This can happen if the worker was restart recently." + logError(message) + workerSource.incCounter(WorkerSource.FETCH_CHUNK_FAIL_COUNT) + client.getChannel.writeAndFlush(new ChunkFetchFailure(streamChunkSlice, message)) + return + } maxChunkBeingTransferred.foreach { threshold => val chunksBeingTransferred = chunkStreamManager.chunksBeingTransferred // take high cpu usage @@ -518,6 +526,8 @@ class FetchHandler( } } + workerSource.recordAppActiveConnection(client, streamState.shuffleKey) + val reqStr = req.toString workerSource.startTimer(WorkerSource.FETCH_CHUNK_TIME, reqStr) val fetchTimeMetric = chunkStreamManager.getFetchTimeMetric(streamChunkSlice.streamId)