From c9c8cd0d626f80de41495380494e7e6ce87f73e1 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Wed, 31 Jan 2024 20:28:53 +0000 Subject: [PATCH] - Swapped embeddings generator to use `llama_decode` - Modified `GetEmbeddings` method to be async --- .../LLamaSharpTextEmbeddingGenerator.cs | 19 ++---- .../LLamaSharpEmbeddingGeneration.cs | 10 ++- LLama.Unittest/LLamaEmbedderTests.cs | 25 ++++--- LLama/LLamaEmbedder.cs | 67 +++++++++++-------- 4 files changed, 68 insertions(+), 53 deletions(-) diff --git a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs index 8148adc88..7806282de 100644 --- a/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs +++ b/LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs @@ -1,14 +1,7 @@ using LLama; -using LLama.Abstractions; using LLama.Common; using Microsoft.KernelMemory; using Microsoft.KernelMemory.AI; -using Microsoft.SemanticKernel.AI.Embeddings; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; namespace LLamaSharp.KernelMemory { @@ -80,24 +73,24 @@ public void Dispose() } /// - public Task>> GenerateEmbeddingsAsync(IList data, CancellationToken cancellationToken = default) + public async Task>> GenerateEmbeddingsAsync(IList data, CancellationToken cancellationToken = default) { IList> results = new List>(); foreach (var d in data) { - var embeddings = _embedder.GetEmbeddings(d); + var embeddings = await _embedder.GetEmbeddings(d, cancellationToken); results.Add(new ReadOnlyMemory(embeddings)); } - return Task.FromResult(results); + return results; } /// - public Task GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default) + public async Task GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default) { - var embeddings = _embedder.GetEmbeddings(text); - return Task.FromResult(new Embedding(embeddings)); + var embeddings = await _embedder.GetEmbeddings(text, cancellationToken); + return new Embedding(embeddings); } /// diff --git a/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs b/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs index 73ceb0f21..6889ba6a7 100644 --- a/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs +++ b/LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs @@ -6,7 +6,7 @@ namespace LLamaSharp.SemanticKernel.TextEmbedding; public sealed class LLamaSharpEmbeddingGeneration : ITextEmbeddingGenerationService { - private LLamaEmbedder _embedder; + private readonly LLamaEmbedder _embedder; private readonly Dictionary _attributes = new(); @@ -20,7 +20,11 @@ public LLamaSharpEmbeddingGeneration(LLamaEmbedder embedder) /// public async Task>> GenerateEmbeddingsAsync(IList data, Kernel? kernel = null, CancellationToken cancellationToken = default) { - var embeddings = data.Select(text => new ReadOnlyMemory(_embedder.GetEmbeddings(text))).ToList(); - return await Task.FromResult(embeddings); + var result = new List>(); + + foreach (var item in data) + result.Add(await _embedder.GetEmbeddings(item, cancellationToken)); + + return result; } } diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs index 4c8fb37fa..052690480 100644 --- a/LLama.Unittest/LLamaEmbedderTests.cs +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -1,14 +1,17 @@ using LLama.Common; +using Xunit.Abstractions; namespace LLama.Unittest; public sealed class LLamaEmbedderTests : IDisposable { + private readonly ITestOutputHelper _testOutputHelper; private readonly LLamaEmbedder _embedder; - public LLamaEmbedderTests() + public LLamaEmbedderTests(ITestOutputHelper testOutputHelper) { + _testOutputHelper = testOutputHelper; var @params = new ModelParams(Constants.ModelPath) { EmbeddingMode = true, @@ -41,21 +44,23 @@ private static float Dot(float[] a, float[] b) } [Fact] - public void EmbedCompare() + public async Task EmbedCompare() { - var cat = _embedder.GetEmbeddings("cat"); - var kitten = _embedder.GetEmbeddings("kitten"); - var spoon = _embedder.GetEmbeddings("spoon"); + var cat = await _embedder.GetEmbeddings("cat"); + var kitten = await _embedder.GetEmbeddings("kitten"); + var spoon = await _embedder.GetEmbeddings("spoon"); Normalize(cat); Normalize(kitten); Normalize(spoon); - var close = Dot(cat, kitten); - var far = Dot(cat, spoon); + var close = 1 - Dot(cat, kitten); + var far = 1 - Dot(cat, spoon); - // This comparison seems backwards, but remember that with a - // dot product 1.0 means **identical** and 0.0 means **completely opposite**! - Assert.True(close > far); + Assert.True(close < far); + + _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); + _testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]"); + _testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]"); } } \ No newline at end of file diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index 8dfc4aaba..c375d2e93 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -3,6 +3,8 @@ using LLama.Exceptions; using LLama.Abstractions; using Microsoft.Extensions.Logging; +using System.Threading; +using System.Threading.Tasks; namespace LLama { @@ -40,27 +42,12 @@ public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logg /// Get the embeddings of the text. /// /// - /// unused - /// Add bos to the text. - /// unused + /// /// /// - [Obsolete("'threads' and 'encoding' parameters are no longer used")] - // ReSharper disable once MethodOverloadWithOptionalParameter - public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8") + public Task GetEmbeddings(string text, CancellationToken cancellationToken = default) { - return GetEmbeddings(text, addBos); - } - - /// - /// Get the embeddings of the text. - /// - /// - /// - /// - public float[] GetEmbeddings(string text) - { - return GetEmbeddings(text, true); + return GetEmbeddings(text, true, cancellationToken); } /// @@ -68,22 +55,48 @@ public float[] GetEmbeddings(string text) /// /// /// Add bos to the text. + /// /// /// - public float[] GetEmbeddings(string text, bool addBos) + public async Task GetEmbeddings(string text, bool addBos, CancellationToken cancellationToken = default) { - var embed_inp_array = Context.Tokenize(text, addBos); + var tokens = Context.Tokenize(text, addBos); + if (tokens.Length > Context.ContextSize) + throw new ArgumentException($"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", nameof(text)); + + // Evaluate prompt in batch-size chunks + var n_past = 0; + var batch = new LLamaBatch(); + var batchSize = (int)Context.Params.BatchSize; + for (var i = 0; i < tokens.Length; i += batchSize) + { + var n_eval = tokens.Length - i; + if (n_eval > batchSize) + n_eval = batchSize; + + batch.Clear(); + for (var j = 0; j < n_eval; j++) + batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, false); + + var returnCode = await Context.DecodeAsync(batch, cancellationToken); + if (returnCode != 0) + throw new LLamaDecodeError(returnCode); + } - // TODO(Rinne): deal with log of prompt + var embeddings = GetEmbeddingsArray(); - if (embed_inp_array.Length > 0) - Context.Eval(embed_inp_array.AsSpan(), 0); + // Remove everything we just evaluated from the context cache + Context.NativeHandle.KvCacheClear(); - var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); - if (embeddings == null) - return Array.Empty(); + return embeddings; - return embeddings.ToArray(); + float[] GetEmbeddingsArray() + { + var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); + if (embeddings == null) + return Array.Empty(); + return embeddings.ToArray(); + } } ///