diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index c551016c0..0c6cc87ca 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -77,14 +77,11 @@ public float[] GetEmbeddings(string text, bool addBos) if (embed_inp_array.Length > 0) Context.Eval(embed_inp_array, 0); - unsafe - { - var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); - if (embeddings == null) - return Array.Empty(); + var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); + if (embeddings == null) + return Array.Empty(); - return embeddings.ToArray(); - } + return embeddings.ToArray(); } /// diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index e0fde1edb..e7b768be1 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -70,7 +70,7 @@ public abstract class StatefulExecutorBase : ILLamaExecutor /// protected float? MirostatMu { get; set; } - private StreamingTokenDecoder _decoder; + private readonly StreamingTokenDecoder _decoder; /// /// @@ -95,7 +95,7 @@ protected StatefulExecutorBase(LLamaContext context, ILogger? logger = null) /// /// /// - public unsafe StatefulExecutorBase WithSessionFile(string filename) + public StatefulExecutorBase WithSessionFile(string filename) { _pathSession = filename; if (string.IsNullOrEmpty(filename)) @@ -105,9 +105,8 @@ public unsafe StatefulExecutorBase WithSessionFile(string filename) if (File.Exists(filename)) { _logger?.LogInformation($"[LLamaExecutor] Attempting to load saved session from {filename}"); - llama_token[] session_tokens = new llama_token[Context.ContextSize]; - ulong n_token_count_out = 0; - if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, &n_token_count_out)) + var session_tokens = new llama_token[Context.ContextSize]; + if (!NativeApi.llama_load_session_file(Context.NativeHandle, _pathSession, session_tokens, (ulong)Context.ContextSize, out var n_token_count_out)) { _logger?.LogError($"[LLamaExecutor] Failed to load session file {filename}"); throw new RuntimeError($"Failed to load session file {_pathSession}"); diff --git a/LLama/LLamaQuantizer.cs b/LLama/LLamaQuantizer.cs index 40632724f..54b0ed021 100644 --- a/LLama/LLamaQuantizer.cs +++ b/LLama/LLamaQuantizer.cs @@ -20,8 +20,7 @@ public static class LLamaQuantizer /// /// Whether the quantization is successful. /// - public static unsafe bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1, bool allowRequantize = true, - bool quantizeOutputTensor = false) + public static bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1, bool allowRequantize = true, bool quantizeOutputTensor = false) { if (!ValidateFtype(ftype)) { @@ -34,7 +33,10 @@ public static unsafe bool Quantize(string srcFileName, string dstFilename, LLama quantizeParams.nthread = nthread; quantizeParams.allow_requantize = allowRequantize; quantizeParams.quantize_output_tensor = quantizeOutputTensor; - return NativeApi.llama_model_quantize(srcFileName, dstFilename, &quantizeParams) == 0; + unsafe + { + return NativeApi.llama_model_quantize(srcFileName, dstFilename, &quantizeParams) == 0; + } } /// diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 1c7715f66..38ba1bc6d 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -175,7 +175,7 @@ public static partial class NativeApi /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, ulong* n_token_count_out); + public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out); /// /// Save session file @@ -439,24 +439,44 @@ public static Span llama_get_embeddings(SafeLLamaContextHandle ctx) /// /// Get metadata key name by index /// - /// - /// - /// - /// + /// Model to fetch from + /// Index of key to fetch + /// buffer to write result into /// The length of the string on success, or -1 on failure - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); + public static int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, Span dest) + { + unsafe + { + fixed (byte* destPtr = dest) + { + return llama_model_meta_key_by_index_native(model, index, destPtr, dest.Length); + } + } + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_key_by_index")] + static extern unsafe int llama_model_meta_key_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); + } /// /// Get metadata value as a string by index /// - /// - /// - /// - /// + /// Model to fetch from + /// Index of val to fetch + /// Buffer to write result into /// The length of the string on success, or -1 on failure - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern unsafe int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); + public static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, Span dest) + { + unsafe + { + fixed (byte* destPtr = dest) + { + return llama_model_meta_val_str_by_index_native(model, index, destPtr, dest.Length); + } + } + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_model_meta_val_str_by_index")] + static extern unsafe int llama_model_meta_val_str_by_index_native(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); + } /// /// Get a string describing the model type diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 2280250ec..fa71af3f1 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -221,21 +221,17 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params) unsafe { // Check if the key exists, without getting any bytes of data - keyLength = NativeApi.llama_model_meta_key_by_index(this, index, (byte*)0, 0); + keyLength = NativeApi.llama_model_meta_key_by_index(this, index, Array.Empty()); if (keyLength < 0) return null; } // get a buffer large enough to hold it var buffer = new byte[keyLength + 1]; - unsafe - { - using var pin = buffer.AsMemory().Pin(); - keyLength = NativeApi.llama_model_meta_key_by_index(this, index, (byte*)pin.Pointer, buffer.Length); - Debug.Assert(keyLength >= 0); + keyLength = NativeApi.llama_model_meta_key_by_index(this, index, buffer); + Debug.Assert(keyLength >= 0); - return buffer.AsMemory().Slice(0, keyLength); - } + return buffer.AsMemory().Slice(0, keyLength); } /// @@ -245,25 +241,17 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params) /// The value, null if there is no such value or if the buffer was too small public Memory? MetadataValueByIndex(int index) { - int valueLength; - unsafe - { - // Check if the key exists, without getting any bytes of data - valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, (byte*)0, 0); - if (valueLength < 0) - return null; - } + // Check if the key exists, without getting any bytes of data + var valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, Array.Empty()); + if (valueLength < 0) + return null; // get a buffer large enough to hold it var buffer = new byte[valueLength + 1]; - unsafe - { - using var pin = buffer.AsMemory().Pin(); - valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, (byte*)pin.Pointer, buffer.Length); - Debug.Assert(valueLength >= 0); + valueLength = NativeApi.llama_model_meta_val_str_by_index(this, index, buffer); + Debug.Assert(valueLength >= 0); - return buffer.AsMemory().Slice(0, valueLength); - } + return buffer.AsMemory().Slice(0, valueLength); } internal IReadOnlyDictionary ReadMetadata()