diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 2c075fdf72a..368e7d6e84b 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -1066,9 +1066,14 @@ ${convo} }); } - getStreamText() { + /** + * + * @param {string[]} [intermediateReply] + * @returns {string} + */ + getStreamText(intermediateReply) { if (!this.streamHandler) { - return ''; + return intermediateReply?.join('') ?? ''; } let thinkMatch; @@ -1088,7 +1093,10 @@ ${convo} } } - const reasoningTokens = reasoningText.length > 0 ? `:::thinking\n${reasoningText}\n:::\n` : ''; + const reasoningTokens = + reasoningText.length > 0 + ? `:::thinking\n${reasoningText.replace('', '').replace('', '').trim()}\n:::\n` + : ''; return `${reasoningTokens}${this.streamHandler.tokens.join('')}`; } @@ -1327,11 +1335,19 @@ ${convo} streamPromise = new Promise((resolve) => { streamResolve = resolve; }); + /** @type {OpenAI.OpenAI.CompletionCreateParamsStreaming} */ + const params = { + ...modelOptions, + stream: true, + }; + if ( + this.options.endpoint === EModelEndpoint.openAI || + this.options.endpoint === EModelEndpoint.azureOpenAI + ) { + params.stream_options = { include_usage: true }; + } const stream = await openai.beta.chat.completions - .stream({ - ...modelOptions, - stream: true, - }) + .stream(params) .on('abort', () => { /* Do nothing here */ }) @@ -1471,7 +1487,7 @@ ${convo} err?.message?.includes('abort') || (err instanceof OpenAI.APIError && err?.message?.includes('abort')) ) { - return intermediateReply.join(''); + return this.getStreamText(intermediateReply); } if ( err?.message?.includes( @@ -1489,7 +1505,7 @@ ${convo} if (this.streamHandler && this.streamHandler.reasoningTokens.length) { return this.getStreamText(); } else if (intermediateReply.length > 0) { - return intermediateReply.join(''); + return this.getStreamText(intermediateReply); } else { throw err; } @@ -1497,7 +1513,7 @@ ${convo} if (this.streamHandler && this.streamHandler.reasoningTokens.length) { return this.getStreamText(); } else if (intermediateReply.length > 0) { - return intermediateReply.join(''); + return this.getStreamText(intermediateReply); } else { throw err; }