Skip to content

Commit

Permalink
Use ToChatCompletion / ToStreamingChatCompletionUpdates in CachingCha…
Browse files Browse the repository at this point in the history
…tClient (#5616)

* Use ToChatCompletion / ToStreamingChatCompletionUpdates in CachingChatClient

Adds a ToStreamingChatCompletionUpdates method that's the counterpart to the recently added ToChatCompletion.

Then uses both from CachingChatClient instead of its now bespoke coalescing implementation. When coalescing is enabled (the default), CachingChatClient caches everything as a ChatCompletion, rather than distinguishing streaming and non-streaming.

* Address PR feedback
  • Loading branch information
stephentoub authored Nov 11, 2024
1 parent c163960 commit 148e221
Show file tree
Hide file tree
Showing 7 changed files with 281 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,53 @@ public ChatMessage Message
/// <inheritdoc />
public override string ToString() =>
Choices is { Count: > 0 } choices ? string.Join(Environment.NewLine, choices) : string.Empty;

/// <summary>Creates an array of <see cref="StreamingChatCompletionUpdate" /> instances that represent this <see cref="ChatCompletion" />.</summary>
/// <returns>An array of <see cref="StreamingChatCompletionUpdate" /> instances that may be used to represent this <see cref="ChatCompletion" />.</returns>
public StreamingChatCompletionUpdate[] ToStreamingChatCompletionUpdates()
{
StreamingChatCompletionUpdate? extra = null;
if (AdditionalProperties is not null || Usage is not null)
{
extra = new StreamingChatCompletionUpdate
{
AdditionalProperties = AdditionalProperties
};

if (Usage is { } usage)
{
extra.Contents.Add(new UsageContent(usage));
}
}

int choicesCount = Choices.Count;
var updates = new StreamingChatCompletionUpdate[choicesCount + (extra is null ? 0 : 1)];

for (int choiceIndex = 0; choiceIndex < choicesCount; choiceIndex++)
{
ChatMessage choice = Choices[choiceIndex];
updates[choiceIndex] = new StreamingChatCompletionUpdate
{
ChoiceIndex = choiceIndex,

AdditionalProperties = choice.AdditionalProperties,
AuthorName = choice.AuthorName,
Contents = choice.Contents,
RawRepresentation = choice.RawRepresentation,
Role = choice.Role,

CompletionId = CompletionId,
CreatedAt = CreatedAt,
FinishReason = FinishReason,
ModelId = ModelId
};
}

if (extra is not null)
{
updates[choicesCount] = extra;
}

return updates;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,35 @@

namespace Microsoft.Extensions.AI;

// Conceptually this combines the roles of ChatCompletion and ChatMessage in streaming output.
// For ease of consumption, it also flattens the nested structure you see on streaming chunks in
// the OpenAI/Gemini APIs, so instead of a dictionary of choices, each update represents a single
// choice (and hence has its own role, choice ID, etc.).

/// <summary>
/// Represents a single response chunk from an <see cref="IChatClient"/>.
/// Represents a single streaming response chunk from an <see cref="IChatClient"/>.
/// </summary>
/// <remarks>
/// <para>
/// Conceptually, this combines the roles of <see cref="ChatCompletion"/> and <see cref="ChatMessage"/>
/// in streaming output. For ease of consumption, it also flattens the nested structure you see on
/// streaming chunks in some AI service, so instead of a dictionary of choices, each update represents a
/// single choice (and hence has its own role, choice ID, etc.).
/// </para>
/// <para>
/// <see cref="StreamingChatCompletionUpdate"/> is so named because it represents streaming updates
/// to a single chat completion. As such, it is considered erroneous for multiple updates that are part
/// of the same completion to contain competing values. For example, some updates that are part of
/// the same completion may have a <see langword="null"/> <see cref="StreamingChatCompletionUpdate.Role"/>
/// value, and others may have a non-<see langword="null"/> value, but all of those with a non-<see langword="null"/>
/// value must have the same value (e.g. <see cref="ChatRole.Assistant"/>. It should never be the case, for example,
/// that one <see cref="StreamingChatCompletionUpdate"/> in a completion has a role of <see cref="ChatRole.Assistant"/>
/// while another has a role of "AI".
/// </para>
/// <para>
/// The relationship between <see cref="ChatCompletion"/> and <see cref="StreamingChatCompletionUpdate"/> is
/// codified in the <see cref="StreamingChatCompletionUpdateExtensions.ToChatCompletionAsync"/> and
/// <see cref="ChatCompletion.ToStreamingChatCompletionUpdates"/>, which enable bidirectional conversions
/// between the two. Note, however, that the conversion may be slightly lossy, for example if multiple updates
/// all have different <see cref="StreamingChatCompletionUpdate.RawRepresentation"/> objects whereas there's
/// only one slot for such an object available in <see cref="ChatCompletion.RawRepresentation"/>.
/// </para>
/// </remarks>
public class StreamingChatCompletionUpdate
{
/// <summary>The completion update content items.</summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Linq;
#if NET
using System.Runtime.InteropServices;
#endif
Expand Down Expand Up @@ -133,7 +134,22 @@ private static void ProcessUpdate(StreamingChatCompletionUpdate update, Dictiona
/// <param name="coalesceContent">The corresponding option value provided to <see cref="ToChatCompletion"/> or <see cref="ToChatCompletionAsync"/>.</param>
private static void AddMessagesToCompletion(Dictionary<int, ChatMessage> messages, ChatCompletion completion, bool coalesceContent)
{
foreach (var entry in messages)
if (messages.Count <= 1)
{
foreach (var entry in messages)
{
AddMessage(completion, coalesceContent, entry);
}
}
else
{
foreach (var entry in messages.OrderBy(entry => entry.Key))
{
AddMessage(completion, coalesceContent, entry);
}
}

static void AddMessage(ChatCompletion completion, bool coalesceContent, KeyValuePair<int, ChatMessage> entry)
{
if (entry.Value.Role == default)
{
Expand All @@ -154,6 +170,8 @@ private static void AddMessagesToCompletion(Dictionary<int, ChatMessage> message
if (content is UsageContent c)
{
completion.Usage = c.Details;
entry.Value.Contents = entry.Value.Contents.ToList();
_ = entry.Value.Contents.Remove(c);
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Shared.Diagnostics;
Expand Down Expand Up @@ -48,13 +47,12 @@ public override async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chat
// concurrent callers might trigger duplicate requests, but that's acceptable.
var cacheKey = GetCacheKey(false, chatMessages, options);

if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is ChatCompletion existing)
if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is not { } result)
{
return existing;
result = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false);
await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false);
}

var result = await base.CompleteAsync(chatMessages, options, cancellationToken).ConfigureAwait(false);
await WriteCacheAsync(cacheKey, result, cancellationToken).ConfigureAwait(false);
return result;
}

Expand All @@ -64,127 +62,59 @@ public override async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteSt
{
_ = Throw.IfNull(chatMessages);

var cacheKey = GetCacheKey(true, chatMessages, options);
if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks)
if (CoalesceStreamingUpdates)
{
// Yield all of the cached items.
foreach (var chunk in existingChunks)
// When coalescing updates, we cache non-streaming results coalesced from streaming ones. That means
// we make a streaming request, yielding those results, but then convert those into a non-streaming
// result and cache it. When we get a cache hit, we yield the non-streaming result as a streaming one.

var cacheKey = GetCacheKey(true, chatMessages, options);
if (await ReadCacheAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } chatCompletion)
{
yield return chunk;
// Yield all of the cached items.
foreach (var chunk in chatCompletion.ToStreamingChatCompletionUpdates())
{
yield return chunk;
}
}
else
{
// Yield and store all of the items.
List<StreamingChatCompletionUpdate> capturedItems = [];
await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false))
{
capturedItems.Add(chunk);
yield return chunk;
}

// Write the captured items to the cache as a non-streaming result.
await WriteCacheAsync(cacheKey, capturedItems.ToChatCompletion(), cancellationToken).ConfigureAwait(false);
}
}
else
{
// Yield and store all of the items.
List<StreamingChatCompletionUpdate> capturedItems = [];
await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false))
var cacheKey = GetCacheKey(true, chatMessages, options);
if (await ReadCacheStreamingAsync(cacheKey, cancellationToken).ConfigureAwait(false) is { } existingChunks)
{
capturedItems.Add(chunk);
yield return chunk;
// Yield all of the cached items.
foreach (var chunk in existingChunks)
{
yield return chunk;
}
}

// If the caching client is configured to coalesce streaming updates, do so now within the capturedItems list.
if (CoalesceStreamingUpdates)
else
{
StringBuilder coalescedText = new();

// Iterate through all of the items in the list looking for contiguous items that can be coalesced.
for (int startInclusive = 0; startInclusive < capturedItems.Count; startInclusive++)
// Yield and store all of the items.
List<StreamingChatCompletionUpdate> capturedItems = [];
await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false))
{
// If an item isn't generally coalescable, skip it.
StreamingChatCompletionUpdate update = capturedItems[startInclusive];
if (update.ChoiceIndex != 0 ||
update.Contents.Count != 1 ||
update.Contents[0] is not TextContent textContent)
{
continue;
}

// We found a coalescable item. Look for more contiguous items that are also coalescable with it.
int endExclusive = startInclusive + 1;
for (; endExclusive < capturedItems.Count; endExclusive++)
{
StreamingChatCompletionUpdate next = capturedItems[endExclusive];
if (next.ChoiceIndex != 0 ||
next.Contents.Count != 1 ||
next.Contents[0] is not TextContent ||

// changing role or author would be really strange, but check anyway
(update.Role is not null && next.Role is not null && update.Role != next.Role) ||
(update.AuthorName is not null && next.AuthorName is not null && update.AuthorName != next.AuthorName))
{
break;
}
}

// If we couldn't find anything to coalesce, there's nothing to do.
if (endExclusive - startInclusive <= 1)
{
continue;
}

// We found a coalescable run of items. Create a new node to represent the run. We create a new one
// rather than reappropriating one of the existing ones so as not to mutate an item already yielded.
_ = coalescedText.Clear().Append(capturedItems[startInclusive].Text);

TextContent coalescedContent = new(null) // will patch the text after examining all items in the run
{
AdditionalProperties = textContent.AdditionalProperties?.Clone(),
};

StreamingChatCompletionUpdate coalesced = new()
{
AdditionalProperties = update.AdditionalProperties?.Clone(),
AuthorName = update.AuthorName,
CompletionId = update.CompletionId,
Contents = [coalescedContent],
CreatedAt = update.CreatedAt,
FinishReason = update.FinishReason,
ModelId = update.ModelId,
Role = update.Role,

// Explicitly don't include RawRepresentation. It's not applicable if one update ends up being used
// to represent multiple, and it won't be serialized anyway.
};

// Replace the starting node with the coalesced node.
capturedItems[startInclusive] = coalesced;

// Now iterate through all the rest of the updates in the run, updating the coalesced node with relevant properties,
// and nulling out the nodes along the way. We do this rather than removing the entry in order to avoid an O(N^2) operation.
// We'll remove all the null entries at the end of the loop, using RemoveAll to do so, which can remove all of
// the nulls in a single O(N) pass.
for (int i = startInclusive + 1; i < endExclusive; i++)
{
// Grab the next item.
StreamingChatCompletionUpdate next = capturedItems[i];
capturedItems[i] = null!;

var nextContent = (TextContent)next.Contents[0];
_ = coalescedText.Append(nextContent.Text);

coalesced.AuthorName ??= next.AuthorName;
coalesced.CompletionId ??= next.CompletionId;
coalesced.CreatedAt ??= next.CreatedAt;
coalesced.FinishReason ??= next.FinishReason;
coalesced.ModelId ??= next.ModelId;
coalesced.Role ??= next.Role;
}

// Complete the coalescing by patching the text of the coalesced node.
coalesced.Text = coalescedText.ToString();

// Jump to the last update in the run, so that when we loop around and bump ahead,
// we're at the next update just after the run.
startInclusive = endExclusive - 1;
capturedItems.Add(chunk);
yield return chunk;
}

// Remove all of the null slots left over from the coalescing process.
_ = capturedItems.RemoveAll(u => u is null);
// Write the captured items to the cache.
await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken).ConfigureAwait(false);
}

// Write the captured items to the cache.
await WriteCacheStreamingAsync(cacheKey, capturedItems, cancellationToken).ConfigureAwait(false);
}
}

Expand Down
Loading

0 comments on commit 148e221

Please sign in to comment.