Skip to content

Commit

Permalink
Merge pull request #401 from martindevans/remove_some_unsafe
Browse files Browse the repository at this point in the history
Removed some unnecessary uses of `unsafe`
  • Loading branch information
martindevans authored Jan 2, 2024
2 parents 5c876cb + bac3e43 commit a1a8461
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 51 deletions.
11 changes: 4 additions & 7 deletions LLama/LLamaEmbedder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,11 @@ public float[] GetEmbeddings(string text, bool addBos)
if (embed_inp_array.Length > 0)
Context.Eval(embed_inp_array, 0);

Check warning on line 78 in LLama/LLamaEmbedder.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

'LLamaContext.Eval(int[], int)' is obsolete: 'use llama_decode() instead'

Check warning on line 78 in LLama/LLamaEmbedder.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

'LLamaContext.Eval(int[], int)' is obsolete: 'use llama_decode() instead'

Check warning on line 78 in LLama/LLamaEmbedder.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

'LLamaContext.Eval(int[], int)' is obsolete: 'use llama_decode() instead'

unsafe
{
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
return Array.Empty<float>();
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
return Array.Empty<float>();

return embeddings.ToArray();
}
return embeddings.ToArray();
}

/// <summary>
Expand Down
9 changes: 4 additions & 5 deletions LLama/LLamaExecutorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public abstract class StatefulExecutorBase : ILLamaExecutor
/// </summary>
protected float? MirostatMu { get; set; }

private StreamingTokenDecoder _decoder;
private readonly StreamingTokenDecoder _decoder;

/// <summary>
///
Expand All @@ -95,7 +95,7 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null)
/// <returns></returns>
/// <exception cref="ArgumentNullException"></exception>
/// <exception cref="RuntimeError"></exception>
public unsafe StatefulExecutorBase WithSessionFile(string filename)
public StatefulExecutorBase WithSessionFile(string filename)
{
_pathSession = filename;
if (string.IsNullOrEmpty(filename))
Expand All @@ -105,9 +105,8 @@ public unsafe StatefulExecutorBase WithSessionFile(string filename)
if (File.Exists(filename))
{
_logger?.LogInformation($"[LLamaExecutor] Attempting to load saved session from {filename}");
llama_token[] session_tokens = new llama_token[Context.ContextSize];
ulong n_token_count_out = 0;
if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, &n_token_count_out))
var session_tokens = new llama_token[Context.ContextSize];
if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, out var n_token_count_out))
{
_logger?.LogError($"[LLamaExecutor] Failed to load session file {filename}");
throw new RuntimeError($"Failed to load session file {_pathSession}");
Expand Down
8 changes: 5 additions & 3 deletions LLama/LLamaQuantizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ public static class LLamaQuantizer
/// <param name="quantizeOutputTensor"></param>
/// <returns>Whether the quantization is successful.</returns>
/// <exception cref="ArgumentException"></exception>
public static unsafe bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1, bool allowRequantize = true,
bool quantizeOutputTensor = false)
public static bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1, bool allowRequantize = true, bool quantizeOutputTensor = false)
{
if (!ValidateFtype(ftype))
{
Expand All @@ -34,7 +33,10 @@ public static unsafe bool Quantize(string srcFileName, string dstFilename, LLama
quantizeParams.nthread = nthread;
quantizeParams.allow_requantize = allowRequantize;
quantizeParams.quantize_output_tensor = quantizeOutputTensor;
return NativeApi.llama_model_quantize(srcFileName, dstFilename, &quantizeParams) == 0;
unsafe
{
return NativeApi.llama_model_quantize(srcFileName, dstFilename, &quantizeParams) == 0;
}
}

/// <summary>
Expand Down
46 changes: 33 additions & 13 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ public static partial class NativeApi
/// <param name="n_token_count_out"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, ulong* n_token_count_out);
public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out);

/// <summary>
/// Save session file
Expand Down Expand Up @@ -439,24 +439,44 @@ public static Span<float> llama_get_embeddings(SafeLLamaContextHandle ctx)
/// <summary>
/// Get metadata key name by index
/// </summary>
/// <param name="model"></param>
/// <param name="index"></param>
/// <param name="buf"></param>
/// <param name="buf_size"></param>
/// <param name="model">Model to fetch from</param>
/// <param name="index">Index of key to fetch</param>
/// <param name="dest">buffer to write result into</param>
/// <returns>The length of the string on success, or -1 on failure</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
public static int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, Span<byte> dest)
{
unsafe
{
fixed (byte* destPtr = dest)
{
return llama_model_meta_key_by_index_native(model, index, destPtr, dest.Length);
}
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_key_by_index")]
static extern unsafe int llama_model_meta_key_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
}

/// <summary>
/// Get metadata value as a string by index
/// </summary>
/// <param name="model"></param>
/// <param name="index"></param>
/// <param name="buf"></param>
/// <param name="buf_size"></param>
/// <param name="model">Model to fetch from</param>
/// <param name="index">Index of val to fetch</param>
/// <param name="dest">Buffer to write result into</param>
/// <returns>The length of the string on success, or -1 on failure</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
public static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, Span<byte> dest)
{
unsafe
{
fixed (byte* destPtr = dest)
{
return llama_model_meta_val_str_by_index_native(model, index, destPtr, dest.Length);
}
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_val_str_by_index")]
static extern unsafe int llama_model_meta_val_str_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size);
}

/// <summary>
/// Get a string describing the model type
Expand Down
34 changes: 11 additions & 23 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -221,21 +221,17 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params)
unsafe
{
// Check if the key exists, without getting any bytes of data
keyLength = NativeApi.llama_model_meta_key_by_index(this, index, (byte*)0, 0);
keyLength = NativeApi.llama_model_meta_key_by_index(this, index, Array.Empty<byte>());
if (keyLength < 0)
return null;
}

// get a buffer large enough to hold it
var buffer = new byte[keyLength + 1];
unsafe
{
using var pin = buffer.AsMemory().Pin();
keyLength = NativeApi.llama_model_meta_key_by_index(this, index, (byte*)pin.Pointer, buffer.Length);
Debug.Assert(keyLength >= 0);
keyLength = NativeApi.llama_model_meta_key_by_index(this, index, buffer);
Debug.Assert(keyLength >= 0);

return buffer.AsMemory().Slice(0, keyLength);
}
return buffer.AsMemory().Slice(0, keyLength);
}

/// <summary>
Expand All @@ -245,25 +241,17 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params)
/// <returns>The value, null if there is no such value or if the buffer was too small</returns>
public Memory<byte>? MetadataValueByIndex(int index)
{
int valueLength;
unsafe
{
// Check if the key exists, without getting any bytes of data
valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, (byte*)0, 0);
if (valueLength < 0)
return null;
}
// Check if the key exists, without getting any bytes of data
var valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, Array.Empty<byte>());
if (valueLength < 0)
return null;

// get a buffer large enough to hold it
var buffer = new byte[valueLength + 1];
unsafe
{
using var pin = buffer.AsMemory().Pin();
valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, (byte*)pin.Pointer, buffer.Length);
Debug.Assert(valueLength >= 0);
valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, buffer);
Debug.Assert(valueLength >= 0);

return buffer.AsMemory().Slice(0, valueLength);
}
return buffer.AsMemory().Slice(0, valueLength);
}

internal IReadOnlyDictionary<string, string> ReadMetadata()
Expand Down

0 comments on commit a1a8461

Please sign in to comment.