From c002642268e08829cddcefac4b6e3cc75ed5549b Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 2 Jan 2024 03:39:50 +0000 Subject: [PATCH 1/3] - Removed some `unsafe` where it wasn't necessary - Wrapped some native functions which take (pointer, length) in function which take a `span` instead. --- LLama/LLamaEmbedder.cs | 11 +++---- LLama/LLamaExecutorBase.cs | 2 +- LLama/LLamaQuantizer.cs | 8 +++-- LLama/Native/NativeApi.cs | 46 ++++++++++++++++++++-------- LLama/Native/SafeLlamaModelHandle.cs | 34 +++++++------------- 5 files changed, 54 insertions(+), 47 deletions(-) diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index c551016c0..0c6cc87ca 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -77,14 +77,11 @@ public float[] GetEmbeddings(string text, bool addBos) if (embed_inp_array.Length > 0) Context.Eval(embed_inp_array, 0); - unsafe - { - var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); - if (embeddings == null) - return Array.Empty(); + var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); + if (embeddings == null) + return Array.Empty(); - return embeddings.ToArray(); - } + return embeddings.ToArray(); } /// diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index e0fde1edb..cb1b850b3 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -70,7 +70,7 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// protected float? MirostatMu { get; set; } - private StreamingTokenDecoder _decoder; + private readonly StreamingTokenDecoder _decoder; /// /// diff --git a/LLama/LLamaQuantizer.cs b/LLama/LLamaQuantizer.cs index 40632724f..54b0ed021 100644 --- a/LLama/LLamaQuantizer.cs +++ b/LLama/LLamaQuantizer.cs @@ -20,8 +20,7 @@ public static class LLamaQuantizer /// /// Whether the quantization is successful. /// - public static unsafe bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1, bool allowRequantize = true, - bool quantizeOutputTensor = false) + public static bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1, bool allowRequantize = true, bool quantizeOutputTensor = false) { if (!ValidateFtype(ftype)) { @@ -34,7 +33,10 @@ public static unsafe bool Quantize(string srcFileName, string dstFilename, LLama quantizeParams.nthread = nthread; quantizeParams.allow_requantize = allowRequantize; quantizeParams.quantize_output_tensor = quantizeOutputTensor; - return NativeApi.llama_model_quantize(srcFileName, dstFilename, &quantizeParams) == 0; + unsafe + { + return NativeApi.llama_model_quantize(srcFileName, dstFilename, &quantizeParams) == 0; + } } /// diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 1c7715f66..5d76f1d87 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -175,7 +175,7 @@ public static partial class NativeApi /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, ulong* n_token_count_out); + public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out); /// /// Save session file @@ -439,24 +439,44 @@ public static Span llama_get_embeddings(SafeLLamaContextHandle ctx) /// /// Get metadata key name by index /// - /// - /// - /// - /// + /// Model to fetch from + /// Index of key to fetch + /// buffer to write result into /// The length of the string on success, or -1 on failure - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); + public static int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, Span dest) + { + unsafe + { + fixed (byte* destPtr = &dest[0]) + { + return llama_model_meta_key_by_index_native(model, index, destPtr, dest.Length); + } + } + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_key_by_index")] + static extern unsafe int llama_model_meta_key_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); + } /// /// Get metadata value as a string by index /// - /// - /// - /// - /// + /// Model to fetch from + /// Index of val to fetch + /// Buffer to write result into /// The length of the string on success, or -1 on failure - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); + public static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, Span dest) + { + unsafe + { + fixed (byte* destPtr = &dest[0]) + { + return llama_model_meta_val_str_by_index_native(model, index, destPtr, dest.Length); + } + } + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_val_str_by_index")] + static extern unsafe int llama_model_meta_val_str_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); + } /// /// Get a string describing the model type diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 2280250ec..fa71af3f1 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -221,21 +221,17 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params) unsafe { // Check if the key exists, without getting any bytes of data - keyLength = NativeApi.llama_model_meta_key_by_index(this, index, (byte*)0, 0); + keyLength = NativeApi.llama_model_meta_key_by_index(this, index, Array.Empty()); if (keyLength < 0) return null; } // get a buffer large enough to hold it var buffer = new byte[keyLength + 1]; - unsafe - { - using var pin = buffer.AsMemory().Pin(); - keyLength = NativeApi.llama_model_meta_key_by_index(this, index, (byte*)pin.Pointer, buffer.Length); - Debug.Assert(keyLength >= 0); + keyLength = NativeApi.llama_model_meta_key_by_index(this, index, buffer); + Debug.Assert(keyLength >= 0); - return buffer.AsMemory().Slice(0, keyLength); - } + return buffer.AsMemory().Slice(0, keyLength); } /// @@ -245,25 +241,17 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params) /// The value, null if there is no such value or if the buffer was too small public Memory? MetadataValueByIndex(int index) { - int valueLength; - unsafe - { - // Check if the key exists, without getting any bytes of data - valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, (byte*)0, 0); - if (valueLength < 0) - return null; - } + // Check if the key exists, without getting any bytes of data + var valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, Array.Empty()); + if (valueLength < 0) + return null; // get a buffer large enough to hold it var buffer = new byte[valueLength + 1]; - unsafe - { - using var pin = buffer.AsMemory().Pin(); - valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, (byte*)pin.Pointer, buffer.Length); - Debug.Assert(valueLength >= 0); + valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, buffer); + Debug.Assert(valueLength >= 0); - return buffer.AsMemory().Slice(0, valueLength); - } + return buffer.AsMemory().Slice(0, valueLength); } internal IReadOnlyDictionary ReadMetadata() From 39255451479eea8fe20023b36533a527e22cdb2f Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 2 Jan 2024 03:41:16 +0000 Subject: [PATCH 2/3] Fixed LLamaExecutorBase.cs --- LLama/LLamaExecutorBase.cs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index cb1b850b3..e7b768be1 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -95,7 +95,7 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) /// /// /// - public unsafe StatefulExecutorBase WithSessionFile(string filename) + public StatefulExecutorBase WithSessionFile(string filename) { _pathSession = filename; if (string.IsNullOrEmpty(filename)) @@ -105,9 +105,8 @@ public unsafe StatefulExecutorBase WithSessionFile(string filename) if (File.Exists(filename)) { _logger?.LogInformation($"[LLamaExecutor] Attempting to load saved session from {filename}"); - llama_token[] session_tokens = new llama_token[Context.ContextSize]; - ulong n_token_count_out = 0; - if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, &n_token_count_out)) + var session_tokens = new llama_token[Context.ContextSize]; + if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, out var n_token_count_out)) { _logger?.LogError($"[LLamaExecutor] Failed to load session file {filename}"); throw new RuntimeError($"Failed to load session file {_pathSession}"); From bac3e43498d968aaf80c2820e2eb6c42a41d3f72 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 2 Jan 2024 03:42:54 +0000 Subject: [PATCH 3/3] Fixed handling of empty spans --- LLama/Native/NativeApi.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 5d76f1d87..38ba1bc6d 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -447,7 +447,7 @@ public static int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int { unsafe { - fixed (byte* destPtr = &dest[0]) + fixed (byte* destPtr = dest) { return llama_model_meta_key_by_index_native(model, index, destPtr, dest.Length); } @@ -468,7 +468,7 @@ public static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, { unsafe { - fixed (byte* destPtr = &dest[0]) + fixed (byte* destPtr = dest) { return llama_model_meta_val_str_by_index_native(model, index, destPtr, dest.Length); }