From 9989f97d133b32654d4d5e56c64dbe7568673fef Mon Sep 17 00:00:00 2001 From: Ruben Talstra Date: Thu, 13 Feb 2025 23:10:57 +0100 Subject: [PATCH] started. but not the right implementation. --- api/app/clients/BaseClient.js | 22 +- api/models/Message.js | 84 +++++-- api/models/Share.js | 67 +++++- api/models/schema/messageSchema.js | 177 ++++++++++++++- api/models/schema/shareSchema.js | 21 +- api/server/routes/messages.js | 205 ++++++++++++++++-- api/server/routes/share.js | 172 +++++++-------- api/server/services/Threads/manage.js | 28 ++- api/server/utils/encryptAES.js | 153 +++++++++++++ api/server/utils/encryptionUtil.js | 56 +++++ api/server/utils/index.js | 2 + .../components/Nav/SettingsTabs/Chat/Chat.tsx | 4 + .../SettingsTabs/Chat/EncryptionSettings.tsx | 181 ++++++++++++++++ client/src/hooks/Conversations/useSearch.ts | 14 +- client/src/hooks/SSE/useSSE.ts | 14 +- client/src/hooks/useEncryptionHeaders.ts | 9 + client/src/locales/en/translation.json | 2 + client/src/routes/Root.tsx | 19 +- client/src/routes/Search.tsx | 31 ++- client/src/store/search.ts | 6 + client/src/utils/encryptAES.ts | 141 ++++++++++++ packages/data-provider/src/request.ts | 24 ++ 22 files changed, 1266 insertions(+), 166 deletions(-) create mode 100644 api/server/utils/encryptAES.js create mode 100644 api/server/utils/encryptionUtil.js create mode 100644 client/src/components/Nav/SettingsTabs/Chat/EncryptionSettings.tsx create mode 100644 client/src/hooks/useEncryptionHeaders.ts create mode 100644 client/src/utils/encryptAES.ts diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index ebf3ca12d9e..317f11cd996 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -793,7 +793,12 @@ class BaseClient { async loadHistory(conversationId, parentMessageId = null) { logger.debug('[BaseClient] Loading history:', { conversationId, parentMessageId }); - const messages = (await getMessages({ conversationId })) ?? []; + const messages = + (await getMessages({ + conversationId, + encryptionKey: this.options.req?.headers['x-encryption-key'], + isEncrypted: this.options.req?.headers['x-encryption-enabled'] === 'true', + })) ?? []; if (messages.length === 0) { return []; @@ -856,7 +861,11 @@ class BaseClient { unfinished: false, user, }, - { context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveMessage' }, + { + context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveMessage', + encryptionKey: this.options.req?.headers['x-encryption-key'], + isEncrypted: this.options.req?.headers['x-encryption-enabled'] === 'true', + }, ); if (this.skipSaveConvo) { @@ -882,7 +891,12 @@ class BaseClient { * @param {Partial} message */ async updateMessageInDatabase(message) { - await updateMessage(this.options.req, message); + await updateMessage(this.options.req, { + ...message, + // Pass encryption headers to the update operation + encryptionKey: this.options.req?.headers['x-encryption-key'], + isEncrypted: this.options.req?.headers['x-encryption-enabled'] === 'true', + }); } /** @@ -1121,4 +1135,4 @@ class BaseClient { } } -module.exports = BaseClient; +module.exports = BaseClient; \ No newline at end of file diff --git a/api/models/Message.js b/api/models/Message.js index e651b20ad0a..ea8a6ce9ffc 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -4,6 +4,13 @@ const { logger } = require('~/config'); const idSchema = z.string().uuid(); +function getEncryptionOptions(req) { + if (req?.headers?.['x-encryption-enabled'] === 'true' && req?.headers?.['x-encryption-key']) { + return { encryptionKey: req.headers['x-encryption-key'] }; + } + return null; +} + /** * Saves a message in the database. * @@ -61,12 +68,18 @@ async function saveMessage(req, params, metadata) { update.expiredAt = null; } - const message = await Message.findOneAndUpdate( + const query = Message.findOneAndUpdate( { messageId: params.messageId, user: req.user.id }, update, { upsert: true, new: true }, ); + const encryptionOptions = getEncryptionOptions(req); + if (encryptionOptions) { + query.setOptions(encryptionOptions); + } + + const message = await query; return message.toObject(); } catch (err) { logger.error('Error saving message:', err); @@ -128,7 +141,6 @@ async function recordMessage({ ...rest }) { try { - // No parsing of convoId as may use threadId const message = { user, endpoint, @@ -138,10 +150,16 @@ async function recordMessage({ ...rest, }; - return await Message.findOneAndUpdate({ user, messageId }, message, { + const query = Message.findOneAndUpdate({ user, messageId }, message, { upsert: true, new: true, }); + + if (rest.encryptionKey && rest.isEncrypted) { + query.setOptions({ encryptionKey: rest.encryptionKey }); + } + + return await query; } catch (err) { logger.error('Error recording message:', err); throw err; @@ -162,7 +180,14 @@ async function recordMessage({ */ async function updateMessageText(req, { messageId, text }) { try { - await Message.updateOne({ messageId, user: req.user.id }, { text }); + const query = Message.updateOne({ messageId, user: req.user.id }, { text }); + + const encryptionOptions = getEncryptionOptions(req); + if (encryptionOptions) { + query.setOptions(encryptionOptions); + } + + await query; } catch (err) { logger.error('Error updating message text:', err); throw err; @@ -190,13 +215,14 @@ async function updateMessageText(req, { messageId, text }) { async function updateMessage(req, message, metadata) { try { const { messageId, ...update } = message; - const updatedMessage = await Message.findOneAndUpdate( - { messageId, user: req.user.id }, - update, - { - new: true, - }, - ); + const query = Message.findOneAndUpdate({ messageId, user: req.user.id }, update, { new: true }); + + const encryptionOptions = getEncryptionOptions(req); + if (encryptionOptions) { + query.setOptions(encryptionOptions); + } + + const updatedMessage = await query; if (!updatedMessage) { throw new Error('Message not found or user not authorized.'); @@ -234,11 +260,20 @@ async function updateMessage(req, message, metadata) { */ async function deleteMessagesSince(req, { messageId, conversationId }) { try { - const message = await Message.findOne({ messageId, user: req.user.id }).lean(); + const query = Message.findOne({ messageId, user: req.user.id }); + const encryptionOptions = getEncryptionOptions(req); + if (encryptionOptions) { + query.setOptions(encryptionOptions); + } + + const message = await query.lean(); if (message) { - const query = Message.find({ conversationId, user: req.user.id }); - return await query.deleteMany({ + const deleteQuery = Message.find({ conversationId, user: req.user.id }); + if (encryptionOptions) { + deleteQuery.setOptions(encryptionOptions); + } + return await deleteQuery.deleteMany({ createdAt: { $gt: message.createdAt }, }); } @@ -260,11 +295,20 @@ async function deleteMessagesSince(req, { messageId, conversationId }) { */ async function getMessages(filter, select) { try { + const query = Message.find(filter); + + if (filter.encryptionKey) { + query.setOptions({ + encryptionKey: filter.encryptionKey, + }); + delete filter.encryptionKey; // Remove from filter after setting options + } + if (select) { - return await Message.find(filter).select(select).sort({ createdAt: 1 }).lean(); + query.select(select); } - return await Message.find(filter).sort({ createdAt: 1 }).lean(); + return await query.sort({ createdAt: 1 }).lean(); } catch (err) { logger.error('Error getting messages:', err); throw err; @@ -281,10 +325,12 @@ async function getMessages(filter, select) { */ async function getMessage({ user, messageId }) { try { - return await Message.findOne({ + const query = Message.findOne({ user, messageId, - }).lean(); + }); + + return await query.lean(); } catch (err) { logger.error('Error getting message:', err); throw err; @@ -320,4 +366,4 @@ module.exports = { getMessages, getMessage, deleteMessages, -}; +}; \ No newline at end of file diff --git a/api/models/Share.js b/api/models/Share.js index 041927ec616..b6d50e5fe6c 100644 --- a/api/models/Share.js +++ b/api/models/Share.js @@ -2,8 +2,9 @@ const { nanoid } = require('nanoid'); const { Constants } = require('librechat-data-provider'); const { Conversation } = require('~/models/Conversation'); const SharedLink = require('./schema/shareSchema'); -const { getMessages } = require('./Message'); const logger = require('~/config/winston'); +const { getMessages } = require('./Message'); +const { decrypt } = require('~/server/utils/encryptionUtil'); class ShareServiceError extends Error { constructor(message, code) { @@ -63,7 +64,15 @@ function anonymizeMessages(messages, newConvoId) { }); } -async function getSharedMessages(shareId) { +const isEncrypted = (text) => { + if (!text || typeof text !== 'string') { + return false; + } + const parts = text.split(':'); + return parts.length === 2 && parts[0].length === 32; +}; + +async function getSharedMessages(shareId, encryptionKey) { try { const share = await SharedLink.findOne({ shareId, isPublic: true }) .populate({ @@ -77,6 +86,27 @@ async function getSharedMessages(shareId) { return null; } + // Decrypt messages if encryption key provided + if (encryptionKey) { + share.messages = share.messages.map((message) => { + if (message.text && isEncrypted(message.text)) { + message.text = decrypt(message.text, encryptionKey); + } + if (message.content) { + message.content = message.content.map((item) => { + if (item.text && isEncrypted(item.text)) { + return { + ...item, + text: decrypt(item.text, encryptionKey), + }; + } + return item; + }); + } + return message; + }); + } + const newConvoId = anonymizeConvoId(share.conversationId); const result = { ...share, @@ -86,7 +116,7 @@ async function getSharedMessages(shareId) { return result; } catch (error) { - logger.error('[getShare] Error getting share link', { + logger.error('[getShare] Error getting share link:', { error: error.message, shareId, }); @@ -187,13 +217,13 @@ async function deleteAllSharedLinks(user) { } } -async function createSharedLink(user, conversationId) { +async function createSharedLink(user, conversationId, encryptionKey) { if (!user || !conversationId) { throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS'); } try { - const [existingShare, conversationMessages] = await Promise.all([ + const [existingShare, messages] = await Promise.all([ SharedLink.findOne({ conversationId, isPublic: true }).select('-_id -__v -user').lean(), getMessages({ conversationId }), ]); @@ -207,11 +237,31 @@ async function createSharedLink(user, conversationId) { const conversation = await Conversation.findOne({ conversationId }).lean(); const title = conversation?.title || 'Untitled'; + // Ensure messages are decrypted before sharing + let sharableMessages = messages; + if (encryptionKey) { + sharableMessages = messages.map((message) => { + const decryptedMessage = { ...message }; + if (message.text && isEncrypted(message.text)) { + decryptedMessage.text = decrypt(message.text, encryptionKey); + } + if (message.content) { + decryptedMessage.content = message.content.map((item) => { + if (item.text && isEncrypted(item.text)) { + return { ...item, text: decrypt(item.text, encryptionKey) }; + } + return item; + }); + } + return decryptedMessage; + }); + } + const shareId = nanoid(); await SharedLink.create({ shareId, conversationId, - messages: conversationMessages, + messages: sharableMessages, title, user, }); @@ -243,7 +293,7 @@ async function getSharedLink(user, conversationId) { return { shareId: share.shareId, success: true }; } catch (error) { - logger.error('[getSharedLink] Error getting shared link', { + logger.error('[getSharedLink] Error getting shared link:', { error: error.message, user, conversationId, @@ -289,7 +339,7 @@ async function updateSharedLink(user, shareId) { return { shareId: newShareId, conversationId: updatedShare.conversationId }; } catch (error) { - logger.error('[updateSharedLink] Error updating shared link', { + logger.error('[updateSharedLink] Error updating shared link:', { error: error.message, user, shareId, @@ -337,4 +387,5 @@ module.exports = { deleteSharedLink, getSharedMessages, deleteAllSharedLinks, + isEncrypted, }; diff --git a/api/models/schema/messageSchema.js b/api/models/schema/messageSchema.js index be711552955..ab8e4758ce3 100644 --- a/api/models/schema/messageSchema.js +++ b/api/models/schema/messageSchema.js @@ -1,4 +1,5 @@ const mongoose = require('mongoose'); +const { encrypt, decrypt } = require('~/server/utils/encryptionUtil'); const mongoMeili = require('~/models/plugins/mongoMeili'); const messageSchema = mongoose.Schema( { @@ -141,6 +142,180 @@ const messageSchema = mongoose.Schema( { timestamps: true }, ); +messageSchema.statics.getEncryptionKey = function () { + return this.getOptions()?.encryptionKey; +}; + +messageSchema.pre('save', function (next) { + const encryptionKey = this.constructor.getEncryptionKey(); + if (!encryptionKey) { + return next(); + } + + try { + if (this.text) { + this.text = encrypt(this.text, encryptionKey); + } + if (this.content) { + this.content = this.content.map((item) => { + if (item.text) { + return { ...item, text: encrypt(item.text, encryptionKey) }; + } + return item; + }); + } + if (this.summary) { + this.summary = encrypt(this.summary, encryptionKey); + } + next(); + } catch (error) { + next(error); + } +}); + +// Post-find middleware to decrypt data +messageSchema.post('find', function (docs) { + const encryptionKey = this.getOptions()?.encryptionKey; + if (!docs || !encryptionKey) { + return; + } + + docs.forEach((doc) => { + if (!doc) { + return; + } + + try { + const isPlainObject = !doc.toObject; + + if (doc.text) { + if (isPlainObject) { + doc.text = decrypt(doc.text, encryptionKey); + } else { + doc.text = decrypt(doc.text, encryptionKey); + doc.markModified('text'); + } + } + if (doc.content) { + doc.content = doc.content.map((item) => { + if (item.text) { + return { ...item, text: decrypt(item.text, encryptionKey) }; + } + return item; + }); + if (!isPlainObject) { + doc.markModified('content'); + } + } + } catch (error) { + console.error('Decryption error:', error); + } + }); +}); + +messageSchema.post('findOne', function (doc) { + const encryptionKey = this.getOptions()?.encryptionKey; + if (!doc || !encryptionKey) { + return; + } + + try { + if (doc.text) { + doc.text = decrypt(doc.text, encryptionKey); + } + if (doc.content) { + doc.content = doc.content.map((item) => { + if (item.text) { + return { ...item, text: decrypt(item.text, encryptionKey) }; + } + return item; + }); + } + } catch (error) { + console.error('Decryption error:', error); + } +}); + +messageSchema.pre('findOneAndUpdate', function (next) { + const encryptionKey = this.getOptions()?.encryptionKey; + if (!encryptionKey) { + return next(); + } + + try { + const update = this.getUpdate(); + if (update.text) { + update.text = encrypt(update.text, encryptionKey); + } + if (update.content) { + update.content = update.content.map((item) => { + if (item.text) { + return { ...item, text: encrypt(item.text, encryptionKey) }; + } + return item; + }); + } + next(); + } catch (error) { + next(error); + } +}); + +messageSchema.pre('updateOne', function (next) { + const encryptionKey = this.getOptions()?.encryptionKey; + if (!encryptionKey) { + return next(); + } + + try { + const update = this.getUpdate(); + if (update.text) { + update.text = encrypt(update.text, encryptionKey); + } + if (update.content) { + update.content = update.content.map((item) => { + if (item.text) { + return { ...item, text: encrypt(item.text, encryptionKey) }; + } + return item; + }); + } + next(); + } catch (error) { + next(error); + } +}); + +messageSchema.pre('bulkWrite', function (next) { + const encryptionKey = this.getOptions()?.encryptionKey; + if (!encryptionKey) { + return next(); + } + + try { + const operations = this.getOperations(); + operations.forEach((op) => { + if (op.updateOne && op.updateOne.update) { + const update = op.updateOne.update; + if (update.text) { + update.text = encrypt(update.text, encryptionKey); + } + if (update.content) { + update.content = update.content.map((item) => { + if (item.text) { + return { ...item, text: encrypt(item.text, encryptionKey) }; + } + return item; + }); + } + } + }); + next(); + } catch (error) { + next(error); + } +}); + if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) { messageSchema.plugin(mongoMeili, { host: process.env.MEILI_HOST, @@ -156,4 +331,4 @@ messageSchema.index({ messageId: 1, user: 1 }, { unique: true }); /** @type {mongoose.Model} */ const Message = mongoose.models.Message || mongoose.model('Message', messageSchema); -module.exports = Message; +module.exports = Message; \ No newline at end of file diff --git a/api/models/schema/shareSchema.js b/api/models/schema/shareSchema.js index 12699a39ec6..feb2a3b35a2 100644 --- a/api/models/schema/shareSchema.js +++ b/api/models/schema/shareSchema.js @@ -1,24 +1,21 @@ const mongoose = require('mongoose'); -const shareSchema = mongoose.Schema( +const shareSchema = new mongoose.Schema( { - conversationId: { + shareId: { type: String, required: true, + unique: true, }, - title: { + conversationId: { type: String, - index: true, + required: true, }, user: { type: String, - index: true, - }, - messages: [{ type: mongoose.Schema.Types.ObjectId, ref: 'Message' }], - shareId: { - type: String, - index: true, + required: true, }, + messages: [mongoose.Schema.Types.Mixed], // Store messages directly isPublic: { type: Boolean, default: true, @@ -27,4 +24,6 @@ const shareSchema = mongoose.Schema( { timestamps: true }, ); -module.exports = mongoose.model('SharedLink', shareSchema); +const SharedLink = mongoose.models.SharedLink || mongoose.model('SharedLink', shareSchema); + +module.exports = SharedLink; \ No newline at end of file diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index 54c4aab1c2d..23d9d42a501 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -1,31 +1,172 @@ const express = require('express'); const { ContentTypes } = require('librechat-data-provider'); -const { - saveConvo, - saveMessage, - getMessage, - getMessages, - updateMessage, - deleteMessages, -} = require('~/models'); +const { Message } = require('~/models/Message'); +const { saveConvo, saveMessage, getMessages, updateMessage, deleteMessages } = require('~/models'); const { findAllArtifacts, replaceArtifactContent } = require('~/server/services/Artifacts/update'); const { requireJwtAuth, validateMessageReq } = require('~/server/middleware'); -const { countTokens } = require('~/server/utils'); +const { countTokens, decrypt, encrypt } = require('~/server/utils'); const { logger } = require('~/config'); const router = express.Router(); router.use(requireJwtAuth); +const isEncrypted = (text) => { + if (!text || typeof text !== 'string') { + return false; + } + const parts = text.split(':'); + return parts.length === 2 && parts[0].length === 32; +}; + +router.post('/encrypt', async (req, res) => { + logger.info('Encrypting messages'); + try { + const encryptionKey = req.headers['x-encryption-key']; + if (!encryptionKey) { + return res.status(400).json({ error: 'Encryption key required' }); + } + + if (!req.user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + // Find all messages for this user + const messages = await Message.find({ user: req.user.id }); + if (!messages) { + return res.status(404).json({ error: 'No messages found' }); + } + + logger.info(`Found ${messages.length} messages to encrypt for user ${req.user.id}`); + + let successCount = 0; + let errorCount = 0; + + for (const message of messages) { + try { + const messageToEncrypt = { ...message.toObject() }; + + // Only encrypt if not already encrypted + if (messageToEncrypt.text && !isEncrypted(messageToEncrypt.text)) { + messageToEncrypt.text = encrypt(messageToEncrypt.text, encryptionKey); + } + + // Encrypt content if it exists and not already encrypted + if (messageToEncrypt.content) { + messageToEncrypt.content = messageToEncrypt.content.map((item) => { + if (item.text && !isEncrypted(item.text)) { + return { ...item, text: encrypt(item.text, encryptionKey) }; + } + return item; + }); + } + + // Save the encrypted message + await Message.findOneAndUpdate({ _id: message._id }, messageToEncrypt, { new: true }); + successCount++; + } catch (error) { + errorCount++; + logger.error(`Error processing message ${message.messageId}:`, error); + } + } + + res.status(200).json({ + success: true, + stats: { + total: messages.length, + success: successCount, + errors: errorCount, + }, + }); + } catch (error) { + logger.error('Error encrypting messages:', error); + res.status(500).json({ error: error.message || 'Internal server error' }); + } +}); + +router.post('/decrypt', async (req, res) => { + logger.info('Decrypting messages'); + try { + const encryptionKey = req.headers['x-encryption-key']; + if (!encryptionKey) { + return res.status(400).json({ error: 'Encryption key required' }); + } + + if (!req.user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + // Find all messages for this user + const messages = await Message.find({ user: req.user.id }); + if (!messages) { + return res.status(404).json({ error: 'No messages found' }); + } + + logger.info(`Found ${messages.length} messages to decrypt for user ${req.user.id}`); + + let successCount = 0; + let errorCount = 0; + + for (const message of messages) { + try { + const decryptedMessage = { ...message.toObject() }; + + // Decrypt text if it exists and is encrypted + if (decryptedMessage.text && isEncrypted(decryptedMessage.text)) { + decryptedMessage.text = decrypt(decryptedMessage.text, encryptionKey); + } + + // Decrypt content if it exists + if (decryptedMessage.content) { + decryptedMessage.content = decryptedMessage.content.map((item) => { + if (item.text && isEncrypted(item.text)) { + return { ...item, text: decrypt(item.text, encryptionKey) }; + } + return item; + }); + } + + // Save the decrypted message + await Message.findOneAndUpdate({ _id: message._id }, decryptedMessage, { new: true }); + successCount++; + } catch (error) { + errorCount++; + logger.error(`Error processing message ${message.messageId}:`, error); + } + } + + res.status(200).json({ + success: true, + stats: { + total: messages.length, + success: successCount, + errors: errorCount, + }, + }); + } catch (error) { + logger.error('Error decrypting messages:', error); + res.status(500).json({ error: error.message || 'Internal server error' }); + } +}); + router.post('/artifact/:messageId', async (req, res) => { try { - const { messageId } = req.params; + const { messageId, conversationId } = req.params; const { index, original, updated } = req.body; if (typeof index !== 'number' || index < 0 || original == null || updated == null) { return res.status(400).json({ error: 'Invalid request parameters' }); } - const message = await getMessage({ user: req.user.id, messageId }); + const filter = { + conversationId, + messageId, + encryptionKey: + req.headers['x-encryption-enabled'] === 'true' + ? req.headers['x-encryption-key'] + : undefined, + }; + const message = await getMessages(filter, '-_id -__v -user'); + if (!message) { return res.status(404).json({ error: 'Message not found' }); } @@ -78,11 +219,17 @@ router.post('/artifact/:messageId', async (req, res) => { } }); -/* Note: It's necessary to add `validateMessageReq` within route definition for correct params */ router.get('/:conversationId', validateMessageReq, async (req, res) => { try { const { conversationId } = req.params; - const messages = await getMessages({ conversationId }, '-_id -__v -user'); + const filter = { + conversationId, + encryptionKey: + req.headers['x-encryption-enabled'] === 'true' + ? req.headers['x-encryption-key'] + : undefined, + }; + const messages = await getMessages(filter, '-_id -__v -user'); res.status(200).json(messages); } catch (error) { logger.error('Error fetching messages:', error); @@ -112,7 +259,15 @@ router.post('/:conversationId', validateMessageReq, async (req, res) => { router.get('/:conversationId/:messageId', validateMessageReq, async (req, res) => { try { const { conversationId, messageId } = req.params; - const message = await getMessages({ conversationId, messageId }, '-_id -__v -user'); + const filter = { + conversationId, + messageId, + encryptionKey: + req.headers['x-encryption-enabled'] === 'true' + ? req.headers['x-encryption-key'] + : undefined, + }; + const message = await getMessages(filter, '-_id -__v -user'); if (!message) { return res.status(404).json({ error: 'Message not found' }); } @@ -138,7 +293,15 @@ router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) = return res.status(400).json({ error: 'Invalid index' }); } - const message = (await getMessages({ conversationId, messageId }, 'content tokenCount'))?.[0]; + const filter = { + conversationId, + messageId, + encryptionKey: + req.headers['x-encryption-enabled'] === 'true' + ? req.headers['x-encryption-key'] + : undefined, + }; + const message = (await getMessages(filter, 'content tokenCount'))?.[0]; if (!message) { return res.status(404).json({ error: 'Message not found' }); } @@ -178,7 +341,15 @@ router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) = router.delete('/:conversationId/:messageId', validateMessageReq, async (req, res) => { try { const { messageId } = req.params; - await deleteMessages({ messageId }); + const filter = { + messageId, + user: req.user.id, + encryptionKey: + req.headers['x-encryption-enabled'] === 'true' + ? req.headers['x-encryption-key'] + : undefined, + }; + await deleteMessages(filter); res.status(204).send(); } catch (error) { logger.error('Error deleting message:', error); @@ -186,4 +357,4 @@ router.delete('/:conversationId/:messageId', validateMessageReq, async (req, res } }); -module.exports = router; +module.exports = router; \ No newline at end of file diff --git a/api/server/routes/share.js b/api/server/routes/share.js index e551f4a354e..841d78721cc 100644 --- a/api/server/routes/share.js +++ b/api/server/routes/share.js @@ -11,6 +11,7 @@ const { const requireJwtAuth = require('~/server/middleware/requireJwtAuth'); const { isEnabled } = require('~/server/utils'); const router = express.Router(); +const { logger } = require('~/config'); /** * Shared messages @@ -39,102 +40,101 @@ if (allowSharedLinks) { } }, ); -} -/** - * Shared links - */ -router.get('/', requireJwtAuth, async (req, res) => { - try { - const params = { - pageParam: req.query.cursor, - pageSize: Math.max(1, parseInt(req.query.pageSize) || 10), - isPublic: isEnabled(req.query.isPublic), - sortBy: ['createdAt', 'title'].includes(req.query.sortBy) ? req.query.sortBy : 'createdAt', - sortDirection: ['asc', 'desc'].includes(req.query.sortDirection) - ? req.query.sortDirection - : 'desc', - search: req.query.search - ? decodeURIComponent(req.query.search.trim()) - : undefined, - }; + /** + * Shared links + */ + router.get('/', requireJwtAuth, async (req, res) => { + try { + const params = { + pageParam: req.query.cursor, + pageSize: Math.max(1, parseInt(req.query.pageSize) || 10), + isPublic: isEnabled(req.query.isPublic), + sortBy: ['createdAt', 'title'].includes(req.query.sortBy) ? req.query.sortBy : 'createdAt', + sortDirection: ['asc', 'desc'].includes(req.query.sortDirection) + ? req.query.sortDirection + : 'desc', + search: req.query.search ? decodeURIComponent(req.query.search.trim()) : undefined, + }; - const result = await getSharedLinks( - req.user.id, - params.pageParam, - params.pageSize, - params.isPublic, - params.sortBy, - params.sortDirection, - params.search, - ); + const result = await getSharedLinks( + req.user.id, + params.pageParam, + params.pageSize, + params.isPublic, + params.sortBy, + params.sortDirection, + params.search, + ); - res.status(200).send({ - links: result.links, - nextCursor: result.nextCursor, - hasNextPage: result.hasNextPage, - }); - } catch (error) { - console.error('Error getting shared links:', error); - res.status(500).json({ - message: 'Error getting shared links', - error: error.message, - }); - } -}); + res.status(200).send({ + links: result.links, + nextCursor: result.nextCursor, + hasNextPage: result.hasNextPage, + }); + } catch (error) { + logger.error('Error getting shared links:', error); + res.status(500).json({ + message: 'Error getting shared links', + error: error.message, + }); + } + }); + + router.get('/link/:conversationId', requireJwtAuth, async (req, res) => { + try { + const share = await getSharedLink(req.user.id, req.params.conversationId); -router.get('/link/:conversationId', requireJwtAuth, async (req, res) => { - try { - const share = await getSharedLink(req.user.id, req.params.conversationId); + return res.status(200).json({ + success: share.success, + shareId: share.shareId, + conversationId: req.params.conversationId, + }); + } catch (error) { + res.status(500).json({ message: 'Error getting shared link' }); + } + }); - return res.status(200).json({ - success: share.success, - shareId: share.shareId, - conversationId: req.params.conversationId, - }); - } catch (error) { - res.status(500).json({ message: 'Error getting shared link' }); - } -}); + router.post('/:conversationId', requireJwtAuth, async (req, res) => { + try { + const encryptionKey = req.headers['x-encryption-key']; + const created = await createSharedLink(req.user.id, req.params.conversationId, encryptionKey); -router.post('/:conversationId', requireJwtAuth, async (req, res) => { - try { - const created = await createSharedLink(req.user.id, req.params.conversationId); - if (created) { - res.status(200).json(created); - } else { - res.status(404).end(); + if (created) { + res.status(200).json(created); + } else { + res.status(404).end(); + } + } catch (error) { + res.status(500).json({ message: 'Error creating shared link' }); } - } catch (error) { - res.status(500).json({ message: 'Error creating shared link' }); - } -}); + }); -router.patch('/:shareId', requireJwtAuth, async (req, res) => { - try { - const updatedShare = await updateSharedLink(req.user.id, req.params.shareId); - if (updatedShare) { - res.status(200).json(updatedShare); - } else { - res.status(404).end(); + router.patch('/:shareId', requireJwtAuth, async (req, res) => { + try { + const updatedShare = await updateSharedLink(req.user.id, req.params.shareId); + if (updatedShare) { + res.status(200).json(updatedShare); + } else { + res.status(404).end(); + } + } catch (error) { + res.status(500).json({ message: 'Error updating shared link' }); } - } catch (error) { - res.status(500).json({ message: 'Error updating shared link' }); - } -}); + }); -router.delete('/:shareId', requireJwtAuth, async (req, res) => { - try { - const result = await deleteSharedLink(req.user.id, req.params.shareId); + router.delete('/:shareId', requireJwtAuth, async (req, res) => { + try { + const result = await deleteSharedLink(req.user.id, req.params.shareId); + if (!result) { + return res.status(404).json({ message: 'Share not found' }); + } - if (!result) { - return res.status(404).json({ message: 'Share not found' }); + return res.status(200).json(result); + } catch (error) { + return res.status(400).json({ message: error.message }); } + }); +} - return res.status(200).json(result); - } catch (error) { - return res.status(400).json({ message: error.message }); - } -}); - -module.exports = router; +module.exports = router; \ No newline at end of file diff --git a/api/server/services/Threads/manage.js b/api/server/services/Threads/manage.js index f99dca7534a..24138070bfc 100644 --- a/api/server/services/Threads/manage.js +++ b/api/server/services/Threads/manage.js @@ -93,6 +93,8 @@ async function saveUserMessage(req, params) { text: params.text, isCreatedByUser: true, tokenCount, + encryptionKey: req.headers['x-encryption-key'], + isEncrypted: req.headers['x-encryption-enabled'] === 'true', }; const convo = { @@ -109,7 +111,11 @@ async function saveUserMessage(req, params) { convo.file_ids = params.file_ids; } - const message = await recordMessage(userMessage); + const message = await recordMessage({ + ...userMessage, + encryptionKey: req.headers['x-encryption-key'], + isEncrypted: req.headers['x-encryption-enabled'] === 'true', + }); await saveConvo(req, convo, { context: 'api/server/services/Threads/manage.js #saveUserMessage', }); @@ -154,6 +160,8 @@ async function saveAssistantMessage(req, params) { text: params.text, unfinished: false, // tokenCount, + encryptionKey: req.headers['x-encryption-key'], + isEncrypted: req.headers['x-encryption-enabled'] === 'true', }); await saveConvo( @@ -228,17 +236,26 @@ async function syncMessages({ const modifyPromises = []; const recordPromises = []; + const encryptionHeaders = { + encryptionKey: openai.req.headers['x-encryption-key'], + isEncrypted: openai.req.headers['x-encryption-enabled'] === 'true', + }; /** - * * Modify API message and save newMessage to DB * * @param {Object} params - The parameters object * @param {TMessage} params.dbMessage - * @param {dbMessage} params.apiMessage + * @param {ThreadMessage} params.apiMessage */ const processNewMessage = async ({ dbMessage, apiMessage }) => { - recordPromises.push(recordMessage({ ...dbMessage, user: openai.req.user.id })); + recordPromises.push( + recordMessage({ + ...dbMessage, + user: openai.req.user.id, + ...encryptionHeaders, + }), + ); if (!apiMessage.id.includes('msg_')) { return; @@ -337,7 +354,6 @@ async function syncMessages({ if (msg.role === 'user' && msg.file_ids?.length) { return [...acc, ...msg.file_ids]; } - return acc; }, []); @@ -684,4 +700,4 @@ module.exports = { addThreadMetadata, mapMessagesToSteps, saveAssistantMessage, -}; +}; \ No newline at end of file diff --git a/api/server/utils/encryptAES.js b/api/server/utils/encryptAES.js new file mode 100644 index 00000000000..848b0c7cc75 --- /dev/null +++ b/api/server/utils/encryptAES.js @@ -0,0 +1,153 @@ +/** + * Encrypts a plaintext string with AES-GCM using a 256-bit key from `keyHex`. + * + * @param {string} plaintext - The raw text to encrypt. + * @param {string} keyHex - A 64-character hex string representing a 256-bit key. + * @returns {Promise} A base64 string of the form "ivBase64:cipherBase64". + * @throws {Error} If encryption fails. + */ +async function encryptLocallyAESGCM(plaintext, keyHex) { + try { + // 1. Convert the hex key into an ArrayBuffer. + const keyBuffer = hexToArrayBuffer(keyHex); + + // 2. Import the key for AES-GCM. + const cryptoKey = await crypto.subtle.importKey( + 'raw', + keyBuffer, + { name: 'AES-GCM' }, + false, // non-extractable. + ['encrypt'], + ); + + // 3. Encode plaintext to bytes. + const encoder = new TextEncoder(); + const plaintextBytes = encoder.encode(plaintext); + + // 4. Generate a random 12-byte IV. + const iv = crypto.getRandomValues(new Uint8Array(12)); + + // 5. Encrypt using AES-GCM. + const cipherBuffer = await crypto.subtle.encrypt( + { name: 'AES-GCM', iv }, + cryptoKey, + plaintextBytes, + ); + + // 6. Convert IV and ciphertext to base64. + const ivBase64 = arrayBufferToBase64(iv); + const cipherBase64 = arrayBufferToBase64(cipherBuffer); + + // 7. Return them as a single string: "ivBase64:cipherBase64". + return ivBase64 + ':' + cipherBase64; + } catch (error) { + console.error('Encryption error:', error); + throw new Error('Failed to encrypt data.'); + } +} + +/** + * Decrypts an AES-GCM ciphertext produced by `encryptLocallyAESGCM`. + * + * @param {string} cipherString - The string from encryption in the format "ivBase64:cipherBase64". + * @param {string} keyHex - The same 64-character hex key used for encryption. + * @returns {Promise} The original plaintext string. + * @throws {Error} If decryption fails. + */ +async function decryptLocallyAESGCM(cipherString, keyHex) { + try { + // 1. Split into ivBase64 and cipherBase64. + const parts = cipherString.split(':'); + if (parts.length !== 2) { + throw new Error('Invalid cipher string format. Expected "ivBase64:cipherBase64".'); + } + const [ivBase64, cipherBase64] = parts; + + // 2. Convert hex key to ArrayBuffer. + const keyBuffer = hexToArrayBuffer(keyHex); + + // 3. Import the key. + const cryptoKey = await crypto.subtle.importKey( + 'raw', + keyBuffer, + { name: 'AES-GCM' }, + false, + ['decrypt'], + ); + + // 4. Convert base64 strings to ArrayBuffers. + const iv = new Uint8Array(base64ToArrayBuffer(ivBase64)); + const cipherBytes = base64ToArrayBuffer(cipherBase64); + + // 5. Decrypt. + const plainBuffer = await crypto.subtle.decrypt( + { name: 'AES-GCM', iv }, + cryptoKey, + cipherBytes, + ); + + // 6. Convert decrypted bytes back to a string. + const decoder = new TextDecoder(); + return decoder.decode(plainBuffer); + } catch (error) { + console.error('Decryption error:', error); + throw new Error('Failed to decrypt data.'); + } +} + +/** + * Converts a base64 string to an ArrayBuffer. + * + * @param {string} base64 - The base64 string. + * @returns {ArrayBuffer} The resulting ArrayBuffer. + */ +function base64ToArrayBuffer(base64) { + const binaryString = window.atob(base64); + const len = binaryString.length; + const bytes = new Uint8Array(len); + for (let i = 0; i < len; i++) { + bytes[i] = binaryString.charCodeAt(i); + } + return bytes.buffer; +} + +/** + * Converts an ArrayBuffer to a base64 string. + * + * @param {ArrayBuffer} buffer - The ArrayBuffer to convert. + * @returns {string} The base64-encoded string. + */ +function arrayBufferToBase64(buffer) { + const bytes = new Uint8Array(buffer); + let binary = ''; + for (let i = 0; i < bytes.byteLength; i++) { + binary += String.fromCharCode(bytes[i]); + } + return window.btoa(binary); +} + +/** + * Converts a hex string to an ArrayBuffer. + * + * @param {string} hexString - The hex string. + * @returns {ArrayBuffer} The resulting ArrayBuffer. + * @throws {Error} If the hex string length is not even. + */ +function hexToArrayBuffer(hexString) { + if (hexString.length % 2 !== 0) { + throw new Error('Invalid hex string'); + } + const bytes = new Uint8Array(hexString.length / 2); + for (let i = 0; i < hexString.length; i += 2) { + bytes[i / 2] = parseInt(hexString.substring(i, i + 2), 16); + } + return bytes.buffer; +} + +module.exports = { + encryptLocallyAESGCM, + decryptLocallyAESGCM, + base64ToArrayBuffer, + arrayBufferToBase64, + hexToArrayBuffer, +}; \ No newline at end of file diff --git a/api/server/utils/encryptionUtil.js b/api/server/utils/encryptionUtil.js new file mode 100644 index 00000000000..3686f391d10 --- /dev/null +++ b/api/server/utils/encryptionUtil.js @@ -0,0 +1,56 @@ +const crypto = require('crypto'); +const IV_LENGTH = 16; +function normalizeKey(key) { + // Remove quotes and clean the key + const cleanKey = key.replace(/"/g, ''); + // If key is hex string, convert to Buffer + if (/^[0-9a-f]+$/i.test(cleanKey)) { + return Buffer.from(cleanKey, 'hex'); + } + // If not hex, hash the key to ensure proper length + const hash = crypto.createHash('sha256'); + hash.update(cleanKey); + return hash.digest(); +} +function encrypt(text, userKey) { + if (!text || !userKey) { + throw new Error('Both text and encryption key are required'); + } + try { + const key = normalizeKey(userKey); + const textToEncrypt = typeof text === 'string' ? text : JSON.stringify(text); + const iv = crypto.randomBytes(IV_LENGTH); + const cipher = crypto.createCipheriv('aes-256-cbc', key, iv); + let encrypted = cipher.update(textToEncrypt, 'utf8', 'hex'); + encrypted += cipher.final('hex'); + return `${iv.toString('hex')}:${encrypted}`; + } catch (error) { + console.error('Encryption error:', error); + throw error; + } +} +function decrypt(text, userKey) { + if (!text || !userKey) { + throw new Error('Both text and encryption key are required'); + } + try { + if (!text.includes(':')) { + return text; + } + const key = normalizeKey(userKey); + const [ivHex, encryptedHex] = text.split(':'); + if (!ivHex || !encryptedHex) { + return text; + } + const iv = Buffer.from(ivHex, 'hex'); + const encryptedText = Buffer.from(encryptedHex, 'hex'); + const decipher = crypto.createDecipheriv('aes-256-cbc', key, iv); + let decrypted = decipher.update(encryptedText, 'hex', 'utf8'); + decrypted += decipher.final('utf8'); + return decrypted; + } catch (error) { + console.error('Decryption error:', error); + throw error; + } +} +module.exports = { encrypt, decrypt }; \ No newline at end of file diff --git a/api/server/utils/index.js b/api/server/utils/index.js index b79b42f00d5..071cebafdfb 100644 --- a/api/server/utils/index.js +++ b/api/server/utils/index.js @@ -7,6 +7,7 @@ const cryptoUtils = require('./crypto'); const queue = require('./queue'); const files = require('./files'); const math = require('./math'); +const encryptionUtil = require('./encryptionUtil'); /** * Check if email configuration is set @@ -26,6 +27,7 @@ module.exports = { checkEmailConfig, ...cryptoUtils, ...handleText, + ...encryptionUtil, countTokens, removePorts, sendEmail, diff --git a/client/src/components/Nav/SettingsTabs/Chat/Chat.tsx b/client/src/components/Nav/SettingsTabs/Chat/Chat.tsx index 1fd2e5c7bd4..db38efe9021 100644 --- a/client/src/components/Nav/SettingsTabs/Chat/Chat.tsx +++ b/client/src/components/Nav/SettingsTabs/Chat/Chat.tsx @@ -1,4 +1,5 @@ import { memo } from 'react'; +import EncryptionSettings from './EncryptionSettings'; import MaximizeChatSpace from './MaximizeChatSpace'; import FontSizeSelector from './FontSizeSelector'; import SendMessageKeyEnter from './EnterToSend'; @@ -20,6 +21,9 @@ function Chat() {
+
+ +
diff --git a/client/src/components/Nav/SettingsTabs/Chat/EncryptionSettings.tsx b/client/src/components/Nav/SettingsTabs/Chat/EncryptionSettings.tsx new file mode 100644 index 00000000000..fd93a9e175a --- /dev/null +++ b/client/src/components/Nav/SettingsTabs/Chat/EncryptionSettings.tsx @@ -0,0 +1,181 @@ +import React, { useState } from 'react'; +import { Button } from '~/components'; +import useLocalStorage from '~/hooks/useLocalStorage'; +import HoverCardSettings from '../HoverCardSettings'; +import { useAuthContext, useLocalize } from '~/hooks'; + +const EncryptionSettings = () => { + const [isEncryptionEnabled, setIsEncryptionEnabled] = useLocalStorage( + 'isEncryptionEnabled', + false, + ); + const [passwordInput, setPasswordInput] = useState(''); + const [encryptionKey, setEncryptionKey] = useLocalStorage('encryptionKey', ''); + const [isSettingPassword, setIsSettingPassword] = useState(false); + const [isDecrypting, setIsDecrypting] = useState(false); + const [isEncrypting, setIsEncrypting] = useState(false); + const { token } = useAuthContext(); + const localize = useLocalize(); + + const generateKeyFromPassword = async (password: string) => { + const encoder = new TextEncoder(); + const data = encoder.encode(password); + const hash = await crypto.subtle.digest('SHA-256', data); + return Array.from(new Uint8Array(hash)) + .map((b) => b.toString(16).padStart(2, '0')) + .join(''); + }; + + const handlePasswordSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + if (passwordInput.length >= 8) { + const key = await generateKeyFromPassword(passwordInput); + setEncryptionKey(key); + setPasswordInput(''); + setIsSettingPassword(false); + handleEncryptAll(key); // Always encrypt when setting password + } + }; + + const handlePasswordChange = (e: React.ChangeEvent) => { + setPasswordInput(e.target.value); + }; + + const handleEncryptAll = async (providedKey?: string) => { + try { + setIsEncrypting(true); + const keyToUse = providedKey || encryptionKey; + + const response = await fetch('/api/messages/encrypt', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + 'x-encryption-key': keyToUse, + 'x-encryption-enabled': 'true', + }, + }); + + if (!response.ok) { + const errorData = await response + .json() + .catch(() => ({ error: 'Failed to encrypt messages' })); + throw new Error(errorData.error || 'Failed to encrypt messages'); + } + + const result = await response.json(); + console.log('Encryption results:', result); + setIsEncryptionEnabled(true); + } catch (err: unknown) { + const error = err as Error; + console.error('Error encrypting messages:', error); + window.alert(error.message || 'Failed to encrypt messages'); + setIsEncryptionEnabled(false); + setEncryptionKey(''); + } finally { + setIsEncrypting(false); + } + }; + + const handleDecryptAll = async () => { + try { + setIsDecrypting(true); + const response = await fetch('/api/messages/decrypt', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + 'x-encryption-key': encryptionKey, + 'x-encryption-enabled': 'true', + }, + }); + + if (!response.ok) { + const errorData = await response + .json() + .catch(() => ({ error: 'Failed to decrypt messages' })); + throw new Error(errorData.error || 'Failed to decrypt messages'); + } + + const result = await response.json(); + console.log('Decryption results:', result); + + // Only disable encryption after successful decryption + setIsEncryptionEnabled(false); + setEncryptionKey(''); + setPasswordInput(''); + } catch (err: unknown) { + const error = err as Error; + console.error('Error decrypting messages:', error); + window.alert(error.message || 'Failed to decrypt messages'); + } finally { + setIsDecrypting(false); + } + }; + + const getToggleButtonText = () => { + if (isEncrypting) { + return 'Encrypting...'; + } + if (isEncryptionEnabled) { + return 'On'; + } + return 'Off'; + }; + + const toggleEncryption = () => { + const newState = !isEncryptionEnabled; + if (newState) { + if (!encryptionKey) { + setIsSettingPassword(true); + } else { + handleEncryptAll(); + setIsEncryptionEnabled(true); + } + } else if (encryptionKey) { + handleDecryptAll(); + } else { + setIsEncryptionEnabled(false); + setPasswordInput(''); + setEncryptionKey(''); + setIsSettingPassword(false); + } + }; + + const showPasswordInput = (isEncryptionEnabled && !encryptionKey) || isSettingPassword; + + return ( +
+
+ {localize('com_nav_encryption')} + + {showPasswordInput && ( +
+ + +
+ )} +
+ +
+ ); +}; + +export default EncryptionSettings; \ No newline at end of file diff --git a/client/src/hooks/Conversations/useSearch.ts b/client/src/hooks/Conversations/useSearch.ts index 753f9d9822b..1933b5bbf29 100644 --- a/client/src/hooks/Conversations/useSearch.ts +++ b/client/src/hooks/Conversations/useSearch.ts @@ -24,11 +24,17 @@ export default function useSearchMessages({ isAuthenticated }: { isAuthenticated const searchQuery = useRecoilValue(store.searchQuery); const setIsSearchEnabled = useSetRecoilState(store.isSearchEnabled); + const isEncryptionEnabled = useRecoilValue(store.isEncryptionEnabled); + + const searchEnabledQuery = useGetSearchEnabledQuery({ + enabled: isAuthenticated && !isEncryptionEnabled, + }); - const searchEnabledQuery = useGetSearchEnabledQuery({ enabled: isAuthenticated }); const searchQueryRes = useSearchInfiniteQuery( { pageNumber: pageNumber.toString(), searchQuery: searchQuery, isArchived: false }, - { enabled: isAuthenticated && !!searchQuery.length }, + { + enabled: isAuthenticated && !!searchQuery.length && !isEncryptionEnabled, + }, ) as UseInfiniteQueryResult | undefined; useEffect(() => { @@ -42,7 +48,7 @@ export default function useSearchMessages({ isAuthenticated }: { isAuthenticated } navigate('/c/new', { replace: true }); /* Disabled eslint rule because we don't want to run this effect when location changes */ - // eslint-disable-next-line react-hooks/exhaustive-deps + }, [navigate, searchQuery]); useEffect(() => { @@ -76,4 +82,4 @@ export default function useSearchMessages({ isAuthenticated }: { isAuthenticated setPageNumber, searchQueryRes, }; -} +} \ No newline at end of file diff --git a/client/src/hooks/SSE/useSSE.ts b/client/src/hooks/SSE/useSSE.ts index a52928caadc..0361dd6e963 100644 --- a/client/src/hooks/SSE/useSSE.ts +++ b/client/src/hooks/SSE/useSSE.ts @@ -16,6 +16,7 @@ import type { TResData } from '~/common'; import { useGenTitleMutation, useGetStartupConfig, useGetUserBalance } from '~/data-provider'; import { useAuthContext } from '~/hooks/AuthContext'; import useEventHandlers from './useEventHandlers'; +import { useEncryptionHeaders } from '~/hooks/useEncryptionHeaders'; import store from '~/store'; type ChatHelpers = Pick< @@ -41,6 +42,7 @@ export default function useSSE( const [completed, setCompleted] = useState(new Set()); const setAbortScroll = useSetRecoilState(store.abortScrollFamily(runIndex)); const setShowStopButton = useSetRecoilState(store.showStopButtonByIndex(runIndex)); + const encryptionHeaders = useEncryptionHeaders(); const { setMessages, @@ -96,8 +98,13 @@ export default function useSSE( const sse = new SSE(payloadData.server, { payload: JSON.stringify(payload), - headers: { 'Content-Type': 'application/json', Authorization: `Bearer ${token}` }, + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${token}`, + ...encryptionHeaders, + }, }); + console.log('SSE Headers:', sse.headers); sse.addEventListener('attachment', (e: MessageEvent) => { try { @@ -196,6 +203,7 @@ export default function useSSE( sse.headers = { 'Content-Type': 'application/json', Authorization: `Bearer ${token}`, + ...encryptionHeaders, }; request.dispatchTokenUpdatedEvent(token); @@ -234,6 +242,6 @@ export default function useSSE( sse.dispatchEvent(e); } }; - // eslint-disable-next-line react-hooks/exhaustive-deps + }, [submission]); -} +} \ No newline at end of file diff --git a/client/src/hooks/useEncryptionHeaders.ts b/client/src/hooks/useEncryptionHeaders.ts new file mode 100644 index 00000000000..935828f42ff --- /dev/null +++ b/client/src/hooks/useEncryptionHeaders.ts @@ -0,0 +1,9 @@ +import useLocalStorage from './useLocalStorage'; +export const useEncryptionHeaders = () => { + const [isEncryptionEnabled] = useLocalStorage('isEncryptionEnabled', false); + const [encryptionKey] = useLocalStorage('encryptionKey', ''); + return { + 'x-encryption-enabled': isEncryptionEnabled.toString(), + 'x-encryption-key': encryptionKey, + }; +}; \ No newline at end of file diff --git a/client/src/locales/en/translation.json b/client/src/locales/en/translation.json index fc6b1800d92..5476c6bb5a5 100644 --- a/client/src/locales/en/translation.json +++ b/client/src/locales/en/translation.json @@ -784,5 +784,7 @@ "com_ui_yes": "Yes", "com_ui_zoom": "Zoom", "com_user_message": "You", + "com_nav_encryption": "Encryption", + "com_nav_info_encryption": "Your messages will be encrypted using a key derived from your password. The password itself is never stored - only the derived key is used for encryption on backend side. IMPORTANT: If you forget your password, you will not be able to decrypt your messages as the encryption key cannot be recovered. Make sure to save your password in a secure location.", "com_warning_resubmit_unsupported": "Resubmitting the AI message is not supported for this endpoint." } \ No newline at end of file diff --git a/client/src/routes/Root.tsx b/client/src/routes/Root.tsx index a7d999ae459..dc808d6088f 100644 --- a/client/src/routes/Root.tsx +++ b/client/src/routes/Root.tsx @@ -1,5 +1,6 @@ import React, { useState, useEffect } from 'react'; import { Outlet, useNavigate } from 'react-router-dom'; +import { useSetRecoilState } from 'recoil'; import type { ContextType } from '~/common'; import { AgentsMapContext, @@ -8,11 +9,19 @@ import { SearchContext, SetConvoProvider, } from '~/Providers'; -import { useAuthContext, useAssistantsMap, useAgentsMap, useFileMap, useSearch } from '~/hooks'; +import { + useAuthContext, + useAssistantsMap, + useAgentsMap, + useFileMap, + useSearch, + useLocalStorage, +} from '~/hooks'; import TermsAndConditionsModal from '~/components/ui/TermsAndConditionsModal'; import { useUserTermsQuery, useGetStartupConfig } from '~/data-provider'; import { Nav, MobileNav } from '~/components/Nav'; import { Banner } from '~/components/Banners'; +import store from '~/store'; export default function Root() { const navigate = useNavigate(); @@ -24,6 +33,8 @@ export default function Root() { }); const { isAuthenticated, logout } = useAuthContext(); + const [isEncryptionEnabled] = useLocalStorage('isEncryptionEnabled', false); + const setIsEncryptionEnabledRecoil = useSetRecoilState(store.isEncryptionEnabled); const assistantsMap = useAssistantsMap({ isAuthenticated }); const agentsMap = useAgentsMap({ isAuthenticated }); const fileMap = useFileMap({ isAuthenticated }); @@ -34,6 +45,10 @@ export default function Root() { enabled: isAuthenticated && config?.interface?.termsOfService?.modalAcceptance === true, }); + useEffect(() => { + setIsEncryptionEnabledRecoil(isEncryptionEnabled); + }, [isEncryptionEnabled, setIsEncryptionEnabledRecoil]); + useEffect(() => { if (termsData) { setShowTerms(!termsData.termsAccepted); @@ -86,4 +101,4 @@ export default function Root() { ); -} +} \ No newline at end of file diff --git a/client/src/routes/Search.tsx b/client/src/routes/Search.tsx index 5d944a6fe35..5b5f4274204 100644 --- a/client/src/routes/Search.tsx +++ b/client/src/routes/Search.tsx @@ -1,27 +1,48 @@ import { useMemo } from 'react'; +import { useRecoilValue } from 'recoil'; import MinimalMessagesWrapper from '~/components/Chat/Messages/MinimalMessages'; import SearchMessage from '~/components/Chat/Messages/SearchMessage'; import { useSearchContext, useFileMapContext } from '~/Providers'; import { useNavScrolling, useLocalize } from '~/hooks'; import { buildTree } from '~/utils'; +import store from '~/store'; export default function Search() { const localize = useLocalize(); const fileMap = useFileMapContext(); const { searchQuery, searchQueryRes } = useSearchContext(); + const isEncryptionEnabled = useRecoilValue(store.isEncryptionEnabled); const { containerRef } = useNavScrolling({ setShowLoading: () => ({}), - hasNextPage: searchQueryRes?.hasNextPage, + hasNextPage: searchQueryRes?.hasNextPage ?? false, fetchNextPage: searchQueryRes?.fetchNextPage, isFetchingNextPage: searchQueryRes?.isFetchingNextPage ?? false, }); const messages = useMemo(() => { + if (isEncryptionEnabled) { + return null; + } const msgs = searchQueryRes?.data?.pages.flatMap((page) => page.messages) || []; const dataTree = buildTree({ messages: msgs, fileMap }); - return dataTree?.length === 0 ? null : dataTree ?? null; - }, [fileMap, searchQueryRes?.data?.pages]); + return dataTree?.length === 0 ? null : (dataTree ?? null); + }, [fileMap, searchQueryRes?.data?.pages, isEncryptionEnabled]); + + if (isEncryptionEnabled) { + return ( +
+
+

+ Search Unavailable +

+

+ Search is disabled when encryption is enabled +

+
+
+ ); + } if (!searchQuery || !searchQueryRes?.data) { return null; @@ -29,7 +50,7 @@ export default function Search() { return ( - {(messages && messages.length == 0) || messages == null ? ( + {(messages && messages.length === 0) || messages == null ? (
{localize('com_ui_nothing_found')}
@@ -39,4 +60,4 @@ export default function Search() {
); -} +} \ No newline at end of file diff --git a/client/src/store/search.ts b/client/src/store/search.ts index 77c4d9be4f7..87335a8a04e 100644 --- a/client/src/store/search.ts +++ b/client/src/store/search.ts @@ -1,5 +1,10 @@ import { atom } from 'recoil'; +const isEncryptionEnabled = atom({ + key: 'isEncryptionEnabled', + default: false, +}); + const isSearchEnabled = atom({ key: 'isSearchEnabled', default: null, @@ -16,6 +21,7 @@ const isSearching = atom({ }); export default { + isEncryptionEnabled, isSearchEnabled, searchQuery, isSearching, diff --git a/client/src/utils/encryptAES.ts b/client/src/utils/encryptAES.ts new file mode 100644 index 00000000000..b8272347595 --- /dev/null +++ b/client/src/utils/encryptAES.ts @@ -0,0 +1,141 @@ +/** + * Encrypts a plaintext string with AES-GCM using a 256-bit key from `keyHex`. + * + * @param plaintext - The raw text to encrypt + * @param keyHex - A 64-char hex string representing a 256-bit key + * @returns A base64 string of the form: `${ivBase64}:${cipherBase64}` + * @throws Error if encryption fails + */ +export async function encryptLocallyAESGCM( + plaintext: string, + keyHex: string +): Promise { + try { + // 1. Convert the hex key into ArrayBuffer + const keyBuffer = hexToArrayBuffer(keyHex); + + // 2. Import the key for AES-GCM + const cryptoKey = await crypto.subtle.importKey( + 'raw', + keyBuffer, + { name: 'AES-GCM' }, + false, // not extractable + ['encrypt'] + ); + + // 3. Encode plaintext to bytes + const encoder = new TextEncoder(); + const plaintextBytes = encoder.encode(plaintext); + + // 4. Generate a random 12-byte IV + const iv = crypto.getRandomValues(new Uint8Array(12)); + + // 5. Encrypt using AES-GCM + const cipherBuffer = await crypto.subtle.encrypt( + { name: 'AES-GCM', iv }, + cryptoKey, + plaintextBytes + ); + + // 6. Convert IV and ciphertext to base64 + const ivBase64 = arrayBufferToBase64(iv); + const cipherBase64 = arrayBufferToBase64(cipherBuffer); + + // 7. Return them as a single string e.g. "ivBase64:cipherBase64" + return `${ivBase64}:${cipherBase64}`; + } catch (error) { + console.error('Encryption error:', error); + throw new Error('Failed to encrypt data.'); + } +} + +/** + * Decrypts an AES-GCM ciphertext produced by `encryptLocallyAESGCM`. + * + * @param cipherString - The string from encryption, in form "ivBase64:cipherBase64" + * @param keyHex - The same 64-char hex key used for encryption + * @returns The original plaintext string + * @throws Error if decryption fails + */ +export async function decryptLocallyAESGCM( + cipherString: string, + keyHex: string +): Promise { + try { + // 1. Split into ivBase64 and cipherBase64 + const [ivBase64, cipherBase64] = cipherString.split(':'); + if (!ivBase64 || !cipherBase64) { + throw new Error('Invalid cipher string format. Expected "ivBase64:cipherBase64".'); + } + + // 2. Convert hex key to ArrayBuffer + const keyBuffer = hexToArrayBuffer(keyHex); + + // 3. Import the key + const cryptoKey = await crypto.subtle.importKey( + 'raw', + keyBuffer, + { name: 'AES-GCM' }, + false, + ['decrypt'] + ); + + // 4. Convert base64 -> ArrayBuffer + const iv = new Uint8Array(base64ToArrayBuffer(ivBase64)); + const cipherBytes = base64ToArrayBuffer(cipherBase64); + + // 5. Decrypt + const plainBuffer = await crypto.subtle.decrypt( + { name: 'AES-GCM', iv }, + cryptoKey, + cipherBytes + ); + + // 6. Convert bytes -> string + const decoder = new TextDecoder(); + const plaintext = decoder.decode(plainBuffer); + return plaintext; + } catch (error) { + console.error('Decryption error:', error); + throw new Error('Failed to decrypt data.'); + } +} + +/** + * Convert a base64 string to an ArrayBuffer + */ +function base64ToArrayBuffer(base64: string): ArrayBuffer { + const binaryString = window.atob(base64); + const len = binaryString.length; + const bytes = new Uint8Array(len); + for (let i = 0; i < len; i++) { + bytes[i] = binaryString.charCodeAt(i); + } + return bytes.buffer; +} + +/** + * Convert an ArrayBuffer to a base64 string + */ +function arrayBufferToBase64(buffer: ArrayBuffer): string { + const bytes = new Uint8Array(buffer); + let binary = ''; + for (let i = 0; i < bytes.byteLength; i++) { + binary += String.fromCharCode(bytes[i]); + } + return window.btoa(binary); +} + +/** + * Convert a hex string to an ArrayBuffer + */ +function hexToArrayBuffer(hexString: string): ArrayBuffer { + if (hexString.length % 2 !== 0) { + throw new Error('Invalid hex string'); + } + const bytes = new Uint8Array(hexString.length / 2); + for (let i = 0; i < hexString.length; i += 2) { + bytes[i / 2] = parseInt(hexString.substring(i, i + 2), 16); + } + return bytes.buffer; +} \ No newline at end of file diff --git a/packages/data-provider/src/request.ts b/packages/data-provider/src/request.ts index 740e9cbe6c8..4527919f979 100644 --- a/packages/data-provider/src/request.ts +++ b/packages/data-provider/src/request.ts @@ -4,6 +4,30 @@ import * as endpoints from './api-endpoints'; import { setTokenHeader } from './headers-helpers'; import type * as t from './types'; +const getEncryptionHeaders = (): Record => { + if (typeof localStorage === 'undefined') { + return {}; + } + const isEncryptionEnabled = localStorage.getItem('isEncryptionEnabled') === 'true'; + const encryptionKey = localStorage.getItem('encryptionKey'); + if (isEncryptionEnabled && encryptionKey) { + return { + 'x-encryption-enabled': 'true', + 'x-encryption-key': encryptionKey, + }; + } + return {}; +}; +axios.interceptors.request.use((config) => { + const headers = config.headers ?? axios.defaults.headers.common; + const encryptionHeaders = getEncryptionHeaders(); + config.headers = new axios.AxiosHeaders({ + ...headers, + ...encryptionHeaders, + }); + return config; +}); + async function _get(url: string, options?: AxiosRequestConfig): Promise { const response = await axios.get(url, { ...options }); return response.data;