From 100231246c0290dd67b0cdb09b59d9ba9c39319f Mon Sep 17 00:00:00 2001
From: lixw <>
Date: Wed, 25 Dec 2024 16:52:33 +0800
Subject: [PATCH 01/13] add spring boot starter of xinference
---
.../pom.xml | 88 ++++++++
.../xinference/spring/AutoConfig.java | 183 +++++++++++++++
.../spring/ChatModelProperties.java | 183 +++++++++++++++
.../spring/EmbeddingModelProperties.java | 101 +++++++++
.../spring/ImageModelProperties.java | 138 ++++++++++++
.../spring/LanguageModelProperties.java | 174 ++++++++++++++
.../xinference/spring/Properties.java | 86 +++++++
.../xinference/spring/ProxyProperties.java | 42 ++++
.../spring/ScoringModelProperties.java | 119 ++++++++++
.../xinference/spring/AutoConfigIT.java | 213 ++++++++++++++++++
spring-boot-starters/pom.xml | 1 +
11 files changed, 1328 insertions(+)
create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml
create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java
create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java
create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java
create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java
create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/Properties.java
create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ProxyProperties.java
create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java
create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml
new file mode 100644
index 0000000..44a387f
--- /dev/null
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml
@@ -0,0 +1,88 @@
+
+
+ 4.0.0
+
+ dev.langchain4j
+ langchain4j-community-spring-boot-starters
+ 1.0.0-alpha1
+ ../pom.xml
+
+
+ langchain4j-community-xinference-spring-boot-starter
+ LangChain4j :: Community :: Spring Boot starter :: Xinference
+
+
+
+ Apache-2.0
+ https://www.apache.org/licenses/LICENSE-2.0.txt
+ repo
+ A business-friendly OSS license
+
+
+
+
+
+ dev.langchain4j
+ langchain4j-community-xinference
+ ${project.version}
+
+
+
+ org.springframework.boot
+ spring-boot-starter
+
+
+ ch.qos.logback
+ logback-classic
+
+
+
+
+
+ ch.qos.logback
+ logback-classic
+
+
+
+ org.springframework.boot
+ spring-boot-autoconfigure-processor
+ true
+
+
+
+
+ org.springframework.boot
+ spring-boot-configuration-processor
+ true
+
+
+
+ org.springframework.boot
+ spring-boot-starter-test
+ test
+
+
+
+
+
+
+ org.honton.chas
+ license-maven-plugin
+
+
+
+
+ Eclipse Public License
+ http://www.eclipse.org/legal/epl-v10.html
+
+
+ GNU Lesser General Public License
+ http://www.gnu.org/licenses/old-licenses/lgpl-2.1.html
+
+
+
+
+
+
+
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
new file mode 100644
index 0000000..79c512f
--- /dev/null
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
@@ -0,0 +1,183 @@
+package dev.langchain4j.community.xinference.spring;
+
+import static dev.langchain4j.community.xinference.spring.Properties.PREFIX;
+
+import dev.langchain4j.community.model.xinference.XinferenceChatModel;
+import dev.langchain4j.community.model.xinference.XinferenceEmbeddingModel;
+import dev.langchain4j.community.model.xinference.XinferenceImageModel;
+import dev.langchain4j.community.model.xinference.XinferenceLanguageModel;
+import dev.langchain4j.community.model.xinference.XinferenceScoringModel;
+import dev.langchain4j.community.model.xinference.XinferenceStreamingChatModel;
+import dev.langchain4j.community.model.xinference.XinferenceStreamingLanguageModel;
+import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
+import org.springframework.boot.context.properties.EnableConfigurationProperties;
+import org.springframework.context.annotation.Bean;
+
+@AutoConfiguration
+@EnableConfigurationProperties(Properties.class)
+public class AutoConfig {
+ @Bean
+ @ConditionalOnProperty(PREFIX + ".chat-model.base-url")
+ public XinferenceChatModel chatModel(Properties properties) {
+ ChatModelProperties chatModelProperties = properties.getChatModel();
+ return XinferenceChatModel.builder()
+ .baseUrl(chatModelProperties.getBaseUrl())
+ .apiKey(chatModelProperties.getApiKey())
+ .modelName(chatModelProperties.getModelName())
+ .temperature(chatModelProperties.getTemperature())
+ .topP(chatModelProperties.getTopP())
+ .stop(chatModelProperties.getStop())
+ .maxTokens(chatModelProperties.getMaxRetries())
+ .presencePenalty(chatModelProperties.getPresencePenalty())
+ .frequencyPenalty(chatModelProperties.getFrequencyPenalty())
+ .seed(chatModelProperties.getSeed())
+ .user(chatModelProperties.getUser())
+ .toolChoice(chatModelProperties.getToolChoice())
+ .parallelToolCalls(chatModelProperties.getParallelToolCalls())
+ .maxRetries(chatModelProperties.getMaxRetries())
+ .timeout(chatModelProperties.getTimeout())
+ .proxy(ProxyProperties.convert(chatModelProperties.getProxy()))
+ .logRequests(chatModelProperties.getLogRequests())
+ .logResponses(chatModelProperties.getLogResponses())
+ .customHeaders(chatModelProperties.getCustomHeaders())
+ .build();
+ }
+
+ @Bean
+ @ConditionalOnProperty(PREFIX + ".streaming-chat-model.base-url")
+ public XinferenceStreamingChatModel streamingChatModel(Properties properties) {
+ ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
+ return XinferenceStreamingChatModel.builder()
+ .baseUrl(chatModelProperties.getBaseUrl())
+ .apiKey(chatModelProperties.getApiKey())
+ .modelName(chatModelProperties.getModelName())
+ .temperature(chatModelProperties.getTemperature())
+ .topP(chatModelProperties.getTopP())
+ .stop(chatModelProperties.getStop())
+ .maxTokens(chatModelProperties.getMaxRetries())
+ .presencePenalty(chatModelProperties.getPresencePenalty())
+ .frequencyPenalty(chatModelProperties.getFrequencyPenalty())
+ .seed(chatModelProperties.getSeed())
+ .user(chatModelProperties.getUser())
+ .toolChoice(chatModelProperties.getToolChoice())
+ .parallelToolCalls(chatModelProperties.getParallelToolCalls())
+ .timeout(chatModelProperties.getTimeout())
+ .proxy(ProxyProperties.convert(chatModelProperties.getProxy()))
+ .logRequests(chatModelProperties.getLogRequests())
+ .logResponses(chatModelProperties.getLogResponses())
+ .customHeaders(chatModelProperties.getCustomHeaders())
+ .build();
+ }
+
+ @Bean
+ @ConditionalOnProperty(PREFIX + ".language-model.base-url")
+ public XinferenceLanguageModel languageModel(Properties properties) {
+ LanguageModelProperties languageModelProperties = properties.getLanguageModel();
+ return XinferenceLanguageModel.builder()
+ .baseUrl(languageModelProperties.getBaseUrl())
+ .apiKey(languageModelProperties.getApiKey())
+ .modelName(languageModelProperties.getModelName())
+ .maxTokens(languageModelProperties.getMaxTokens())
+ .temperature(languageModelProperties.getTemperature())
+ .topP(languageModelProperties.getTopP())
+ .logprobs(languageModelProperties.getLogprobs())
+ .echo(languageModelProperties.getEcho())
+ .stop(languageModelProperties.getStop())
+ .presencePenalty(languageModelProperties.getPresencePenalty())
+ .frequencyPenalty(languageModelProperties.getFrequencyPenalty())
+ .user(languageModelProperties.getUser())
+ .maxRetries(languageModelProperties.getMaxRetries())
+ .timeout(languageModelProperties.getTimeout())
+ .proxy(ProxyProperties.convert(languageModelProperties.getProxy()))
+ .logRequests(languageModelProperties.getLogRequests())
+ .logResponses(languageModelProperties.getLogResponses())
+ .customHeaders(languageModelProperties.getCustomHeaders())
+ .build();
+ }
+
+ @Bean
+ @ConditionalOnProperty(PREFIX + ".streaming-language-model.base-url")
+ public XinferenceStreamingLanguageModel streamingLanguageModel(Properties properties) {
+ LanguageModelProperties languageModelProperties = properties.getStreamingLanguageModel();
+ return XinferenceStreamingLanguageModel.builder()
+ .baseUrl(languageModelProperties.getBaseUrl())
+ .apiKey(languageModelProperties.getApiKey())
+ .modelName(languageModelProperties.getModelName())
+ .maxTokens(languageModelProperties.getMaxTokens())
+ .temperature(languageModelProperties.getTemperature())
+ .topP(languageModelProperties.getTopP())
+ .logprobs(languageModelProperties.getLogprobs())
+ .echo(languageModelProperties.getEcho())
+ .stop(languageModelProperties.getStop())
+ .presencePenalty(languageModelProperties.getPresencePenalty())
+ .frequencyPenalty(languageModelProperties.getFrequencyPenalty())
+ .user(languageModelProperties.getUser())
+ .timeout(languageModelProperties.getTimeout())
+ .proxy(ProxyProperties.convert(languageModelProperties.getProxy()))
+ .logRequests(languageModelProperties.getLogRequests())
+ .logResponses(languageModelProperties.getLogResponses())
+ .customHeaders(languageModelProperties.getCustomHeaders())
+ .build();
+ }
+
+ @Bean
+ @ConditionalOnProperty(PREFIX + ".embedding-model.base-url")
+ public XinferenceEmbeddingModel embeddingModel(Properties properties) {
+ EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
+ return XinferenceEmbeddingModel.builder()
+ .baseUrl(embeddingModelProperties.getBaseUrl())
+ .apiKey(embeddingModelProperties.getApiKey())
+ .modelName(embeddingModelProperties.getModelName())
+ .user(embeddingModelProperties.getUser())
+ .maxRetries(embeddingModelProperties.getMaxRetries())
+ .timeout(embeddingModelProperties.getTimeout())
+ .proxy(ProxyProperties.convert(embeddingModelProperties.getProxy()))
+ .logRequests(embeddingModelProperties.getLogRequests())
+ .logResponses(embeddingModelProperties.getLogResponses())
+ .customHeaders(embeddingModelProperties.getCustomHeaders())
+ .build();
+ }
+
+ @Bean
+ @ConditionalOnProperty(PREFIX + ".image-model.base-url")
+ public XinferenceImageModel imageModel(Properties properties) {
+ ImageModelProperties imageModelProperties = properties.getImageModel();
+ return XinferenceImageModel.builder()
+ .baseUrl(imageModelProperties.getBaseUrl())
+ .apiKey(imageModelProperties.getApiKey())
+ .modelName(imageModelProperties.getModelName())
+ .negativePrompt(imageModelProperties.getNegativePrompt())
+ .responseFormat(imageModelProperties.getResponseFormat())
+ .size(imageModelProperties.getSize())
+ .kwargs(imageModelProperties.getKwargs())
+ .user(imageModelProperties.getUser())
+ .maxRetries(imageModelProperties.getMaxRetries())
+ .timeout(imageModelProperties.getTimeout())
+ .proxy(ProxyProperties.convert(imageModelProperties.getProxy()))
+ .logRequests(imageModelProperties.getLogRequests())
+ .logResponses(imageModelProperties.getLogResponses())
+ .customHeaders(imageModelProperties.getCustomHeaders())
+ .build();
+ }
+
+ @Bean
+ @ConditionalOnProperty(PREFIX + ".scoring-model.base-url")
+ public XinferenceScoringModel scoringModel(Properties properties) {
+ ScoringModelProperties scoringModelProperties = properties.getScoringModel();
+ return XinferenceScoringModel.builder()
+ .baseUrl(scoringModelProperties.getBaseUrl())
+ .apiKey(scoringModelProperties.getApiKey())
+ .modelName(scoringModelProperties.getModelName())
+ .topN(scoringModelProperties.getTopN())
+ .returnDocuments(scoringModelProperties.getReturnDocuments())
+ .returnLen(scoringModelProperties.getReturnLen())
+ .maxRetries(scoringModelProperties.getMaxRetries())
+ .timeout(scoringModelProperties.getTimeout())
+ .proxy(ProxyProperties.convert(scoringModelProperties.getProxy()))
+ .logRequests(scoringModelProperties.getLogRequests())
+ .logResponses(scoringModelProperties.getLogResponses())
+ .customHeaders(scoringModelProperties.getCustomHeaders())
+ .build();
+ }
+}
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java
new file mode 100644
index 0000000..13e0705
--- /dev/null
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java
@@ -0,0 +1,183 @@
+package dev.langchain4j.community.xinference.spring;
+
+import java.time.Duration;
+import java.util.List;
+import java.util.Map;
+import org.springframework.boot.context.properties.NestedConfigurationProperty;
+
+public class ChatModelProperties {
+ private String baseUrl;
+ private String apiKey;
+ private String modelName;
+ private Double temperature;
+ private Double topP;
+ private List stop;
+ private Integer maxTokens;
+ private Double presencePenalty;
+ private Double frequencyPenalty;
+ private Integer seed;
+ private String user;
+ private Object toolChoice;
+ private Boolean parallelToolCalls;
+ private Integer maxRetries;
+ private Duration timeout;
+
+ @NestedConfigurationProperty
+ private ProxyProperties proxy;
+
+ private Boolean logRequests;
+ private Boolean logResponses;
+ private Map customHeaders;
+
+ public String getBaseUrl() {
+ return baseUrl;
+ }
+
+ public void setBaseUrl(final String baseUrl) {
+ this.baseUrl = baseUrl;
+ }
+
+ public String getApiKey() {
+ return apiKey;
+ }
+
+ public void setApiKey(final String apiKey) {
+ this.apiKey = apiKey;
+ }
+
+ public String getModelName() {
+ return modelName;
+ }
+
+ public void setModelName(final String modelName) {
+ this.modelName = modelName;
+ }
+
+ public Double getTemperature() {
+ return temperature;
+ }
+
+ public void setTemperature(final Double temperature) {
+ this.temperature = temperature;
+ }
+
+ public Double getTopP() {
+ return topP;
+ }
+
+ public void setTopP(final Double topP) {
+ this.topP = topP;
+ }
+
+ public List getStop() {
+ return stop;
+ }
+
+ public void setStop(final List stop) {
+ this.stop = stop;
+ }
+
+ public Integer getMaxTokens() {
+ return maxTokens;
+ }
+
+ public void setMaxTokens(final Integer maxTokens) {
+ this.maxTokens = maxTokens;
+ }
+
+ public Double getPresencePenalty() {
+ return presencePenalty;
+ }
+
+ public void setPresencePenalty(final Double presencePenalty) {
+ this.presencePenalty = presencePenalty;
+ }
+
+ public Double getFrequencyPenalty() {
+ return frequencyPenalty;
+ }
+
+ public void setFrequencyPenalty(final Double frequencyPenalty) {
+ this.frequencyPenalty = frequencyPenalty;
+ }
+
+ public Integer getSeed() {
+ return seed;
+ }
+
+ public void setSeed(final Integer seed) {
+ this.seed = seed;
+ }
+
+ public String getUser() {
+ return user;
+ }
+
+ public void setUser(final String user) {
+ this.user = user;
+ }
+
+ public Object getToolChoice() {
+ return toolChoice;
+ }
+
+ public void setToolChoice(final Object toolChoice) {
+ this.toolChoice = toolChoice;
+ }
+
+ public Boolean getParallelToolCalls() {
+ return parallelToolCalls;
+ }
+
+ public void setParallelToolCalls(final Boolean parallelToolCalls) {
+ this.parallelToolCalls = parallelToolCalls;
+ }
+
+ public Integer getMaxRetries() {
+ return maxRetries;
+ }
+
+ public void setMaxRetries(final Integer maxRetries) {
+ this.maxRetries = maxRetries;
+ }
+
+ public Duration getTimeout() {
+ return timeout;
+ }
+
+ public void setTimeout(final Duration timeout) {
+ this.timeout = timeout;
+ }
+
+ public ProxyProperties getProxy() {
+ return proxy;
+ }
+
+ public void setProxy(final ProxyProperties proxy) {
+ this.proxy = proxy;
+ }
+
+ public Boolean getLogRequests() {
+ return logRequests;
+ }
+
+ public void setLogRequests(final Boolean logRequests) {
+ this.logRequests = logRequests;
+ }
+
+ public Boolean getLogResponses() {
+ return logResponses;
+ }
+
+ public void setLogResponses(final Boolean logResponses) {
+ this.logResponses = logResponses;
+ }
+
+ public Map getCustomHeaders() {
+ return customHeaders;
+ }
+
+ public void setCustomHeaders(final Map customHeaders) {
+ this.customHeaders = customHeaders;
+ }
+}
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java
new file mode 100644
index 0000000..3fe8edc
--- /dev/null
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java
@@ -0,0 +1,101 @@
+package dev.langchain4j.community.xinference.spring;
+
+import java.time.Duration;
+import java.util.Map;
+import org.springframework.boot.context.properties.NestedConfigurationProperty;
+
+public class EmbeddingModelProperties {
+ private String baseUrl;
+ private String apiKey;
+ private String modelName;
+ private String user;
+ private Integer maxRetries;
+ private Duration timeout;
+
+ @NestedConfigurationProperty
+ private ProxyProperties proxy;
+
+ private Boolean logRequests;
+ private Boolean logResponses;
+ private Map customHeaders;
+
+ public String getBaseUrl() {
+ return baseUrl;
+ }
+
+ public void setBaseUrl(final String baseUrl) {
+ this.baseUrl = baseUrl;
+ }
+
+ public String getApiKey() {
+ return apiKey;
+ }
+
+ public void setApiKey(final String apiKey) {
+ this.apiKey = apiKey;
+ }
+
+ public String getModelName() {
+ return modelName;
+ }
+
+ public void setModelName(final String modelName) {
+ this.modelName = modelName;
+ }
+
+ public String getUser() {
+ return user;
+ }
+
+ public void setUser(final String user) {
+ this.user = user;
+ }
+
+ public Integer getMaxRetries() {
+ return maxRetries;
+ }
+
+ public void setMaxRetries(final Integer maxRetries) {
+ this.maxRetries = maxRetries;
+ }
+
+ public Duration getTimeout() {
+ return timeout;
+ }
+
+ public void setTimeout(final Duration timeout) {
+ this.timeout = timeout;
+ }
+
+ public ProxyProperties getProxy() {
+ return proxy;
+ }
+
+ public void setProxy(final ProxyProperties proxy) {
+ this.proxy = proxy;
+ }
+
+ public Boolean getLogRequests() {
+ return logRequests;
+ }
+
+ public void setLogRequests(final Boolean logRequests) {
+ this.logRequests = logRequests;
+ }
+
+ public Boolean getLogResponses() {
+ return logResponses;
+ }
+
+ public void setLogResponses(final Boolean logResponses) {
+ this.logResponses = logResponses;
+ }
+
+ public Map getCustomHeaders() {
+ return customHeaders;
+ }
+
+ public void setCustomHeaders(final Map customHeaders) {
+ this.customHeaders = customHeaders;
+ }
+}
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java
new file mode 100644
index 0000000..22289dc
--- /dev/null
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java
@@ -0,0 +1,138 @@
+package dev.langchain4j.community.xinference.spring;
+
+import dev.langchain4j.community.model.xinference.client.image.ResponseFormat;
+import java.time.Duration;
+import java.util.Map;
+import org.springframework.boot.context.properties.NestedConfigurationProperty;
+
+public class ImageModelProperties {
+ private String baseUrl;
+ private String apiKey;
+ private String modelName;
+ private String negativePrompt;
+ private ResponseFormat responseFormat;
+ private String size;
+ private String kwargs;
+ private String user;
+ private Integer maxRetries;
+ private Duration timeout;
+
+ @NestedConfigurationProperty
+ private ProxyProperties proxy;
+
+ private Boolean logRequests;
+ private Boolean logResponses;
+ private Map customHeaders;
+
+ public String getBaseUrl() {
+ return baseUrl;
+ }
+
+ public void setBaseUrl(final String baseUrl) {
+ this.baseUrl = baseUrl;
+ }
+
+ public String getApiKey() {
+ return apiKey;
+ }
+
+ public void setApiKey(final String apiKey) {
+ this.apiKey = apiKey;
+ }
+
+ public String getModelName() {
+ return modelName;
+ }
+
+ public void setModelName(final String modelName) {
+ this.modelName = modelName;
+ }
+
+ public String getNegativePrompt() {
+ return negativePrompt;
+ }
+
+ public void setNegativePrompt(final String negativePrompt) {
+ this.negativePrompt = negativePrompt;
+ }
+
+ public ResponseFormat getResponseFormat() {
+ return responseFormat;
+ }
+
+ public void setResponseFormat(final ResponseFormat responseFormat) {
+ this.responseFormat = responseFormat;
+ }
+
+ public String getSize() {
+ return size;
+ }
+
+ public void setSize(final String size) {
+ this.size = size;
+ }
+
+ public String getKwargs() {
+ return kwargs;
+ }
+
+ public void setKwargs(final String kwargs) {
+ this.kwargs = kwargs;
+ }
+
+ public String getUser() {
+ return user;
+ }
+
+ public void setUser(final String user) {
+ this.user = user;
+ }
+
+ public Integer getMaxRetries() {
+ return maxRetries;
+ }
+
+ public void setMaxRetries(final Integer maxRetries) {
+ this.maxRetries = maxRetries;
+ }
+
+ public Duration getTimeout() {
+ return timeout;
+ }
+
+ public void setTimeout(final Duration timeout) {
+ this.timeout = timeout;
+ }
+
+ public ProxyProperties getProxy() {
+ return proxy;
+ }
+
+ public void setProxy(final ProxyProperties proxy) {
+ this.proxy = proxy;
+ }
+
+ public Boolean getLogRequests() {
+ return logRequests;
+ }
+
+ public void setLogRequests(final Boolean logRequests) {
+ this.logRequests = logRequests;
+ }
+
+ public Boolean getLogResponses() {
+ return logResponses;
+ }
+
+ public void setLogResponses(final Boolean logResponses) {
+ this.logResponses = logResponses;
+ }
+
+ public Map getCustomHeaders() {
+ return customHeaders;
+ }
+
+ public void setCustomHeaders(final Map customHeaders) {
+ this.customHeaders = customHeaders;
+ }
+}
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java
new file mode 100644
index 0000000..3a23229
--- /dev/null
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java
@@ -0,0 +1,174 @@
+package dev.langchain4j.community.xinference.spring;
+
+import java.time.Duration;
+import java.util.List;
+import java.util.Map;
+import org.springframework.boot.context.properties.NestedConfigurationProperty;
+
+public class LanguageModelProperties {
+ private String baseUrl;
+ private String apiKey;
+ private String modelName;
+ private Integer maxTokens;
+ private Double temperature;
+ private Double topP;
+ private Integer logprobs;
+ private Boolean echo;
+ private List stop;
+ private Double presencePenalty;
+ private Double frequencyPenalty;
+ private String user;
+ private Integer maxRetries;
+ private Duration timeout;
+
+ @NestedConfigurationProperty
+ private ProxyProperties proxy;
+
+ private Boolean logRequests;
+ private Boolean logResponses;
+ private Map customHeaders;
+
+ public String getBaseUrl() {
+ return baseUrl;
+ }
+
+ public void setBaseUrl(final String baseUrl) {
+ this.baseUrl = baseUrl;
+ }
+
+ public String getApiKey() {
+ return apiKey;
+ }
+
+ public void setApiKey(final String apiKey) {
+ this.apiKey = apiKey;
+ }
+
+ public String getModelName() {
+ return modelName;
+ }
+
+ public void setModelName(final String modelName) {
+ this.modelName = modelName;
+ }
+
+ public Integer getMaxTokens() {
+ return maxTokens;
+ }
+
+ public void setMaxTokens(final Integer maxTokens) {
+ this.maxTokens = maxTokens;
+ }
+
+ public Double getTemperature() {
+ return temperature;
+ }
+
+ public void setTemperature(final Double temperature) {
+ this.temperature = temperature;
+ }
+
+ public Double getTopP() {
+ return topP;
+ }
+
+ public void setTopP(final Double topP) {
+ this.topP = topP;
+ }
+
+ public Integer getLogprobs() {
+ return logprobs;
+ }
+
+ public void setLogprobs(final Integer logprobs) {
+ this.logprobs = logprobs;
+ }
+
+ public Boolean getEcho() {
+ return echo;
+ }
+
+ public void setEcho(final Boolean echo) {
+ this.echo = echo;
+ }
+
+ public List getStop() {
+ return stop;
+ }
+
+ public void setStop(final List stop) {
+ this.stop = stop;
+ }
+
+ public Double getPresencePenalty() {
+ return presencePenalty;
+ }
+
+ public void setPresencePenalty(final Double presencePenalty) {
+ this.presencePenalty = presencePenalty;
+ }
+
+ public Double getFrequencyPenalty() {
+ return frequencyPenalty;
+ }
+
+ public void setFrequencyPenalty(final Double frequencyPenalty) {
+ this.frequencyPenalty = frequencyPenalty;
+ }
+
+ public String getUser() {
+ return user;
+ }
+
+ public void setUser(final String user) {
+ this.user = user;
+ }
+
+ public Integer getMaxRetries() {
+ return maxRetries;
+ }
+
+ public void setMaxRetries(final Integer maxRetries) {
+ this.maxRetries = maxRetries;
+ }
+
+ public Duration getTimeout() {
+ return timeout;
+ }
+
+ public void setTimeout(final Duration timeout) {
+ this.timeout = timeout;
+ }
+
+ public ProxyProperties getProxy() {
+ return proxy;
+ }
+
+ public void setProxy(final ProxyProperties proxy) {
+ this.proxy = proxy;
+ }
+
+ public Boolean getLogRequests() {
+ return logRequests;
+ }
+
+ public void setLogRequests(final Boolean logRequests) {
+ this.logRequests = logRequests;
+ }
+
+ public Boolean getLogResponses() {
+ return logResponses;
+ }
+
+ public void setLogResponses(final Boolean logResponses) {
+ this.logResponses = logResponses;
+ }
+
+ public Map getCustomHeaders() {
+ return customHeaders;
+ }
+
+ public void setCustomHeaders(final Map customHeaders) {
+ this.customHeaders = customHeaders;
+ }
+}
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/Properties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/Properties.java
new file mode 100644
index 0000000..5511231
--- /dev/null
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/Properties.java
@@ -0,0 +1,86 @@
+package dev.langchain4j.community.xinference.spring;
+
+import org.springframework.boot.context.properties.ConfigurationProperties;
+import org.springframework.boot.context.properties.NestedConfigurationProperty;
+
+@ConfigurationProperties(prefix = Properties.PREFIX)
+public class Properties {
+ static final String PREFIX = "langchain4j.community.xinference";
+
+ @NestedConfigurationProperty
+ private ChatModelProperties chatModel;
+
+ @NestedConfigurationProperty
+ private ChatModelProperties streamingChatModel;
+
+ @NestedConfigurationProperty
+ private EmbeddingModelProperties embeddingModel;
+
+ @NestedConfigurationProperty
+ private ImageModelProperties imageModel;
+
+ @NestedConfigurationProperty
+ private LanguageModelProperties languageModel;
+
+ @NestedConfigurationProperty
+ private LanguageModelProperties streamingLanguageModel;
+
+ @NestedConfigurationProperty
+ private ScoringModelProperties scoringModel;
+
+ public ChatModelProperties getChatModel() {
+ return chatModel;
+ }
+
+ public void setChatModel(final ChatModelProperties chatModel) {
+ this.chatModel = chatModel;
+ }
+
+ public ChatModelProperties getStreamingChatModel() {
+ return streamingChatModel;
+ }
+
+ public void setStreamingChatModel(final ChatModelProperties streamingChatModel) {
+ this.streamingChatModel = streamingChatModel;
+ }
+
+ public EmbeddingModelProperties getEmbeddingModel() {
+ return embeddingModel;
+ }
+
+ public void setEmbeddingModel(final EmbeddingModelProperties embeddingModel) {
+ this.embeddingModel = embeddingModel;
+ }
+
+ public ImageModelProperties getImageModel() {
+ return imageModel;
+ }
+
+ public void setImageModel(final ImageModelProperties imageModel) {
+ this.imageModel = imageModel;
+ }
+
+ public LanguageModelProperties getLanguageModel() {
+ return languageModel;
+ }
+
+ public void setLanguageModel(final LanguageModelProperties languageModel) {
+ this.languageModel = languageModel;
+ }
+
+ public LanguageModelProperties getStreamingLanguageModel() {
+ return streamingLanguageModel;
+ }
+
+ public void setStreamingLanguageModel(final LanguageModelProperties streamingLanguageModel) {
+ this.streamingLanguageModel = streamingLanguageModel;
+ }
+
+ public ScoringModelProperties getScoringModel() {
+ return scoringModel;
+ }
+
+ public void setScoringModel(final ScoringModelProperties scoringModel) {
+ this.scoringModel = scoringModel;
+ }
+}
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ProxyProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ProxyProperties.java
new file mode 100644
index 0000000..f110310
--- /dev/null
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ProxyProperties.java
@@ -0,0 +1,42 @@
+package dev.langchain4j.community.xinference.spring;
+
+import java.net.InetSocketAddress;
+import java.net.Proxy;
+import java.util.Objects;
+
+public class ProxyProperties {
+ private Proxy.Type type;
+ private String host;
+ private Integer port;
+
+ public Proxy.Type getType() {
+ return type;
+ }
+
+ public void setType(final Proxy.Type type) {
+ this.type = type;
+ }
+
+ public String getHost() {
+ return host;
+ }
+
+ public void setHost(final String host) {
+ this.host = host;
+ }
+
+ public Integer getPort() {
+ return port;
+ }
+
+ public void setPort(final Integer port) {
+ this.port = port;
+ }
+
+ public static Proxy convert(ProxyProperties properties) {
+ if (Objects.isNull(properties)) {
+ return null;
+ }
+ return new Proxy(properties.getType(), new InetSocketAddress(properties.getHost(), properties.getPort()));
+ }
+}
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java
new file mode 100644
index 0000000..29c29c6
--- /dev/null
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java
@@ -0,0 +1,119 @@
+package dev.langchain4j.community.xinference.spring;
+
+import java.time.Duration;
+import java.util.Map;
+import org.springframework.boot.context.properties.NestedConfigurationProperty;
+
+public class ScoringModelProperties {
+ private String baseUrl;
+ private String apiKey;
+ private String modelName;
+ private Integer topN;
+ private Boolean returnDocuments;
+ private Boolean returnLen;
+ private Integer maxRetries;
+ private Duration timeout;
+
+ @NestedConfigurationProperty
+ private ProxyProperties proxy;
+
+ private Boolean logRequests;
+ private Boolean logResponses;
+ private Map customHeaders;
+
+ public String getBaseUrl() {
+ return baseUrl;
+ }
+
+ public void setBaseUrl(final String baseUrl) {
+ this.baseUrl = baseUrl;
+ }
+
+ public String getApiKey() {
+ return apiKey;
+ }
+
+ public void setApiKey(final String apiKey) {
+ this.apiKey = apiKey;
+ }
+
+ public String getModelName() {
+ return modelName;
+ }
+
+ public void setModelName(final String modelName) {
+ this.modelName = modelName;
+ }
+
+ public Integer getTopN() {
+ return topN;
+ }
+
+ public void setTopN(final Integer topN) {
+ this.topN = topN;
+ }
+
+ public Boolean getReturnDocuments() {
+ return returnDocuments;
+ }
+
+ public void setReturnDocuments(final Boolean returnDocuments) {
+ this.returnDocuments = returnDocuments;
+ }
+
+ public Boolean getReturnLen() {
+ return returnLen;
+ }
+
+ public void setReturnLen(final Boolean returnLen) {
+ this.returnLen = returnLen;
+ }
+
+ public Integer getMaxRetries() {
+ return maxRetries;
+ }
+
+ public void setMaxRetries(final Integer maxRetries) {
+ this.maxRetries = maxRetries;
+ }
+
+ public Duration getTimeout() {
+ return timeout;
+ }
+
+ public void setTimeout(final Duration timeout) {
+ this.timeout = timeout;
+ }
+
+ public ProxyProperties getProxy() {
+ return proxy;
+ }
+
+ public void setProxy(final ProxyProperties proxy) {
+ this.proxy = proxy;
+ }
+
+ public Boolean getLogRequests() {
+ return logRequests;
+ }
+
+ public void setLogRequests(final Boolean logRequests) {
+ this.logRequests = logRequests;
+ }
+
+ public Boolean getLogResponses() {
+ return logResponses;
+ }
+
+ public void setLogResponses(final Boolean logResponses) {
+ this.logResponses = logResponses;
+ }
+
+ public Map getCustomHeaders() {
+ return customHeaders;
+ }
+
+ public void setCustomHeaders(final Map customHeaders) {
+ this.customHeaders = customHeaders;
+ }
+}
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java
new file mode 100644
index 0000000..c1bda4c
--- /dev/null
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java
@@ -0,0 +1,213 @@
+package dev.langchain4j.community.xinference.spring;
+
+import static java.util.concurrent.TimeUnit.SECONDS;
+import static org.assertj.core.api.Assertions.assertThat;
+
+import dev.langchain4j.community.model.xinference.XinferenceChatModel;
+import dev.langchain4j.community.model.xinference.XinferenceEmbeddingModel;
+import dev.langchain4j.community.model.xinference.XinferenceImageModel;
+import dev.langchain4j.community.model.xinference.XinferenceLanguageModel;
+import dev.langchain4j.community.model.xinference.XinferenceScoringModel;
+import dev.langchain4j.community.model.xinference.XinferenceStreamingChatModel;
+import dev.langchain4j.community.model.xinference.XinferenceStreamingLanguageModel;
+import dev.langchain4j.data.message.AiMessage;
+import dev.langchain4j.data.segment.TextSegment;
+import dev.langchain4j.model.StreamingResponseHandler;
+import dev.langchain4j.model.chat.ChatLanguageModel;
+import dev.langchain4j.model.chat.StreamingChatLanguageModel;
+import dev.langchain4j.model.embedding.EmbeddingModel;
+import dev.langchain4j.model.image.ImageModel;
+import dev.langchain4j.model.language.LanguageModel;
+import dev.langchain4j.model.language.StreamingLanguageModel;
+import dev.langchain4j.model.output.Response;
+import dev.langchain4j.model.scoring.ScoringModel;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+import org.springframework.boot.autoconfigure.AutoConfigurations;
+import org.springframework.boot.test.context.runner.ApplicationContextRunner;
+
+/**
+ * Xinference Cloud
+ * https://docs.inference.top/zh
+ */
+@EnabledIfEnvironmentVariable(named = "XINFERENCE_API_KEY", matches = ".+")
+class AutoConfigIT {
+ private static final String API_KEY = System.getenv("XINFERENCE_API_KEY");
+ private static final String BASE_URL = System.getenv("XINFERENCE_BASE_URL");
+ ApplicationContextRunner contextRunner =
+ new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(AutoConfig.class));
+
+ @AfterEach
+ void waitForNextTest() throws InterruptedException {
+ Thread.sleep(3000); // 每个测试后延时3秒
+ }
+
+ @Test
+ void should_provide_chat_model() {
+ contextRunner
+ .withPropertyValues(
+ "langchain4j.community.xinference.chat-model.base-url=" + BASE_URL,
+ "langchain4j.community.xinference.chat-model.api-key=" + API_KEY,
+ "langchain4j.community.xinference.chat-model.model-name=qwen2-vl-instruct")
+ .run(context -> {
+ ChatLanguageModel chatLanguageModel = context.getBean(ChatLanguageModel.class);
+ assertThat(chatLanguageModel).isInstanceOf(XinferenceChatModel.class);
+ assertThat(chatLanguageModel.generate("What is the capital of Germany?"))
+ .contains("Berlin");
+ assertThat(context.getBean(XinferenceChatModel.class)).isSameAs(chatLanguageModel);
+ });
+ }
+
+ @Test
+ void should_provide_streaming_chat_model() {
+ contextRunner
+ .withPropertyValues(
+ "langchain4j.community.xinference.streaming-chat-model.base-url=" + BASE_URL,
+ "langchain4j.community.xinference.streaming-chat-model.api-key=" + API_KEY,
+ "langchain4j.community.xinference.streaming-chat-model.model-name=qwen2-vl-instruct")
+ .run(context -> {
+ StreamingChatLanguageModel streamingChatLanguageModel =
+ context.getBean(StreamingChatLanguageModel.class);
+ assertThat(streamingChatLanguageModel).isInstanceOf(XinferenceStreamingChatModel.class);
+ CompletableFuture> future = new CompletableFuture<>();
+ streamingChatLanguageModel.generate(
+ "What is the capital of Germany?", new StreamingResponseHandler() {
+ @Override
+ public void onNext(String token) {
+ }
+
+ @Override
+ public void onComplete(Response response) {
+ future.complete(response);
+ }
+
+ @Override
+ public void onError(Throwable error) {
+ }
+ });
+ Response response = future.get(60, SECONDS);
+ assertThat(response.content().text()).contains("Berlin");
+ assertThat(context.getBean(XinferenceStreamingChatModel.class))
+ .isSameAs(streamingChatLanguageModel);
+ });
+ }
+
+ @Test
+ void should_provide_language_model() {
+ contextRunner
+ .withPropertyValues(
+ "langchain4j.community.xinference.language-model.base-url=" + BASE_URL,
+ "langchain4j.community.xinference.language-model.api-key=" + API_KEY,
+ "langchain4j.community.xinference.language-model.model-name=qwen2-vl-instruct",
+ "langchain4j.community.xinference.language-model.logRequests=true",
+ "langchain4j.community.xinference.language-model.logResponses=true")
+ .run(context -> {
+ LanguageModel languageModel = context.getBean(LanguageModel.class);
+ assertThat(languageModel).isInstanceOf(XinferenceLanguageModel.class);
+ assertThat(languageModel
+ .generate("What is the capital of Germany?")
+ .content())
+ .contains("Berlin");
+ assertThat(context.getBean(XinferenceLanguageModel.class)).isSameAs(languageModel);
+ });
+ }
+
+ @Test
+ void should_provide_streaming_language_model() {
+ contextRunner
+ .withPropertyValues(
+ "langchain4j.community.xinference.streaming-language-model.base-url=" + BASE_URL,
+ "langchain4j.community.xinference.streaming-language-model.api-key=" + API_KEY,
+ "langchain4j.community.xinference.streaming-language-model.model-name=qwen2-vl-instruct",
+ "langchain4j.community.xinference.streaming-language-model.logRequests=true",
+ "langchain4j.community.xinference.streaming-language-model.logResponses=true")
+ .run(context -> {
+ StreamingLanguageModel streamingLanguageModel = context.getBean(StreamingLanguageModel.class);
+ assertThat(streamingLanguageModel).isInstanceOf(XinferenceStreamingLanguageModel.class);
+ CompletableFuture> future = new CompletableFuture<>();
+ streamingLanguageModel.generate(
+ "What is the capital of Germany?", new StreamingResponseHandler() {
+ @Override
+ public void onNext(String token) {
+ }
+
+ @Override
+ public void onComplete(Response response) {
+ future.complete(response);
+ }
+
+ @Override
+ public void onError(Throwable error) {
+ }
+ });
+ Response response = future.get(60, SECONDS);
+ assertThat(response.content()).contains("Berlin");
+
+ assertThat(context.getBean(XinferenceStreamingLanguageModel.class))
+ .isSameAs(streamingLanguageModel);
+ });
+ }
+
+ @Test
+ void should_provide_embedding_model() {
+ contextRunner
+ .withPropertyValues(
+ "langchain4j.community.xinference.embeddingModel.base-url=" + BASE_URL,
+ "langchain4j.community.xinference.embeddingModel.api-key=" + API_KEY,
+ "langchain4j.community.xinference.embeddingModel.modelName=bge-m3")
+ .run(context -> {
+ EmbeddingModel embeddingModel = context.getBean(EmbeddingModel.class);
+ assertThat(embeddingModel).isInstanceOf(XinferenceEmbeddingModel.class);
+ assertThat(embeddingModel.embed("hello world").content().dimension())
+ .isEqualTo(1024);
+ assertThat(context.getBean(XinferenceEmbeddingModel.class)).isSameAs(embeddingModel);
+ });
+ }
+
+ @Test
+ @Disabled("Xinference Cloud Not Support")
+ void should_provide_sc_model() {
+ contextRunner
+ .withPropertyValues(
+ "langchain4j.community.xinference.scoringModel.base-url=" + BASE_URL,
+ "langchain4j.community.xinference.scoringModel.api-key=" + API_KEY,
+ "langchain4j.community.xinference.scoringModel.modelName=bge-m3")
+ .run(context -> {
+ ScoringModel scoringModel = context.getBean(ScoringModel.class);
+ assertThat(scoringModel).isInstanceOf(XinferenceScoringModel.class);
+ TextSegment catSegment = TextSegment.from("The Maine Coon is a large domesticated cat breed.");
+ TextSegment dogSegment = TextSegment.from(
+ "The sweet-faced, lovable Labrador Retriever is one of America's most popular dog breeds, year after year.");
+ List segments = Arrays.asList(catSegment, dogSegment);
+ String query = "tell me about dogs";
+ Response> response = scoringModel.scoreAll(segments, query);
+ List scores = response.content();
+ assertThat(scores).hasSize(2);
+ assertThat(scores.get(0)).isLessThan(scores.get(1));
+ assertThat(context.getBean(XinferenceScoringModel.class)).isSameAs(scoringModel);
+ });
+ }
+
+ @Test
+ void should_provide_image_model() {
+ contextRunner
+ .withPropertyValues(
+ "langchain4j.community.xinference.imageModel.base-url=" + BASE_URL,
+ "langchain4j.community.xinference.imageModel.api-key=" + API_KEY,
+ "langchain4j.community.xinference.imageModel.modelName=sd3-medium")
+ .run(context -> {
+ ImageModel imageModel = context.getBean(ImageModel.class);
+ assertThat(imageModel).isInstanceOf(XinferenceImageModel.class);
+ assertThat(imageModel.generate("banana").content().base64Data())
+ .isNotNull();
+ assertThat(context.getBean(XinferenceImageModel.class)).isSameAs(imageModel);
+ });
+ }
+}
diff --git a/spring-boot-starters/pom.xml b/spring-boot-starters/pom.xml
index 2648a01..d05a547 100644
--- a/spring-boot-starters/pom.xml
+++ b/spring-boot-starters/pom.xml
@@ -26,6 +26,7 @@
langchain4j-community-dashscope-spring-boot-starter
langchain4j-community-qianfan-spring-boot-starter
+ langchain4j-community-xinference-spring-boot-starter
From ed6a7317b22b8ba5d6663c54bb1175e2b2fd836d Mon Sep 17 00:00:00 2001
From: lixw <>
Date: Wed, 25 Dec 2024 16:55:07 +0800
Subject: [PATCH 02/13] add spring boot starter of xinference
---
.../pom.xml | 3 +--
.../xinference/spring/AutoConfigIT.java | 18 ++++++------------
2 files changed, 7 insertions(+), 14 deletions(-)
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml
index 44a387f..9fdbe01 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml
@@ -1,6 +1,5 @@
-
+
4.0.0
dev.langchain4j
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java
index c1bda4c..cd2e6c5 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java
@@ -21,11 +21,9 @@
import dev.langchain4j.model.language.StreamingLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.scoring.ScoringModel;
-
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;
-
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
@@ -80,8 +78,7 @@ void should_provide_streaming_chat_model() {
streamingChatLanguageModel.generate(
"What is the capital of Germany?", new StreamingResponseHandler() {
@Override
- public void onNext(String token) {
- }
+ public void onNext(String token) {}
@Override
public void onComplete(Response response) {
@@ -89,8 +86,7 @@ public void onComplete(Response response) {
}
@Override
- public void onError(Throwable error) {
- }
+ public void onError(Throwable error) {}
});
Response response = future.get(60, SECONDS);
assertThat(response.content().text()).contains("Berlin");
@@ -112,8 +108,8 @@ void should_provide_language_model() {
LanguageModel languageModel = context.getBean(LanguageModel.class);
assertThat(languageModel).isInstanceOf(XinferenceLanguageModel.class);
assertThat(languageModel
- .generate("What is the capital of Germany?")
- .content())
+ .generate("What is the capital of Germany?")
+ .content())
.contains("Berlin");
assertThat(context.getBean(XinferenceLanguageModel.class)).isSameAs(languageModel);
});
@@ -135,8 +131,7 @@ void should_provide_streaming_language_model() {
streamingLanguageModel.generate(
"What is the capital of Germany?", new StreamingResponseHandler() {
@Override
- public void onNext(String token) {
- }
+ public void onNext(String token) {}
@Override
public void onComplete(Response response) {
@@ -144,8 +139,7 @@ public void onComplete(Response response) {
}
@Override
- public void onError(Throwable error) {
- }
+ public void onError(Throwable error) {}
});
Response response = future.get(60, SECONDS);
assertThat(response.content()).contains("Berlin");
From 3e16d3700ce8077f893c81b7947268a40d52c85a Mon Sep 17 00:00:00 2001
From: lixw <>
Date: Thu, 26 Dec 2024 17:20:47 +0800
Subject: [PATCH 03/13] switch testcontainers
---
.../pom.xml | 12 +++
.../xinference/spring/AutoConfigIT.java | 99 +++++++++++--------
.../spring/XinferenceContainer.java | 96 ++++++++++++++++++
.../xinference/spring/XinferenceUtils.java | 69 +++++++++++++
4 files changed, 234 insertions(+), 42 deletions(-)
create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java
create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceUtils.java
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml
index 9fdbe01..21e5b09 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/pom.xml
@@ -61,6 +61,18 @@
spring-boot-starter-test
test
+
+
+ org.testcontainers
+ testcontainers
+ test
+
+
+
+ org.testcontainers
+ junit-jupiter
+ test
+
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java
index cd2e6c5..1605849 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java
@@ -1,5 +1,12 @@
package dev.langchain4j.community.xinference.spring;
+import static dev.langchain4j.community.xinference.spring.XinferenceUtils.CHAT_MODEL_NAME;
+import static dev.langchain4j.community.xinference.spring.XinferenceUtils.EMBEDDING_MODEL_NAME;
+import static dev.langchain4j.community.xinference.spring.XinferenceUtils.GENERATE_MODEL_NAME;
+import static dev.langchain4j.community.xinference.spring.XinferenceUtils.IMAGE_MODEL_NAME;
+import static dev.langchain4j.community.xinference.spring.XinferenceUtils.RERANK_MODEL_NAME;
+import static dev.langchain4j.community.xinference.spring.XinferenceUtils.XINFERENCE_IMAGE;
+import static dev.langchain4j.community.xinference.spring.XinferenceUtils.launchCmd;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
@@ -21,39 +28,37 @@
import dev.langchain4j.model.language.StreamingLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.scoring.ScoringModel;
+import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;
-import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
+import org.testcontainers.junit.jupiter.Container;
+import org.testcontainers.junit.jupiter.Testcontainers;
/**
- * Xinference Cloud
- * https://docs.inference.top/zh
+ *
*/
-@EnabledIfEnvironmentVariable(named = "XINFERENCE_API_KEY", matches = ".+")
+@Testcontainers
class AutoConfigIT {
- private static final String API_KEY = System.getenv("XINFERENCE_API_KEY");
- private static final String BASE_URL = System.getenv("XINFERENCE_BASE_URL");
ApplicationContextRunner contextRunner =
new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(AutoConfig.class));
- @AfterEach
- void waitForNextTest() throws InterruptedException {
- Thread.sleep(3000); // 每个测试后延时3秒
- }
+ @Container
+ XinferenceContainer chatModelContainer = new XinferenceContainer(XINFERENCE_IMAGE);
@Test
- void should_provide_chat_model() {
+ void should_provide_chat_model() throws IOException, InterruptedException {
+ chatModelContainer.execInContainer("bash", "-c", launchCmd(CHAT_MODEL_NAME));
contextRunner
.withPropertyValues(
- "langchain4j.community.xinference.chat-model.base-url=" + BASE_URL,
- "langchain4j.community.xinference.chat-model.api-key=" + API_KEY,
- "langchain4j.community.xinference.chat-model.model-name=qwen2-vl-instruct")
+ "langchain4j.community.xinference.chat-model.base-url=" + chatModelContainer.getEndpoint(),
+ "langchain4j.community.xinference.chat-model.model-name=" + CHAT_MODEL_NAME,
+ "langchain4j.community.xinference.language-model.logRequests=true",
+ "langchain4j.community.xinference.language-model.logResponses=true")
.run(context -> {
ChatLanguageModel chatLanguageModel = context.getBean(ChatLanguageModel.class);
assertThat(chatLanguageModel).isInstanceOf(XinferenceChatModel.class);
@@ -64,12 +69,15 @@ void should_provide_chat_model() {
}
@Test
- void should_provide_streaming_chat_model() {
+ void should_provide_streaming_chat_model() throws IOException, InterruptedException {
+ chatModelContainer.execInContainer("bash", "-c", launchCmd(CHAT_MODEL_NAME));
contextRunner
.withPropertyValues(
- "langchain4j.community.xinference.streaming-chat-model.base-url=" + BASE_URL,
- "langchain4j.community.xinference.streaming-chat-model.api-key=" + API_KEY,
- "langchain4j.community.xinference.streaming-chat-model.model-name=qwen2-vl-instruct")
+ "langchain4j.community.xinference.streaming-chat-model.base-url="
+ + chatModelContainer.getEndpoint(),
+ "langchain4j.community.xinference.streaming-chat-model.model-name=" + CHAT_MODEL_NAME,
+ "langchain4j.community.xinference.language-model.logRequests=true",
+ "langchain4j.community.xinference.language-model.logResponses=true")
.run(context -> {
StreamingChatLanguageModel streamingChatLanguageModel =
context.getBean(StreamingChatLanguageModel.class);
@@ -96,12 +104,12 @@ public void onError(Throwable error) {}
}
@Test
- void should_provide_language_model() {
+ void should_provide_language_model() throws IOException, InterruptedException {
+ chatModelContainer.execInContainer("bash", "-c", launchCmd(GENERATE_MODEL_NAME));
contextRunner
.withPropertyValues(
- "langchain4j.community.xinference.language-model.base-url=" + BASE_URL,
- "langchain4j.community.xinference.language-model.api-key=" + API_KEY,
- "langchain4j.community.xinference.language-model.model-name=qwen2-vl-instruct",
+ "langchain4j.community.xinference.language-model.base-url=" + chatModelContainer.getEndpoint(),
+ "langchain4j.community.xinference.language-model.model-name=" + GENERATE_MODEL_NAME,
"langchain4j.community.xinference.language-model.logRequests=true",
"langchain4j.community.xinference.language-model.logResponses=true")
.run(context -> {
@@ -116,12 +124,13 @@ void should_provide_language_model() {
}
@Test
- void should_provide_streaming_language_model() {
+ void should_provide_streaming_language_model() throws IOException, InterruptedException {
+ chatModelContainer.execInContainer("bash", "-c", launchCmd(GENERATE_MODEL_NAME));
contextRunner
.withPropertyValues(
- "langchain4j.community.xinference.streaming-language-model.base-url=" + BASE_URL,
- "langchain4j.community.xinference.streaming-language-model.api-key=" + API_KEY,
- "langchain4j.community.xinference.streaming-language-model.model-name=qwen2-vl-instruct",
+ "langchain4j.community.xinference.streaming-language-model.base-url="
+ + chatModelContainer.getEndpoint(),
+ "langchain4j.community.xinference.streaming-language-model.model-name=" + GENERATE_MODEL_NAME,
"langchain4j.community.xinference.streaming-language-model.logRequests=true",
"langchain4j.community.xinference.streaming-language-model.logResponses=true")
.run(context -> {
@@ -150,29 +159,32 @@ public void onError(Throwable error) {}
}
@Test
- void should_provide_embedding_model() {
+ void should_provide_embedding_model() throws IOException, InterruptedException {
+ chatModelContainer.execInContainer("bash", "-c", launchCmd(EMBEDDING_MODEL_NAME));
contextRunner
.withPropertyValues(
- "langchain4j.community.xinference.embeddingModel.base-url=" + BASE_URL,
- "langchain4j.community.xinference.embeddingModel.api-key=" + API_KEY,
- "langchain4j.community.xinference.embeddingModel.modelName=bge-m3")
+ "langchain4j.community.xinference.embeddingModel.base-url=" + chatModelContainer.getEndpoint(),
+ "langchain4j.community.xinference.embeddingModel.modelName=" + EMBEDDING_MODEL_NAME,
+ "langchain4j.community.xinference.language-model.logRequests=true",
+ "langchain4j.community.xinference.language-model.logResponses=true")
.run(context -> {
EmbeddingModel embeddingModel = context.getBean(EmbeddingModel.class);
assertThat(embeddingModel).isInstanceOf(XinferenceEmbeddingModel.class);
assertThat(embeddingModel.embed("hello world").content().dimension())
- .isEqualTo(1024);
+ .isEqualTo(768);
assertThat(context.getBean(XinferenceEmbeddingModel.class)).isSameAs(embeddingModel);
});
}
@Test
- @Disabled("Xinference Cloud Not Support")
- void should_provide_sc_model() {
+ void should_provide_sc_model() throws IOException, InterruptedException {
+ chatModelContainer.execInContainer("bash", "-c", launchCmd(RERANK_MODEL_NAME));
contextRunner
.withPropertyValues(
- "langchain4j.community.xinference.scoringModel.base-url=" + BASE_URL,
- "langchain4j.community.xinference.scoringModel.api-key=" + API_KEY,
- "langchain4j.community.xinference.scoringModel.modelName=bge-m3")
+ "langchain4j.community.xinference.scoringModel.base-url=" + chatModelContainer.getEndpoint(),
+ "langchain4j.community.xinference.scoringModel.modelName=" + RERANK_MODEL_NAME,
+ "langchain4j.community.xinference.language-model.logRequests=true",
+ "langchain4j.community.xinference.language-model.logResponses=true")
.run(context -> {
ScoringModel scoringModel = context.getBean(ScoringModel.class);
assertThat(scoringModel).isInstanceOf(XinferenceScoringModel.class);
@@ -184,18 +196,21 @@ void should_provide_sc_model() {
Response> response = scoringModel.scoreAll(segments, query);
List scores = response.content();
assertThat(scores).hasSize(2);
- assertThat(scores.get(0)).isLessThan(scores.get(1));
+ assertThat(scores.get(0)).isGreaterThan(scores.get(1));
assertThat(context.getBean(XinferenceScoringModel.class)).isSameAs(scoringModel);
});
}
@Test
- void should_provide_image_model() {
+ @Disabled("Not supported to run in a Docker environment without GPU .")
+ void should_provide_image_model() throws IOException, InterruptedException {
+ chatModelContainer.execInContainer("bash", "-c", launchCmd(IMAGE_MODEL_NAME));
contextRunner
.withPropertyValues(
- "langchain4j.community.xinference.imageModel.base-url=" + BASE_URL,
- "langchain4j.community.xinference.imageModel.api-key=" + API_KEY,
- "langchain4j.community.xinference.imageModel.modelName=sd3-medium")
+ "langchain4j.community.xinference.imageModel.base-url=" + chatModelContainer.getEndpoint(),
+ "langchain4j.community.xinference.imageModel.modelName=" + IMAGE_MODEL_NAME,
+ "langchain4j.community.xinference.language-model.logRequests=true",
+ "langchain4j.community.xinference.language-model.logResponses=true")
.run(context -> {
ImageModel imageModel = context.getBean(ImageModel.class);
assertThat(imageModel).isInstanceOf(XinferenceImageModel.class);
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java
new file mode 100644
index 0000000..acaf356
--- /dev/null
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java
@@ -0,0 +1,96 @@
+package dev.langchain4j.community.xinference.spring;
+
+import static dev.langchain4j.community.xinference.spring.XinferenceUtils.launchCmd;
+
+import com.github.dockerjava.api.DockerClient;
+import com.github.dockerjava.api.command.InspectContainerResponse;
+import com.github.dockerjava.api.model.DeviceRequest;
+import com.github.dockerjava.api.model.Image;
+import com.github.dockerjava.api.model.Info;
+import com.github.dockerjava.api.model.RuntimeInfo;
+import java.io.IOException;
+import java.time.Duration;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testcontainers.DockerClientFactory;
+import org.testcontainers.containers.GenericContainer;
+import org.testcontainers.containers.wait.strategy.Wait;
+import org.testcontainers.utility.DockerImageName;
+
+class XinferenceContainer extends GenericContainer {
+ private static final Logger log = LoggerFactory.getLogger(XinferenceContainer.class);
+ private static final DockerImageName DOCKER_IMAGE_NAME = DockerImageName.parse("xprobe/xinference");
+ private static final Integer EXPOSED_PORT = 9997;
+ private String modelName;
+
+ public XinferenceContainer(String image) {
+ this(DockerImageName.parse(image));
+ }
+
+ public XinferenceContainer(DockerImageName image) {
+ super(image);
+ image.assertCompatibleWith(DOCKER_IMAGE_NAME);
+ Info info = this.dockerClient.infoCmd().exec();
+ Map runtimes = info.getRuntimes();
+ if (runtimes != null && runtimes.containsKey("nvidia")) {
+ this.withCreateContainerCmdModifier((cmd) -> {
+ Objects.requireNonNull(cmd.getHostConfig())
+ .withDeviceRequests(Collections.singletonList((new DeviceRequest())
+ .withCapabilities(Collections.singletonList(Collections.singletonList("gpu")))
+ .withCount(-1)));
+ });
+ }
+ this.withExposedPorts(EXPOSED_PORT);
+ this.withEnv(Map.of("XINFERENCE_MODEL_SRC", "modelscope"));
+ // https://github.com/xorbitsai/inference/issues/2573
+ this.withCommand("bash", "-c", "xinference-local -H 0.0.0.0");
+ this.waitingFor(Wait.forListeningPort().withStartupTimeout(Duration.ofMinutes(10)));
+ }
+
+ @Override
+ protected void containerIsStarted(InspectContainerResponse containerInfo) {
+ if (this.modelName != null) {
+ try {
+ log.info("Start pulling the '{}' model ... would take several minutes ...", this.modelName);
+ ExecResult r = execInContainer("bash", "-c", launchCmd(this.modelName));
+ if (r.getExitCode() != 0) {
+ throw new RuntimeException(r.getStderr());
+ }
+ log.info("Model pulling competed! {}", r);
+ } catch (IOException | InterruptedException e) {
+ throw new RuntimeException("Error pulling model", e);
+ }
+ }
+ }
+
+ public XinferenceContainer withModel(String modelName) {
+ this.modelName = modelName;
+ return this;
+ }
+
+ public void commitToImage(String imageName) {
+ DockerImageName dockerImageName = DockerImageName.parse(this.getDockerImageName());
+ if (!dockerImageName.equals(DockerImageName.parse(imageName))) {
+ DockerClient dockerClient = DockerClientFactory.instance().client();
+ List images =
+ dockerClient.listImagesCmd().withReferenceFilter(imageName).exec();
+ if (images.isEmpty()) {
+ DockerImageName imageModel = DockerImageName.parse(imageName);
+ dockerClient
+ .commitCmd(this.getContainerId())
+ .withRepository(imageModel.getUnversionedPart())
+ .withLabels(Collections.singletonMap("org.testcontainers.sessionId", ""))
+ .withTag(imageModel.getVersionPart())
+ .exec();
+ }
+ }
+ }
+
+ public String getEndpoint() {
+ return "http://" + this.getHost() + ":" + this.getMappedPort(EXPOSED_PORT);
+ }
+}
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceUtils.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceUtils.java
new file mode 100644
index 0000000..edb94be
--- /dev/null
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceUtils.java
@@ -0,0 +1,69 @@
+package dev.langchain4j.community.xinference.spring;
+
+import com.github.dockerjava.api.DockerClient;
+import com.github.dockerjava.api.model.Image;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.testcontainers.DockerClientFactory;
+import org.testcontainers.utility.DockerImageName;
+
+class XinferenceUtils {
+ public static final String XINFERENCE_BASE_URL = System.getenv("XINFERENCE_BASE_URL");
+ public static final String XINFERENCE_API_KEY = System.getenv("XINFERENCE_BASE_URL");
+ // CPU
+ public static final String XINFERENCE_IMAGE = "xprobe/xinference:latest-cpu";
+ // GPU
+ // public static final String XINFERENCE_IMAGE = "xprobe/xinference:latest";
+
+ public static final String CHAT_MODEL_NAME = "qwen2.5-instruct";
+ public static final String GENERATE_MODEL_NAME = "qwen2.5";
+ public static final String VISION_MODEL_NAME = "qwen2-vl-instruct";
+ public static final String IMAGE_MODEL_NAME = "sd3-medium";
+ public static final String EMBEDDING_MODEL_NAME = "text2vec-base-chinese";
+ public static final String RERANK_MODEL_NAME = "bge-reranker-base";
+
+ private static final Map MODEL_LAUNCH_MAP = new HashMap<>() {
+ {
+ put(
+ CHAT_MODEL_NAME,
+ String.format(
+ "xinference launch --model-engine Transformers --model-name %s --size-in-billions 0_5 --model-format pytorch --quantization none",
+ CHAT_MODEL_NAME));
+ put(
+ GENERATE_MODEL_NAME,
+ String.format(
+ "xinference launch --model-engine Transformers --model-name %s --size-in-billions 0_5 --model-format pytorch --quantization none",
+ GENERATE_MODEL_NAME));
+ put(
+ VISION_MODEL_NAME,
+ String.format(
+ "xinference launch --model-engine Transformers --model-name %s --size-in-billions 2 --model-format pytorch --quantization none",
+ VISION_MODEL_NAME));
+ put(
+ RERANK_MODEL_NAME,
+ String.format("xinference launch --model-name %s --model-type rerank", RERANK_MODEL_NAME));
+ put(
+ IMAGE_MODEL_NAME,
+ String.format("xinference launch --model-name %s --model-type image", IMAGE_MODEL_NAME));
+ put(
+ EMBEDDING_MODEL_NAME,
+ String.format("xinference launch --model-name %s --model-type embedding", EMBEDDING_MODEL_NAME));
+ }
+ };
+
+ public static DockerImageName resolve(String baseImage, String localImageName) {
+ DockerImageName dockerImageName = DockerImageName.parse(baseImage);
+ DockerClient dockerClient = DockerClientFactory.instance().client();
+ List images =
+ dockerClient.listImagesCmd().withReferenceFilter(localImageName).exec();
+ if (images.isEmpty()) {
+ return dockerImageName;
+ }
+ return DockerImageName.parse(localImageName).asCompatibleSubstituteFor(baseImage);
+ }
+
+ public static String launchCmd(String modelName) {
+ return MODEL_LAUNCH_MAP.get(modelName);
+ }
+}
From ac41eaef1e1c10d70844a6fa1a884aac197c41db Mon Sep 17 00:00:00 2001
From: alvinlee518
Date: Fri, 27 Dec 2024 18:44:41 +0800
Subject: [PATCH 04/13] Update
spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
Co-authored-by: Martin7-1
---
.../dev/langchain4j/community/xinference/spring/AutoConfig.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
index 79c512f..1bc98e5 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
@@ -16,7 +16,7 @@
@AutoConfiguration
@EnableConfigurationProperties(Properties.class)
-public class AutoConfig {
+public class XinferenceAutoConfiguration {
@Bean
@ConditionalOnProperty(PREFIX + ".chat-model.base-url")
public XinferenceChatModel chatModel(Properties properties) {
From aa6474293aea156a8002d7fbaaddbbbe16411275 Mon Sep 17 00:00:00 2001
From: alvinlee518
Date: Fri, 27 Dec 2024 18:44:50 +0800
Subject: [PATCH 05/13] Update
spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
Co-authored-by: Martin7-1
---
.../dev/langchain4j/community/xinference/spring/AutoConfig.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
index 1bc98e5..d1de16f 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
@@ -19,7 +19,7 @@
public class XinferenceAutoConfiguration {
@Bean
@ConditionalOnProperty(PREFIX + ".chat-model.base-url")
- public XinferenceChatModel chatModel(Properties properties) {
+ public XinferenceChatModel xinferenceChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getChatModel();
return XinferenceChatModel.builder()
.baseUrl(chatModelProperties.getBaseUrl())
From ae275c59dd1aaffbbfe2b37390461da5179900f2 Mon Sep 17 00:00:00 2001
From: alvinlee518
Date: Fri, 27 Dec 2024 18:45:00 +0800
Subject: [PATCH 06/13] Update
spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
Co-authored-by: Martin7-1
---
.../dev/langchain4j/community/xinference/spring/AutoConfig.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
index d1de16f..4c253a2 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
@@ -46,7 +46,7 @@ public XinferenceChatModel xinferenceChatModel(Properties properties) {
@Bean
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.base-url")
- public XinferenceStreamingChatModel streamingChatModel(Properties properties) {
+ public XinferenceStreamingChatModel xinferenceStreamingChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
return XinferenceStreamingChatModel.builder()
.baseUrl(chatModelProperties.getBaseUrl())
From 782a3eed939806e3b8e4de26f4acb4df5122f826 Mon Sep 17 00:00:00 2001
From: alvinlee518
Date: Fri, 27 Dec 2024 18:45:08 +0800
Subject: [PATCH 07/13] Update
spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
Co-authored-by: Martin7-1
---
.../dev/langchain4j/community/xinference/spring/AutoConfig.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
index 4c253a2..0ea14c3 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
@@ -72,7 +72,7 @@ public XinferenceStreamingChatModel xinferenceStreamingChatModel(Properties prop
@Bean
@ConditionalOnProperty(PREFIX + ".language-model.base-url")
- public XinferenceLanguageModel languageModel(Properties properties) {
+ public XinferenceLanguageModel xinferenceLanguageModel(Properties properties) {
LanguageModelProperties languageModelProperties = properties.getLanguageModel();
return XinferenceLanguageModel.builder()
.baseUrl(languageModelProperties.getBaseUrl())
From 7a79892e9c1aa43711a698e28f255704c01f1003 Mon Sep 17 00:00:00 2001
From: alvinlee518
Date: Fri, 27 Dec 2024 18:45:16 +0800
Subject: [PATCH 08/13] Update
spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
Co-authored-by: Martin7-1
---
.../dev/langchain4j/community/xinference/spring/AutoConfig.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
index 0ea14c3..c9380a4 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
@@ -98,7 +98,7 @@ public XinferenceLanguageModel xinferenceLanguageModel(Properties properties) {
@Bean
@ConditionalOnProperty(PREFIX + ".streaming-language-model.base-url")
- public XinferenceStreamingLanguageModel streamingLanguageModel(Properties properties) {
+ public XinferenceStreamingLanguageModel xinferenceStreamingLanguageModel(Properties properties) {
LanguageModelProperties languageModelProperties = properties.getStreamingLanguageModel();
return XinferenceStreamingLanguageModel.builder()
.baseUrl(languageModelProperties.getBaseUrl())
From e14ab0f0d8ea0a441628e3d91519742e7173d489 Mon Sep 17 00:00:00 2001
From: alvinlee518
Date: Fri, 27 Dec 2024 18:45:22 +0800
Subject: [PATCH 09/13] Update
spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
Co-authored-by: Martin7-1
---
.../dev/langchain4j/community/xinference/spring/AutoConfig.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
index c9380a4..6b56128 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
@@ -123,7 +123,7 @@ public XinferenceStreamingLanguageModel xinferenceStreamingLanguageModel(Propert
@Bean
@ConditionalOnProperty(PREFIX + ".embedding-model.base-url")
- public XinferenceEmbeddingModel embeddingModel(Properties properties) {
+ public XinferenceEmbeddingModel xinferenceEmbeddingModel(Properties properties) {
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
return XinferenceEmbeddingModel.builder()
.baseUrl(embeddingModelProperties.getBaseUrl())
From df094bcd75cc63a2d032a31bc5f102ffa91752c6 Mon Sep 17 00:00:00 2001
From: alvinlee518
Date: Fri, 27 Dec 2024 18:45:28 +0800
Subject: [PATCH 10/13] Update
spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
Co-authored-by: Martin7-1
---
.../dev/langchain4j/community/xinference/spring/AutoConfig.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
index 6b56128..e9bfbd8 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
@@ -141,7 +141,7 @@ public XinferenceEmbeddingModel xinferenceEmbeddingModel(Properties properties)
@Bean
@ConditionalOnProperty(PREFIX + ".image-model.base-url")
- public XinferenceImageModel imageModel(Properties properties) {
+ public XinferenceImageModel xinferenceImageModel(Properties properties) {
ImageModelProperties imageModelProperties = properties.getImageModel();
return XinferenceImageModel.builder()
.baseUrl(imageModelProperties.getBaseUrl())
From 6095fd1af6820abac36eba7fb8cc3fd4c139d623 Mon Sep 17 00:00:00 2001
From: alvinlee518
Date: Fri, 27 Dec 2024 18:45:34 +0800
Subject: [PATCH 11/13] Update
spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
Co-authored-by: Martin7-1
---
.../dev/langchain4j/community/xinference/spring/AutoConfig.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
index e9bfbd8..8513988 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
@@ -163,7 +163,7 @@ public XinferenceImageModel xinferenceImageModel(Properties properties) {
@Bean
@ConditionalOnProperty(PREFIX + ".scoring-model.base-url")
- public XinferenceScoringModel scoringModel(Properties properties) {
+ public XinferenceScoringModel xinferenceScoringModel(Properties properties) {
ScoringModelProperties scoringModelProperties = properties.getScoringModel();
return XinferenceScoringModel.builder()
.baseUrl(scoringModelProperties.getBaseUrl())
From ea3ce23b43a39e9bae2c86290c2dfbcddccba990 Mon Sep 17 00:00:00 2001
From: lixw <>
Date: Mon, 30 Dec 2024 11:09:41 +0800
Subject: [PATCH 12/13] Update
spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
Co-authored-by: Martin7-1
---
langchain4j-community-bom/pom.xml | 9 +-
.../spring/ChatModelProperties.java | 3 +
.../spring/EmbeddingModelProperties.java | 4 +
.../spring/ImageModelProperties.java | 4 +
.../spring/LanguageModelProperties.java | 4 +
.../xinference/spring/Properties.java | 86 ---------
.../spring/ScoringModelProperties.java | 4 +
.../spring/StreamingChatModelProperties.java | 178 ++++++++++++++++++
.../StreamingLanguageModelProperties.java | 169 +++++++++++++++++
....java => XinferenceAutoConfiguration.java} | 129 +++++++------
...ot.autoconfigure.AutoConfiguration.imports | 1 +
.../xinference/spring/AutoConfigIT.java | 60 +++---
12 files changed, 473 insertions(+), 178 deletions(-)
delete mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/Properties.java
create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingChatModelProperties.java
create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingLanguageModelProperties.java
rename spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/{AutoConfig.java => XinferenceAutoConfiguration.java} (59%)
create mode 100644 spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports
diff --git a/langchain4j-community-bom/pom.xml b/langchain4j-community-bom/pom.xml
index 38b4bc4..32a7bef 100644
--- a/langchain4j-community-bom/pom.xml
+++ b/langchain4j-community-bom/pom.xml
@@ -11,7 +11,8 @@
pom
LangChain4j :: Community :: BOM
- Bill of Materials POM for getting full, complete set of compatible versions of LangChain4j Community modules
+ Bill of Materials POM for getting full, complete set of compatible versions of LangChain4j Community
+ modules
@@ -82,6 +83,12 @@
${project.version}
+
+ dev.langchain4j
+ langchain4j-community-xinference-spring-boot-starter
+ ${project.version}
+
+
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java
index 13e0705..d86d41d 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ChatModelProperties.java
@@ -3,9 +3,12 @@
import java.time.Duration;
import java.util.List;
import java.util.Map;
+import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
+@ConfigurationProperties(prefix = ChatModelProperties.PREFIX)
public class ChatModelProperties {
+ static final String PREFIX = "langchain4j.community.xinference.chat-model";
private String baseUrl;
private String apiKey;
private String modelName;
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java
index 3fe8edc..d628324 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/EmbeddingModelProperties.java
@@ -2,9 +2,13 @@
import java.time.Duration;
import java.util.Map;
+import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
+@ConfigurationProperties(prefix = EmbeddingModelProperties.PREFIX)
public class EmbeddingModelProperties {
+ static final String PREFIX = "langchain4j.community.xinference.embedding-model";
+
private String baseUrl;
private String apiKey;
private String modelName;
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java
index 22289dc..e59ecfb 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ImageModelProperties.java
@@ -3,9 +3,13 @@
import dev.langchain4j.community.model.xinference.client.image.ResponseFormat;
import java.time.Duration;
import java.util.Map;
+import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
+@ConfigurationProperties(prefix = ImageModelProperties.PREFIX)
public class ImageModelProperties {
+ static final String PREFIX = "langchain4j.community.xinference.image-model";
+
private String baseUrl;
private String apiKey;
private String modelName;
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java
index 3a23229..43bfdd0 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/LanguageModelProperties.java
@@ -3,9 +3,13 @@
import java.time.Duration;
import java.util.List;
import java.util.Map;
+import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
+@ConfigurationProperties(prefix = LanguageModelProperties.PREFIX)
public class LanguageModelProperties {
+ static final String PREFIX = "langchain4j.community.xinference.language-model";
+
private String baseUrl;
private String apiKey;
private String modelName;
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/Properties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/Properties.java
deleted file mode 100644
index 5511231..0000000
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/Properties.java
+++ /dev/null
@@ -1,86 +0,0 @@
-package dev.langchain4j.community.xinference.spring;
-
-import org.springframework.boot.context.properties.ConfigurationProperties;
-import org.springframework.boot.context.properties.NestedConfigurationProperty;
-
-@ConfigurationProperties(prefix = Properties.PREFIX)
-public class Properties {
- static final String PREFIX = "langchain4j.community.xinference";
-
- @NestedConfigurationProperty
- private ChatModelProperties chatModel;
-
- @NestedConfigurationProperty
- private ChatModelProperties streamingChatModel;
-
- @NestedConfigurationProperty
- private EmbeddingModelProperties embeddingModel;
-
- @NestedConfigurationProperty
- private ImageModelProperties imageModel;
-
- @NestedConfigurationProperty
- private LanguageModelProperties languageModel;
-
- @NestedConfigurationProperty
- private LanguageModelProperties streamingLanguageModel;
-
- @NestedConfigurationProperty
- private ScoringModelProperties scoringModel;
-
- public ChatModelProperties getChatModel() {
- return chatModel;
- }
-
- public void setChatModel(final ChatModelProperties chatModel) {
- this.chatModel = chatModel;
- }
-
- public ChatModelProperties getStreamingChatModel() {
- return streamingChatModel;
- }
-
- public void setStreamingChatModel(final ChatModelProperties streamingChatModel) {
- this.streamingChatModel = streamingChatModel;
- }
-
- public EmbeddingModelProperties getEmbeddingModel() {
- return embeddingModel;
- }
-
- public void setEmbeddingModel(final EmbeddingModelProperties embeddingModel) {
- this.embeddingModel = embeddingModel;
- }
-
- public ImageModelProperties getImageModel() {
- return imageModel;
- }
-
- public void setImageModel(final ImageModelProperties imageModel) {
- this.imageModel = imageModel;
- }
-
- public LanguageModelProperties getLanguageModel() {
- return languageModel;
- }
-
- public void setLanguageModel(final LanguageModelProperties languageModel) {
- this.languageModel = languageModel;
- }
-
- public LanguageModelProperties getStreamingLanguageModel() {
- return streamingLanguageModel;
- }
-
- public void setStreamingLanguageModel(final LanguageModelProperties streamingLanguageModel) {
- this.streamingLanguageModel = streamingLanguageModel;
- }
-
- public ScoringModelProperties getScoringModel() {
- return scoringModel;
- }
-
- public void setScoringModel(final ScoringModelProperties scoringModel) {
- this.scoringModel = scoringModel;
- }
-}
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java
index 29c29c6..a53a280 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/ScoringModelProperties.java
@@ -2,9 +2,13 @@
import java.time.Duration;
import java.util.Map;
+import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
+@ConfigurationProperties(prefix = ScoringModelProperties.PREFIX)
public class ScoringModelProperties {
+ static final String PREFIX = "langchain4j.community.xinference.scoring-model";
+
private String baseUrl;
private String apiKey;
private String modelName;
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingChatModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingChatModelProperties.java
new file mode 100644
index 0000000..6f1121f
--- /dev/null
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingChatModelProperties.java
@@ -0,0 +1,178 @@
+package dev.langchain4j.community.xinference.spring;
+
+import java.time.Duration;
+import java.util.List;
+import java.util.Map;
+import org.springframework.boot.context.properties.ConfigurationProperties;
+import org.springframework.boot.context.properties.NestedConfigurationProperty;
+
+@ConfigurationProperties(prefix = StreamingChatModelProperties.PREFIX)
+public class StreamingChatModelProperties {
+ static final String PREFIX = "langchain4j.community.xinference.streaming-chat-model";
+
+ private String baseUrl;
+ private String apiKey;
+ private String modelName;
+ private Double temperature;
+ private Double topP;
+ private List stop;
+ private Integer maxTokens;
+ private Double presencePenalty;
+ private Double frequencyPenalty;
+ private Integer seed;
+ private String user;
+ private Object toolChoice;
+ private Boolean parallelToolCalls;
+ private Duration timeout;
+
+ @NestedConfigurationProperty
+ private ProxyProperties proxy;
+
+ private Boolean logRequests;
+ private Boolean logResponses;
+ private Map customHeaders;
+
+ public String getBaseUrl() {
+ return baseUrl;
+ }
+
+ public void setBaseUrl(final String baseUrl) {
+ this.baseUrl = baseUrl;
+ }
+
+ public String getApiKey() {
+ return apiKey;
+ }
+
+ public void setApiKey(final String apiKey) {
+ this.apiKey = apiKey;
+ }
+
+ public String getModelName() {
+ return modelName;
+ }
+
+ public void setModelName(final String modelName) {
+ this.modelName = modelName;
+ }
+
+ public Double getTemperature() {
+ return temperature;
+ }
+
+ public void setTemperature(final Double temperature) {
+ this.temperature = temperature;
+ }
+
+ public Double getTopP() {
+ return topP;
+ }
+
+ public void setTopP(final Double topP) {
+ this.topP = topP;
+ }
+
+ public List getStop() {
+ return stop;
+ }
+
+ public void setStop(final List stop) {
+ this.stop = stop;
+ }
+
+ public Integer getMaxTokens() {
+ return maxTokens;
+ }
+
+ public void setMaxTokens(final Integer maxTokens) {
+ this.maxTokens = maxTokens;
+ }
+
+ public Double getPresencePenalty() {
+ return presencePenalty;
+ }
+
+ public void setPresencePenalty(final Double presencePenalty) {
+ this.presencePenalty = presencePenalty;
+ }
+
+ public Double getFrequencyPenalty() {
+ return frequencyPenalty;
+ }
+
+ public void setFrequencyPenalty(final Double frequencyPenalty) {
+ this.frequencyPenalty = frequencyPenalty;
+ }
+
+ public Integer getSeed() {
+ return seed;
+ }
+
+ public void setSeed(final Integer seed) {
+ this.seed = seed;
+ }
+
+ public String getUser() {
+ return user;
+ }
+
+ public void setUser(final String user) {
+ this.user = user;
+ }
+
+ public Object getToolChoice() {
+ return toolChoice;
+ }
+
+ public void setToolChoice(final Object toolChoice) {
+ this.toolChoice = toolChoice;
+ }
+
+ public Boolean getParallelToolCalls() {
+ return parallelToolCalls;
+ }
+
+ public void setParallelToolCalls(final Boolean parallelToolCalls) {
+ this.parallelToolCalls = parallelToolCalls;
+ }
+
+ public Duration getTimeout() {
+ return timeout;
+ }
+
+ public void setTimeout(final Duration timeout) {
+ this.timeout = timeout;
+ }
+
+ public ProxyProperties getProxy() {
+ return proxy;
+ }
+
+ public void setProxy(final ProxyProperties proxy) {
+ this.proxy = proxy;
+ }
+
+ public Boolean getLogRequests() {
+ return logRequests;
+ }
+
+ public void setLogRequests(final Boolean logRequests) {
+ this.logRequests = logRequests;
+ }
+
+ public Boolean getLogResponses() {
+ return logResponses;
+ }
+
+ public void setLogResponses(final Boolean logResponses) {
+ this.logResponses = logResponses;
+ }
+
+ public Map getCustomHeaders() {
+ return customHeaders;
+ }
+
+ public void setCustomHeaders(final Map customHeaders) {
+ this.customHeaders = customHeaders;
+ }
+}
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingLanguageModelProperties.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingLanguageModelProperties.java
new file mode 100644
index 0000000..2937006
--- /dev/null
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/StreamingLanguageModelProperties.java
@@ -0,0 +1,169 @@
+package dev.langchain4j.community.xinference.spring;
+
+import java.time.Duration;
+import java.util.List;
+import java.util.Map;
+import org.springframework.boot.context.properties.ConfigurationProperties;
+import org.springframework.boot.context.properties.NestedConfigurationProperty;
+
+@ConfigurationProperties(prefix = StreamingLanguageModelProperties.PREFIX)
+public class StreamingLanguageModelProperties {
+ static final String PREFIX = "langchain4j.community.xinference.streaming-language-model";
+
+ private String baseUrl;
+ private String apiKey;
+ private String modelName;
+ private Integer maxTokens;
+ private Double temperature;
+ private Double topP;
+ private Integer logprobs;
+ private Boolean echo;
+ private List stop;
+ private Double presencePenalty;
+ private Double frequencyPenalty;
+ private String user;
+ private Duration timeout;
+
+ @NestedConfigurationProperty
+ private ProxyProperties proxy;
+
+ private Boolean logRequests;
+ private Boolean logResponses;
+ private Map customHeaders;
+
+ public String getBaseUrl() {
+ return baseUrl;
+ }
+
+ public void setBaseUrl(final String baseUrl) {
+ this.baseUrl = baseUrl;
+ }
+
+ public String getApiKey() {
+ return apiKey;
+ }
+
+ public void setApiKey(final String apiKey) {
+ this.apiKey = apiKey;
+ }
+
+ public String getModelName() {
+ return modelName;
+ }
+
+ public void setModelName(final String modelName) {
+ this.modelName = modelName;
+ }
+
+ public Integer getMaxTokens() {
+ return maxTokens;
+ }
+
+ public void setMaxTokens(final Integer maxTokens) {
+ this.maxTokens = maxTokens;
+ }
+
+ public Double getTemperature() {
+ return temperature;
+ }
+
+ public void setTemperature(final Double temperature) {
+ this.temperature = temperature;
+ }
+
+ public Double getTopP() {
+ return topP;
+ }
+
+ public void setTopP(final Double topP) {
+ this.topP = topP;
+ }
+
+ public Integer getLogprobs() {
+ return logprobs;
+ }
+
+ public void setLogprobs(final Integer logprobs) {
+ this.logprobs = logprobs;
+ }
+
+ public Boolean getEcho() {
+ return echo;
+ }
+
+ public void setEcho(final Boolean echo) {
+ this.echo = echo;
+ }
+
+ public List getStop() {
+ return stop;
+ }
+
+ public void setStop(final List stop) {
+ this.stop = stop;
+ }
+
+ public Double getPresencePenalty() {
+ return presencePenalty;
+ }
+
+ public void setPresencePenalty(final Double presencePenalty) {
+ this.presencePenalty = presencePenalty;
+ }
+
+ public Double getFrequencyPenalty() {
+ return frequencyPenalty;
+ }
+
+ public void setFrequencyPenalty(final Double frequencyPenalty) {
+ this.frequencyPenalty = frequencyPenalty;
+ }
+
+ public String getUser() {
+ return user;
+ }
+
+ public void setUser(final String user) {
+ this.user = user;
+ }
+
+ public Duration getTimeout() {
+ return timeout;
+ }
+
+ public void setTimeout(final Duration timeout) {
+ this.timeout = timeout;
+ }
+
+ public ProxyProperties getProxy() {
+ return proxy;
+ }
+
+ public void setProxy(final ProxyProperties proxy) {
+ this.proxy = proxy;
+ }
+
+ public Boolean getLogRequests() {
+ return logRequests;
+ }
+
+ public void setLogRequests(final Boolean logRequests) {
+ this.logRequests = logRequests;
+ }
+
+ public Boolean getLogResponses() {
+ return logResponses;
+ }
+
+ public void setLogResponses(final Boolean logResponses) {
+ this.logResponses = logResponses;
+ }
+
+ public Map getCustomHeaders() {
+ return customHeaders;
+ }
+
+ public void setCustomHeaders(final Map customHeaders) {
+ this.customHeaders = customHeaders;
+ }
+}
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/XinferenceAutoConfiguration.java
similarity index 59%
rename from spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
rename to spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/XinferenceAutoConfiguration.java
index 8513988..3f59135 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/AutoConfig.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/java/dev/langchain4j/community/xinference/spring/XinferenceAutoConfiguration.java
@@ -1,7 +1,5 @@
package dev.langchain4j.community.xinference.spring;
-import static dev.langchain4j.community.xinference.spring.Properties.PREFIX;
-
import dev.langchain4j.community.model.xinference.XinferenceChatModel;
import dev.langchain4j.community.model.xinference.XinferenceEmbeddingModel;
import dev.langchain4j.community.model.xinference.XinferenceImageModel;
@@ -10,17 +8,26 @@
import dev.langchain4j.community.model.xinference.XinferenceStreamingChatModel;
import dev.langchain4j.community.model.xinference.XinferenceStreamingLanguageModel;
import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
@AutoConfiguration
-@EnableConfigurationProperties(Properties.class)
+@EnableConfigurationProperties({
+ ChatModelProperties.class,
+ StreamingChatModelProperties.class,
+ LanguageModelProperties.class,
+ StreamingLanguageModelProperties.class,
+ EmbeddingModelProperties.class,
+ ImageModelProperties.class,
+ ScoringModelProperties.class
+})
public class XinferenceAutoConfiguration {
@Bean
- @ConditionalOnProperty(PREFIX + ".chat-model.base-url")
- public XinferenceChatModel xinferenceChatModel(Properties properties) {
- ChatModelProperties chatModelProperties = properties.getChatModel();
+ @ConditionalOnMissingBean
+ @ConditionalOnProperty(ChatModelProperties.PREFIX + ".base-url")
+ public XinferenceChatModel xinferenceChatModel(ChatModelProperties chatModelProperties) {
return XinferenceChatModel.builder()
.baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
@@ -28,7 +35,7 @@ public XinferenceChatModel xinferenceChatModel(Properties properties) {
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.stop(chatModelProperties.getStop())
- .maxTokens(chatModelProperties.getMaxRetries())
+ .maxTokens(chatModelProperties.getMaxTokens())
.presencePenalty(chatModelProperties.getPresencePenalty())
.frequencyPenalty(chatModelProperties.getFrequencyPenalty())
.seed(chatModelProperties.getSeed())
@@ -45,35 +52,36 @@ public XinferenceChatModel xinferenceChatModel(Properties properties) {
}
@Bean
- @ConditionalOnProperty(PREFIX + ".streaming-chat-model.base-url")
- public XinferenceStreamingChatModel xinferenceStreamingChatModel(Properties properties) {
- ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
+ @ConditionalOnMissingBean
+ @ConditionalOnProperty(StreamingChatModelProperties.PREFIX + ".base-url")
+ public XinferenceStreamingChatModel xinferenceStreamingChatModel(
+ StreamingChatModelProperties streamingChatModelProperties) {
return XinferenceStreamingChatModel.builder()
- .baseUrl(chatModelProperties.getBaseUrl())
- .apiKey(chatModelProperties.getApiKey())
- .modelName(chatModelProperties.getModelName())
- .temperature(chatModelProperties.getTemperature())
- .topP(chatModelProperties.getTopP())
- .stop(chatModelProperties.getStop())
- .maxTokens(chatModelProperties.getMaxRetries())
- .presencePenalty(chatModelProperties.getPresencePenalty())
- .frequencyPenalty(chatModelProperties.getFrequencyPenalty())
- .seed(chatModelProperties.getSeed())
- .user(chatModelProperties.getUser())
- .toolChoice(chatModelProperties.getToolChoice())
- .parallelToolCalls(chatModelProperties.getParallelToolCalls())
- .timeout(chatModelProperties.getTimeout())
- .proxy(ProxyProperties.convert(chatModelProperties.getProxy()))
- .logRequests(chatModelProperties.getLogRequests())
- .logResponses(chatModelProperties.getLogResponses())
- .customHeaders(chatModelProperties.getCustomHeaders())
+ .baseUrl(streamingChatModelProperties.getBaseUrl())
+ .apiKey(streamingChatModelProperties.getApiKey())
+ .modelName(streamingChatModelProperties.getModelName())
+ .temperature(streamingChatModelProperties.getTemperature())
+ .topP(streamingChatModelProperties.getTopP())
+ .stop(streamingChatModelProperties.getStop())
+ .maxTokens(streamingChatModelProperties.getMaxTokens())
+ .presencePenalty(streamingChatModelProperties.getPresencePenalty())
+ .frequencyPenalty(streamingChatModelProperties.getFrequencyPenalty())
+ .seed(streamingChatModelProperties.getSeed())
+ .user(streamingChatModelProperties.getUser())
+ .toolChoice(streamingChatModelProperties.getToolChoice())
+ .parallelToolCalls(streamingChatModelProperties.getParallelToolCalls())
+ .timeout(streamingChatModelProperties.getTimeout())
+ .proxy(ProxyProperties.convert(streamingChatModelProperties.getProxy()))
+ .logRequests(streamingChatModelProperties.getLogRequests())
+ .logResponses(streamingChatModelProperties.getLogResponses())
+ .customHeaders(streamingChatModelProperties.getCustomHeaders())
.build();
}
@Bean
- @ConditionalOnProperty(PREFIX + ".language-model.base-url")
- public XinferenceLanguageModel xinferenceLanguageModel(Properties properties) {
- LanguageModelProperties languageModelProperties = properties.getLanguageModel();
+ @ConditionalOnMissingBean
+ @ConditionalOnProperty(LanguageModelProperties.PREFIX + ".base-url")
+ public XinferenceLanguageModel xinferenceLanguageModel(LanguageModelProperties languageModelProperties) {
return XinferenceLanguageModel.builder()
.baseUrl(languageModelProperties.getBaseUrl())
.apiKey(languageModelProperties.getApiKey())
@@ -97,34 +105,35 @@ public XinferenceLanguageModel xinferenceLanguageModel(Properties properties) {
}
@Bean
- @ConditionalOnProperty(PREFIX + ".streaming-language-model.base-url")
- public XinferenceStreamingLanguageModel xinferenceStreamingLanguageModel(Properties properties) {
- LanguageModelProperties languageModelProperties = properties.getStreamingLanguageModel();
+ @ConditionalOnMissingBean
+ @ConditionalOnProperty(StreamingLanguageModelProperties.PREFIX + ".base-url")
+ public XinferenceStreamingLanguageModel xinferenceStreamingLanguageModel(
+ StreamingLanguageModelProperties streamingLanguageModelProperties) {
return XinferenceStreamingLanguageModel.builder()
- .baseUrl(languageModelProperties.getBaseUrl())
- .apiKey(languageModelProperties.getApiKey())
- .modelName(languageModelProperties.getModelName())
- .maxTokens(languageModelProperties.getMaxTokens())
- .temperature(languageModelProperties.getTemperature())
- .topP(languageModelProperties.getTopP())
- .logprobs(languageModelProperties.getLogprobs())
- .echo(languageModelProperties.getEcho())
- .stop(languageModelProperties.getStop())
- .presencePenalty(languageModelProperties.getPresencePenalty())
- .frequencyPenalty(languageModelProperties.getFrequencyPenalty())
- .user(languageModelProperties.getUser())
- .timeout(languageModelProperties.getTimeout())
- .proxy(ProxyProperties.convert(languageModelProperties.getProxy()))
- .logRequests(languageModelProperties.getLogRequests())
- .logResponses(languageModelProperties.getLogResponses())
- .customHeaders(languageModelProperties.getCustomHeaders())
+ .baseUrl(streamingLanguageModelProperties.getBaseUrl())
+ .apiKey(streamingLanguageModelProperties.getApiKey())
+ .modelName(streamingLanguageModelProperties.getModelName())
+ .maxTokens(streamingLanguageModelProperties.getMaxTokens())
+ .temperature(streamingLanguageModelProperties.getTemperature())
+ .topP(streamingLanguageModelProperties.getTopP())
+ .logprobs(streamingLanguageModelProperties.getLogprobs())
+ .echo(streamingLanguageModelProperties.getEcho())
+ .stop(streamingLanguageModelProperties.getStop())
+ .presencePenalty(streamingLanguageModelProperties.getPresencePenalty())
+ .frequencyPenalty(streamingLanguageModelProperties.getFrequencyPenalty())
+ .user(streamingLanguageModelProperties.getUser())
+ .timeout(streamingLanguageModelProperties.getTimeout())
+ .proxy(ProxyProperties.convert(streamingLanguageModelProperties.getProxy()))
+ .logRequests(streamingLanguageModelProperties.getLogRequests())
+ .logResponses(streamingLanguageModelProperties.getLogResponses())
+ .customHeaders(streamingLanguageModelProperties.getCustomHeaders())
.build();
}
@Bean
- @ConditionalOnProperty(PREFIX + ".embedding-model.base-url")
- public XinferenceEmbeddingModel xinferenceEmbeddingModel(Properties properties) {
- EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
+ @ConditionalOnMissingBean
+ @ConditionalOnProperty(EmbeddingModelProperties.PREFIX + ".base-url")
+ public XinferenceEmbeddingModel xinferenceEmbeddingModel(EmbeddingModelProperties embeddingModelProperties) {
return XinferenceEmbeddingModel.builder()
.baseUrl(embeddingModelProperties.getBaseUrl())
.apiKey(embeddingModelProperties.getApiKey())
@@ -140,9 +149,9 @@ public XinferenceEmbeddingModel xinferenceEmbeddingModel(Properties properties)
}
@Bean
- @ConditionalOnProperty(PREFIX + ".image-model.base-url")
- public XinferenceImageModel xinferenceImageModel(Properties properties) {
- ImageModelProperties imageModelProperties = properties.getImageModel();
+ @ConditionalOnMissingBean
+ @ConditionalOnProperty(ImageModelProperties.PREFIX + ".base-url")
+ public XinferenceImageModel xinferenceImageModel(ImageModelProperties imageModelProperties) {
return XinferenceImageModel.builder()
.baseUrl(imageModelProperties.getBaseUrl())
.apiKey(imageModelProperties.getApiKey())
@@ -162,9 +171,9 @@ public XinferenceImageModel xinferenceImageModel(Properties properties) {
}
@Bean
- @ConditionalOnProperty(PREFIX + ".scoring-model.base-url")
- public XinferenceScoringModel xinferenceScoringModel(Properties properties) {
- ScoringModelProperties scoringModelProperties = properties.getScoringModel();
+ @ConditionalOnMissingBean
+ @ConditionalOnProperty(ScoringModelProperties.PREFIX + ".base-url")
+ public XinferenceScoringModel xinferenceScoringModel(ScoringModelProperties scoringModelProperties) {
return XinferenceScoringModel.builder()
.baseUrl(scoringModelProperties.getBaseUrl())
.apiKey(scoringModelProperties.getApiKey())
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports
new file mode 100644
index 0000000..4de934a
--- /dev/null
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports
@@ -0,0 +1 @@
+dev.langchain4j.community.xinference.spring.XinferenceAutoConfiguration
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java
index 1605849..5198026 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/AutoConfigIT.java
@@ -45,7 +45,7 @@
@Testcontainers
class AutoConfigIT {
ApplicationContextRunner contextRunner =
- new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(AutoConfig.class));
+ new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(XinferenceAutoConfiguration.class));
@Container
XinferenceContainer chatModelContainer = new XinferenceContainer(XINFERENCE_IMAGE);
@@ -55,10 +55,10 @@ void should_provide_chat_model() throws IOException, InterruptedException {
chatModelContainer.execInContainer("bash", "-c", launchCmd(CHAT_MODEL_NAME));
contextRunner
.withPropertyValues(
- "langchain4j.community.xinference.chat-model.base-url=" + chatModelContainer.getEndpoint(),
- "langchain4j.community.xinference.chat-model.model-name=" + CHAT_MODEL_NAME,
- "langchain4j.community.xinference.language-model.logRequests=true",
- "langchain4j.community.xinference.language-model.logResponses=true")
+ ChatModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(),
+ ChatModelProperties.PREFIX + ".model-name=" + CHAT_MODEL_NAME,
+ ChatModelProperties.PREFIX + ".logRequests=true",
+ ChatModelProperties.PREFIX + ".logResponses=true")
.run(context -> {
ChatLanguageModel chatLanguageModel = context.getBean(ChatLanguageModel.class);
assertThat(chatLanguageModel).isInstanceOf(XinferenceChatModel.class);
@@ -73,11 +73,10 @@ void should_provide_streaming_chat_model() throws IOException, InterruptedExcept
chatModelContainer.execInContainer("bash", "-c", launchCmd(CHAT_MODEL_NAME));
contextRunner
.withPropertyValues(
- "langchain4j.community.xinference.streaming-chat-model.base-url="
- + chatModelContainer.getEndpoint(),
- "langchain4j.community.xinference.streaming-chat-model.model-name=" + CHAT_MODEL_NAME,
- "langchain4j.community.xinference.language-model.logRequests=true",
- "langchain4j.community.xinference.language-model.logResponses=true")
+ StreamingChatModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(),
+ StreamingChatModelProperties.PREFIX + ".model-name=" + CHAT_MODEL_NAME,
+ StreamingChatModelProperties.PREFIX + ".logRequests=true",
+ StreamingChatModelProperties.PREFIX + ".logResponses=true")
.run(context -> {
StreamingChatLanguageModel streamingChatLanguageModel =
context.getBean(StreamingChatLanguageModel.class);
@@ -108,10 +107,10 @@ void should_provide_language_model() throws IOException, InterruptedException {
chatModelContainer.execInContainer("bash", "-c", launchCmd(GENERATE_MODEL_NAME));
contextRunner
.withPropertyValues(
- "langchain4j.community.xinference.language-model.base-url=" + chatModelContainer.getEndpoint(),
- "langchain4j.community.xinference.language-model.model-name=" + GENERATE_MODEL_NAME,
- "langchain4j.community.xinference.language-model.logRequests=true",
- "langchain4j.community.xinference.language-model.logResponses=true")
+ LanguageModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(),
+ LanguageModelProperties.PREFIX + ".model-name=" + GENERATE_MODEL_NAME,
+ LanguageModelProperties.PREFIX + ".logRequests=true",
+ LanguageModelProperties.PREFIX + ".logResponses=true")
.run(context -> {
LanguageModel languageModel = context.getBean(LanguageModel.class);
assertThat(languageModel).isInstanceOf(XinferenceLanguageModel.class);
@@ -128,11 +127,10 @@ void should_provide_streaming_language_model() throws IOException, InterruptedEx
chatModelContainer.execInContainer("bash", "-c", launchCmd(GENERATE_MODEL_NAME));
contextRunner
.withPropertyValues(
- "langchain4j.community.xinference.streaming-language-model.base-url="
- + chatModelContainer.getEndpoint(),
- "langchain4j.community.xinference.streaming-language-model.model-name=" + GENERATE_MODEL_NAME,
- "langchain4j.community.xinference.streaming-language-model.logRequests=true",
- "langchain4j.community.xinference.streaming-language-model.logResponses=true")
+ StreamingLanguageModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(),
+ StreamingLanguageModelProperties.PREFIX + ".model-name=" + GENERATE_MODEL_NAME,
+ StreamingLanguageModelProperties.PREFIX + ".logRequests=true",
+ StreamingLanguageModelProperties.PREFIX + ".logResponses=true")
.run(context -> {
StreamingLanguageModel streamingLanguageModel = context.getBean(StreamingLanguageModel.class);
assertThat(streamingLanguageModel).isInstanceOf(XinferenceStreamingLanguageModel.class);
@@ -163,10 +161,10 @@ void should_provide_embedding_model() throws IOException, InterruptedException {
chatModelContainer.execInContainer("bash", "-c", launchCmd(EMBEDDING_MODEL_NAME));
contextRunner
.withPropertyValues(
- "langchain4j.community.xinference.embeddingModel.base-url=" + chatModelContainer.getEndpoint(),
- "langchain4j.community.xinference.embeddingModel.modelName=" + EMBEDDING_MODEL_NAME,
- "langchain4j.community.xinference.language-model.logRequests=true",
- "langchain4j.community.xinference.language-model.logResponses=true")
+ EmbeddingModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(),
+ EmbeddingModelProperties.PREFIX + ".modelName=" + EMBEDDING_MODEL_NAME,
+ EmbeddingModelProperties.PREFIX + ".logRequests=true",
+ EmbeddingModelProperties.PREFIX + ".logResponses=true")
.run(context -> {
EmbeddingModel embeddingModel = context.getBean(EmbeddingModel.class);
assertThat(embeddingModel).isInstanceOf(XinferenceEmbeddingModel.class);
@@ -181,10 +179,10 @@ void should_provide_sc_model() throws IOException, InterruptedException {
chatModelContainer.execInContainer("bash", "-c", launchCmd(RERANK_MODEL_NAME));
contextRunner
.withPropertyValues(
- "langchain4j.community.xinference.scoringModel.base-url=" + chatModelContainer.getEndpoint(),
- "langchain4j.community.xinference.scoringModel.modelName=" + RERANK_MODEL_NAME,
- "langchain4j.community.xinference.language-model.logRequests=true",
- "langchain4j.community.xinference.language-model.logResponses=true")
+ ScoringModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(),
+ ScoringModelProperties.PREFIX + ".modelName=" + RERANK_MODEL_NAME,
+ ScoringModelProperties.PREFIX + ".logRequests=true",
+ ScoringModelProperties.PREFIX + ".logResponses=true")
.run(context -> {
ScoringModel scoringModel = context.getBean(ScoringModel.class);
assertThat(scoringModel).isInstanceOf(XinferenceScoringModel.class);
@@ -207,10 +205,10 @@ void should_provide_image_model() throws IOException, InterruptedException {
chatModelContainer.execInContainer("bash", "-c", launchCmd(IMAGE_MODEL_NAME));
contextRunner
.withPropertyValues(
- "langchain4j.community.xinference.imageModel.base-url=" + chatModelContainer.getEndpoint(),
- "langchain4j.community.xinference.imageModel.modelName=" + IMAGE_MODEL_NAME,
- "langchain4j.community.xinference.language-model.logRequests=true",
- "langchain4j.community.xinference.language-model.logResponses=true")
+ ImageModelProperties.PREFIX + ".base-url=" + chatModelContainer.getEndpoint(),
+ ImageModelProperties.PREFIX + ".modelName=" + IMAGE_MODEL_NAME,
+ ImageModelProperties.PREFIX + ".logRequests=true",
+ ImageModelProperties.PREFIX + ".logResponses=true")
.run(context -> {
ImageModel imageModel = context.getBean(ImageModel.class);
assertThat(imageModel).isInstanceOf(XinferenceImageModel.class);
From 8e8a9426f5604d9bdfda1e0e26e26407b62c764a Mon Sep 17 00:00:00 2001
From: lixw <>
Date: Mon, 30 Dec 2024 11:11:47 +0800
Subject: [PATCH 13/13] add spring boot starter of xinference
---
.../community/xinference/spring/XinferenceContainer.java | 1 -
1 file changed, 1 deletion(-)
diff --git a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java
index acaf356..7ed4fe5 100644
--- a/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java
+++ b/spring-boot-starters/langchain4j-community-xinference-spring-boot-starter/src/test/java/dev/langchain4j/community/xinference/spring/XinferenceContainer.java
@@ -45,7 +45,6 @@ public XinferenceContainer(DockerImageName image) {
});
}
this.withExposedPorts(EXPOSED_PORT);
- this.withEnv(Map.of("XINFERENCE_MODEL_SRC", "modelscope"));
// https://github.com/xorbitsai/inference/issues/2573
this.withCommand("bash", "-c", "xinference-local -H 0.0.0.0");
this.waitingFor(Wait.forListeningPort().withStartupTimeout(Duration.ofMinutes(10)));