From f59b60a00a0c88a8c3c7abee38954bc860e8c4e6 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Wed, 4 Sep 2024 02:33:53 +0100 Subject: [PATCH 1/2] Throwing an exception when `llama_get_logits_ith` returns `null`. --- LLama/Exceptions/RuntimeError.cs | 21 ++++++++++++++++++++- LLama/Native/SafeLLamaContextHandle.cs | 3 +++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/LLama/Exceptions/RuntimeError.cs b/LLama/Exceptions/RuntimeError.cs index 0feb53665..822a01772 100644 --- a/LLama/Exceptions/RuntimeError.cs +++ b/LLama/Exceptions/RuntimeError.cs @@ -1,4 +1,4 @@ -using System; +using System; using LLama.Native; namespace LLama.Exceptions; @@ -56,4 +56,23 @@ public LLamaDecodeError(DecodeResult returnCode) { ReturnCode = returnCode; } +} + +/// +/// `llama_get_logits_ith` returned null, indicating that the index was invalid +/// +public class NoLogitsException + : RuntimeError +{ + /// + /// The incorrect index passed to the `llama_get_logits_ith` call + /// + public int Index { get; } + + /// + public NoLogitsException(int index) + : base($"llama_get_logits_ith({index}) returned null") + { + Index = index; + } } \ No newline at end of file diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index dee74f590..dbe81b08a 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -472,6 +472,9 @@ public Span GetLogitsIth(int i) unsafe { var logits = llama_get_logits_ith(this, i); + if (logits == null) + throw new NoLogitsException(i); + return new Span(logits, model.VocabCount); } } From 846d2dc06665e7f4bb11391938ece412b05121f7 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Wed, 4 Sep 2024 02:37:36 +0100 Subject: [PATCH 2/2] Renamed exception --- LLama/Exceptions/RuntimeError.cs | 4 ++-- LLama/Native/SafeLLamaContextHandle.cs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/LLama/Exceptions/RuntimeError.cs b/LLama/Exceptions/RuntimeError.cs index 822a01772..4db77911e 100644 --- a/LLama/Exceptions/RuntimeError.cs +++ b/LLama/Exceptions/RuntimeError.cs @@ -61,7 +61,7 @@ public LLamaDecodeError(DecodeResult returnCode) /// /// `llama_get_logits_ith` returned null, indicating that the index was invalid /// -public class NoLogitsException +public class GetLogitsInvalidIndexException : RuntimeError { /// @@ -70,7 +70,7 @@ public class NoLogitsException public int Index { get; } /// - public NoLogitsException(int index) + public GetLogitsInvalidIndexException(int index) : base($"llama_get_logits_ith({index}) returned null") { Index = index; diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index dbe81b08a..61c3c0bbc 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -473,7 +473,7 @@ public Span GetLogitsIth(int i) { var logits = llama_get_logits_ith(this, i); if (logits == null) - throw new NoLogitsException(i); + throw new GetLogitsInvalidIndexException(i); return new Span(logits, model.VocabCount); }