From 100231246c0290dd67b0cdb09b59d9ba9c39319f Mon Sep 17 00:00:00 2001 From: lixw <> Date: Wed, 25 Dec 2024 16:52:33 +0800 Subject: [PATCH 01/13] add spring boot starter of xinference --- .../pom.xml | 88 ++++++++ .../xinference/spring/AutoConfig.java | 183 +++++++++++++++ .../spring/ChatModelProperties.java | 183 +++++++++++++++ .../spring/EmbeddingModelProperties.java | 101 +++++++++ .../spring/ImageModelProperties.java | 138 ++++++++++++ .../spring/LanguageModelProperties.java | 174 ++++++++++++++ .../xinference/spring/Properties.java | 86 +++++++ .../xinference/spring/ProxyProperties.java | 42 ++++ .../spring/ScoringModelProperties.java | 119 ++++++++++ .../xinference/spring/AutoConfigIT.java | 213 ++++++++++++++++++ spring-boot-starters/pom.xml | 1 + 11 files changed, 1328 insertions(+) create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/Properties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ProxyProperties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml new file mode 100644 index 0000000..44a387f --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml @@ -0,0 +1,88 @@ + + + 4.0.0 + + dev.langchain4j + langchain4j-community-spring-boot-starters + 1.0.0-alpha1 + ../pom.xml + + + langchain4j-community-xinference-spring-boot-starter + LangChain4j :: Community :: Spring Boot starter :: Xinference + + + + Apache-2.0 + https://www.apache.org/licenses/LICENSE-2.0.txt + repo + A business-friendly OSS license + + + + + + dev.langchain4j + langchain4j-community-xinference + ${project.version} + + + + org.springframework.boot + spring-boot-starter + + + ch.qos.logback + logback-classic + + + + + + ch.qos.logback + logback-classic + + + + org.springframework.boot + spring-boot-autoconfigure-processor + true + + + + + org.springframework.boot + spring-boot-configuration-processor + true + + + + org.springframework.boot + spring-boot-starter-test + test + + + + + + + org.honton.chas + license-maven-plugin + + + + + Eclipse Public License + http://www.eclipse.org/legal/epl-v10.html + + + GNU Lesser General Public License + http://www.gnu.org/licenses/old-licenses/lgpl-2.1.html + + + + + + + diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java new file mode 100644 index 0000000..79c512f --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java @@ -0,0 +1,183 @@ +package dev.langchain4j.community.xinference.spring; + +import static dev.langchain4j.community.xinference.spring.Properties.PREFIX; + +import dev.langchain4j.community.model.xinference.XinferenceChatModel; +import dev.langchain4j.community.model.xinference.XinferenceEmbeddingModel; +import dev.langchain4j.community.model.xinference.XinferenceImageModel; +import dev.langchain4j.community.model.xinference.XinferenceLanguageModel; +import dev.langchain4j.community.model.xinference.XinferenceScoringModel; +import dev.langchain4j.community.model.xinference.XinferenceStreamingChatModel; +import dev.langchain4j.community.model.xinference.XinferenceStreamingLanguageModel; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; + +@AutoConfiguration +@EnableConfigurationProperties(Properties.class) +public class AutoConfig { + @Bean + @ConditionalOnProperty(PREFIX + ".chat-model.base-url") + public XinferenceChatModel chatModel(Properties properties) { + ChatModelProperties chatModelProperties = properties.getChatModel(); + return XinferenceChatModel.builder() + .baseUrl(chatModelProperties.getBaseUrl()) + .apiKey(chatModelProperties.getApiKey()) + .modelName(chatModelProperties.getModelName()) + .temperature(chatModelProperties.getTemperature()) + .topP(chatModelProperties.getTopP()) + .stop(chatModelProperties.getStop()) + .maxTokens(chatModelProperties.getMaxRetries()) + .presencePenalty(chatModelProperties.getPresencePenalty()) + .frequencyPenalty(chatModelProperties.getFrequencyPenalty()) + .seed(chatModelProperties.getSeed()) + .user(chatModelProperties.getUser()) + .toolChoice(chatModelProperties.getToolChoice()) + .parallelToolCalls(chatModelProperties.getParallelToolCalls()) + .maxRetries(chatModelProperties.getMaxRetries()) + .timeout(chatModelProperties.getTimeout()) + .proxy(ProxyProperties.convert(chatModelProperties.getProxy())) + .logRequests(chatModelProperties.getLogRequests()) + .logResponses(chatModelProperties.getLogResponses()) + .customHeaders(chatModelProperties.getCustomHeaders()) + .build(); + } + + @Bean + @ConditionalOnProperty(PREFIX + ".streaming-chat-model.base-url") + public XinferenceStreamingChatModel streamingChatModel(Properties properties) { + ChatModelProperties chatModelProperties = properties.getStreamingChatModel(); + return XinferenceStreamingChatModel.builder() + .baseUrl(chatModelProperties.getBaseUrl()) + .apiKey(chatModelProperties.getApiKey()) + .modelName(chatModelProperties.getModelName()) + .temperature(chatModelProperties.getTemperature()) + .topP(chatModelProperties.getTopP()) + .stop(chatModelProperties.getStop()) + .maxTokens(chatModelProperties.getMaxRetries()) + .presencePenalty(chatModelProperties.getPresencePenalty()) + .frequencyPenalty(chatModelProperties.getFrequencyPenalty()) + .seed(chatModelProperties.getSeed()) + .user(chatModelProperties.getUser()) + .toolChoice(chatModelProperties.getToolChoice()) + .parallelToolCalls(chatModelProperties.getParallelToolCalls()) + .timeout(chatModelProperties.getTimeout()) + .proxy(ProxyProperties.convert(chatModelProperties.getProxy())) + .logRequests(chatModelProperties.getLogRequests()) + .logResponses(chatModelProperties.getLogResponses()) + .customHeaders(chatModelProperties.getCustomHeaders()) + .build(); + } + + @Bean + @ConditionalOnProperty(PREFIX + ".language-model.base-url") + public XinferenceLanguageModel languageModel(Properties properties) { + LanguageModelProperties languageModelProperties = properties.getLanguageModel(); + return XinferenceLanguageModel.builder() + .baseUrl(languageModelProperties.getBaseUrl()) + .apiKey(languageModelProperties.getApiKey()) + .modelName(languageModelProperties.getModelName()) + .maxTokens(languageModelProperties.getMaxTokens()) + .temperature(languageModelProperties.getTemperature()) + .topP(languageModelProperties.getTopP()) + .logprobs(languageModelProperties.getLogprobs()) + .echo(languageModelProperties.getEcho()) + .stop(languageModelProperties.getStop()) + .presencePenalty(languageModelProperties.getPresencePenalty()) + .frequencyPenalty(languageModelProperties.getFrequencyPenalty()) + .user(languageModelProperties.getUser()) + .maxRetries(languageModelProperties.getMaxRetries()) + .timeout(languageModelProperties.getTimeout()) + .proxy(ProxyProperties.convert(languageModelProperties.getProxy())) + .logRequests(languageModelProperties.getLogRequests()) + .logResponses(languageModelProperties.getLogResponses()) + .customHeaders(languageModelProperties.getCustomHeaders()) + .build(); + } + + @Bean + @ConditionalOnProperty(PREFIX + ".streaming-language-model.base-url") + public XinferenceStreamingLanguageModel streamingLanguageModel(Properties properties) { + LanguageModelProperties languageModelProperties = properties.getStreamingLanguageModel(); + return XinferenceStreamingLanguageModel.builder() + .baseUrl(languageModelProperties.getBaseUrl()) + .apiKey(languageModelProperties.getApiKey()) + .modelName(languageModelProperties.getModelName()) + .maxTokens(languageModelProperties.getMaxTokens()) + .temperature(languageModelProperties.getTemperature()) + .topP(languageModelProperties.getTopP()) + .logprobs(languageModelProperties.getLogprobs()) + .echo(languageModelProperties.getEcho()) + .stop(languageModelProperties.getStop()) + .presencePenalty(languageModelProperties.getPresencePenalty()) + .frequencyPenalty(languageModelProperties.getFrequencyPenalty()) + .user(languageModelProperties.getUser()) + .timeout(languageModelProperties.getTimeout()) + .proxy(ProxyProperties.convert(languageModelProperties.getProxy())) + .logRequests(languageModelProperties.getLogRequests()) + .logResponses(languageModelProperties.getLogResponses()) + .customHeaders(languageModelProperties.getCustomHeaders()) + .build(); + } + + @Bean + @ConditionalOnProperty(PREFIX + ".embedding-model.base-url") + public XinferenceEmbeddingModel embeddingModel(Properties properties) { + EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel(); + return XinferenceEmbeddingModel.builder() + .baseUrl(embeddingModelProperties.getBaseUrl()) + .apiKey(embeddingModelProperties.getApiKey()) + .modelName(embeddingModelProperties.getModelName()) + .user(embeddingModelProperties.getUser()) + .maxRetries(embeddingModelProperties.getMaxRetries()) + .timeout(embeddingModelProperties.getTimeout()) + .proxy(ProxyProperties.convert(embeddingModelProperties.getProxy())) + .logRequests(embeddingModelProperties.getLogRequests()) + .logResponses(embeddingModelProperties.getLogResponses()) + .customHeaders(embeddingModelProperties.getCustomHeaders()) + .build(); + } + + @Bean + @ConditionalOnProperty(PREFIX + ".image-model.base-url") + public XinferenceImageModel imageModel(Properties properties) { + ImageModelProperties imageModelProperties = properties.getImageModel(); + return XinferenceImageModel.builder() + .baseUrl(imageModelProperties.getBaseUrl()) + .apiKey(imageModelProperties.getApiKey()) + .modelName(imageModelProperties.getModelName()) + .negativePrompt(imageModelProperties.getNegativePrompt()) + .responseFormat(imageModelProperties.getResponseFormat()) + .size(imageModelProperties.getSize()) + .kwargs(imageModelProperties.getKwargs()) + .user(imageModelProperties.getUser()) + .maxRetries(imageModelProperties.getMaxRetries()) + .timeout(imageModelProperties.getTimeout()) + .proxy(ProxyProperties.convert(imageModelProperties.getProxy())) + .logRequests(imageModelProperties.getLogRequests()) + .logResponses(imageModelProperties.getLogResponses()) + .customHeaders(imageModelProperties.getCustomHeaders()) + .build(); + } + + @Bean + @ConditionalOnProperty(PREFIX + ".scoring-model.base-url") + public XinferenceScoringModel scoringModel(Properties properties) { + ScoringModelProperties scoringModelProperties = properties.getScoringModel(); + return XinferenceScoringModel.builder() + .baseUrl(scoringModelProperties.getBaseUrl()) + .apiKey(scoringModelProperties.getApiKey()) + .modelName(scoringModelProperties.getModelName()) + .topN(scoringModelProperties.getTopN()) + .returnDocuments(scoringModelProperties.getReturnDocuments()) + .returnLen(scoringModelProperties.getReturnLen()) + .maxRetries(scoringModelProperties.getMaxRetries()) + .timeout(scoringModelProperties.getTimeout()) + .proxy(ProxyProperties.convert(scoringModelProperties.getProxy())) + .logRequests(scoringModelProperties.getLogRequests()) + .logResponses(scoringModelProperties.getLogResponses()) + .customHeaders(scoringModelProperties.getCustomHeaders()) + .build(); + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java new file mode 100644 index 0000000..13e0705 --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java @@ -0,0 +1,183 @@ +package dev.langchain4j.community.xinference.spring; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +public class ChatModelProperties { + private String baseUrl; + private String apiKey; + private String modelName; + private Double temperature; + private Double topP; + private List stop; + private Integer maxTokens; + private Double presencePenalty; + private Double frequencyPenalty; + private Integer seed; + private String user; + private Object toolChoice; + private Boolean parallelToolCalls; + private Integer maxRetries; + private Duration timeout; + + @NestedConfigurationProperty + private ProxyProperties proxy; + + private Boolean logRequests; + private Boolean logResponses; + private Map customHeaders; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(final String baseUrl) { + this.baseUrl = baseUrl; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(final String apiKey) { + this.apiKey = apiKey; + } + + public String getModelName() { + return modelName; + } + + public void setModelName(final String modelName) { + this.modelName = modelName; + } + + public Double getTemperature() { + return temperature; + } + + public void setTemperature(final Double temperature) { + this.temperature = temperature; + } + + public Double getTopP() { + return topP; + } + + public void setTopP(final Double topP) { + this.topP = topP; + } + + public List getStop() { + return stop; + } + + public void setStop(final List stop) { + this.stop = stop; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(final Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Double getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(final Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public Double getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(final Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public Integer getSeed() { + return seed; + } + + public void setSeed(final Integer seed) { + this.seed = seed; + } + + public String getUser() { + return user; + } + + public void setUser(final String user) { + this.user = user; + } + + public Object getToolChoice() { + return toolChoice; + } + + public void setToolChoice(final Object toolChoice) { + this.toolChoice = toolChoice; + } + + public Boolean getParallelToolCalls() { + return parallelToolCalls; + } + + public void setParallelToolCalls(final Boolean parallelToolCalls) { + this.parallelToolCalls = parallelToolCalls; + } + + public Integer getMaxRetries() { + return maxRetries; + } + + public void setMaxRetries(final Integer maxRetries) { + this.maxRetries = maxRetries; + } + + public Duration getTimeout() { + return timeout; + } + + public void setTimeout(final Duration timeout) { + this.timeout = timeout; + } + + public ProxyProperties getProxy() { + return proxy; + } + + public void setProxy(final ProxyProperties proxy) { + this.proxy = proxy; + } + + public Boolean getLogRequests() { + return logRequests; + } + + public void setLogRequests(final Boolean logRequests) { + this.logRequests = logRequests; + } + + public Boolean getLogResponses() { + return logResponses; + } + + public void setLogResponses(final Boolean logResponses) { + this.logResponses = logResponses; + } + + public Map getCustomHeaders() { + return customHeaders; + } + + public void setCustomHeaders(final Map customHeaders) { + this.customHeaders = customHeaders; + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java new file mode 100644 index 0000000..3fe8edc --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java @@ -0,0 +1,101 @@ +package dev.langchain4j.community.xinference.spring; + +import java.time.Duration; +import java.util.Map; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +public class EmbeddingModelProperties { + private String baseUrl; + private String apiKey; + private String modelName; + private String user; + private Integer maxRetries; + private Duration timeout; + + @NestedConfigurationProperty + private ProxyProperties proxy; + + private Boolean logRequests; + private Boolean logResponses; + private Map customHeaders; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(final String baseUrl) { + this.baseUrl = baseUrl; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(final String apiKey) { + this.apiKey = apiKey; + } + + public String getModelName() { + return modelName; + } + + public void setModelName(final String modelName) { + this.modelName = modelName; + } + + public String getUser() { + return user; + } + + public void setUser(final String user) { + this.user = user; + } + + public Integer getMaxRetries() { + return maxRetries; + } + + public void setMaxRetries(final Integer maxRetries) { + this.maxRetries = maxRetries; + } + + public Duration getTimeout() { + return timeout; + } + + public void setTimeout(final Duration timeout) { + this.timeout = timeout; + } + + public ProxyProperties getProxy() { + return proxy; + } + + public void setProxy(final ProxyProperties proxy) { + this.proxy = proxy; + } + + public Boolean getLogRequests() { + return logRequests; + } + + public void setLogRequests(final Boolean logRequests) { + this.logRequests = logRequests; + } + + public Boolean getLogResponses() { + return logResponses; + } + + public void setLogResponses(final Boolean logResponses) { + this.logResponses = logResponses; + } + + public Map getCustomHeaders() { + return customHeaders; + } + + public void setCustomHeaders(final Map customHeaders) { + this.customHeaders = customHeaders; + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java new file mode 100644 index 0000000..22289dc --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java @@ -0,0 +1,138 @@ +package dev.langchain4j.community.xinference.spring; + +import dev.langchain4j.community.model.xinference.client.image.ResponseFormat; +import java.time.Duration; +import java.util.Map; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +public class ImageModelProperties { + private String baseUrl; + private String apiKey; + private String modelName; + private String negativePrompt; + private ResponseFormat responseFormat; + private String size; + private String kwargs; + private String user; + private Integer maxRetries; + private Duration timeout; + + @NestedConfigurationProperty + private ProxyProperties proxy; + + private Boolean logRequests; + private Boolean logResponses; + private Map customHeaders; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(final String baseUrl) { + this.baseUrl = baseUrl; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(final String apiKey) { + this.apiKey = apiKey; + } + + public String getModelName() { + return modelName; + } + + public void setModelName(final String modelName) { + this.modelName = modelName; + } + + public String getNegativePrompt() { + return negativePrompt; + } + + public void setNegativePrompt(final String negativePrompt) { + this.negativePrompt = negativePrompt; + } + + public ResponseFormat getResponseFormat() { + return responseFormat; + } + + public void setResponseFormat(final ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + } + + public String getSize() { + return size; + } + + public void setSize(final String size) { + this.size = size; + } + + public String getKwargs() { + return kwargs; + } + + public void setKwargs(final String kwargs) { + this.kwargs = kwargs; + } + + public String getUser() { + return user; + } + + public void setUser(final String user) { + this.user = user; + } + + public Integer getMaxRetries() { + return maxRetries; + } + + public void setMaxRetries(final Integer maxRetries) { + this.maxRetries = maxRetries; + } + + public Duration getTimeout() { + return timeout; + } + + public void setTimeout(final Duration timeout) { + this.timeout = timeout; + } + + public ProxyProperties getProxy() { + return proxy; + } + + public void setProxy(final ProxyProperties proxy) { + this.proxy = proxy; + } + + public Boolean getLogRequests() { + return logRequests; + } + + public void setLogRequests(final Boolean logRequests) { + this.logRequests = logRequests; + } + + public Boolean getLogResponses() { + return logResponses; + } + + public void setLogResponses(final Boolean logResponses) { + this.logResponses = logResponses; + } + + public Map getCustomHeaders() { + return customHeaders; + } + + public void setCustomHeaders(final Map customHeaders) { + this.customHeaders = customHeaders; + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java new file mode 100644 index 0000000..3a23229 --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java @@ -0,0 +1,174 @@ +package dev.langchain4j.community.xinference.spring; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +public class LanguageModelProperties { + private String baseUrl; + private String apiKey; + private String modelName; + private Integer maxTokens; + private Double temperature; + private Double topP; + private Integer logprobs; + private Boolean echo; + private List stop; + private Double presencePenalty; + private Double frequencyPenalty; + private String user; + private Integer maxRetries; + private Duration timeout; + + @NestedConfigurationProperty + private ProxyProperties proxy; + + private Boolean logRequests; + private Boolean logResponses; + private Map customHeaders; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(final String baseUrl) { + this.baseUrl = baseUrl; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(final String apiKey) { + this.apiKey = apiKey; + } + + public String getModelName() { + return modelName; + } + + public void setModelName(final String modelName) { + this.modelName = modelName; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(final Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Double getTemperature() { + return temperature; + } + + public void setTemperature(final Double temperature) { + this.temperature = temperature; + } + + public Double getTopP() { + return topP; + } + + public void setTopP(final Double topP) { + this.topP = topP; + } + + public Integer getLogprobs() { + return logprobs; + } + + public void setLogprobs(final Integer logprobs) { + this.logprobs = logprobs; + } + + public Boolean getEcho() { + return echo; + } + + public void setEcho(final Boolean echo) { + this.echo = echo; + } + + public List getStop() { + return stop; + } + + public void setStop(final List stop) { + this.stop = stop; + } + + public Double getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(final Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public Double getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(final Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public String getUser() { + return user; + } + + public void setUser(final String user) { + this.user = user; + } + + public Integer getMaxRetries() { + return maxRetries; + } + + public void setMaxRetries(final Integer maxRetries) { + this.maxRetries = maxRetries; + } + + public Duration getTimeout() { + return timeout; + } + + public void setTimeout(final Duration timeout) { + this.timeout = timeout; + } + + public ProxyProperties getProxy() { + return proxy; + } + + public void setProxy(final ProxyProperties proxy) { + this.proxy = proxy; + } + + public Boolean getLogRequests() { + return logRequests; + } + + public void setLogRequests(final Boolean logRequests) { + this.logRequests = logRequests; + } + + public Boolean getLogResponses() { + return logResponses; + } + + public void setLogResponses(final Boolean logResponses) { + this.logResponses = logResponses; + } + + public Map getCustomHeaders() { + return customHeaders; + } + + public void setCustomHeaders(final Map customHeaders) { + this.customHeaders = customHeaders; + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/Properties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/Properties.java new file mode 100644 index 0000000..5511231 --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/Properties.java @@ -0,0 +1,86 @@ +package dev.langchain4j.community.xinference.spring; + +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +@ConfigurationProperties(prefix = Properties.PREFIX) +public class Properties { + static final String PREFIX = "langchain4j.community.xinference"; + + @NestedConfigurationProperty + private ChatModelProperties chatModel; + + @NestedConfigurationProperty + private ChatModelProperties streamingChatModel; + + @NestedConfigurationProperty + private EmbeddingModelProperties embeddingModel; + + @NestedConfigurationProperty + private ImageModelProperties imageModel; + + @NestedConfigurationProperty + private LanguageModelProperties languageModel; + + @NestedConfigurationProperty + private LanguageModelProperties streamingLanguageModel; + + @NestedConfigurationProperty + private ScoringModelProperties scoringModel; + + public ChatModelProperties getChatModel() { + return chatModel; + } + + public void setChatModel(final ChatModelProperties chatModel) { + this.chatModel = chatModel; + } + + public ChatModelProperties getStreamingChatModel() { + return streamingChatModel; + } + + public void setStreamingChatModel(final ChatModelProperties streamingChatModel) { + this.streamingChatModel = streamingChatModel; + } + + public EmbeddingModelProperties getEmbeddingModel() { + return embeddingModel; + } + + public void setEmbeddingModel(final EmbeddingModelProperties embeddingModel) { + this.embeddingModel = embeddingModel; + } + + public ImageModelProperties getImageModel() { + return imageModel; + } + + public void setImageModel(final ImageModelProperties imageModel) { + this.imageModel = imageModel; + } + + public LanguageModelProperties getLanguageModel() { + return languageModel; + } + + public void setLanguageModel(final LanguageModelProperties languageModel) { + this.languageModel = languageModel; + } + + public LanguageModelProperties getStreamingLanguageModel() { + return streamingLanguageModel; + } + + public void setStreamingLanguageModel(final LanguageModelProperties streamingLanguageModel) { + this.streamingLanguageModel = streamingLanguageModel; + } + + public ScoringModelProperties getScoringModel() { + return scoringModel; + } + + public void setScoringModel(final ScoringModelProperties scoringModel) { + this.scoringModel = scoringModel; + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ProxyProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ProxyProperties.java new file mode 100644 index 0000000..f110310 --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ProxyProperties.java @@ -0,0 +1,42 @@ +package dev.langchain4j.community.xinference.spring; + +import java.net.InetSocketAddress; +import java.net.Proxy; +import java.util.Objects; + +public class ProxyProperties { + private Proxy.Type type; + private String host; + private Integer port; + + public Proxy.Type getType() { + return type; + } + + public void setType(final Proxy.Type type) { + this.type = type; + } + + public String getHost() { + return host; + } + + public void setHost(final String host) { + this.host = host; + } + + public Integer getPort() { + return port; + } + + public void setPort(final Integer port) { + this.port = port; + } + + public static Proxy convert(ProxyProperties properties) { + if (Objects.isNull(properties)) { + return null; + } + return new Proxy(properties.getType(), new InetSocketAddress(properties.getHost(), properties.getPort())); + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java new file mode 100644 index 0000000..29c29c6 --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java @@ -0,0 +1,119 @@ +package dev.langchain4j.community.xinference.spring; + +import java.time.Duration; +import java.util.Map; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +public class ScoringModelProperties { + private String baseUrl; + private String apiKey; + private String modelName; + private Integer topN; + private Boolean returnDocuments; + private Boolean returnLen; + private Integer maxRetries; + private Duration timeout; + + @NestedConfigurationProperty + private ProxyProperties proxy; + + private Boolean logRequests; + private Boolean logResponses; + private Map customHeaders; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(final String baseUrl) { + this.baseUrl = baseUrl; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(final String apiKey) { + this.apiKey = apiKey; + } + + public String getModelName() { + return modelName; + } + + public void setModelName(final String modelName) { + this.modelName = modelName; + } + + public Integer getTopN() { + return topN; + } + + public void setTopN(final Integer topN) { + this.topN = topN; + } + + public Boolean getReturnDocuments() { + return returnDocuments; + } + + public void setReturnDocuments(final Boolean returnDocuments) { + this.returnDocuments = returnDocuments; + } + + public Boolean getReturnLen() { + return returnLen; + } + + public void setReturnLen(final Boolean returnLen) { + this.returnLen = returnLen; + } + + public Integer getMaxRetries() { + return maxRetries; + } + + public void setMaxRetries(final Integer maxRetries) { + this.maxRetries = maxRetries; + } + + public Duration getTimeout() { + return timeout; + } + + public void setTimeout(final Duration timeout) { + this.timeout = timeout; + } + + public ProxyProperties getProxy() { + return proxy; + } + + public void setProxy(final ProxyProperties proxy) { + this.proxy = proxy; + } + + public Boolean getLogRequests() { + return logRequests; + } + + public void setLogRequests(final Boolean logRequests) { + this.logRequests = logRequests; + } + + public Boolean getLogResponses() { + return logResponses; + } + + public void setLogResponses(final Boolean logResponses) { + this.logResponses = logResponses; + } + + public Map getCustomHeaders() { + return customHeaders; + } + + public void setCustomHeaders(final Map customHeaders) { + this.customHeaders = customHeaders; + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java new file mode 100644 index 0000000..c1bda4c --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java @@ -0,0 +1,213 @@ +package dev.langchain4j.community.xinference.spring; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; + +import dev.langchain4j.community.model.xinference.XinferenceChatModel; +import dev.langchain4j.community.model.xinference.XinferenceEmbeddingModel; +import dev.langchain4j.community.model.xinference.XinferenceImageModel; +import dev.langchain4j.community.model.xinference.XinferenceLanguageModel; +import dev.langchain4j.community.model.xinference.XinferenceScoringModel; +import dev.langchain4j.community.model.xinference.XinferenceStreamingChatModel; +import dev.langchain4j.community.model.xinference.XinferenceStreamingLanguageModel; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.image.ImageModel; +import dev.langchain4j.model.language.LanguageModel; +import dev.langchain4j.model.language.StreamingLanguageModel; +import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.scoring.ScoringModel; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +/** + * Xinference Cloud + * https://docs.inference.top/zh + */ +@EnabledIfEnvironmentVariable(named = "XINFERENCE_API_KEY", matches = ".+") +class AutoConfigIT { + private static final String API_KEY = System.getenv("XINFERENCE_API_KEY"); + private static final String BASE_URL = System.getenv("XINFERENCE_BASE_URL"); + ApplicationContextRunner contextRunner = + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(AutoConfig.class)); + + @AfterEach + void waitForNextTest() throws InterruptedException { + Thread.sleep(3000); // 每个测试后延时3秒 + } + + @Test + void should_provide_chat_model() { + contextRunner + .withPropertyValues( + "langchain4j.community.xinference.chat-model.base-url=" + BASE_URL, + "langchain4j.community.xinference.chat-model.api-key=" + API_KEY, + "langchain4j.community.xinference.chat-model.model-name=qwen2-vl-instruct") + .run(context -> { + ChatLanguageModel chatLanguageModel = context.getBean(ChatLanguageModel.class); + assertThat(chatLanguageModel).isInstanceOf(XinferenceChatModel.class); + assertThat(chatLanguageModel.generate("What is the capital of Germany?")) + .contains("Berlin"); + assertThat(context.getBean(XinferenceChatModel.class)).isSameAs(chatLanguageModel); + }); + } + + @Test + void should_provide_streaming_chat_model() { + contextRunner + .withPropertyValues( + "langchain4j.community.xinference.streaming-chat-model.base-url=" + BASE_URL, + "langchain4j.community.xinference.streaming-chat-model.api-key=" + API_KEY, + "langchain4j.community.xinference.streaming-chat-model.model-name=qwen2-vl-instruct") + .run(context -> { + StreamingChatLanguageModel streamingChatLanguageModel = + context.getBean(StreamingChatLanguageModel.class); + assertThat(streamingChatLanguageModel).isInstanceOf(XinferenceStreamingChatModel.class); + CompletableFuture> future = new CompletableFuture<>(); + streamingChatLanguageModel.generate( + "What is the capital of Germany?", new StreamingResponseHandler() { + @Override + public void onNext(String token) { + } + + @Override + public void onComplete(Response response) { + future.complete(response); + } + + @Override + public void onError(Throwable error) { + } + }); + Response response = future.get(60, SECONDS); + assertThat(response.content().text()).contains("Berlin"); + assertThat(context.getBean(XinferenceStreamingChatModel.class)) + .isSameAs(streamingChatLanguageModel); + }); + } + + @Test + void should_provide_language_model() { + contextRunner + .withPropertyValues( + "langchain4j.community.xinference.language-model.base-url=" + BASE_URL, + "langchain4j.community.xinference.language-model.api-key=" + API_KEY, + "langchain4j.community.xinference.language-model.model-name=qwen2-vl-instruct", + "langchain4j.community.xinference.language-model.logRequests=true", + "langchain4j.community.xinference.language-model.logResponses=true") + .run(context -> { + LanguageModel languageModel = context.getBean(LanguageModel.class); + assertThat(languageModel).isInstanceOf(XinferenceLanguageModel.class); + assertThat(languageModel + .generate("What is the capital of Germany?") + .content()) + .contains("Berlin"); + assertThat(context.getBean(XinferenceLanguageModel.class)).isSameAs(languageModel); + }); + } + + @Test + void should_provide_streaming_language_model() { + contextRunner + .withPropertyValues( + "langchain4j.community.xinference.streaming-language-model.base-url=" + BASE_URL, + "langchain4j.community.xinference.streaming-language-model.api-key=" + API_KEY, + "langchain4j.community.xinference.streaming-language-model.model-name=qwen2-vl-instruct", + "langchain4j.community.xinference.streaming-language-model.logRequests=true", + "langchain4j.community.xinference.streaming-language-model.logResponses=true") + .run(context -> { + StreamingLanguageModel streamingLanguageModel = context.getBean(StreamingLanguageModel.class); + assertThat(streamingLanguageModel).isInstanceOf(XinferenceStreamingLanguageModel.class); + CompletableFuture> future = new CompletableFuture<>(); + streamingLanguageModel.generate( + "What is the capital of Germany?", new StreamingResponseHandler() { + @Override + public void onNext(String token) { + } + + @Override + public void onComplete(Response response) { + future.complete(response); + } + + @Override + public void onError(Throwable error) { + } + }); + Response response = future.get(60, SECONDS); + assertThat(response.content()).contains("Berlin"); + + assertThat(context.getBean(XinferenceStreamingLanguageModel.class)) + .isSameAs(streamingLanguageModel); + }); + } + + @Test + void should_provide_embedding_model() { + contextRunner + .withPropertyValues( + "langchain4j.community.xinference.embeddingModel.base-url=" + BASE_URL, + "langchain4j.community.xinference.embeddingModel.api-key=" + API_KEY, + "langchain4j.community.xinference.embeddingModel.modelName=bge-m3") + .run(context -> { + EmbeddingModel embeddingModel = context.getBean(EmbeddingModel.class); + assertThat(embeddingModel).isInstanceOf(XinferenceEmbeddingModel.class); + assertThat(embeddingModel.embed("hello world").content().dimension()) + .isEqualTo(1024); + assertThat(context.getBean(XinferenceEmbeddingModel.class)).isSameAs(embeddingModel); + }); + } + + @Test + @Disabled("Xinference Cloud Not Support") + void should_provide_sc_model() { + contextRunner + .withPropertyValues( + "langchain4j.community.xinference.scoringModel.base-url=" + BASE_URL, + "langchain4j.community.xinference.scoringModel.api-key=" + API_KEY, + "langchain4j.community.xinference.scoringModel.modelName=bge-m3") + .run(context -> { + ScoringModel scoringModel = context.getBean(ScoringModel.class); + assertThat(scoringModel).isInstanceOf(XinferenceScoringModel.class); + TextSegment catSegment = TextSegment.from("The Maine Coon is a large domesticated cat breed."); + TextSegment dogSegment = TextSegment.from( + "The sweet-faced, lovable Labrador Retriever is one of America's most popular dog breeds, year after year."); + List segments = Arrays.asList(catSegment, dogSegment); + String query = "tell me about dogs"; + Response> response = scoringModel.scoreAll(segments, query); + List scores = response.content(); + assertThat(scores).hasSize(2); + assertThat(scores.get(0)).isLessThan(scores.get(1)); + assertThat(context.getBean(XinferenceScoringModel.class)).isSameAs(scoringModel); + }); + } + + @Test + void should_provide_image_model() { + contextRunner + .withPropertyValues( + "langchain4j.community.xinference.imageModel.base-url=" + BASE_URL, + "langchain4j.community.xinference.imageModel.api-key=" + API_KEY, + "langchain4j.community.xinference.imageModel.modelName=sd3-medium") + .run(context -> { + ImageModel imageModel = context.getBean(ImageModel.class); + assertThat(imageModel).isInstanceOf(XinferenceImageModel.class); + assertThat(imageModel.generate("banana").content().base64Data()) + .isNotNull(); + assertThat(context.getBean(XinferenceImageModel.class)).isSameAs(imageModel); + }); + } +} diff --git a/spring-boot-starters/pom.xml b/spring-boot-starters/pom.xml index 2648a01..d05a547 100644 --- a/spring-boot-starters/pom.xml +++ b/spring-boot-starters/pom.xml @@ -26,6 +26,7 @@ langchain4j-community-dashscope-spring-boot-starter langchain4j-community-qianfan-spring-boot-starter + langchain4j-community-xinference-spring-boot-starter From ed6a7317b22b8ba5d6663c54bb1175e2b2fd836d Mon Sep 17 00:00:00 2001 From: lixw <> Date: Wed, 25 Dec 2024 16:55:07 +0800 Subject: [PATCH 02/13] add spring boot starter of xinference --- .../pom.xml | 3 +-- .../xinference/spring/AutoConfigIT.java | 18 ++++++------------ 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml index 44a387f..9fdbe01 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml @@ -1,6 +1,5 @@ - + 4.0.0 dev.langchain4j diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java index c1bda4c..cd2e6c5 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java @@ -21,11 +21,9 @@ import dev.langchain4j.model.language.StreamingLanguageModel; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.scoring.ScoringModel; - import java.util.Arrays; import java.util.List; import java.util.concurrent.CompletableFuture; - import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -80,8 +78,7 @@ void should_provide_streaming_chat_model() { streamingChatLanguageModel.generate( "What is the capital of Germany?", new StreamingResponseHandler() { @Override - public void onNext(String token) { - } + public void onNext(String token) {} @Override public void onComplete(Response response) { @@ -89,8 +86,7 @@ public void onComplete(Response response) { } @Override - public void onError(Throwable error) { - } + public void onError(Throwable error) {} }); Response response = future.get(60, SECONDS); assertThat(response.content().text()).contains("Berlin"); @@ -112,8 +108,8 @@ void should_provide_language_model() { LanguageModel languageModel = context.getBean(LanguageModel.class); assertThat(languageModel).isInstanceOf(XinferenceLanguageModel.class); assertThat(languageModel - .generate("What is the capital of Germany?") - .content()) + .generate("What is the capital of Germany?") + .content()) .contains("Berlin"); assertThat(context.getBean(XinferenceLanguageModel.class)).isSameAs(languageModel); }); @@ -135,8 +131,7 @@ void should_provide_streaming_language_model() { streamingLanguageModel.generate( "What is the capital of Germany?", new StreamingResponseHandler() { @Override - public void onNext(String token) { - } + public void onNext(String token) {} @Override public void onComplete(Response response) { @@ -144,8 +139,7 @@ public void onComplete(Response response) { } @Override - public void onError(Throwable error) { - } + public void onError(Throwable error) {} }); Response response = future.get(60, SECONDS); assertThat(response.content()).contains("Berlin"); From 3e16d3700ce8077f893c81b7947268a40d52c85a Mon Sep 17 00:00:00 2001 From: lixw <> Date: Thu, 26 Dec 2024 17:20:47 +0800 Subject: [PATCH 03/13] switch testcontainers --- .../pom.xml | 12 +++ .../xinference/spring/AutoConfigIT.java | 99 +++++++++++-------- .../spring/XinferenceContainer.java | 96 ++++++++++++++++++ .../xinference/spring/XinferenceUtils.java | 69 +++++++++++++ 4 files changed, 234 insertions(+), 42 deletions(-) create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceUtils.java diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml index 9fdbe01..21e5b09 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml @@ -61,6 +61,18 @@ spring-boot-starter-test test + + + org.testcontainers + testcontainers + test + + + + org.testcontainers + junit-jupiter + test + diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java index cd2e6c5..1605849 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java @@ -1,5 +1,12 @@ package dev.langchain4j.community.xinference.spring; +import static dev.langchain4j.community.xinference.spring.XinferenceUtils.CHAT_MODEL_NAME; +import static dev.langchain4j.community.xinference.spring.XinferenceUtils.EMBEDDING_MODEL_NAME; +import static dev.langchain4j.community.xinference.spring.XinferenceUtils.GENERATE_MODEL_NAME; +import static dev.langchain4j.community.xinference.spring.XinferenceUtils.IMAGE_MODEL_NAME; +import static dev.langchain4j.community.xinference.spring.XinferenceUtils.RERANK_MODEL_NAME; +import static dev.langchain4j.community.xinference.spring.XinferenceUtils.XINFERENCE_IMAGE; +import static dev.langchain4j.community.xinference.spring.XinferenceUtils.launchCmd; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; @@ -21,39 +28,37 @@ import dev.langchain4j.model.language.StreamingLanguageModel; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.scoring.ScoringModel; +import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.concurrent.CompletableFuture; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; /** - * Xinference Cloud - * https://docs.inference.top/zh + * */ -@EnabledIfEnvironmentVariable(named = "XINFERENCE_API_KEY", matches = ".+") +@Testcontainers class AutoConfigIT { - private static final String API_KEY = System.getenv("XINFERENCE_API_KEY"); - private static final String BASE_URL = System.getenv("XINFERENCE_BASE_URL"); ApplicationContextRunner contextRunner = new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(AutoConfig.class)); - @AfterEach - void waitForNextTest() throws InterruptedException { - Thread.sleep(3000); // 每个测试后延时3秒 - } + @Container + XinferenceContainer chatModelContainer = new XinferenceContainer(XINFERENCE_IMAGE); @Test - void should_provide_chat_model() { + void should_provide_chat_model() throws IOException, InterruptedException { + chatModelContainer.execInContainer("bash", "-c", launchCmd(CHAT_MODEL_NAME)); contextRunner .withPropertyValues( - "langchain4j.community.xinference.chat-model.base-url=" + BASE_URL, - "langchain4j.community.xinference.chat-model.api-key=" + API_KEY, - "langchain4j.community.xinference.chat-model.model-name=qwen2-vl-instruct") + "langchain4j.community.xinference.chat-model.base-url=" + chatModelContainer.getEndpoint(), + "langchain4j.community.xinference.chat-model.model-name=" + CHAT_MODEL_NAME, + "langchain4j.community.xinference.language-model.logRequests=true", + "langchain4j.community.xinference.language-model.logResponses=true") .run(context -> { ChatLanguageModel chatLanguageModel = context.getBean(ChatLanguageModel.class); assertThat(chatLanguageModel).isInstanceOf(XinferenceChatModel.class); @@ -64,12 +69,15 @@ void should_provide_chat_model() { } @Test - void should_provide_streaming_chat_model() { + void should_provide_streaming_chat_model() throws IOException, InterruptedException { + chatModelContainer.execInContainer("bash", "-c", launchCmd(CHAT_MODEL_NAME)); contextRunner .withPropertyValues( - "langchain4j.community.xinference.streaming-chat-model.base-url=" + BASE_URL, - "langchain4j.community.xinference.streaming-chat-model.api-key=" + API_KEY, - "langchain4j.community.xinference.streaming-chat-model.model-name=qwen2-vl-instruct") + "langchain4j.community.xinference.streaming-chat-model.base-url=" + + chatModelContainer.getEndpoint(), + "langchain4j.community.xinference.streaming-chat-model.model-name=" + CHAT_MODEL_NAME, + "langchain4j.community.xinference.language-model.logRequests=true", + "langchain4j.community.xinference.language-model.logResponses=true") .run(context -> { StreamingChatLanguageModel streamingChatLanguageModel = context.getBean(StreamingChatLanguageModel.class); @@ -96,12 +104,12 @@ public void onError(Throwable error) {} } @Test - void should_provide_language_model() { + void should_provide_language_model() throws IOException, InterruptedException { + chatModelContainer.execInContainer("bash", "-c", launchCmd(GENERATE_MODEL_NAME)); contextRunner .withPropertyValues( - "langchain4j.community.xinference.language-model.base-url=" + BASE_URL, - "langchain4j.community.xinference.language-model.api-key=" + API_KEY, - "langchain4j.community.xinference.language-model.model-name=qwen2-vl-instruct", + "langchain4j.community.xinference.language-model.base-url=" + chatModelContainer.getEndpoint(), + "langchain4j.community.xinference.language-model.model-name=" + GENERATE_MODEL_NAME, "langchain4j.community.xinference.language-model.logRequests=true", "langchain4j.community.xinference.language-model.logResponses=true") .run(context -> { @@ -116,12 +124,13 @@ void should_provide_language_model() { } @Test - void should_provide_streaming_language_model() { + void should_provide_streaming_language_model() throws IOException, InterruptedException { + chatModelContainer.execInContainer("bash", "-c", launchCmd(GENERATE_MODEL_NAME)); contextRunner .withPropertyValues( - "langchain4j.community.xinference.streaming-language-model.base-url=" + BASE_URL, - "langchain4j.community.xinference.streaming-language-model.api-key=" + API_KEY, - "langchain4j.community.xinference.streaming-language-model.model-name=qwen2-vl-instruct", + "langchain4j.community.xinference.streaming-language-model.base-url=" + + chatModelContainer.getEndpoint(), + "langchain4j.community.xinference.streaming-language-model.model-name=" + GENERATE_MODEL_NAME, "langchain4j.community.xinference.streaming-language-model.logRequests=true", "langchain4j.community.xinference.streaming-language-model.logResponses=true") .run(context -> { @@ -150,29 +159,32 @@ public void onError(Throwable error) {} } @Test - void should_provide_embedding_model() { + void should_provide_embedding_model() throws IOException, InterruptedException { + chatModelContainer.execInContainer("bash", "-c", launchCmd(EMBEDDING_MODEL_NAME)); contextRunner .withPropertyValues( - "langchain4j.community.xinference.embeddingModel.base-url=" + BASE_URL, - "langchain4j.community.xinference.embeddingModel.api-key=" + API_KEY, - "langchain4j.community.xinference.embeddingModel.modelName=bge-m3") + "langchain4j.community.xinference.embeddingModel.base-url=" + chatModelContainer.getEndpoint(), + "langchain4j.community.xinference.embeddingModel.modelName=" + EMBEDDING_MODEL_NAME, + "langchain4j.community.xinference.language-model.logRequests=true", + "langchain4j.community.xinference.language-model.logResponses=true") .run(context -> { EmbeddingModel embeddingModel = context.getBean(EmbeddingModel.class); assertThat(embeddingModel).isInstanceOf(XinferenceEmbeddingModel.class); assertThat(embeddingModel.embed("hello world").content().dimension()) - .isEqualTo(1024); + .isEqualTo(768); assertThat(context.getBean(XinferenceEmbeddingModel.class)).isSameAs(embeddingModel); }); } @Test - @Disabled("Xinference Cloud Not Support") - void should_provide_sc_model() { + void should_provide_sc_model() throws IOException, InterruptedException { + chatModelContainer.execInContainer("bash", "-c", launchCmd(RERANK_MODEL_NAME)); contextRunner .withPropertyValues( - "langchain4j.community.xinference.scoringModel.base-url=" + BASE_URL, - "langchain4j.community.xinference.scoringModel.api-key=" + API_KEY, - "langchain4j.community.xinference.scoringModel.modelName=bge-m3") + "langchain4j.community.xinference.scoringModel.base-url=" + chatModelContainer.getEndpoint(), + "langchain4j.community.xinference.scoringModel.modelName=" + RERANK_MODEL_NAME, + "langchain4j.community.xinference.language-model.logRequests=true", + "langchain4j.community.xinference.language-model.logResponses=true") .run(context -> { ScoringModel scoringModel = context.getBean(ScoringModel.class); assertThat(scoringModel).isInstanceOf(XinferenceScoringModel.class); @@ -184,18 +196,21 @@ void should_provide_sc_model() { Response> response = scoringModel.scoreAll(segments, query); List scores = response.content(); assertThat(scores).hasSize(2); - assertThat(scores.get(0)).isLessThan(scores.get(1)); + assertThat(scores.get(0)).isGreaterThan(scores.get(1)); assertThat(context.getBean(XinferenceScoringModel.class)).isSameAs(scoringModel); }); } @Test - void should_provide_image_model() { + @Disabled("Not supported to run in a Docker environment without GPU .") + void should_provide_image_model() throws IOException, InterruptedException { + chatModelContainer.execInContainer("bash", "-c", launchCmd(IMAGE_MODEL_NAME)); contextRunner .withPropertyValues( - "langchain4j.community.xinference.imageModel.base-url=" + BASE_URL, - "langchain4j.community.xinference.imageModel.api-key=" + API_KEY, - "langchain4j.community.xinference.imageModel.modelName=sd3-medium") + "langchain4j.community.xinference.imageModel.base-url=" + chatModelContainer.getEndpoint(), + "langchain4j.community.xinference.imageModel.modelName=" + IMAGE_MODEL_NAME, + "langchain4j.community.xinference.language-model.logRequests=true", + "langchain4j.community.xinference.language-model.logResponses=true") .run(context -> { ImageModel imageModel = context.getBean(ImageModel.class); assertThat(imageModel).isInstanceOf(XinferenceImageModel.class); diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java new file mode 100644 index 0000000..acaf356 --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java @@ -0,0 +1,96 @@ +package dev.langchain4j.community.xinference.spring; + +import static dev.langchain4j.community.xinference.spring.XinferenceUtils.launchCmd; + +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.command.InspectContainerResponse; +import com.github.dockerjava.api.model.DeviceRequest; +import com.github.dockerjava.api.model.Image; +import com.github.dockerjava.api.model.Info; +import com.github.dockerjava.api.model.RuntimeInfo; +import java.io.IOException; +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.DockerClientFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.utility.DockerImageName; + +class XinferenceContainer extends GenericContainer { + private static final Logger log = LoggerFactory.getLogger(XinferenceContainer.class); + private static final DockerImageName DOCKER_IMAGE_NAME = DockerImageName.parse("xprobe/xinference"); + private static final Integer EXPOSED_PORT = 9997; + private String modelName; + + public XinferenceContainer(String image) { + this(DockerImageName.parse(image)); + } + + public XinferenceContainer(DockerImageName image) { + super(image); + image.assertCompatibleWith(DOCKER_IMAGE_NAME); + Info info = this.dockerClient.infoCmd().exec(); + Map runtimes = info.getRuntimes(); + if (runtimes != null && runtimes.containsKey("nvidia")) { + this.withCreateContainerCmdModifier((cmd) -> { + Objects.requireNonNull(cmd.getHostConfig()) + .withDeviceRequests(Collections.singletonList((new DeviceRequest()) + .withCapabilities(Collections.singletonList(Collections.singletonList("gpu"))) + .withCount(-1))); + }); + } + this.withExposedPorts(EXPOSED_PORT); + this.withEnv(Map.of("XINFERENCE_MODEL_SRC", "modelscope")); + // https://github.com/xorbitsai/inference/issues/2573 + this.withCommand("bash", "-c", "xinference-local -H 0.0.0.0"); + this.waitingFor(Wait.forListeningPort().withStartupTimeout(Duration.ofMinutes(10))); + } + + @Override + protected void containerIsStarted(InspectContainerResponse containerInfo) { + if (this.modelName != null) { + try { + log.info("Start pulling the '{}' model ... would take several minutes ...", this.modelName); + ExecResult r = execInContainer("bash", "-c", launchCmd(this.modelName)); + if (r.getExitCode() != 0) { + throw new RuntimeException(r.getStderr()); + } + log.info("Model pulling competed! {}", r); + } catch (IOException | InterruptedException e) { + throw new RuntimeException("Error pulling model", e); + } + } + } + + public XinferenceContainer withModel(String modelName) { + this.modelName = modelName; + return this; + } + + public void commitToImage(String imageName) { + DockerImageName dockerImageName = DockerImageName.parse(this.getDockerImageName()); + if (!dockerImageName.equals(DockerImageName.parse(imageName))) { + DockerClient dockerClient = DockerClientFactory.instance().client(); + List images = + dockerClient.listImagesCmd().withReferenceFilter(imageName).exec(); + if (images.isEmpty()) { + DockerImageName imageModel = DockerImageName.parse(imageName); + dockerClient + .commitCmd(this.getContainerId()) + .withRepository(imageModel.getUnversionedPart()) + .withLabels(Collections.singletonMap("org.testcontainers.sessionId", "")) + .withTag(imageModel.getVersionPart()) + .exec(); + } + } + } + + public String getEndpoint() { + return "http://" + this.getHost() + ":" + this.getMappedPort(EXPOSED_PORT); + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceUtils.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceUtils.java new file mode 100644 index 0000000..edb94be --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceUtils.java @@ -0,0 +1,69 @@ +package dev.langchain4j.community.xinference.spring; + +import com.github.dockerjava.api.DockerClient; +import com.github.dockerjava.api.model.Image; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.testcontainers.DockerClientFactory; +import org.testcontainers.utility.DockerImageName; + +class XinferenceUtils { + public static final String XINFERENCE_BASE_URL = System.getenv("XINFERENCE_BASE_URL"); + public static final String XINFERENCE_API_KEY = System.getenv("XINFERENCE_BASE_URL"); + // CPU + public static final String XINFERENCE_IMAGE = "xprobe/xinference:latest-cpu"; + // GPU + // public static final String XINFERENCE_IMAGE = "xprobe/xinference:latest"; + + public static final String CHAT_MODEL_NAME = "qwen2.5-instruct"; + public static final String GENERATE_MODEL_NAME = "qwen2.5"; + public static final String VISION_MODEL_NAME = "qwen2-vl-instruct"; + public static final String IMAGE_MODEL_NAME = "sd3-medium"; + public static final String EMBEDDING_MODEL_NAME = "text2vec-base-chinese"; + public static final String RERANK_MODEL_NAME = "bge-reranker-base"; + + private static final Map MODEL_LAUNCH_MAP = new HashMap<>() { + { + put( + CHAT_MODEL_NAME, + String.format( + "xinference launch --model-engine Transformers --model-name %s --size-in-billions 0_5 --model-format pytorch --quantization none", + CHAT_MODEL_NAME)); + put( + GENERATE_MODEL_NAME, + String.format( + "xinference launch --model-engine Transformers --model-name %s --size-in-billions 0_5 --model-format pytorch --quantization none", + GENERATE_MODEL_NAME)); + put( + VISION_MODEL_NAME, + String.format( + "xinference launch --model-engine Transformers --model-name %s --size-in-billions 2 --model-format pytorch --quantization none", + VISION_MODEL_NAME)); + put( + RERANK_MODEL_NAME, + String.format("xinference launch --model-name %s --model-type rerank", RERANK_MODEL_NAME)); + put( + IMAGE_MODEL_NAME, + String.format("xinference launch --model-name %s --model-type image", IMAGE_MODEL_NAME)); + put( + EMBEDDING_MODEL_NAME, + String.format("xinference launch --model-name %s --model-type embedding", EMBEDDING_MODEL_NAME)); + } + }; + + public static DockerImageName resolve(String baseImage, String localImageName) { + DockerImageName dockerImageName = DockerImageName.parse(baseImage); + DockerClient dockerClient = DockerClientFactory.instance().client(); + List images = + dockerClient.listImagesCmd().withReferenceFilter(localImageName).exec(); + if (images.isEmpty()) { + return dockerImageName; + } + return DockerImageName.parse(localImageName).asCompatibleSubstituteFor(baseImage); + } + + public static String launchCmd(String modelName) { + return MODEL_LAUNCH_MAP.get(modelName); + } +} From ac41eaef1e1c10d70844a6fa1a884aac197c41db Mon Sep 17 00:00:00 2001 From: alvinlee518 Date: Fri, 27 Dec 2024 18:44:41 +0800 Subject: [PATCH 04/13] Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 --- .../dev/langchain4j/community/xinference/spring/AutoConfig.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java index 79c512f..1bc98e5 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java @@ -16,7 +16,7 @@ @AutoConfiguration @EnableConfigurationProperties(Properties.class) -public class AutoConfig { +public class XinferenceAutoConfiguration { @Bean @ConditionalOnProperty(PREFIX + ".chat-model.base-url") public XinferenceChatModel chatModel(Properties properties) { From aa6474293aea156a8002d7fbaaddbbbe16411275 Mon Sep 17 00:00:00 2001 From: alvinlee518 Date: Fri, 27 Dec 2024 18:44:50 +0800 Subject: [PATCH 05/13] Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 --- .../dev/langchain4j/community/xinference/spring/AutoConfig.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java index 1bc98e5..d1de16f 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java @@ -19,7 +19,7 @@ public class XinferenceAutoConfiguration { @Bean @ConditionalOnProperty(PREFIX + ".chat-model.base-url") - public XinferenceChatModel chatModel(Properties properties) { + public XinferenceChatModel xinferenceChatModel(Properties properties) { ChatModelProperties chatModelProperties = properties.getChatModel(); return XinferenceChatModel.builder() .baseUrl(chatModelProperties.getBaseUrl()) From ae275c59dd1aaffbbfe2b37390461da5179900f2 Mon Sep 17 00:00:00 2001 From: alvinlee518 Date: Fri, 27 Dec 2024 18:45:00 +0800 Subject: [PATCH 06/13] Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 --- .../dev/langchain4j/community/xinference/spring/AutoConfig.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java index d1de16f..4c253a2 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java @@ -46,7 +46,7 @@ public XinferenceChatModel xinferenceChatModel(Properties properties) { @Bean @ConditionalOnProperty(PREFIX + ".streaming-chat-model.base-url") - public XinferenceStreamingChatModel streamingChatModel(Properties properties) { + public XinferenceStreamingChatModel xinferenceStreamingChatModel(Properties properties) { ChatModelProperties chatModelProperties = properties.getStreamingChatModel(); return XinferenceStreamingChatModel.builder() .baseUrl(chatModelProperties.getBaseUrl()) From 782a3eed939806e3b8e4de26f4acb4df5122f826 Mon Sep 17 00:00:00 2001 From: alvinlee518 Date: Fri, 27 Dec 2024 18:45:08 +0800 Subject: [PATCH 07/13] Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 --- .../dev/langchain4j/community/xinference/spring/AutoConfig.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java index 4c253a2..0ea14c3 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java @@ -72,7 +72,7 @@ public XinferenceStreamingChatModel xinferenceStreamingChatModel(Properties prop @Bean @ConditionalOnProperty(PREFIX + ".language-model.base-url") - public XinferenceLanguageModel languageModel(Properties properties) { + public XinferenceLanguageModel xinferenceLanguageModel(Properties properties) { LanguageModelProperties languageModelProperties = properties.getLanguageModel(); return XinferenceLanguageModel.builder() .baseUrl(languageModelProperties.getBaseUrl()) From 7a79892e9c1aa43711a698e28f255704c01f1003 Mon Sep 17 00:00:00 2001 From: alvinlee518 Date: Fri, 27 Dec 2024 18:45:16 +0800 Subject: [PATCH 08/13] Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 --- .../dev/langchain4j/community/xinference/spring/AutoConfig.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java index 0ea14c3..c9380a4 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java @@ -98,7 +98,7 @@ public XinferenceLanguageModel xinferenceLanguageModel(Properties properties) { @Bean @ConditionalOnProperty(PREFIX + ".streaming-language-model.base-url") - public XinferenceStreamingLanguageModel streamingLanguageModel(Properties properties) { + public XinferenceStreamingLanguageModel xinferenceStreamingLanguageModel(Properties properties) { LanguageModelProperties languageModelProperties = properties.getStreamingLanguageModel(); return XinferenceStreamingLanguageModel.builder() .baseUrl(languageModelProperties.getBaseUrl()) From e14ab0f0d8ea0a441628e3d91519742e7173d489 Mon Sep 17 00:00:00 2001 From: alvinlee518 Date: Fri, 27 Dec 2024 18:45:22 +0800 Subject: [PATCH 09/13] Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 --- .../dev/langchain4j/community/xinference/spring/AutoConfig.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java index c9380a4..6b56128 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java @@ -123,7 +123,7 @@ public XinferenceStreamingLanguageModel xinferenceStreamingLanguageModel(Propert @Bean @ConditionalOnProperty(PREFIX + ".embedding-model.base-url") - public XinferenceEmbeddingModel embeddingModel(Properties properties) { + public XinferenceEmbeddingModel xinferenceEmbeddingModel(Properties properties) { EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel(); return XinferenceEmbeddingModel.builder() .baseUrl(embeddingModelProperties.getBaseUrl()) From df094bcd75cc63a2d032a31bc5f102ffa91752c6 Mon Sep 17 00:00:00 2001 From: alvinlee518 Date: Fri, 27 Dec 2024 18:45:28 +0800 Subject: [PATCH 10/13] Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 --- .../dev/langchain4j/community/xinference/spring/AutoConfig.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java index 6b56128..e9bfbd8 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java @@ -141,7 +141,7 @@ public XinferenceEmbeddingModel xinferenceEmbeddingModel(Properties properties) @Bean @ConditionalOnProperty(PREFIX + ".image-model.base-url") - public XinferenceImageModel imageModel(Properties properties) { + public XinferenceImageModel xinferenceImageModel(Properties properties) { ImageModelProperties imageModelProperties = properties.getImageModel(); return XinferenceImageModel.builder() .baseUrl(imageModelProperties.getBaseUrl()) From 6095fd1af6820abac36eba7fb8cc3fd4c139d623 Mon Sep 17 00:00:00 2001 From: alvinlee518 Date: Fri, 27 Dec 2024 18:45:34 +0800 Subject: [PATCH 11/13] Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 --- .../dev/langchain4j/community/xinference/spring/AutoConfig.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java index e9bfbd8..8513988 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java @@ -163,7 +163,7 @@ public XinferenceImageModel xinferenceImageModel(Properties properties) { @Bean @ConditionalOnProperty(PREFIX + ".scoring-model.base-url") - public XinferenceScoringModel scoringModel(Properties properties) { + public XinferenceScoringModel xinferenceScoringModel(Properties properties) { ScoringModelProperties scoringModelProperties = properties.getScoringModel(); return XinferenceScoringModel.builder() .baseUrl(scoringModelProperties.getBaseUrl()) From ea3ce23b43a39e9bae2c86290c2dfbcddccba990 Mon Sep 17 00:00:00 2001 From: lixw <> Date: Mon, 30 Dec 2024 11:09:41 +0800 Subject: [PATCH 12/13] Update spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java Co-authored-by: Martin7-1 --- langchain4j-community-bom/pom.xml | 9 +- .../spring/ChatModelProperties.java | 3 + .../spring/EmbeddingModelProperties.java | 4 + .../spring/ImageModelProperties.java | 4 + .../spring/LanguageModelProperties.java | 4 + .../xinference/spring/Properties.java | 86 --------- .../spring/ScoringModelProperties.java | 4 + .../spring/StreamingChatModelProperties.java | 178 ++++++++++++++++++ .../StreamingLanguageModelProperties.java | 169 +++++++++++++++++ ....java => XinferenceAutoConfiguration.java} | 129 +++++++------ ...ot.autoconfigure.AutoConfiguration.imports | 1 + .../xinference/spring/AutoConfigIT.java | 60 +++--- 12 files changed, 473 insertions(+), 178 deletions(-) delete mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/Properties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingChatModelProperties.java create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingLanguageModelProperties.java rename spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/{AutoConfig.java => XinferenceAutoConfiguration.java} (59%) create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports diff --git a/langchain4j-community-bom/pom.xml b/langchain4j-community-bom/pom.xml index 38b4bc4..32a7bef 100644 --- a/langchain4j-community-bom/pom.xml +++ b/langchain4j-community-bom/pom.xml @@ -11,7 +11,8 @@ pom LangChain4j :: Community :: BOM - Bill of Materials POM for getting full, complete set of compatible versions of LangChain4j Community modules + Bill of Materials POM for getting full, complete set of compatible versions of LangChain4j Community + modules @@ -82,6 +83,12 @@ ${project.version} + + dev.langchain4j + langchain4j-community-xinference-spring-boot-starter + ${project.version} + + diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java index 13e0705..d86d41d 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java @@ -3,9 +3,12 @@ import java.time.Duration; import java.util.List; import java.util.Map; +import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; +@ConfigurationProperties(prefix = ChatModelProperties.PREFIX) public class ChatModelProperties { + static final String PREFIX = "langchain4j.community.xinference.chat-model"; private String baseUrl; private String apiKey; private String modelName; diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java index 3fe8edc..d628324 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java @@ -2,9 +2,13 @@ import java.time.Duration; import java.util.Map; +import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; +@ConfigurationProperties(prefix = EmbeddingModelProperties.PREFIX) public class EmbeddingModelProperties { + static final String PREFIX = "langchain4j.community.xinference.embedding-model"; + private String baseUrl; private String apiKey; private String modelName; diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java index 22289dc..e59ecfb 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java @@ -3,9 +3,13 @@ import dev.langchain4j.community.model.xinference.client.image.ResponseFormat; import java.time.Duration; import java.util.Map; +import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; +@ConfigurationProperties(prefix = ImageModelProperties.PREFIX) public class ImageModelProperties { + static final String PREFIX = "langchain4j.community.xinference.image-model"; + private String baseUrl; private String apiKey; private String modelName; diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java index 3a23229..43bfdd0 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java @@ -3,9 +3,13 @@ import java.time.Duration; import java.util.List; import java.util.Map; +import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; +@ConfigurationProperties(prefix = LanguageModelProperties.PREFIX) public class LanguageModelProperties { + static final String PREFIX = "langchain4j.community.xinference.language-model"; + private String baseUrl; private String apiKey; private String modelName; diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/Properties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/Properties.java deleted file mode 100644 index 5511231..0000000 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/Properties.java +++ /dev/null @@ -1,86 +0,0 @@ -package dev.langchain4j.community.xinference.spring; - -import org.springframework.boot.context.properties.ConfigurationProperties; -import org.springframework.boot.context.properties.NestedConfigurationProperty; - -@ConfigurationProperties(prefix = Properties.PREFIX) -public class Properties { - static final String PREFIX = "langchain4j.community.xinference"; - - @NestedConfigurationProperty - private ChatModelProperties chatModel; - - @NestedConfigurationProperty - private ChatModelProperties streamingChatModel; - - @NestedConfigurationProperty - private EmbeddingModelProperties embeddingModel; - - @NestedConfigurationProperty - private ImageModelProperties imageModel; - - @NestedConfigurationProperty - private LanguageModelProperties languageModel; - - @NestedConfigurationProperty - private LanguageModelProperties streamingLanguageModel; - - @NestedConfigurationProperty - private ScoringModelProperties scoringModel; - - public ChatModelProperties getChatModel() { - return chatModel; - } - - public void setChatModel(final ChatModelProperties chatModel) { - this.chatModel = chatModel; - } - - public ChatModelProperties getStreamingChatModel() { - return streamingChatModel; - } - - public void setStreamingChatModel(final ChatModelProperties streamingChatModel) { - this.streamingChatModel = streamingChatModel; - } - - public EmbeddingModelProperties getEmbeddingModel() { - return embeddingModel; - } - - public void setEmbeddingModel(final EmbeddingModelProperties embeddingModel) { - this.embeddingModel = embeddingModel; - } - - public ImageModelProperties getImageModel() { - return imageModel; - } - - public void setImageModel(final ImageModelProperties imageModel) { - this.imageModel = imageModel; - } - - public LanguageModelProperties getLanguageModel() { - return languageModel; - } - - public void setLanguageModel(final LanguageModelProperties languageModel) { - this.languageModel = languageModel; - } - - public LanguageModelProperties getStreamingLanguageModel() { - return streamingLanguageModel; - } - - public void setStreamingLanguageModel(final LanguageModelProperties streamingLanguageModel) { - this.streamingLanguageModel = streamingLanguageModel; - } - - public ScoringModelProperties getScoringModel() { - return scoringModel; - } - - public void setScoringModel(final ScoringModelProperties scoringModel) { - this.scoringModel = scoringModel; - } -} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java index 29c29c6..a53a280 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java @@ -2,9 +2,13 @@ import java.time.Duration; import java.util.Map; +import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; +@ConfigurationProperties(prefix = ScoringModelProperties.PREFIX) public class ScoringModelProperties { + static final String PREFIX = "langchain4j.community.xinference.scoring-model"; + private String baseUrl; private String apiKey; private String modelName; diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingChatModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingChatModelProperties.java new file mode 100644 index 0000000..6f1121f --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingChatModelProperties.java @@ -0,0 +1,178 @@ +package dev.langchain4j.community.xinference.spring; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +@ConfigurationProperties(prefix = StreamingChatModelProperties.PREFIX) +public class StreamingChatModelProperties { + static final String PREFIX = "langchain4j.community.xinference.streaming-chat-model"; + + private String baseUrl; + private String apiKey; + private String modelName; + private Double temperature; + private Double topP; + private List stop; + private Integer maxTokens; + private Double presencePenalty; + private Double frequencyPenalty; + private Integer seed; + private String user; + private Object toolChoice; + private Boolean parallelToolCalls; + private Duration timeout; + + @NestedConfigurationProperty + private ProxyProperties proxy; + + private Boolean logRequests; + private Boolean logResponses; + private Map customHeaders; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(final String baseUrl) { + this.baseUrl = baseUrl; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(final String apiKey) { + this.apiKey = apiKey; + } + + public String getModelName() { + return modelName; + } + + public void setModelName(final String modelName) { + this.modelName = modelName; + } + + public Double getTemperature() { + return temperature; + } + + public void setTemperature(final Double temperature) { + this.temperature = temperature; + } + + public Double getTopP() { + return topP; + } + + public void setTopP(final Double topP) { + this.topP = topP; + } + + public List getStop() { + return stop; + } + + public void setStop(final List stop) { + this.stop = stop; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(final Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Double getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(final Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public Double getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(final Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public Integer getSeed() { + return seed; + } + + public void setSeed(final Integer seed) { + this.seed = seed; + } + + public String getUser() { + return user; + } + + public void setUser(final String user) { + this.user = user; + } + + public Object getToolChoice() { + return toolChoice; + } + + public void setToolChoice(final Object toolChoice) { + this.toolChoice = toolChoice; + } + + public Boolean getParallelToolCalls() { + return parallelToolCalls; + } + + public void setParallelToolCalls(final Boolean parallelToolCalls) { + this.parallelToolCalls = parallelToolCalls; + } + + public Duration getTimeout() { + return timeout; + } + + public void setTimeout(final Duration timeout) { + this.timeout = timeout; + } + + public ProxyProperties getProxy() { + return proxy; + } + + public void setProxy(final ProxyProperties proxy) { + this.proxy = proxy; + } + + public Boolean getLogRequests() { + return logRequests; + } + + public void setLogRequests(final Boolean logRequests) { + this.logRequests = logRequests; + } + + public Boolean getLogResponses() { + return logResponses; + } + + public void setLogResponses(final Boolean logResponses) { + this.logResponses = logResponses; + } + + public Map getCustomHeaders() { + return customHeaders; + } + + public void setCustomHeaders(final Map customHeaders) { + this.customHeaders = customHeaders; + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingLanguageModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingLanguageModelProperties.java new file mode 100644 index 0000000..2937006 --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingLanguageModelProperties.java @@ -0,0 +1,169 @@ +package dev.langchain4j.community.xinference.spring; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +@ConfigurationProperties(prefix = StreamingLanguageModelProperties.PREFIX) +public class StreamingLanguageModelProperties { + static final String PREFIX = "langchain4j.community.xinference.streaming-language-model"; + + private String baseUrl; + private String apiKey; + private String modelName; + private Integer maxTokens; + private Double temperature; + private Double topP; + private Integer logprobs; + private Boolean echo; + private List stop; + private Double presencePenalty; + private Double frequencyPenalty; + private String user; + private Duration timeout; + + @NestedConfigurationProperty + private ProxyProperties proxy; + + private Boolean logRequests; + private Boolean logResponses; + private Map customHeaders; + + public String getBaseUrl() { + return baseUrl; + } + + public void setBaseUrl(final String baseUrl) { + this.baseUrl = baseUrl; + } + + public String getApiKey() { + return apiKey; + } + + public void setApiKey(final String apiKey) { + this.apiKey = apiKey; + } + + public String getModelName() { + return modelName; + } + + public void setModelName(final String modelName) { + this.modelName = modelName; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public void setMaxTokens(final Integer maxTokens) { + this.maxTokens = maxTokens; + } + + public Double getTemperature() { + return temperature; + } + + public void setTemperature(final Double temperature) { + this.temperature = temperature; + } + + public Double getTopP() { + return topP; + } + + public void setTopP(final Double topP) { + this.topP = topP; + } + + public Integer getLogprobs() { + return logprobs; + } + + public void setLogprobs(final Integer logprobs) { + this.logprobs = logprobs; + } + + public Boolean getEcho() { + return echo; + } + + public void setEcho(final Boolean echo) { + this.echo = echo; + } + + public List getStop() { + return stop; + } + + public void setStop(final List stop) { + this.stop = stop; + } + + public Double getPresencePenalty() { + return presencePenalty; + } + + public void setPresencePenalty(final Double presencePenalty) { + this.presencePenalty = presencePenalty; + } + + public Double getFrequencyPenalty() { + return frequencyPenalty; + } + + public void setFrequencyPenalty(final Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + } + + public String getUser() { + return user; + } + + public void setUser(final String user) { + this.user = user; + } + + public Duration getTimeout() { + return timeout; + } + + public void setTimeout(final Duration timeout) { + this.timeout = timeout; + } + + public ProxyProperties getProxy() { + return proxy; + } + + public void setProxy(final ProxyProperties proxy) { + this.proxy = proxy; + } + + public Boolean getLogRequests() { + return logRequests; + } + + public void setLogRequests(final Boolean logRequests) { + this.logRequests = logRequests; + } + + public Boolean getLogResponses() { + return logResponses; + } + + public void setLogResponses(final Boolean logResponses) { + this.logResponses = logResponses; + } + + public Map getCustomHeaders() { + return customHeaders; + } + + public void setCustomHeaders(final Map customHeaders) { + this.customHeaders = customHeaders; + } +} diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/XinferenceAutoConfiguration.java similarity index 59% rename from spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java rename to spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/XinferenceAutoConfiguration.java index 8513988..3f59135 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/XinferenceAutoConfiguration.java @@ -1,7 +1,5 @@ package dev.langchain4j.community.xinference.spring; -import static dev.langchain4j.community.xinference.spring.Properties.PREFIX; - import dev.langchain4j.community.model.xinference.XinferenceChatModel; import dev.langchain4j.community.model.xinference.XinferenceEmbeddingModel; import dev.langchain4j.community.model.xinference.XinferenceImageModel; @@ -10,17 +8,26 @@ import dev.langchain4j.community.model.xinference.XinferenceStreamingChatModel; import dev.langchain4j.community.model.xinference.XinferenceStreamingLanguageModel; import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; @AutoConfiguration -@EnableConfigurationProperties(Properties.class) +@EnableConfigurationProperties({ + ChatModelProperties.class, + StreamingChatModelProperties.class, + LanguageModelProperties.class, + StreamingLanguageModelProperties.class, + EmbeddingModelProperties.class, + ImageModelProperties.class, + ScoringModelProperties.class +}) public class XinferenceAutoConfiguration { @Bean - @ConditionalOnProperty(PREFIX + ".chat-model.base-url") - public XinferenceChatModel xinferenceChatModel(Properties properties) { - ChatModelProperties chatModelProperties = properties.getChatModel(); + @ConditionalOnMissingBean + @ConditionalOnProperty(ChatModelProperties.PREFIX + ".base-url") + public XinferenceChatModel xinferenceChatModel(ChatModelProperties chatModelProperties) { return XinferenceChatModel.builder() .baseUrl(chatModelProperties.getBaseUrl()) .apiKey(chatModelProperties.getApiKey()) @@ -28,7 +35,7 @@ public XinferenceChatModel xinferenceChatModel(Properties properties) { .temperature(chatModelProperties.getTemperature()) .topP(chatModelProperties.getTopP()) .stop(chatModelProperties.getStop()) - .maxTokens(chatModelProperties.getMaxRetries()) + .maxTokens(chatModelProperties.getMaxTokens()) .presencePenalty(chatModelProperties.getPresencePenalty()) .frequencyPenalty(chatModelProperties.getFrequencyPenalty()) .seed(chatModelProperties.getSeed()) @@ -45,35 +52,36 @@ public XinferenceChatModel xinferenceChatModel(Properties properties) { } @Bean - @ConditionalOnProperty(PREFIX + ".streaming-chat-model.base-url") - public XinferenceStreamingChatModel xinferenceStreamingChatModel(Properties properties) { - ChatModelProperties chatModelProperties = properties.getStreamingChatModel(); + @ConditionalOnMissingBean + @ConditionalOnProperty(StreamingChatModelProperties.PREFIX + ".base-url") + public XinferenceStreamingChatModel xinferenceStreamingChatModel( + StreamingChatModelProperties streamingChatModelProperties) { return XinferenceStreamingChatModel.builder() - .baseUrl(chatModelProperties.getBaseUrl()) - .apiKey(chatModelProperties.getApiKey()) - .modelName(chatModelProperties.getModelName()) - .temperature(chatModelProperties.getTemperature()) - .topP(chatModelProperties.getTopP()) - .stop(chatModelProperties.getStop()) - .maxTokens(chatModelProperties.getMaxRetries()) - .presencePenalty(chatModelProperties.getPresencePenalty()) - .frequencyPenalty(chatModelProperties.getFrequencyPenalty()) - .seed(chatModelProperties.getSeed()) - .user(chatModelProperties.getUser()) - .toolChoice(chatModelProperties.getToolChoice()) - .parallelToolCalls(chatModelProperties.getParallelToolCalls()) - .timeout(chatModelProperties.getTimeout()) - .proxy(ProxyProperties.convert(chatModelProperties.getProxy())) - .logRequests(chatModelProperties.getLogRequests()) - .logResponses(chatModelProperties.getLogResponses()) - .customHeaders(chatModelProperties.getCustomHeaders()) + .baseUrl(streamingChatModelProperties.getBaseUrl()) + .apiKey(streamingChatModelProperties.getApiKey()) + .modelName(streamingChatModelProperties.getModelName()) + .temperature(streamingChatModelProperties.getTemperature()) + .topP(streamingChatModelProperties.getTopP()) + .stop(streamingChatModelProperties.getStop()) + .maxTokens(streamingChatModelProperties.getMaxTokens()) + .presencePenalty(streamingChatModelProperties.getPresencePenalty()) + .frequencyPenalty(streamingChatModelProperties.getFrequencyPenalty()) + .seed(streamingChatModelProperties.getSeed()) + .user(streamingChatModelProperties.getUser()) + .toolChoice(streamingChatModelProperties.getToolChoice()) + .parallelToolCalls(streamingChatModelProperties.getParallelToolCalls()) + .timeout(streamingChatModelProperties.getTimeout()) + .proxy(ProxyProperties.convert(streamingChatModelProperties.getProxy())) + .logRequests(streamingChatModelProperties.getLogRequests()) + .logResponses(streamingChatModelProperties.getLogResponses()) + .customHeaders(streamingChatModelProperties.getCustomHeaders()) .build(); } @Bean - @ConditionalOnProperty(PREFIX + ".language-model.base-url") - public XinferenceLanguageModel xinferenceLanguageModel(Properties properties) { - LanguageModelProperties languageModelProperties = properties.getLanguageModel(); + @ConditionalOnMissingBean + @ConditionalOnProperty(LanguageModelProperties.PREFIX + ".base-url") + public XinferenceLanguageModel xinferenceLanguageModel(LanguageModelProperties languageModelProperties) { return XinferenceLanguageModel.builder() .baseUrl(languageModelProperties.getBaseUrl()) .apiKey(languageModelProperties.getApiKey()) @@ -97,34 +105,35 @@ public XinferenceLanguageModel xinferenceLanguageModel(Properties properties) { } @Bean - @ConditionalOnProperty(PREFIX + ".streaming-language-model.base-url") - public XinferenceStreamingLanguageModel xinferenceStreamingLanguageModel(Properties properties) { - LanguageModelProperties languageModelProperties = properties.getStreamingLanguageModel(); + @ConditionalOnMissingBean + @ConditionalOnProperty(StreamingLanguageModelProperties.PREFIX + ".base-url") + public XinferenceStreamingLanguageModel xinferenceStreamingLanguageModel( + StreamingLanguageModelProperties streamingLanguageModelProperties) { return XinferenceStreamingLanguageModel.builder() - .baseUrl(languageModelProperties.getBaseUrl()) - .apiKey(languageModelProperties.getApiKey()) - .modelName(languageModelProperties.getModelName()) - .maxTokens(languageModelProperties.getMaxTokens()) - .temperature(languageModelProperties.getTemperature()) - .topP(languageModelProperties.getTopP()) - .logprobs(languageModelProperties.getLogprobs()) - .echo(languageModelProperties.getEcho()) - .stop(languageModelProperties.getStop()) - .presencePenalty(languageModelProperties.getPresencePenalty()) - .frequencyPenalty(languageModelProperties.getFrequencyPenalty()) - .user(languageModelProperties.getUser()) - .timeout(languageModelProperties.getTimeout()) - .proxy(ProxyProperties.convert(languageModelProperties.getProxy())) - .logRequests(languageModelProperties.getLogRequests()) - .logResponses(languageModelProperties.getLogResponses()) - .customHeaders(languageModelProperties.getCustomHeaders()) + .baseUrl(streamingLanguageModelProperties.getBaseUrl()) + .apiKey(streamingLanguageModelProperties.getApiKey()) + .modelName(streamingLanguageModelProperties.getModelName()) + .maxTokens(streamingLanguageModelProperties.getMaxTokens()) + .temperature(streamingLanguageModelProperties.getTemperature()) + .topP(streamingLanguageModelProperties.getTopP()) + .logprobs(streamingLanguageModelProperties.getLogprobs()) + .echo(streamingLanguageModelProperties.getEcho()) + .stop(streamingLanguageModelProperties.getStop()) + .presencePenalty(streamingLanguageModelProperties.getPresencePenalty()) + .frequencyPenalty(streamingLanguageModelProperties.getFrequencyPenalty()) + .user(streamingLanguageModelProperties.getUser()) + .timeout(streamingLanguageModelProperties.getTimeout()) + .proxy(ProxyProperties.convert(streamingLanguageModelProperties.getProxy())) + .logRequests(streamingLanguageModelProperties.getLogRequests()) + .logResponses(streamingLanguageModelProperties.getLogResponses()) + .customHeaders(streamingLanguageModelProperties.getCustomHeaders()) .build(); } @Bean - @ConditionalOnProperty(PREFIX + ".embedding-model.base-url") - public XinferenceEmbeddingModel xinferenceEmbeddingModel(Properties properties) { - EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel(); + @ConditionalOnMissingBean + @ConditionalOnProperty(EmbeddingModelProperties.PREFIX + ".base-url") + public XinferenceEmbeddingModel xinferenceEmbeddingModel(EmbeddingModelProperties embeddingModelProperties) { return XinferenceEmbeddingModel.builder() .baseUrl(embeddingModelProperties.getBaseUrl()) .apiKey(embeddingModelProperties.getApiKey()) @@ -140,9 +149,9 @@ public XinferenceEmbeddingModel xinferenceEmbeddingModel(Properties properties) } @Bean - @ConditionalOnProperty(PREFIX + ".image-model.base-url") - public XinferenceImageModel xinferenceImageModel(Properties properties) { - ImageModelProperties imageModelProperties = properties.getImageModel(); + @ConditionalOnMissingBean + @ConditionalOnProperty(ImageModelProperties.PREFIX + ".base-url") + public XinferenceImageModel xinferenceImageModel(ImageModelProperties imageModelProperties) { return XinferenceImageModel.builder() .baseUrl(imageModelProperties.getBaseUrl()) .apiKey(imageModelProperties.getApiKey()) @@ -162,9 +171,9 @@ public XinferenceImageModel xinferenceImageModel(Properties properties) { } @Bean - @ConditionalOnProperty(PREFIX + ".scoring-model.base-url") - public XinferenceScoringModel xinferenceScoringModel(Properties properties) { - ScoringModelProperties scoringModelProperties = properties.getScoringModel(); + @ConditionalOnMissingBean + @ConditionalOnProperty(ScoringModelProperties.PREFIX + ".base-url") + public XinferenceScoringModel xinferenceScoringModel(ScoringModelProperties scoringModelProperties) { return XinferenceScoringModel.builder() .baseUrl(scoringModelProperties.getBaseUrl()) .apiKey(scoringModelProperties.getApiKey()) diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports new file mode 100644 index 0000000..4de934a --- /dev/null +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -0,0 +1 @@ +dev.langchain4j.community.xinference.spring.XinferenceAutoConfiguration diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java index 1605849..5198026 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java @@ -45,7 +45,7 @@ @Testcontainers class AutoConfigIT { ApplicationContextRunner contextRunner = - new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(AutoConfig.class)); + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(XinferenceAutoConfiguration.class)); @Container XinferenceContainer chatModelContainer = new XinferenceContainer(XINFERENCE_IMAGE); @@ -55,10 +55,10 @@ void should_provide_chat_model() throws IOException, InterruptedException { chatModelContainer.execInContainer("bash", "-c", launchCmd(CHAT_MODEL_NAME)); contextRunner .withPropertyValues( - "langchain4j.community.xinference.chat-model.base-url=" + chatModelContainer.getEndpoint(), - "langchain4j.community.xinference.chat-model.model-name=" + CHAT_MODEL_NAME, - "langchain4j.community.xinference.language-model.logRequests=true", - "langchain4j.community.xinference.language-model.logResponses=true") + ChatModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(), + ChatModelProperties.PREFIX + ".model-name=" + CHAT_MODEL_NAME, + ChatModelProperties.PREFIX + ".logRequests=true", + ChatModelProperties.PREFIX + ".logResponses=true") .run(context -> { ChatLanguageModel chatLanguageModel = context.getBean(ChatLanguageModel.class); assertThat(chatLanguageModel).isInstanceOf(XinferenceChatModel.class); @@ -73,11 +73,10 @@ void should_provide_streaming_chat_model() throws IOException, InterruptedExcept chatModelContainer.execInContainer("bash", "-c", launchCmd(CHAT_MODEL_NAME)); contextRunner .withPropertyValues( - "langchain4j.community.xinference.streaming-chat-model.base-url=" - + chatModelContainer.getEndpoint(), - "langchain4j.community.xinference.streaming-chat-model.model-name=" + CHAT_MODEL_NAME, - "langchain4j.community.xinference.language-model.logRequests=true", - "langchain4j.community.xinference.language-model.logResponses=true") + StreamingChatModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(), + StreamingChatModelProperties.PREFIX + ".model-name=" + CHAT_MODEL_NAME, + StreamingChatModelProperties.PREFIX + ".logRequests=true", + StreamingChatModelProperties.PREFIX + ".logResponses=true") .run(context -> { StreamingChatLanguageModel streamingChatLanguageModel = context.getBean(StreamingChatLanguageModel.class); @@ -108,10 +107,10 @@ void should_provide_language_model() throws IOException, InterruptedException { chatModelContainer.execInContainer("bash", "-c", launchCmd(GENERATE_MODEL_NAME)); contextRunner .withPropertyValues( - "langchain4j.community.xinference.language-model.base-url=" + chatModelContainer.getEndpoint(), - "langchain4j.community.xinference.language-model.model-name=" + GENERATE_MODEL_NAME, - "langchain4j.community.xinference.language-model.logRequests=true", - "langchain4j.community.xinference.language-model.logResponses=true") + LanguageModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(), + LanguageModelProperties.PREFIX + ".model-name=" + GENERATE_MODEL_NAME, + LanguageModelProperties.PREFIX + ".logRequests=true", + LanguageModelProperties.PREFIX + ".logResponses=true") .run(context -> { LanguageModel languageModel = context.getBean(LanguageModel.class); assertThat(languageModel).isInstanceOf(XinferenceLanguageModel.class); @@ -128,11 +127,10 @@ void should_provide_streaming_language_model() throws IOException, InterruptedEx chatModelContainer.execInContainer("bash", "-c", launchCmd(GENERATE_MODEL_NAME)); contextRunner .withPropertyValues( - "langchain4j.community.xinference.streaming-language-model.base-url=" - + chatModelContainer.getEndpoint(), - "langchain4j.community.xinference.streaming-language-model.model-name=" + GENERATE_MODEL_NAME, - "langchain4j.community.xinference.streaming-language-model.logRequests=true", - "langchain4j.community.xinference.streaming-language-model.logResponses=true") + StreamingLanguageModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(), + StreamingLanguageModelProperties.PREFIX + ".model-name=" + GENERATE_MODEL_NAME, + StreamingLanguageModelProperties.PREFIX + ".logRequests=true", + StreamingLanguageModelProperties.PREFIX + ".logResponses=true") .run(context -> { StreamingLanguageModel streamingLanguageModel = context.getBean(StreamingLanguageModel.class); assertThat(streamingLanguageModel).isInstanceOf(XinferenceStreamingLanguageModel.class); @@ -163,10 +161,10 @@ void should_provide_embedding_model() throws IOException, InterruptedException { chatModelContainer.execInContainer("bash", "-c", launchCmd(EMBEDDING_MODEL_NAME)); contextRunner .withPropertyValues( - "langchain4j.community.xinference.embeddingModel.base-url=" + chatModelContainer.getEndpoint(), - "langchain4j.community.xinference.embeddingModel.modelName=" + EMBEDDING_MODEL_NAME, - "langchain4j.community.xinference.language-model.logRequests=true", - "langchain4j.community.xinference.language-model.logResponses=true") + EmbeddingModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(), + EmbeddingModelProperties.PREFIX + ".modelName=" + EMBEDDING_MODEL_NAME, + EmbeddingModelProperties.PREFIX + ".logRequests=true", + EmbeddingModelProperties.PREFIX + ".logResponses=true") .run(context -> { EmbeddingModel embeddingModel = context.getBean(EmbeddingModel.class); assertThat(embeddingModel).isInstanceOf(XinferenceEmbeddingModel.class); @@ -181,10 +179,10 @@ void should_provide_sc_model() throws IOException, InterruptedException { chatModelContainer.execInContainer("bash", "-c", launchCmd(RERANK_MODEL_NAME)); contextRunner .withPropertyValues( - "langchain4j.community.xinference.scoringModel.base-url=" + chatModelContainer.getEndpoint(), - "langchain4j.community.xinference.scoringModel.modelName=" + RERANK_MODEL_NAME, - "langchain4j.community.xinference.language-model.logRequests=true", - "langchain4j.community.xinference.language-model.logResponses=true") + ScoringModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(), + ScoringModelProperties.PREFIX + ".modelName=" + RERANK_MODEL_NAME, + ScoringModelProperties.PREFIX + ".logRequests=true", + ScoringModelProperties.PREFIX + ".logResponses=true") .run(context -> { ScoringModel scoringModel = context.getBean(ScoringModel.class); assertThat(scoringModel).isInstanceOf(XinferenceScoringModel.class); @@ -207,10 +205,10 @@ void should_provide_image_model() throws IOException, InterruptedException { chatModelContainer.execInContainer("bash", "-c", launchCmd(IMAGE_MODEL_NAME)); contextRunner .withPropertyValues( - "langchain4j.community.xinference.imageModel.base-url=" + chatModelContainer.getEndpoint(), - "langchain4j.community.xinference.imageModel.modelName=" + IMAGE_MODEL_NAME, - "langchain4j.community.xinference.language-model.logRequests=true", - "langchain4j.community.xinference.language-model.logResponses=true") + ImageModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(), + ImageModelProperties.PREFIX + ".modelName=" + IMAGE_MODEL_NAME, + ImageModelProperties.PREFIX + ".logRequests=true", + ImageModelProperties.PREFIX + ".logResponses=true") .run(context -> { ImageModel imageModel = context.getBean(ImageModel.class); assertThat(imageModel).isInstanceOf(XinferenceImageModel.class); From 8e8a9426f5604d9bdfda1e0e26e26407b62c764a Mon Sep 17 00:00:00 2001 From: lixw <> Date: Mon, 30 Dec 2024 11:11:47 +0800 Subject: [PATCH 13/13] add spring boot starter of xinference --- .../community/xinference/spring/XinferenceContainer.java | 1 - 1 file changed, 1 deletion(-) diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java index acaf356..7ed4fe5 100644 --- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java +++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java @@ -45,7 +45,6 @@ public XinferenceContainer(DockerImageName image) { }); } this.withExposedPorts(EXPOSED_PORT); - this.withEnv(Map.of("XINFERENCE_MODEL_SRC", "modelscope")); // https://github.com/xorbitsai/inference/issues/2573 this.withCommand("bash", "-c", "xinference-local -H 0.0.0.0"); this.waitingFor(Wait.forListeningPort().withStartupTimeout(Duration.ofMinutes(10)));