Skip to content

Commit

Permalink
Keep track of prompt and completion tokens separately
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Dec 4, 2023
1 parent 8a1430f commit 4039550
Showing 1 changed file with 35 additions and 7 deletions.
42 changes: 35 additions & 7 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,14 @@ def __init__(
self.client = openai.AsyncOpenAI(
api_key=api_key, max_retries=max_retries, timeout=timeout
)
self.total_tokens = 0
self.last_tokens: int
self.system_prompt = system_prompt

# We count the total number of prompt and generated tokens as returned
# by the OpenAI API, summed over all the requests performed with this
# model instance.
self.prompt_tokens = 0
self.completion_tokens = 0

def __call__(
self,
prompt: Union[str, List[str]],
Expand Down Expand Up @@ -170,10 +174,12 @@ def __call__(
)
)
if "gpt-" in self.config.model:
response, self.last_tokens = generate_chat(
response, usage = generate_chat(
prompt, self.system_prompt, self.client, config
)
self.total_tokens += self.last_tokens
self.prompt_tokens += usage.prompt_tokens
self.completion_tokens += usage.completion_tokens

return response

def generate_choice(
Expand Down Expand Up @@ -227,10 +233,11 @@ def generate_choice(

config = replace(config, logit_bias=mask, max_tokens=max_tokens_left)

response, self.last_tokens = generate_chat(
response, usage = generate_chat(
prompt, self.system_prompt, self.client, config
)
self.total_tokens += self.last_tokens
self.completion_tokens += usage.completion_tokens
self.prompt_tokens += usage.completion_tokens

encoded_response = tokenizer.encode(response)

Expand Down Expand Up @@ -283,12 +290,33 @@ async def generate_chat(
client: "AsyncOpenAI",
config: OpenAIConfig,
) -> Tuple[np.ndarray, "CompletionUsage"]:
"""Call OpenAI's Chat Completion API.
Parameters
----------
prompt
The prompt we use to start the generation. Passed to the model
with the "user" role.
system_prompt
The system prompt, passed to the model with the "system" role
before the prompt.
client
The API client
config
An `OpenAIConfig` instance.
Returns
-------
A tuple that contains the model's response(s) and usage statistics.
"""
system_message = (
[{"role": "system", "content": system_prompt}] if system_prompt else []
)
user_message = [{"role": "user", "content": prompt}]

responses = await client.chat.completions.create(
messages=system_message + [{"role": "user", "content": prompt}],
messages=system_message + user_message,
**asdict(config), # type: ignore
)

Expand Down

0 comments on commit 4039550

Please sign in to comment.