Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔉 feat: TTS/STT rate limiters #2925

Merged
merged 5 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/cache/getLogStores.js
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ const namespaces = {
message_limit: createViolationInstance('message_limit'),
token_balance: createViolationInstance(ViolationTypes.TOKEN_BALANCE),
registrations: createViolationInstance('registrations'),
[ViolationTypes.TTS_LIMIT]: createViolationInstance(ViolationTypes.TTS_LIMIT),
[ViolationTypes.STT_LIMIT]: createViolationInstance(ViolationTypes.STT_LIMIT),
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
ViolationTypes.ILLEGAL_MODEL_REQUEST,
Expand Down
7 changes: 7 additions & 0 deletions api/server/middleware/speech/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
const createTTSLimiters = require('./ttsLimiters');
const createSTTLimiters = require('./sttLimiters');

module.exports = {
createTTSLimiters,
createSTTLimiters,
};
68 changes: 68 additions & 0 deletions api/server/middleware/speech/sttLimiters.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider');
const logViolation = require('~/cache/logViolation');

const getEnvironmentVariables = () => {
const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100;
const STT_IP_WINDOW = parseInt(process.env.STT_IP_WINDOW) || 1;
const STT_USER_MAX = parseInt(process.env.STT_USER_MAX) || 50;
const STT_USER_WINDOW = parseInt(process.env.STT_USER_WINDOW) || 1;

const sttIpWindowMs = STT_IP_WINDOW * 60 * 1000;
const sttIpMax = STT_IP_MAX;
const sttIpWindowInMinutes = sttIpWindowMs / 60000;

const sttUserWindowMs = STT_USER_WINDOW * 60 * 1000;
const sttUserMax = STT_USER_MAX;
const sttUserWindowInMinutes = sttUserWindowMs / 60000;

return {
sttIpWindowMs,
sttIpMax,
sttIpWindowInMinutes,
sttUserWindowMs,
sttUserMax,
sttUserWindowInMinutes,
};
};

const createSTTHandler = (ip = true) => {
const { sttIpMax, sttIpWindowInMinutes, sttUserMax, sttUserWindowInMinutes } =
getEnvironmentVariables();

return async (req, res) => {
const type = ViolationTypes.STT_LIMIT;
const errorMessage = {
type,
max: ip ? sttIpMax : sttUserMax,
limiter: ip ? 'ip' : 'user',
windowInMinutes: ip ? sttIpWindowInMinutes : sttUserWindowInMinutes,
};

await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many STT requests. Try again later' });
};
};

const createSTTLimiters = () => {
const { sttIpWindowMs, sttIpMax, sttUserWindowMs, sttUserMax } = getEnvironmentVariables();

const sttIpLimiter = rateLimit({
windowMs: sttIpWindowMs,
max: sttIpMax,
handler: createSTTHandler(),
});

const sttUserLimiter = rateLimit({
windowMs: sttUserWindowMs,
max: sttUserMax,
handler: createSTTHandler(false),
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
});

return { sttIpLimiter, sttUserLimiter };
};

module.exports = createSTTLimiters;
68 changes: 68 additions & 0 deletions api/server/middleware/speech/ttsLimiters.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
const rateLimit = require('express-rate-limit');
const { ViolationTypes } = require('librechat-data-provider');
const logViolation = require('~/cache/logViolation');

const getEnvironmentVariables = () => {
const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100;
const TTS_IP_WINDOW = parseInt(process.env.TTS_IP_WINDOW) || 1;
const TTS_USER_MAX = parseInt(process.env.TTS_USER_MAX) || 50;
const TTS_USER_WINDOW = parseInt(process.env.TTS_USER_WINDOW) || 1;

const ttsIpWindowMs = TTS_IP_WINDOW * 60 * 1000;
const ttsIpMax = TTS_IP_MAX;
const ttsIpWindowInMinutes = ttsIpWindowMs / 60000;

const ttsUserWindowMs = TTS_USER_WINDOW * 60 * 1000;
const ttsUserMax = TTS_USER_MAX;
const ttsUserWindowInMinutes = ttsUserWindowMs / 60000;

return {
ttsIpWindowMs,
ttsIpMax,
ttsIpWindowInMinutes,
ttsUserWindowMs,
ttsUserMax,
ttsUserWindowInMinutes,
};
};

const createTTSHandler = (ip = true) => {
const { ttsIpMax, ttsIpWindowInMinutes, ttsUserMax, ttsUserWindowInMinutes } =
getEnvironmentVariables();

return async (req, res) => {
const type = ViolationTypes.TTS_LIMIT;
const errorMessage = {
type,
max: ip ? ttsIpMax : ttsUserMax,
limiter: ip ? 'ip' : 'user',
windowInMinutes: ip ? ttsIpWindowInMinutes : ttsUserWindowInMinutes,
};

await logViolation(req, res, type, errorMessage);
res.status(429).json({ message: 'Too many TTS requests. Try again later' });
};
};

const createTTSLimiters = () => {
const { ttsIpWindowMs, ttsIpMax, ttsUserWindowMs, ttsUserMax } = getEnvironmentVariables();

const ttsIpLimiter = rateLimit({
windowMs: ttsIpWindowMs,
max: ttsIpMax,
handler: createTTSHandler(),
});

const ttsUserLimiter = rateLimit({
windowMs: ttsUserWindowMs,
max: ttsUserMax,
handler: createTTSHandler(false),
keyGenerator: function (req) {
return req.user?.id; // Use the user ID or NULL if not available
},
});

return { ttsIpLimiter, ttsUserLimiter };
};

module.exports = createTTSLimiters;
10 changes: 5 additions & 5 deletions api/server/routes/files/index.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
const express = require('express');
const { uaParser, checkBan, requireJwtAuth, createFileLimiters } = require('~/server/middleware');
const { createTTSLimiters, createSTTLimiters } = require('~/server/middleware/speech');
const { createMulterInstance } = require('./multer');

const files = require('./files');
Expand All @@ -15,18 +16,17 @@ const initialize = async () => {
router.use(uaParser);

/* Important: stt/tts routes must be added before the upload limiters */
router.use('/stt', stt);
router.use('/tts', tts);
const { sttIpLimiter, sttUserLimiter } = createSTTLimiters();
const { ttsIpLimiter, ttsUserLimiter } = createTTSLimiters();
router.use('/stt', sttIpLimiter, sttUserLimiter, stt);
router.use('/tts', ttsIpLimiter, ttsUserLimiter, tts);

const upload = await createMulterInstance();
const { fileUploadIpLimiter, fileUploadUserLimiter } = createFileLimiters();
router.post('*', fileUploadIpLimiter, fileUploadUserLimiter);
router.post('/', upload.single('file'));
router.post('/images', upload.single('file'));

router.use('/stt', stt);
router.use('/tts', tts);

router.use('/', files);
router.use('/images', images);
router.use('/images/avatar', avatar);
Expand Down
53 changes: 36 additions & 17 deletions api/server/services/Config/handleRateLimits.js
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
const { RateLimitPrefix } = require('librechat-data-provider');

/**
*
* @param {TCustomConfig['rateLimits'] | undefined} rateLimits
Expand All @@ -6,24 +8,41 @@ const handleRateLimits = (rateLimits) => {
if (!rateLimits) {
return;
}
const { fileUploads, conversationsImport } = rateLimits;
if (fileUploads) {
process.env.FILE_UPLOAD_IP_MAX = fileUploads.ipMax ?? process.env.FILE_UPLOAD_IP_MAX;
process.env.FILE_UPLOAD_IP_WINDOW =
fileUploads.ipWindowInMinutes ?? process.env.FILE_UPLOAD_IP_WINDOW;
process.env.FILE_UPLOAD_USER_MAX = fileUploads.userMax ?? process.env.FILE_UPLOAD_USER_MAX;
process.env.FILE_UPLOAD_USER_WINDOW =
fileUploads.userWindowInMinutes ?? process.env.FILE_UPLOAD_USER_WINDOW;
}

if (conversationsImport) {
process.env.IMPORT_IP_MAX = conversationsImport.ipMax ?? process.env.IMPORT_IP_MAX;
process.env.IMPORT_IP_WINDOW =
conversationsImport.ipWindowInMinutes ?? process.env.IMPORT_IP_WINDOW;
process.env.IMPORT_USER_MAX = conversationsImport.userMax ?? process.env.IMPORT_USER_MAX;
process.env.IMPORT_USER_WINDOW =
conversationsImport.userWindowInMinutes ?? process.env.IMPORT_USER_WINDOW;
}
const rateLimitKeys = {
fileUploads: RateLimitPrefix.FILE_UPLOAD,
conversationsImport: RateLimitPrefix.IMPORT,
tts: RateLimitPrefix.TTS,
stt: RateLimitPrefix.STT,
};

Object.entries(rateLimitKeys).forEach(([key, prefix]) => {
const rateLimit = rateLimits[key];
if (rateLimit) {
setRateLimitEnvVars(prefix, rateLimit);
}
});
};

/**
* Set environment variables for rate limit configurations
*
* @param {string} prefix - Prefix for environment variable names
* @param {object} rateLimit - Rate limit configuration object
*/
const setRateLimitEnvVars = (prefix, rateLimit) => {
const envVarsMapping = {
ipMax: `${prefix}_IP_MAX`,
ipWindowInMinutes: `${prefix}_IP_WINDOW`,
userMax: `${prefix}_USER_MAX`,
userWindowInMinutes: `${prefix}_USER_WINDOW`,
};

Object.entries(envVarsMapping).forEach(([key, envVar]) => {
if (rateLimit[key] !== undefined) {
process.env[envVar] = rateLimit[key];
}
});
};

module.exports = handleRateLimits;
4 changes: 3 additions & 1 deletion client/src/components/Chat/Input/StreamAudio.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type { TMessage } from 'librechat-data-provider';
import { useCustomAudioRef, MediaSourceAppender, usePauseGlobalAudio } from '~/hooks/Audio';
import { useAuthContext } from '~/hooks';
import { globalAudioId } from '~/common';
import { getLatestText } from '~/utils';
import store from '~/store';

function timeoutPromise(ms: number, message?: string) {
Expand Down Expand Up @@ -47,13 +48,14 @@ export default function StreamAudio({ index = 0 }) {
);

useEffect(() => {
const latestText = getLatestText(latestMessage);
const shouldFetch =
token &&
automaticPlayback &&
isSubmitting &&
latestMessage &&
!latestMessage.isCreatedByUser &&
(latestMessage.text || latestMessage.content) &&
latestText &&
latestMessage.messageId &&
!latestMessage.messageId.includes('_') &&
!isFetching &&
Expand Down
24 changes: 15 additions & 9 deletions client/src/hooks/Messages/useMessageHelpers.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { useEffect, useRef, useCallback } from 'react';
import { isAssistantsEndpoint } from 'librechat-data-provider';
import type { TMessageProps } from '~/common';
import { useChatContext, useAssistantsMapContext } from '~/Providers';
import { getLatestText, getLengthAndFirstFiveChars } from '~/utils';
import useCopyToClipboard from './useCopyToClipboard';

export default function useMessageHelpers(props: TMessageProps) {
Expand All @@ -26,20 +27,25 @@ export default function useMessageHelpers(props: TMessageProps) {
const isLast = !children?.length;

useEffect(() => {
let contentChanged = message?.content
? message?.content?.length !== latestText.current
: message?.text !== latestText.current;

if (conversation?.conversationId === 'new') {
return;
}
if (!message) {
return;
}
if (!isLast) {
contentChanged = false;
return;
}

if (!message) {
const text = getLatestText(message);
const textKey = `${message?.messageId ?? ''}${getLengthAndFirstFiveChars(text)}`;

if (textKey === latestText.current) {
return;
} else if (isLast && conversation?.conversationId !== 'new' && contentChanged) {
setLatestMessage({ ...message });
latestText.current = message?.content ? message.content.length : message.text;
}

latestText.current = textKey;
setLatestMessage({ ...message });
}, [isLast, message, setLatestMessage, conversation?.conversationId]);

const enterEdit = useCallback(
Expand Down
2 changes: 1 addition & 1 deletion client/src/store/settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ const localStorageAtoms = {
autoScroll: atomWithLocalStorage('autoScroll', false),
showCode: atomWithLocalStorage('showCode', false),
hideSidePanel: atomWithLocalStorage('hideSidePanel', false),
modularChat: atomWithLocalStorage('modularChat', false),
modularChat: atomWithLocalStorage('modularChat', true),
LaTeXParsing: atomWithLocalStorage('LaTeXParsing', true),
UsernameDisplay: atomWithLocalStorage('UsernameDisplay', true),
TextToSpeech: atomWithLocalStorage('textToSpeech', true),
Expand Down
1 change: 1 addition & 0 deletions client/src/utils/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export * from './latex';
export * from './convos';
export * from './presets';
export * from './textarea';
export * from './messages';
export * from './languages';
export * from './endpoints';
export * from './sharedLink';
Expand Down
26 changes: 26 additions & 0 deletions client/src/utils/messages.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { ContentTypes } from 'librechat-data-provider';
import type { TMessage } from 'librechat-data-provider';

export const getLengthAndFirstFiveChars = (str?: string) => {
const length = str ? str.length : 0;
const firstFiveChars = str ? str.substring(0, 5) : '';
return `${length}${firstFiveChars}`;
};

export const getLatestText = (message?: TMessage | null) => {
if (!message) {
return '';
}
if (message.text) {
return message.text;
}
if (message.content?.length) {
for (let i = message.content.length - 1; i >= 0; i--) {
const part = message.content[i];
if (part.type === ContentTypes.TEXT && part[ContentTypes.TEXT]?.value?.length > 0) {
return part[ContentTypes.TEXT].value;
}
}
}
return '';
};
Loading
Loading