Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐼 feat: Add Flux Image Generation Tool #6147

Merged
merged 6 commits into from
Mar 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,13 @@ AZURE_AI_SEARCH_SEARCH_OPTION_SELECT=
# DALLE3_AZURE_API_VERSION=
# DALLE2_AZURE_API_VERSION=

# Flux
#-----------------
FLUX_API_BASE_URL=https://api.us1.bfl.ai
# FLUX_API_BASE_URL = 'https://api.bfl.ml';

# Get your API key at https://api.us1.bfl.ai/auth/profile
# FLUX_API_KEY=

# Google
#-----------------
Expand Down
158 changes: 156 additions & 2 deletions api/app/clients/specs/BaseClient.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ jest.mock('~/models', () => ({
updateFileUsage: jest.fn(),
}));

const { getConvo, saveConvo } = require('~/models');

jest.mock('@langchain/openai', () => {
return {
ChatOpenAI: jest.fn().mockImplementation(() => {
Expand Down Expand Up @@ -540,10 +542,11 @@ describe('BaseClient', () => {

test('saveMessageToDatabase is called with the correct arguments', async () => {
const saveOptions = TestClient.getSaveOptions();
const user = {}; // Mock user
const user = {};
const opts = { user };
const saveSpy = jest.spyOn(TestClient, 'saveMessageToDatabase');
await TestClient.sendMessage('Hello, world!', opts);
expect(TestClient.saveMessageToDatabase).toHaveBeenCalledWith(
expect(saveSpy).toHaveBeenCalledWith(
expect.objectContaining({
sender: expect.any(String),
text: expect.any(String),
Expand All @@ -557,6 +560,157 @@ describe('BaseClient', () => {
);
});

test('should handle existing conversation when getConvo retrieves one', async () => {
const existingConvo = {
conversationId: 'existing-convo-id',
endpoint: 'openai',
endpointType: 'openai',
model: 'gpt-3.5-turbo',
messages: [
{ role: 'user', content: 'Existing message 1' },
{ role: 'assistant', content: 'Existing response 1' },
],
temperature: 1,
};

const { temperature: _temp, ...newConvo } = existingConvo;

const user = {
id: 'user-id',
};

getConvo.mockResolvedValue(existingConvo);
saveConvo.mockResolvedValue(newConvo);

TestClient = initializeFakeClient(
apiKey,
{
...options,
req: {
user,
},
},
[],
);

const saveSpy = jest.spyOn(TestClient, 'saveMessageToDatabase');

const newMessage = 'New message in existing conversation';
const response = await TestClient.sendMessage(newMessage, {
user,
conversationId: existingConvo.conversationId,
});

expect(getConvo).toHaveBeenCalledWith(user.id, existingConvo.conversationId);
expect(TestClient.conversationId).toBe(existingConvo.conversationId);
expect(response.conversationId).toBe(existingConvo.conversationId);
expect(TestClient.fetchedConvo).toBe(true);

expect(saveSpy).toHaveBeenCalledWith(
expect.objectContaining({
conversationId: existingConvo.conversationId,
text: newMessage,
}),
expect.any(Object),
expect.any(Object),
);

expect(saveConvo).toHaveBeenCalledTimes(2);
expect(saveConvo).toHaveBeenCalledWith(
expect.any(Object),
expect.objectContaining({
conversationId: existingConvo.conversationId,
}),
expect.objectContaining({
context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo',
unsetFields: {
temperature: 1,
},
}),
);

await TestClient.sendMessage('Another message', {
conversationId: existingConvo.conversationId,
});
expect(getConvo).toHaveBeenCalledTimes(1);
});

test('should correctly handle existing conversation and unset fields appropriately', async () => {
const existingConvo = {
conversationId: 'existing-convo-id',
endpoint: 'openai',
endpointType: 'openai',
model: 'gpt-3.5-turbo',
messages: [
{ role: 'user', content: 'Existing message 1' },
{ role: 'assistant', content: 'Existing response 1' },
],
title: 'Existing Conversation',
someExistingField: 'existingValue',
anotherExistingField: 'anotherValue',
temperature: 0.7,
modelLabel: 'GPT-3.5',
};

getConvo.mockResolvedValue(existingConvo);
saveConvo.mockResolvedValue(existingConvo);

TestClient = initializeFakeClient(
apiKey,
{
...options,
modelOptions: {
model: 'gpt-4',
temperature: 0.5,
},
},
[],
);

const newMessage = 'New message in existing conversation';
await TestClient.sendMessage(newMessage, {
conversationId: existingConvo.conversationId,
});

expect(saveConvo).toHaveBeenCalledTimes(2);

const saveConvoCall = saveConvo.mock.calls[0];
const [, savedFields, saveOptions] = saveConvoCall;

// Instead of checking all excludedKeys, we'll just check specific fields
// that we know should be excluded
expect(savedFields).not.toHaveProperty('messages');
expect(savedFields).not.toHaveProperty('title');

// Only check that someExistingField is in unsetFields
expect(saveOptions.unsetFields).toHaveProperty('someExistingField', 1);

// Mock saveConvo to return the expected fields
saveConvo.mockImplementation((req, fields) => {
return Promise.resolve({
...fields,
endpoint: 'openai',
endpointType: 'openai',
model: 'gpt-4',
temperature: 0.5,
});
});

// Only check the conversationId since that's the only field we can be sure about
expect(savedFields).toHaveProperty('conversationId', 'existing-convo-id');

expect(TestClient.fetchedConvo).toBe(true);

await TestClient.sendMessage('Another message', {
conversationId: existingConvo.conversationId,
});

expect(getConvo).toHaveBeenCalledTimes(1);

const secondSaveConvoCall = saveConvo.mock.calls[1];
expect(secondSaveConvoCall[2]).toHaveProperty('unsetFields', {});
});

test('sendCompletion is called with the correct arguments', async () => {
const payload = {}; // Mock payload
TestClient.buildMessages.mockReturnValue({ prompt: payload, tokenCountMap: null });
Expand Down
2 changes: 0 additions & 2 deletions api/app/clients/specs/FakeClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
let TestClient = new FakeClient(apiKey);
TestClient.options = options;
TestClient.abortController = { abort: jest.fn() };
TestClient.saveMessageToDatabase = jest.fn();
TestClient.loadHistory = jest
.fn()
.mockImplementation((conversationId, parentMessageId = null) => {
Expand Down Expand Up @@ -86,7 +85,6 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
return 'Mock response text';
});

// eslint-disable-next-line no-unused-vars
TestClient.getCompletion = jest.fn().mockImplementation(async (..._args) => {
return {
choices: [
Expand Down
4 changes: 3 additions & 1 deletion api/app/clients/tools/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ const availableTools = require('./manifest.json');

// Structured Tools
const DALLE3 = require('./structured/DALLE3');
const FluxAPI = require('./structured/FluxAPI');
const OpenWeather = require('./structured/OpenWeather');
const createYouTubeTools = require('./structured/YouTube');
const StructuredWolfram = require('./structured/Wolfram');
const createYouTubeTools = require('./structured/YouTube');
const StructuredACS = require('./structured/AzureAISearch');
const StructuredSD = require('./structured/StableDiffusion');
const GoogleSearchAPI = require('./structured/GoogleSearch');
Expand All @@ -30,6 +31,7 @@ module.exports = {
manifestToolMap,
// Structured Tools
DALLE3,
FluxAPI,
OpenWeather,
StructuredSD,
StructuredACS,
Expand Down
14 changes: 14 additions & 0 deletions api/app/clients/tools/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -164,5 +164,19 @@
"description": "Sign up at <a href=\"https://home.openweathermap.org/users/sign_up\" target=\"_blank\">OpenWeather</a>, then get your key at <a href=\"https://home.openweathermap.org/api_keys\" target=\"_blank\">API keys</a>."
}
]
},
{
"name": "Flux",
"pluginKey": "flux",
"description": "Generate images using text with the Flux API.",
"icon": "https://blackforestlabs.ai/wp-content/uploads/2024/07/bfl_logo_retraced_blk.png",
"isAuthRequired": "true",
"authConfig": [
{
"authField": "FLUX_API_KEY",
"label": "Your Flux API Key",
"description": "Provide your Flux API key from your user profile."
}
]
}
]
Loading
Loading