Skip to content

Commit

Permalink
[FEATURE] Support customize dispatcher of okHttpClient in DefaultOpen…
Browse files Browse the repository at this point in the history
…AiClient #37 (#38)

1. Support customize dispatcher of okHttpClient in DefaultOpenAiClient.
2. Use custom dispatcher and validate by thread id in
ChatCompletionStreamingTest.
  • Loading branch information
jsonwan authored Dec 2, 2024
1 parent 87ae803 commit 7aca342
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/main/java/dev/ai4j/openai4j/DefaultOpenAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ private DefaultOpenAiClient(Builder serviceBuilder) {
.readTimeout(serviceBuilder.readTimeout)
.writeTimeout(serviceBuilder.writeTimeout);

if (serviceBuilder.dispatcher != null) {
okHttpClientBuilder.dispatcher(serviceBuilder.dispatcher);
}

if (serviceBuilder.openAiApiKey == null && serviceBuilder.azureApiKey == null) {
throw new IllegalArgumentException("openAiApiKey OR azureApiKey must be defined");
}
Expand Down
8 changes: 8 additions & 0 deletions src/main/java/dev/ai4j/openai4j/OpenAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import dev.ai4j.openai4j.moderation.ModerationResult;
import dev.ai4j.openai4j.spi.OpenAiClientBuilderFactory;
import dev.ai4j.openai4j.spi.ServiceHelper;
import okhttp3.Dispatcher;

import java.net.InetSocketAddress;
import java.net.Proxy;
import java.nio.file.Path;
Expand Down Expand Up @@ -156,6 +158,7 @@ public abstract static class Builder<T extends OpenAiClient, B extends Builder<T
public Duration connectTimeout = Duration.ofSeconds(60);
public Duration readTimeout = Duration.ofSeconds(60);
public Duration writeTimeout = Duration.ofSeconds(60);
public Dispatcher dispatcher;
public Proxy proxy;
public String userAgent;
public boolean logRequests;
Expand Down Expand Up @@ -260,6 +263,11 @@ public B writeTimeout(Duration writeTimeout) {
return (B) this;
}

public B dispatcher(Dispatcher dispatcher) {
this.dispatcher = dispatcher;
return (B) this;
}

public B proxy(Proxy.Type type, String ip, int port) {
this.proxy = new Proxy(type, new InetSocketAddress(ip, port));
return (B) this;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
package dev.ai4j.openai4j.chat;

import dev.ai4j.openai4j.DefaultOpenAiClient;
import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.RateLimitAwareTest;
import dev.ai4j.openai4j.ResponseHandle;
import dev.ai4j.openai4j.shared.StreamOptions;
import dev.ai4j.openai4j.shared.Usage;
import okhttp3.Dispatcher;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

Expand All @@ -29,14 +36,29 @@

class ChatCompletionStreamingTest extends RateLimitAwareTest {

private static final Logger log = LoggerFactory.getLogger(ChatCompletionStreamingTest.class);
private final OpenAiClient client = OpenAiClient.builder()
.baseUrl(System.getenv("OPENAI_BASE_URL"))
.openAiApiKey(System.getenv("OPENAI_API_KEY"))
.logRequests()
.logResponses()
.logStreamingResponses()
.dispatcher(getCustomDispatcher())
.build();

private Dispatcher getCustomDispatcher() {
return new Dispatcher(
new ThreadPoolExecutor(0, 100, 60, TimeUnit.SECONDS,
new SynchronousQueue<>(),
runnable -> {
Thread thread = new Thread(runnable, "CustomDispatcher");
log.debug("CustomDispatcher thread created, id={}", thread.getId());
return thread;
}
)
);
}

@Test
void testSimpleApi() throws Exception {

Expand All @@ -46,7 +68,10 @@ void testSimpleApi() throws Exception {

client.chatCompletion(USER_MESSAGE)
.onPartialResponse(responseBuilder::append)
.onComplete(() -> future.complete(responseBuilder.toString()))
.onComplete(() -> {
log.debug("Current thread id={}", Thread.currentThread().getId());
future.complete(responseBuilder.toString());
})
.onError(future::completeExceptionally)
.execute();

Expand Down

0 comments on commit 7aca342

Please sign in to comment.