Skip to content

Commit

Permalink
Merge pull request #961 from martindevans/experimental_custom_sampler…
Browse files Browse the repository at this point in the history
…_wip

Custom Sampler Stages
  • Loading branch information
martindevans authored Oct 29, 2024
2 parents 40ea046 + 8af713e commit 6df26d3
Show file tree
Hide file tree
Showing 4 changed files with 337 additions and 38 deletions.
1 change: 1 addition & 0 deletions LLama.Examples/ExampleRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class ExampleRunner
{ "Batched Executor: LLava", BatchedExecutorLLava.Run },
{ "Batched Executor: BoolQ Benchmark", BatchedExecutorBoolQ.Run },
{ "Batched Executor: Beam Search", BatchedExecutorBeamSearch.Run },
{ "Custom Sampling Pipeline", CustomSampler.Run },
{ "Speech Chat: Integration with Whisper.net", SpeechChat.Run },
{ "Exit", () => { Environment.Exit(0); return Task.CompletedTask; } }
};
Expand Down
114 changes: 114 additions & 0 deletions LLama.Examples/Examples/CustomSampler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
using LLama.Common;
using LLama.Examples.Extensions;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Examples.Examples
{
public class CustomSampler
{
public static async Task Run()
{
var modelPath = UserSettings.GetModelPath();

var parameters = new ModelParams(modelPath);
using var model = await LLamaWeights.LoadFromFileAsync(parameters);

var ex = new StatelessExecutor(model, parameters);

Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine("In this example a custom sampling pipeline with a custom sampler stage is being used. This demonstrates how to customise the samplers used, and " +
"how to create a completely custom sampler stage which modifies the logits or selects a token." +
"" +
"In this case the custom sampler stage removes the most likely token. This will probably produce bad results, it's just a demo!"
);
Console.ForegroundColor = ConsoleColor.White;

var inferenceParams = new InferenceParams
{
SamplingPipeline = new CustomSamplingPipeline(),
MaxTokens = 50
};

while (true)
{
Console.Write("\nQuestion: ");
Console.ForegroundColor = ConsoleColor.Green;
var prompt = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.White;
Console.Write("Answer: ");
prompt = $"Question: {prompt?.Trim()} Answer: ";
await foreach (var text in ex.InferAsync(prompt, inferenceParams).Spinner())
{
Console.Write(text);
}
}
}
}

public class CustomSamplingPipeline
: BaseSamplingPipeline
{
protected override SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context)
{
var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default());

// Take only the 10 most likely tokens
chain.AddTopK(10);

// Remove the most likely token
chain.AddCustom(new RemoveMostLikelyToken());

// Select from the distribution
chain.AddSoftmax();
chain.AddDistributionSampler(42);

return chain;
}
}

public class RemoveMostLikelyToken
: ICustomSampler
{
public string Name => "Remove Most Likely Token";

public void Apply(ref LLamaTokenDataArrayNative tokenData)
{
// Doesn't make sense to run this stage if there is only one candidate left
if (tokenData.Size <= 1)
return;

// Ensure token data is sorted, so most likely token is first.
// Note that this is a descending sort, the **largest** value is first.
if (!tokenData.Sorted)
tokenData.Data.Sort((a, b) => b.Logit.CompareTo(a.Logit));

// Make the most likely token impossible to pick
tokenData.Data[0].Logit = float.NegativeInfinity;

// It's **critically** important to set this if the logits are no longer sorted after the custom
// sampler has run. If you're not sure, it's always safer to set it to false.
//
// In this case, because the first logit has just been set to negative infinity
// the token data is definitely not sorted!
tokenData.Sorted = false;
}

public void Accept(LLamaToken token)
{
}

public void Reset()
{
}

public ICustomSampler Clone()
{
return new RemoveMostLikelyToken();
}

public void Dispose()
{
}
}
}
22 changes: 18 additions & 4 deletions LLama/Native/LLamaTokenDataArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ public struct LLamaTokenDataArrayNative
/// <summary>
/// Number of LLamaTokenData in the array
/// </summary>
public ulong size;
private ulong _size;

/// <summary>
/// The index in the array (i.e. not the token id)
Expand All @@ -167,13 +167,13 @@ public Span<LLamaTokenData> Data
{
unsafe
{
return new Span<LLamaTokenData>(_data, checked((int)size));
return new Span<LLamaTokenData>(_data, checked((int)Size));
}
}
}

/// <summary>
/// Indicates if the items in the array are sorted
/// Indicates if the items in the array are sorted, so the most likely token is first
/// </summary>
public bool Sorted
{
Expand All @@ -190,6 +190,20 @@ public long Selected
set => _selected = value;
}

/// <summary>
/// Number of LLamaTokenData in the array. Set this to shrink the array
/// </summary>
public ulong Size
{
get => _size;
set
{
if (value > _size)
throw new ArgumentOutOfRangeException(nameof(value), "Cannot set Size property to a larger value");
_size = value;
}
}

/// <summary>
/// Create a new LLamaTokenDataArrayNative around the data in the LLamaTokenDataArray
/// </summary>
Expand All @@ -205,7 +219,7 @@ public static MemoryHandle Create(LLamaTokenDataArray array, out LLamaTokenDataA
native = new LLamaTokenDataArrayNative
{
_data = (LLamaTokenData*)handle.Pointer,
size = (ulong)array.Data.Length,
Size = (ulong)array.Data.Length,
Sorted = array.Sorted
};
}
Expand Down
Loading

0 comments on commit 6df26d3

Please sign in to comment.