Skip to content

Commit

Permalink
Merge pull request #474 from martindevans/embeddings_generator_decode
Browse files Browse the repository at this point in the history
Swapped `GetEmbeddings` to `llama_decode`
  • Loading branch information
martindevans authored Jan 31, 2024
2 parents 3b08874 + c9c8cd0 commit 3523c51
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 53 deletions.
19 changes: 6 additions & 13 deletions LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
using LLama;
using LLama.Abstractions;
using LLama.Common;
using Microsoft.KernelMemory;
using Microsoft.KernelMemory.AI;
using Microsoft.SemanticKernel.AI.Embeddings;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace LLamaSharp.KernelMemory
{
Expand Down Expand Up @@ -80,24 +73,24 @@ public void Dispose()
}

/// <inheritdoc/>
public Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, CancellationToken cancellationToken = default)
public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, CancellationToken cancellationToken = default)
{
IList<ReadOnlyMemory<float>> results = new List<ReadOnlyMemory<float>>();

foreach (var d in data)
{
var embeddings = _embedder.GetEmbeddings(d);
var embeddings = await _embedder.GetEmbeddings(d, cancellationToken);
results.Add(new ReadOnlyMemory<float>(embeddings));
}

return Task.FromResult(results);
return results;
}

/// <inheritdoc/>
public Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
public async Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
{
var embeddings = _embedder.GetEmbeddings(text);
return Task.FromResult(new Embedding(embeddings));
var embeddings = await _embedder.GetEmbeddings(text, cancellationToken);
return new Embedding(embeddings);
}

/// <inheritdoc/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace LLamaSharp.SemanticKernel.TextEmbedding;

public sealed class LLamaSharpEmbeddingGeneration : ITextEmbeddingGenerationService
{
private LLamaEmbedder _embedder;
private readonly LLamaEmbedder _embedder;

private readonly Dictionary<string, object?> _attributes = new();

Expand All @@ -20,7 +20,11 @@ public LLamaSharpEmbeddingGeneration(LLamaEmbedder embedder)
/// <inheritdoc/>
public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
var embeddings = data.Select(text => new ReadOnlyMemory<float>(_embedder.GetEmbeddings(text))).ToList();
return await Task.FromResult(embeddings);
var result = new List<ReadOnlyMemory<float>>();

foreach (var item in data)
result.Add(await _embedder.GetEmbeddings(item, cancellationToken));

return result;
}
}
25 changes: 15 additions & 10 deletions LLama.Unittest/LLamaEmbedderTests.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
using LLama.Common;
using Xunit.Abstractions;

namespace LLama.Unittest;

public sealed class LLamaEmbedderTests
: IDisposable
{
private readonly ITestOutputHelper _testOutputHelper;
private readonly LLamaEmbedder _embedder;

public LLamaEmbedderTests()
public LLamaEmbedderTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
var @params = new ModelParams(Constants.ModelPath)
{
EmbeddingMode = true,
Expand Down Expand Up @@ -41,21 +44,23 @@ private static float Dot(float[] a, float[] b)
}

[Fact]
public void EmbedCompare()
public async Task EmbedCompare()
{
var cat = _embedder.GetEmbeddings("cat");
var kitten = _embedder.GetEmbeddings("kitten");
var spoon = _embedder.GetEmbeddings("spoon");
var cat = await _embedder.GetEmbeddings("cat");
var kitten = await _embedder.GetEmbeddings("kitten");
var spoon = await _embedder.GetEmbeddings("spoon");

Normalize(cat);
Normalize(kitten);
Normalize(spoon);

var close = Dot(cat, kitten);
var far = Dot(cat, spoon);
var close = 1 - Dot(cat, kitten);
var far = 1 - Dot(cat, spoon);

// This comparison seems backwards, but remember that with a
// dot product 1.0 means **identical** and 0.0 means **completely opposite**!
Assert.True(close > far);
Assert.True(close < far);

_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");
}
}
67 changes: 40 additions & 27 deletions LLama/LLamaEmbedder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
using LLama.Exceptions;
using LLama.Abstractions;
using Microsoft.Extensions.Logging;
using System.Threading;
using System.Threading.Tasks;

namespace LLama
{
Expand Down Expand Up @@ -40,50 +42,61 @@ public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logg
/// Get the embeddings of the text.
/// </summary>
/// <param name="text"></param>
/// <param name="threads">unused</param>
/// <param name="addBos">Add bos to the text.</param>
/// <param name="encoding">unused</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("'threads' and 'encoding' parameters are no longer used")]
// ReSharper disable once MethodOverloadWithOptionalParameter
public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
public Task<float[]> GetEmbeddings(string text, CancellationToken cancellationToken = default)
{
return GetEmbeddings(text, addBos);
}

/// <summary>
/// Get the embeddings of the text.
/// </summary>
/// <param name="text"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public float[] GetEmbeddings(string text)
{
return GetEmbeddings(text, true);
return GetEmbeddings(text, true, cancellationToken);
}

/// <summary>
/// Get the embeddings of the text.
/// </summary>
/// <param name="text"></param>
/// <param name="addBos">Add bos to the text.</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public float[] GetEmbeddings(string text, bool addBos)
public async Task<float[]> GetEmbeddings(string text, bool addBos, CancellationToken cancellationToken = default)
{
var embed_inp_array = Context.Tokenize(text, addBos);
var tokens = Context.Tokenize(text, addBos);
if (tokens.Length > Context.ContextSize)
throw new ArgumentException($"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", nameof(text));

// Evaluate prompt in batch-size chunks
var n_past = 0;
var batch = new LLamaBatch();
var batchSize = (int)Context.Params.BatchSize;
for (var i = 0; i < tokens.Length; i += batchSize)
{
var n_eval = tokens.Length - i;
if (n_eval > batchSize)
n_eval = batchSize;

batch.Clear();
for (var j = 0; j < n_eval; j++)
batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, false);

var returnCode = await Context.DecodeAsync(batch, cancellationToken);
if (returnCode != 0)
throw new LLamaDecodeError(returnCode);
}

// TODO(Rinne): deal with log of prompt
var embeddings = GetEmbeddingsArray();

if (embed_inp_array.Length > 0)
Context.Eval(embed_inp_array.AsSpan(), 0);
// Remove everything we just evaluated from the context cache
Context.NativeHandle.KvCacheClear();

var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
return Array.Empty<float>();
return embeddings;

return embeddings.ToArray();
float[] GetEmbeddingsArray()
{
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
return Array.Empty<float>();
return embeddings.ToArray();
}
}

/// <summary>
Expand Down

0 comments on commit 3523c51

Please sign in to comment.