diff --git a/outlines/models/openai.py b/outlines/models/openai.py index 16126acb0..f87a26f7e 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -130,10 +130,14 @@ def __init__( self.client = openai.AsyncOpenAI( api_key=api_key, max_retries=max_retries, timeout=timeout ) - self.total_tokens = 0 - self.last_tokens: int self.system_prompt = system_prompt + # We count the total number of prompt and generated tokens as returned + # by the OpenAI API, summed over all the requests performed with this + # model instance. + self.prompt_tokens = 0 + self.completion_tokens = 0 + def __call__( self, prompt: Union[str, List[str]], @@ -170,10 +174,12 @@ def __call__( ) ) if "gpt-" in self.config.model: - response, self.last_tokens = generate_chat( + response, usage = generate_chat( prompt, self.system_prompt, self.client, config ) - self.total_tokens += self.last_tokens + self.prompt_tokens += usage.prompt_tokens + self.completion_tokens += usage.completion_tokens + return response def generate_choice( @@ -227,10 +233,11 @@ def generate_choice( config = replace(config, logit_bias=mask, max_tokens=max_tokens_left) - response, self.last_tokens = generate_chat( + response, usage = generate_chat( prompt, self.system_prompt, self.client, config ) - self.total_tokens += self.last_tokens + self.completion_tokens += usage.completion_tokens + self.prompt_tokens += usage.completion_tokens encoded_response = tokenizer.encode(response) @@ -283,12 +290,33 @@ async def generate_chat( client: "AsyncOpenAI", config: OpenAIConfig, ) -> Tuple[np.ndarray, "CompletionUsage"]: + """Call OpenAI's Chat Completion API. + + Parameters + ---------- + prompt + The prompt we use to start the generation. Passed to the model + with the "user" role. + system_prompt + The system prompt, passed to the model with the "system" role + before the prompt. + client + The API client + config + An `OpenAIConfig` instance. + + Returns + ------- + A tuple that contains the model's response(s) and usage statistics. + + """ system_message = ( [{"role": "system", "content": system_prompt}] if system_prompt else [] ) + user_message = [{"role": "user", "content": prompt}] responses = await client.chat.completions.create( - messages=system_message + [{"role": "user", "content": prompt}], + messages=system_message + user_message, **asdict(config), # type: ignore )