Skip to content

Commit

Permalink
Added per-call headers capabilities to OpenAiClient
Browse files Browse the repository at this point in the history
  • Loading branch information
singularity-sg committed May 9, 2024
1 parent e4d41a0 commit bc207e7
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 33 deletions.
44 changes: 26 additions & 18 deletions src/main/java/dev/ai4j/openai4j/DefaultOpenAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,12 @@ public DefaultOpenAiClient build() {
}

@Override
public SyncOrAsyncOrStreaming<CompletionResponse> completion(CompletionRequest request) {
public SyncOrAsyncOrStreaming<CompletionResponse> completion(OpenAiClientContext context,
CompletionRequest request) {
CompletionRequest syncRequest = CompletionRequest.builder().from(request).stream(null).build();

return new RequestExecutor<>(
openAiApi.completions(syncRequest, apiVersion),
openAiApi.completions(context.headers(), syncRequest, apiVersion),
r -> r,
okHttpClient,
formatUrl("completions"),
Expand All @@ -144,13 +145,13 @@ public SyncOrAsyncOrStreaming<CompletionResponse> completion(CompletionRequest r
}

@Override
public SyncOrAsyncOrStreaming<String> completion(String prompt) {
public SyncOrAsyncOrStreaming<String> completion(OpenAiClientContext context, String prompt) {
CompletionRequest request = CompletionRequest.builder().prompt(prompt).build();

CompletionRequest syncRequest = CompletionRequest.builder().from(request).stream(null).build();

return new RequestExecutor<>(
openAiApi.completions(syncRequest, apiVersion),
openAiApi.completions(context.headers(), syncRequest, apiVersion),
CompletionResponse::text,
okHttpClient,
formatUrl("completions"),
Expand All @@ -162,11 +163,12 @@ public SyncOrAsyncOrStreaming<String> completion(String prompt) {
}

@Override
public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(ChatCompletionRequest request) {
public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(OpenAiClientContext context,
ChatCompletionRequest request) {
ChatCompletionRequest syncRequest = ChatCompletionRequest.builder().from(request).stream(null).build();

return new RequestExecutor<>(
openAiApi.chatCompletions(syncRequest, apiVersion),
openAiApi.chatCompletions(context.headers(), syncRequest, apiVersion),
r -> r,
okHttpClient,
formatUrl("chat/completions"),
Expand All @@ -178,13 +180,13 @@ public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(ChatComplet
}

@Override
public SyncOrAsyncOrStreaming<String> chatCompletion(String userMessage) {
public SyncOrAsyncOrStreaming<String> chatCompletion(OpenAiClientContext context, String userMessage) {
ChatCompletionRequest request = ChatCompletionRequest.builder().addUserMessage(userMessage).build();

ChatCompletionRequest syncRequest = ChatCompletionRequest.builder().from(request).stream(null).build();

return new RequestExecutor<>(
openAiApi.chatCompletions(syncRequest, apiVersion),
openAiApi.chatCompletions(context.headers(), syncRequest, apiVersion),
ChatCompletionResponse::content,
okHttpClient,
formatUrl("chat/completions"),
Expand All @@ -196,32 +198,38 @@ public SyncOrAsyncOrStreaming<String> chatCompletion(String userMessage) {
}

@Override
public SyncOrAsync<EmbeddingResponse> embedding(EmbeddingRequest request) {
return new RequestExecutor<>(openAiApi.embeddings(request, apiVersion), r -> r);
public SyncOrAsync<EmbeddingResponse> embedding(OpenAiClientContext context, EmbeddingRequest request) {
return new RequestExecutor<>(openAiApi.embeddings(context.headers(), request, apiVersion), r -> r);
}

@Override
public SyncOrAsync<List<Float>> embedding(String input) {
public SyncOrAsync<List<Float>> embedding(OpenAiClientContext context, String input) {
EmbeddingRequest request = EmbeddingRequest.builder().input(input).build();

return new RequestExecutor<>(openAiApi.embeddings(request, apiVersion), EmbeddingResponse::embedding);
return new RequestExecutor<>(openAiApi.embeddings(context.headers(), request, apiVersion),
EmbeddingResponse::embedding);
}

@Override
public SyncOrAsync<ModerationResponse> moderation(ModerationRequest request) {
return new RequestExecutor<>(openAiApi.moderations(request, apiVersion), r -> r);
public SyncOrAsync<ModerationResponse> moderation(OpenAiClientContext context,
ModerationRequest request) {
return new RequestExecutor<>(openAiApi.moderations(context.headers(), request, apiVersion),
r -> r);
}

@Override
public SyncOrAsync<ModerationResult> moderation(String input) {
public SyncOrAsync<ModerationResult> moderation(OpenAiClientContext context, String input) {
ModerationRequest request = ModerationRequest.builder().input(input).build();

return new RequestExecutor<>(openAiApi.moderations(request, apiVersion), r -> r.results().get(0));
return new RequestExecutor<>(openAiApi.moderations(context.headers(), request, apiVersion),
r -> r.results().get(0));
}

@Override
public SyncOrAsync<GenerateImagesResponse> imagesGeneration(GenerateImagesRequest request) {
return new RequestExecutor<>(openAiApi.imagesGenerations(request, apiVersion), r -> r);
public SyncOrAsync<GenerateImagesResponse> imagesGeneration(OpenAiClientContext context,
GenerateImagesRequest request) {
return new RequestExecutor<>(openAiApi.imagesGenerations(context.headers(), request, apiVersion),
r -> r);
}

private String formatUrl(String endpoint) {
Expand Down
52 changes: 48 additions & 4 deletions src/main/java/dev/ai4j/openai4j/OpenAiApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,26 @@
import dev.ai4j.openai4j.image.GenerateImagesResponse;
import dev.ai4j.openai4j.moderation.ModerationRequest;
import dev.ai4j.openai4j.moderation.ModerationResponse;
import java.util.Map;
import retrofit2.Call;
import retrofit2.http.Body;
import retrofit2.http.HeaderMap;
import retrofit2.http.Headers;
import retrofit2.http.POST;
import retrofit2.http.Query;

interface OpenAiApi {
@POST("completions")
@Headers("Content-Type: application/json")
Call<CompletionResponse> completions(@Body CompletionRequest request, @Query("api-version") String apiVersion);
Call<CompletionResponse> completions(@Body CompletionRequest request,
@Query("api-version") String apiVersion);

@POST("completions")
@Headers("Content-Type: application/json")
Call<CompletionResponse> completions(
@HeaderMap Map<String, String> headers,
@Body CompletionRequest request,
@Query("api-version") String apiVersion);

@POST("chat/completions")
@Headers("Content-Type: application/json")
Expand All @@ -28,17 +38,51 @@ Call<ChatCompletionResponse> chatCompletions(
@Query("api-version") String apiVersion
);

@POST("chat/completions")
@Headers("Content-Type: application/json")
Call<ChatCompletionResponse> chatCompletions(
@HeaderMap Map<String, String> headers,
@Body ChatCompletionRequest request,
@Query("api-version") String apiVersion
);

@POST("embeddings")
@Headers("Content-Type: application/json")
Call<EmbeddingResponse> embeddings(
@Body EmbeddingRequest request,
@Query("api-version") String apiVersion);

@POST("embeddings")
@Headers("Content-Type: application/json")
Call<EmbeddingResponse> embeddings(@Body EmbeddingRequest request, @Query("api-version") String apiVersion);
Call<EmbeddingResponse> embeddings(
@HeaderMap Map<String, String> headers,
@Body EmbeddingRequest request,
@Query("api-version") String apiVersion);

@POST("moderations")
@Headers("Content-Type: application/json")
Call<ModerationResponse> moderations(
@Body ModerationRequest request,
@Query("api-version") String apiVersion);

@POST("moderations")
@Headers("Content-Type: application/json")
Call<ModerationResponse> moderations(@Body ModerationRequest request, @Query("api-version") String apiVersion);
Call<ModerationResponse> moderations(
@HeaderMap Map<String, String> headers,
@Body ModerationRequest request,
@Query("api-version") String apiVersion);

@POST("images/generations")
@Headers({"Content-Type: application/json"})
Call<GenerateImagesResponse> imagesGenerations(
@Body GenerateImagesRequest request,
@Query("api-version") String apiVersion
);

@POST("images/generations")
@Headers({ "Content-Type: application/json" })
@Headers({"Content-Type: application/json"})
Call<GenerateImagesResponse> imagesGenerations(
@HeaderMap Map<String, String> headers,
@Body GenerateImagesRequest request,
@Query("api-version") String apiVersion
);
Expand Down
90 changes: 79 additions & 11 deletions src/main/java/dev/ai4j/openai4j/OpenAiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Duration;
import java.util.HashMap;
import java.util.List;

import dev.ai4j.openai4j.chat.ChatCompletionRequest;
Expand All @@ -26,23 +27,72 @@

public abstract class OpenAiClient {

public abstract SyncOrAsyncOrStreaming<CompletionResponse> completion(CompletionRequest request);
public abstract SyncOrAsyncOrStreaming<CompletionResponse> completion(
OpenAiClientContext clientContext, CompletionRequest request);

public abstract SyncOrAsyncOrStreaming<String> completion(String prompt);
public abstract SyncOrAsyncOrStreaming<String> completion(OpenAiClientContext clientContext,
String prompt);

public abstract SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(ChatCompletionRequest request);
public SyncOrAsyncOrStreaming<CompletionResponse> completion(CompletionRequest request) {
return completion(new OpenAiClientContext(), request);
}

public SyncOrAsyncOrStreaming<String> completion(String prompt) {
return completion(new OpenAiClientContext(), prompt);
}

public SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(
ChatCompletionRequest request) {
return chatCompletion(new OpenAiClientContext(), request);
}

public abstract SyncOrAsyncOrStreaming<ChatCompletionResponse> chatCompletion(
OpenAiClientContext clientContext,
ChatCompletionRequest request);

public SyncOrAsyncOrStreaming<String> chatCompletion(String userMessage) {
return chatCompletion(new OpenAiClientContext(), userMessage);
}

public abstract SyncOrAsyncOrStreaming<String> chatCompletion(
OpenAiClientContext clientContext,
String userMessage);

public SyncOrAsync<EmbeddingResponse> embedding(EmbeddingRequest request) {
return embedding(new OpenAiClientContext(), request);
}

public abstract SyncOrAsyncOrStreaming<String> chatCompletion(String userMessage);
public abstract SyncOrAsync<EmbeddingResponse> embedding(OpenAiClientContext clientContext,
EmbeddingRequest request);

public abstract SyncOrAsync<EmbeddingResponse> embedding(EmbeddingRequest request);
public SyncOrAsync<List<Float>> embedding(String input) {
return embedding(new OpenAiClientContext(), input);
}

public abstract SyncOrAsync<List<Float>> embedding(OpenAiClientContext clientContext,
String input);

public SyncOrAsync<ModerationResponse> moderation(ModerationRequest request) {
return moderation(new OpenAiClientContext(), request);
}

public abstract SyncOrAsync<List<Float>> embedding(String input);
public abstract SyncOrAsync<ModerationResponse> moderation(OpenAiClientContext clientContext,
ModerationRequest request);

public abstract SyncOrAsync<ModerationResponse> moderation(ModerationRequest request);
public SyncOrAsync<ModerationResult> moderation(String input) {
return moderation(new OpenAiClientContext(), input);
}

public abstract SyncOrAsync<ModerationResult> moderation(String input);
public abstract SyncOrAsync<ModerationResult> moderation(OpenAiClientContext clientContext,
String input);

public abstract SyncOrAsync<GenerateImagesResponse> imagesGeneration(GenerateImagesRequest request);
public SyncOrAsync<GenerateImagesResponse> imagesGeneration(GenerateImagesRequest request) {
return imagesGeneration(new OpenAiClientContext(), request);
}

public abstract SyncOrAsync<GenerateImagesResponse> imagesGeneration(
OpenAiClientContext clientContext,
GenerateImagesRequest request);

public abstract void shutdown();

Expand All @@ -55,8 +105,26 @@ public static OpenAiClient.Builder builder() {
return DefaultOpenAiClient.builder();
}

@SuppressWarnings("unchecked")
public abstract static class Builder<T extends OpenAiClient, B extends Builder<T, B>> {
public static class OpenAiClientContext {
private final Map<String, String> headers = new HashMap<>();

public OpenAiClientContext addHeaders(Map<String, String> headers) {
this.headers.putAll(headers);
return this;
}

public OpenAiClientContext addHeader(String key, String value) {
headers.put(key, value);
return this;
}

public Map<String, String> headers() {
return headers;
}
}

@SuppressWarnings("unchecked")
public abstract static class Builder<T extends OpenAiClient, B extends Builder<T, B>> {

public String baseUrl = "https://api.openai.com/v1/";
public String organizationId;
Expand Down

0 comments on commit bc207e7

Please sign in to comment.