Skip to content

Commit

Permalink
DashScope: Support deep customization of wanx model request parameters (
Browse files Browse the repository at this point in the history
#51)

* Support deep customization of wanx model request parameters

* Make SonaCube happy

* Add some useful methods for enums

* Format codes.
  • Loading branch information
jiangsier-xyz authored Jan 16, 2025
1 parent 379e497 commit 172714b
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import com.alibaba.dashscope.utils.OSSUtils;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.internal.Utils;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.file.Files;
Expand Down Expand Up @@ -43,7 +42,7 @@ static String imageUrl(Image image, String model, String apiKey) {
try {
imageUrl = OSSUtils.upload(model, filePath, apiKey);
} catch (NoApiKeyException e) {
throw new RuntimeException(e);
throw new IllegalArgumentException(e);
}
} else {
throw new IllegalArgumentException("Failed to get image url from " + image);
Expand All @@ -69,7 +68,7 @@ static String saveDataAsTemporaryFile(String base64Data, String mimeType) {
try {
Files.copy(new ByteArrayInputStream(data), tmpFilePath, StandardCopyOption.REPLACE_EXISTING);
} catch (IOException e) {
throw new RuntimeException(e);
throw new IllegalStateException(e);
}
return tmpFilePath.toAbsolutePath().toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
import dev.langchain4j.model.output.Response;

import java.util.List;
import java.util.function.Consumer;

import static dev.langchain4j.community.model.dashscope.WanxHelper.imageUrl;
import static dev.langchain4j.community.model.dashscope.WanxHelper.imagesFrom;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;

/**
Expand All @@ -39,17 +41,20 @@ public class WanxImageModel implements ImageModel {
private final WanxImageSize size;
private final WanxImageStyle style;
private final ImageSynthesis imageSynthesis;

public WanxImageModel(String baseUrl,
String apiKey,
String modelName,
WanxImageRefMode refMode,
Float refStrength,
Integer seed,
WanxImageSize size,
WanxImageStyle style) {
private Consumer<ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?>> imageSynthesisParamCustomizer = p -> {};

public WanxImageModel(
String baseUrl,
String apiKey,
String modelName,
WanxImageRefMode refMode,
Float refStrength,
Integer seed,
WanxImageSize size,
WanxImageStyle style) {
if (Utils.isNullOrBlank(apiKey)) {
throw new IllegalArgumentException("DashScope api key must be defined. It can be generated here: https://dashscope.console.aliyun.com/apiKey");
throw new IllegalArgumentException(
"DashScope api key must be defined. It can be generated here: https://dashscope.console.aliyun.com/apiKey");
}
this.modelName = Utils.isNullOrBlank(modelName) ? WanxModelName.WANX_V1 : modelName;
this.apiKey = apiKey;
Expand All @@ -58,65 +63,74 @@ public WanxImageModel(String baseUrl,
this.seed = seed;
this.size = size;
this.style = style;
this.imageSynthesis = Utils.isNullOrBlank(baseUrl) ? new ImageSynthesis() : new ImageSynthesis("text2image", baseUrl);
this.imageSynthesis =
Utils.isNullOrBlank(baseUrl) ? new ImageSynthesis() : new ImageSynthesis("text2image", baseUrl);
}

@Override
public Response<Image> generate(String prompt) {
ImageSynthesisParam param = requestBuilder(prompt).n(1).build();
ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?> builder =
requestBuilder(prompt).n(1);

try {
ImageSynthesisResult result = imageSynthesis.call(param);
imageSynthesisParamCustomizer.accept(builder);
ImageSynthesisResult result = imageSynthesis.call(builder.build());
return Response.from(imagesFrom(result).get(0));
} catch (NoApiKeyException e) {
throw new RuntimeException(e);
throw new IllegalArgumentException(e);
}
}

@Override
public Response<List<Image>> generate(String prompt, int n) {
ImageSynthesisParam param = requestBuilder(prompt).n(n).build();
ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?> builder =
requestBuilder(prompt).n(n);

try {
ImageSynthesisResult result = imageSynthesis.call(param);
imageSynthesisParamCustomizer.accept(builder);
ImageSynthesisResult result = imageSynthesis.call(builder.build());
return Response.from(imagesFrom(result));
} catch (NoApiKeyException e) {
throw new RuntimeException(e);
throw new IllegalArgumentException(e);
}
}

@Override
public Response<Image> edit(Image image, String prompt) {
String imageUrl = imageUrl(image, modelName, apiKey);

ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?> builder = requestBuilder(prompt)
.refImage(imageUrl)
.n(1);
ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?> builder =
requestBuilder(prompt).refImage(imageUrl).n(1);

if (imageUrl.startsWith("oss://")) {
builder.header("X-DashScope-OssResourceResolve", "enable");
}

try {
imageSynthesisParamCustomizer.accept(builder);
ImageSynthesisResult result = imageSynthesis.call(builder.build());
List<Image> images = imagesFrom(result);
if (images.isEmpty()) {
ImageSynthesisOutput output = result.getOutput();
String errorMessage = String.format("[%s] %s: %s",
output.getTaskStatus(), output.getCode(), output.getMessage());
String errorMessage =
String.format("[%s] %s: %s", output.getTaskStatus(), output.getCode(), output.getMessage());
throw new IllegalStateException(errorMessage);
}
return Response.from(images.get(0));
} catch (NoApiKeyException e) {
throw new RuntimeException(e);
throw new IllegalArgumentException(e);
}
}

public void setImageSynthesisParamCustomizer(
Consumer<ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?>> imageSynthesisParamCustomizer) {
this.imageSynthesisParamCustomizer =
ensureNotNull(imageSynthesisParamCustomizer, "imageSynthesisParamCustomizer");
}

private ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?> requestBuilder(String prompt) {
ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?> builder = ImageSynthesisParam.builder()
.apiKey(apiKey)
.model(modelName)
.prompt(prompt);
ImageSynthesisParam.ImageSynthesisParamBuilder<?, ?> builder =
ImageSynthesisParam.builder().apiKey(apiKey).model(modelName).prompt(prompt);

if (seed != null) {
builder.seed(seed);
Expand Down Expand Up @@ -160,7 +174,7 @@ public static class WanxImageModelBuilder {
private WanxImageStyle style;

public WanxImageModelBuilder() {
// This is public so it can be extended
// This is public, so it can be extended
// By default with Lombok it becomes package private
}

Expand Down Expand Up @@ -205,16 +219,7 @@ public WanxImageModelBuilder style(WanxImageStyle style) {
}

public WanxImageModel build() {
return new WanxImageModel(
baseUrl,
apiKey,
modelName,
refMode,
refStrength,
seed,
size,
style
);
return new WanxImageModel(baseUrl, apiKey, modelName, refMode, refStrength, seed, size, style);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package dev.langchain4j.community.model.dashscope;

public enum WanxImageSize {
import static dev.langchain4j.internal.Utils.isNullOrBlank;

public enum WanxImageSize {
SIZE_1024_1024("1024*1024"),
SIZE_720_1280("720*1280"),
SIZE_1280_720("1280*720");
Expand All @@ -16,4 +17,22 @@ public enum WanxImageSize {
public String toString() {
return size;
}

public String getSize() {
return size;
}

public static WanxImageSize of(String size) {
if (isNullOrBlank(size)) {
return null;
}

for (WanxImageSize imageSize : values()) {
if (imageSize.size.equalsIgnoreCase(size)) {
return imageSize;
}
}

return null;
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package dev.langchain4j.community.model.dashscope;

public enum WanxImageStyle {
import static dev.langchain4j.internal.Utils.isNullOrBlank;

public enum WanxImageStyle {
PHOTOGRAPHY("<photography>"),
PORTRAIT("<portrait>"),
CARTOON_3D("<3d cartoon>"),
Expand All @@ -23,4 +24,22 @@ public enum WanxImageStyle {
public String toString() {
return style;
}

public String getStyle() {
return style;
}

public static WanxImageStyle of(String style) {
if (isNullOrBlank(style)) {
return null;
}

for (WanxImageStyle imageStyle : values()) {
if (imageStyle.toString().equalsIgnoreCase(style)) {
return imageStyle;
}
}

return null;
}
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
package dev.langchain4j.community.model.dashscope;

import static dev.langchain4j.community.model.dashscope.QwenTestHelper.apiKey;
import static dev.langchain4j.community.model.dashscope.QwenTestHelper.multimodalImageData;
import static org.assertj.core.api.Assertions.assertThat;

import dev.langchain4j.data.image.Image;
import dev.langchain4j.model.output.Response;
import java.net.URI;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.URI;

import static dev.langchain4j.community.model.dashscope.QwenTestHelper.apiKey;
import static dev.langchain4j.community.model.dashscope.QwenTestHelper.multimodalImageData;
import static org.assertj.core.api.Assertions.assertThat;

@EnabledIfEnvironmentVariable(named = "DASHSCOPE_API_KEY", matches = ".+")
class WanxImageModelIT {

Expand All @@ -22,10 +21,26 @@ class WanxImageModelIT {
@ParameterizedTest
@MethodSource("dev.langchain4j.community.model.dashscope.WanxTestHelper#imageModelNameProvider")
void simple_image_generation_works(String modelName) {
WanxImageModel model = WanxImageModel.builder()
.apiKey(apiKey())
.modelName(modelName)
.build();
WanxImageModel model =
WanxImageModel.builder().apiKey(apiKey()).modelName(modelName).build();

Response<Image> response = model.generate("Beautiful house on country side");

URI remoteImage = response.content().url();
log.info("Your remote image is here: {}", remoteImage);
assertThat(remoteImage).isNotNull();
}

@ParameterizedTest
@MethodSource("dev.langchain4j.community.model.dashscope.WanxTestHelper#imageModelNameProvider")
void simple_image_generation_works_by_customize_request(String modelName) {
WanxImageModel model =
WanxImageModel.builder().apiKey(apiKey()).modelName(modelName).build();

model.setImageSynthesisParamCustomizer(builder -> {
builder.extraInput("lora_index", "wanx1.4.6_textlora_jianzhi1_20240816");
builder.extraInput("trigger_word", "papercut");
});

Response<Image> response = model.generate("Beautiful house on country side");

Expand All @@ -37,10 +52,8 @@ void simple_image_generation_works(String modelName) {
@ParameterizedTest
@MethodSource("dev.langchain4j.community.model.dashscope.WanxTestHelper#imageModelNameProvider")
void simple_image_edition_works_by_url(String modelName) {
WanxImageModel model = WanxImageModel.builder()
.apiKey(apiKey())
.modelName(modelName)
.build();
WanxImageModel model =
WanxImageModel.builder().apiKey(apiKey()).modelName(modelName).build();

Image image = Image.builder()
.url("https://help-static-aliyun-doc.aliyuncs.com/assets/img/zh-CN/2476628361/p335710.png")
Expand All @@ -56,10 +69,8 @@ void simple_image_edition_works_by_url(String modelName) {
@ParameterizedTest
@MethodSource("dev.langchain4j.community.model.dashscope.WanxTestHelper#imageModelNameProvider")
void simple_image_edition_works_by_data(String modelName) {
WanxImageModel model = WanxImageModel.builder()
.apiKey(apiKey())
.modelName(modelName)
.build();
WanxImageModel model =
WanxImageModel.builder().apiKey(apiKey()).modelName(modelName).build();

Image image = Image.builder()
.base64Data(multimodalImageData())
Expand Down

0 comments on commit 172714b

Please sign in to comment.