Skip to content

Commit

Permalink
Avoid lock up on unexpected ExecutorService errors while executing Lo…
Browse files Browse the repository at this point in the history
…cal Activities (#2371)
  • Loading branch information
Sushisource authored Jan 17, 2025
1 parent b187644 commit 3ad0b0e
Show file tree
Hide file tree
Showing 13 changed files with 248 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ public boolean start() {
new TaskHandlerImpl(handler),
pollerOptions,
slotSupplier.maximumSlots().orElse(Integer.MAX_VALUE),
true,
options.isUsingVirtualThreads());
poller =
new Poller<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import io.temporal.worker.tuning.LocalActivitySlotInfo;
import io.temporal.worker.tuning.SlotPermit;
import io.temporal.worker.tuning.SlotReleaseReason;
import io.temporal.workflow.Functions;
import java.util.concurrent.*;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -77,10 +78,11 @@ static final class QueuedLARequest {
}

private void processQueue() {
try {
while (running || !requestQueue.isEmpty()) {
QueuedLARequest request = requestQueue.take();
SlotPermit slotPermit;
while (running || !requestQueue.isEmpty()) {
SlotPermit slotPermit = null;
QueuedLARequest request = null;
try {
request = requestQueue.take();
try {
slotPermit = slotSupplier.reserveSlot(request.data);
} catch (InterruptedException e) {
Expand All @@ -95,9 +97,22 @@ private void processQueue() {
}
request.task.getExecutionContext().setPermit(slotPermit);
afterReservedCallback.apply(request.task);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (Throwable e) {
// Fail the workflow task if something went wrong executing the local activity (at the
// executor level, otherwise, the LA handler itself should be handling errors)
log.error("Unexpected error submitting local activity task to worker", e);
if (slotPermit != null) {
slotSupplier.releaseSlot(SlotReleaseReason.error(new RuntimeException(e)), slotPermit);
}
if (request != null) {
LocalActivityExecutionContext executionContext = request.task.getExecutionContext();
executionContext.callback(
LocalActivityResult.processingFailed(
executionContext.getActivityId(), request.task.getAttemptTask().getAttempt(), e));
}
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,6 @@ public boolean start() {
new AttemptTaskHandlerImpl(handler),
pollerOptions,
slotSupplier.maximumSlots().orElse(Integer.MAX_VALUE),
false,
options.isUsingVirtualThreads());

this.workerMetricsScope.counter(MetricsType.WORKER_START_COUNTER).inc(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ public boolean start() {
new TaskHandlerImpl(handler),
pollerOptions,
slotSupplier.maximumSlots().orElse(Integer.MAX_VALUE),
true,
options.isUsingVirtualThreads());
poller =
new Poller<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ public interface TaskHandler<TT> {
@Nonnull String identity,
@Nonnull TaskHandler<T> handler,
@Nonnull PollerOptions pollerOptions,
int workerTaskSlots,
boolean synchronousQueue,
int threadPoolMax,
boolean useVirtualThreads) {
this.namespace = Objects.requireNonNull(namespace);
this.taskQueue = Objects.requireNonNull(taskQueue);
Expand All @@ -63,8 +62,10 @@ public interface TaskHandler<TT> {

this.pollThreadNamePrefix =
pollerOptions.getPollThreadNamePrefix().replaceFirst("Poller", "Executor");
// If virtual threads are enabled, we use a virtual thread executor.
if (useVirtualThreads) {
if (pollerOptions.getPollerTaskExecutorOverride() != null) {
this.taskExecutor = pollerOptions.getPollerTaskExecutorOverride();
} else if (useVirtualThreads) {
// If virtual threads are enabled, we use a virtual thread executor.
AtomicInteger threadIndex = new AtomicInteger();
this.taskExecutor =
VirtualThreadDelegate.newVirtualThreadExecutor(
Expand All @@ -74,18 +75,7 @@ public interface TaskHandler<TT> {
});
} else {
ThreadPoolExecutor threadPoolTaskExecutor =
new ThreadPoolExecutor(
// for SynchronousQueue we can afford to set it to 0, because the queue is always full
// or empty
// for LinkedBlockingQueue we have to set slots to workerTaskSlots to avoid situation
// when the queue grows, but the amount of threads is not, because the queue is not
// (and
// never) full
synchronousQueue ? 0 : workerTaskSlots,
workerTaskSlots,
10,
TimeUnit.SECONDS,
synchronousQueue ? new SynchronousQueue<>() : new LinkedBlockingQueue<>());
new ThreadPoolExecutor(0, threadPoolMax, 10, TimeUnit.SECONDS, new SynchronousQueue<>());
threadPoolTaskExecutor.allowCoreThreadTimeOut(true);
threadPoolTaskExecutor.setThreadFactory(
new ExecutorThreadFactory(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ private void logPollErrors(Thread t, Throwable e) {

/**
* Some exceptions are considered normal during shutdown {@link #shouldIgnoreDuringShutdown} and
* we log them in the most quite manner.
* we log them in the most quiet manner.
*
* @param t thread where the exception happened
* @param e the exception itself
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.time.Duration;
import java.util.concurrent.ExecutorService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -63,6 +64,7 @@ public static final class Builder {
private String pollThreadNamePrefix;
private Thread.UncaughtExceptionHandler uncaughtExceptionHandler;
private boolean usingVirtualThreads;
private ExecutorService pollerTaskExecutorOverride;

private Builder() {}

Expand All @@ -81,6 +83,7 @@ private Builder(PollerOptions options) {
this.pollThreadNamePrefix = options.getPollThreadNamePrefix();
this.uncaughtExceptionHandler = options.getUncaughtExceptionHandler();
this.usingVirtualThreads = options.isUsingVirtualThreads();
this.pollerTaskExecutorOverride = options.getPollerTaskExecutorOverride();
}

/** Defines interval for measuring poll rate. Larger the interval more spiky can be the load. */
Expand Down Expand Up @@ -162,6 +165,12 @@ public Builder setUsingVirtualThreads(boolean usingVirtualThreads) {
return this;
}

/** Override the task executor ExecutorService */
public Builder setPollerTaskExecutorOverride(ExecutorService overrideTaskExecutor) {
this.pollerTaskExecutorOverride = overrideTaskExecutor;
return this;
}

public PollerOptions build() {
if (uncaughtExceptionHandler == null) {
uncaughtExceptionHandler =
Expand Down Expand Up @@ -189,7 +198,8 @@ public PollerOptions build() {
pollThreadCount,
uncaughtExceptionHandler,
pollThreadNamePrefix,
usingVirtualThreads);
usingVirtualThreads,
pollerTaskExecutorOverride);
}
}

Expand All @@ -206,6 +216,7 @@ public PollerOptions build() {
private final Thread.UncaughtExceptionHandler uncaughtExceptionHandler;
private final String pollThreadNamePrefix;
private final boolean usingVirtualThreads;
private final ExecutorService pollerTaskExecutorOverride;

private PollerOptions(
int maximumPollRateIntervalMilliseconds,
Expand All @@ -218,7 +229,8 @@ private PollerOptions(
int pollThreadCount,
Thread.UncaughtExceptionHandler uncaughtExceptionHandler,
String pollThreadNamePrefix,
boolean usingVirtualThreads) {
boolean usingVirtualThreads,
ExecutorService pollerTaskExecutorOverride) {
this.maximumPollRateIntervalMilliseconds = maximumPollRateIntervalMilliseconds;
this.maximumPollRatePerSecond = maximumPollRatePerSecond;
this.backoffCoefficient = backoffCoefficient;
Expand All @@ -230,6 +242,7 @@ private PollerOptions(
this.uncaughtExceptionHandler = uncaughtExceptionHandler;
this.pollThreadNamePrefix = pollThreadNamePrefix;
this.usingVirtualThreads = usingVirtualThreads;
this.pollerTaskExecutorOverride = pollerTaskExecutorOverride;
}

public int getMaximumPollRateIntervalMilliseconds() {
Expand Down Expand Up @@ -276,6 +289,10 @@ public boolean isUsingVirtualThreads() {
return usingVirtualThreads;
}

public ExecutorService getPollerTaskExecutorOverride() {
return pollerTaskExecutorOverride;
}

@Override
public String toString() {
return "PollerOptions{"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ public boolean start() {
new TaskHandlerImpl(handler),
pollerOptions,
this.slotSupplier.maximumSlots().orElse(Integer.MAX_VALUE),
true,
options.isUsingVirtualThreads());
stickyQueueBalancer =
new StickyQueueBalancer(
Expand Down
7 changes: 6 additions & 1 deletion temporal-sdk/src/main/java/io/temporal/worker/Worker.java
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,12 @@ private static SingleWorkerOptions toLocalActivityOptions(
List<ContextPropagator> contextPropagators,
Scope metricsScope) {
return toSingleWorkerOptions(factoryOptions, options, clientOptions, contextPropagators)
.setPollerOptions(PollerOptions.newBuilder().setPollThreadCount(1).build())
.setPollerOptions(
PollerOptions.newBuilder()
.setPollThreadCount(1)
.setPollerTaskExecutorOverride(
factoryOptions.getOverrideLocalActivityTaskExecutor())
.build())
.setMetricsScope(metricsScope)
.setUsingVirtualThreads(options.isUsingVirtualThreadsOnLocalActivityWorker())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

package io.temporal.worker;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.temporal.common.Experimental;
import io.temporal.common.interceptors.WorkerInterceptor;
import java.time.Duration;
import java.util.concurrent.ExecutorService;
import javax.annotation.Nullable;

public class WorkerFactoryOptions {
Expand Down Expand Up @@ -57,6 +59,7 @@ public static class Builder {
private WorkerInterceptor[] workerInterceptors;
private boolean enableLoggingInReplay;
private boolean usingVirtualWorkflowThreads;
private ExecutorService overrideLocalActivityTaskExecutor;

private Builder() {}

Expand All @@ -71,6 +74,7 @@ private Builder(WorkerFactoryOptions options) {
this.workerInterceptors = options.workerInterceptors;
this.enableLoggingInReplay = options.enableLoggingInReplay;
this.usingVirtualWorkflowThreads = options.usingVirtualWorkflowThreads;
this.overrideLocalActivityTaskExecutor = options.overrideLocalActivityTaskExecutor;
}

/**
Expand Down Expand Up @@ -143,6 +147,14 @@ public Builder setWorkflowHostLocalPollThreadCount(int workflowHostLocalPollThre
return this;
}

/** For internal use only. Overrides the local activity task ExecutorService. */
@VisibleForTesting
Builder setOverrideLocalActivityTaskExecutor(
ExecutorService overrideLocalActivityTaskExecutor) {
this.overrideLocalActivityTaskExecutor = overrideLocalActivityTaskExecutor;
return this;
}

public WorkerFactoryOptions build() {
return new WorkerFactoryOptions(
workflowCacheSize,
Expand All @@ -151,6 +163,7 @@ public WorkerFactoryOptions build() {
workerInterceptors,
enableLoggingInReplay,
usingVirtualWorkflowThreads,
overrideLocalActivityTaskExecutor,
false);
}

Expand All @@ -162,6 +175,7 @@ public WorkerFactoryOptions validateAndBuildWithDefaults() {
workerInterceptors == null ? new WorkerInterceptor[0] : workerInterceptors,
enableLoggingInReplay,
usingVirtualWorkflowThreads,
overrideLocalActivityTaskExecutor,
true);
}
}
Expand All @@ -172,6 +186,7 @@ public WorkerFactoryOptions validateAndBuildWithDefaults() {
private final WorkerInterceptor[] workerInterceptors;
private final boolean enableLoggingInReplay;
private final boolean usingVirtualWorkflowThreads;
private final ExecutorService overrideLocalActivityTaskExecutor;

private WorkerFactoryOptions(
int workflowCacheSize,
Expand All @@ -180,6 +195,7 @@ private WorkerFactoryOptions(
WorkerInterceptor[] workerInterceptors,
boolean enableLoggingInReplay,
boolean usingVirtualWorkflowThreads,
ExecutorService overrideLocalActivityTaskExecutor,
boolean validate) {
if (validate) {
Preconditions.checkState(workflowCacheSize >= 0, "negative workflowCacheSize");
Expand Down Expand Up @@ -207,6 +223,7 @@ private WorkerFactoryOptions(
this.workerInterceptors = workerInterceptors;
this.enableLoggingInReplay = enableLoggingInReplay;
this.usingVirtualWorkflowThreads = usingVirtualWorkflowThreads;
this.overrideLocalActivityTaskExecutor = overrideLocalActivityTaskExecutor;
}

public int getWorkflowCacheSize() {
Expand Down Expand Up @@ -235,6 +252,16 @@ public boolean isUsingVirtualWorkflowThreads() {
return usingVirtualWorkflowThreads;
}

/**
* For internal use only.
*
* @return the ExecutorService to use for local activity tasks, or null if the default should be
* used
*/
ExecutorService getOverrideLocalActivityTaskExecutor() {
return overrideLocalActivityTaskExecutor;
}

/**
* @deprecated not used anymore by JavaSDK, this value doesn't have any effect
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
* at once.
*
* @param <SI> The type of information that will be used to reserve a slot. The three info types are
* {@link WorkflowSlotInfo}, {@link ActivitySlotInfo}, and {@link LocalActivitySlotInfo}.
* {@link WorkflowSlotInfo}, {@link ActivitySlotInfo}, {@link LocalActivitySlotInfo}, and {@link
* NexusSlotInfo}.
*/
@Experimental
public interface SlotSupplier<SI extends SlotInfo> {
Expand Down Expand Up @@ -77,11 +78,11 @@ public interface SlotSupplier<SI extends SlotInfo> {
void releaseSlot(SlotReleaseContext<SI> ctx);

/**
* Because we currently use thread pools to execute tasks, there must be *some* defined
* upper-limit on the size of the thread pool for each kind of task. You must not hand out more
* permits than this number. If unspecified, the default is {@link Integer#MAX_VALUE}. Be aware
* that if your implementation hands out unreasonable numbers of permits, you could easily
* oversubscribe the worker, and cause it to run out of resources.
* Because we use thread pools to execute tasks when virtual threads are not enabled, there must
* be *some* defined upper-limit on the size of the thread pool for each kind of task. You must
* not hand out more permits than this number. If unspecified, the default is {@link
* Integer#MAX_VALUE}. Be aware that if your implementation hands out unreasonable numbers of
* permits, you could easily oversubscribe the worker, and cause it to run out of resources.
*
* <p>If a non-empty value is returned, it is assumed to be meaningful, and the worker will emit
* {@link io.temporal.worker.MetricsType#WORKER_TASK_SLOTS_AVAILABLE} metrics based on this value.
Expand Down
Loading

0 comments on commit 3ad0b0e

Please sign in to comment.