Skip to content

Commit

Permalink
fix!: Separate Sampling struct for chat
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman82 authored and markus-klein-aa committed Feb 4, 2025
1 parent 7e88e7f commit 0431a04
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 42 deletions.
59 changes: 50 additions & 9 deletions src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::borrow::Cow;

use serde::{Deserialize, Serialize};

use crate::{Sampling, Stopping, StreamTask, Task};
use crate::{Stopping, StreamTask, Task};

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct Message<'a> {
Expand Down Expand Up @@ -34,7 +34,7 @@ pub struct TaskChat<'a> {
/// Controls in which circumstances the model will stop generating new tokens.
pub stopping: Stopping<'a>,
/// Sampling controls how the tokens ("words") are selected for the completion.
pub sampling: Sampling,
pub sampling: ChatSampling,
}

impl<'a> TaskChat<'a> {
Expand All @@ -49,7 +49,7 @@ impl<'a> TaskChat<'a> {
pub fn with_messages(messages: Vec<Message<'a>>) -> Self {
TaskChat {
messages,
sampling: Sampling::default(),
sampling: ChatSampling::default(),
stopping: Stopping::default(),
}
}
Expand All @@ -67,6 +67,52 @@ impl<'a> TaskChat<'a> {
}
}

/// Sampling controls how the tokens ("words") are selected for the completion. This is different
/// from [`crate::Sampling`], because it does **not** supprot the `top_k` parameter.
pub struct ChatSampling {
/// A temperature encourages the model to produce less probable outputs ("be more creative").
/// Values are expected to be between 0 and 1. Try high values for a more random ("creative")
/// response.
pub temperature: Option<f64>,
/// Introduces random sampling for generated tokens by randomly selecting the next token from
/// the k most likely options. A value larger than 1 encourages the model to be more creative.
/// Set to 0 to get the same behaviour as `None`.
pub top_p: Option<f64>,
/// When specified, this number will decrease (or increase) the likelihood of repeating tokens
/// that were mentioned prior in the completion. The penalty is cumulative. The more a token
/// is mentioned in the completion, the more its probability will decrease.
/// A negative value will increase the likelihood of repeating tokens.
pub frequency_penalty: Option<f64>,
/// The presence penalty reduces the likelihood of generating tokens that are already present
/// in the generated text (repetition_penalties_include_completion=true) respectively the
/// prompt (repetition_penalties_include_prompt=true). Presence penalty is independent of the
/// number of occurrences. Increase the value to reduce the likelihood of repeating text.
/// An operation like the following is applied:
///
/// logits[t] -> logits[t] - 1 * penalty
///
/// where logits[t] is the logits for any given token. Note that the formula is independent
/// of the number of times that a token appears.
pub presence_penalty: Option<f64>,
}

impl ChatSampling {
/// Always chooses the token most likely to come next. Choose this if you do want close to
/// deterministic behaviour and do not want to apply any penalties to avoid repetitions.
pub const MOST_LIKELY: Self = ChatSampling {
temperature: None,
top_p: None,
frequency_penalty: None,
presence_penalty: None,
};
}

impl Default for ChatSampling {
fn default() -> Self {
Self::MOST_LIKELY
}
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct ChatOutput {
pub message: Message<'static>,
Expand Down Expand Up @@ -117,19 +163,14 @@ impl<'a> ChatBody<'a> {
stop_sequences,
},
sampling:
Sampling {
ChatSampling {
temperature,
top_p,
top_k,
frequency_penalty,
presence_penalty,
},
} = task;

if top_k.is_some() {
panic!("The top_k parameter is not supported for chat completions.");
}

Self {
model,
messages,
Expand Down
9 changes: 5 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ use std::{pin::Pin, time::Duration};
use tokenizers::Tokenizer;

pub use self::{
chat::{ChatEvent, ChatStreamChunk},
chat::{ChatOutput, Message, TaskChat},
completion::{CompletionEvent, CompletionSummary, StreamChunk, StreamSummary},
completion::{CompletionOutput, Sampling, Stopping, TaskCompletion},
chat::{ChatEvent, ChatOutput, ChatSampling, ChatStreamChunk, Message, TaskChat},
completion::{
CompletionEvent, CompletionOutput, CompletionSummary, Sampling, Stopping, StreamChunk,
StreamSummary, TaskCompletion,
},
detokenization::{DetokenizationOutput, TaskDetokenization},
explanation::{
Explanation, ExplanationOutput, Granularity, ImageScore, ItemExplanation,
Expand Down
36 changes: 7 additions & 29 deletions tests/integration.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::{fs::File, io::BufReader};

use aleph_alpha_client::{
cosine_similarity, Client, CompletionEvent, Granularity, How, ImageScore, ItemExplanation,
Message, Modality, Prompt, PromptGranularity, Sampling, SemanticRepresentation, Stopping, Task,
TaskBatchSemanticEmbedding, TaskChat, TaskCompletion, TaskDetokenization, TaskExplanation,
TaskSemanticEmbedding, TaskTokenization, TextScore,
cosine_similarity, ChatSampling, Client, CompletionEvent, Granularity, How, ImageScore,
ItemExplanation, Message, Modality, Prompt, PromptGranularity, Sampling,
SemanticRepresentation, Stopping, Task, TaskBatchSemanticEmbedding, TaskChat, TaskCompletion,
TaskDetokenization, TaskExplanation, TaskSemanticEmbedding, TaskTokenization, TextScore,
};
use dotenvy::dotenv;
use futures_util::StreamExt;
Expand Down Expand Up @@ -600,7 +600,7 @@ async fn frequency_penalty_request() {
let model = "pharia-1-llm-7b-control";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let message = Message::user("Haiku about oat milk!");
let sampling = Sampling {
let sampling = ChatSampling {
frequency_penalty: Some(-10.0),
..Default::default()
};
Expand Down Expand Up @@ -635,7 +635,7 @@ async fn presence_penalty_request() {
let model = "pharia-1-llm-7b-control";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let message = Message::user("Haiku about oat milk!");
let sampling = Sampling {
let sampling = ChatSampling {
presence_penalty: Some(-10.0),
..Default::default()
};
Expand Down Expand Up @@ -678,7 +678,7 @@ async fn stop_sequences_request() {
let task = TaskChat {
messages: vec![message],
stopping,
sampling: Sampling::MOST_LIKELY,
sampling: ChatSampling::MOST_LIKELY,
};

// When the response is requested
Expand All @@ -691,25 +691,3 @@ async fn stop_sequences_request() {
// Actually, it should be `stop`, but the api scheduler is inconsistent here
assert_eq!(response.finish_reason, "content_filter");
}

#[tokio::test]
#[should_panic(expected = "The top_k parameter is not supported for chat completions.")]
async fn chat_does_not_support_top_k() {
// Given a high negative frequency penalty
let model = "pharia-1-llm-7b-control";
let client = Client::with_auth(inference_url(), pharia_ai_token()).unwrap();
let message = Message::user("Haiku about oat milk!");

// When
let sampling = Sampling {
top_k: Some(10),
..Default::default()
};
let stopping = Stopping::from_maximum_tokens(20);
let task = TaskChat {
messages: vec![message],
stopping,
sampling,
};
client.chat(&task, model, &How::default()).await.unwrap();
}

0 comments on commit 0431a04

Please sign in to comment.