Skip to content

Commit

Permalink
track only total_tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
HerrIvan committed Dec 4, 2023
1 parent 7efa6b7 commit ab09bb3
Showing 1 changed file with 15 additions and 41 deletions.
56 changes: 15 additions & 41 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,6 @@
from openai import AsyncOpenAI


@dataclass(frozen=True)
class OpenAIUsage:
prompt_tokens: Optional[int] = 0
completion_tokens: Optional[int] = 0
total_tokens: Optional[int] = 0

def __add__(self, other):
return OpenAIUsage(
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
completion_tokens=self.completion_tokens + other.completion_tokens,
total_tokens=self.total_tokens + other.total_tokens,
)


@dataclass(frozen=True)
class OpenAIConfig:
"""Represents the parameters of the OpenAI API.
Expand Down Expand Up @@ -142,9 +128,8 @@ def __init__(
api_key=api_key, max_retries=max_retries, timeout=timeout
)
self.system_prompt = role

self.total_usage = OpenAIUsage()
self.last_usage: Union[OpenAIUsage, None] = None
self.total_tokens = 0
self.last_tokens: int

def __call__(
self,
Expand Down Expand Up @@ -182,7 +167,11 @@ def __call__(
)
)
if "gpt-" in self.config.model:
return self.generate_chat(prompt, config)
response, self.last_tokens = generate_chat(
prompt, self.system_prompt, self.client, config
)
self.total_tokens += self.last_tokens
return response

def generate_choice(
self, prompt: str, choices: List[str], max_tokens: Optional[int] = None
Expand Down Expand Up @@ -234,7 +223,12 @@ def generate_choice(
break

config = replace(config, logit_bias=mask, max_tokens=max_tokens_left)
response = self.generate_chat(prompt, config)

response, self.last_tokens = generate_chat(
prompt, self.system_prompt, self.client, config
)
self.total_tokens += self.last_tokens

encoded_response = tokenizer.encode(response)

if encoded_response in encoded_choices_left:
Expand Down Expand Up @@ -267,26 +261,6 @@ def generate_choice(

return choice

def generate_chat(
self, prompt: Union[str, List[str]], config: OpenAIConfig
) -> np.ndarray:
"""Call the async function to generate a chat response and keeps track of usage data.
Parameters
----------
prompt
A string used to prompt the model as user message
config
An instance of `OpenAIConfig`.
"""
results, usage = generate_chat(prompt, self.system_prompt, self.client, config)

self.last_usage = OpenAIUsage(**usage)
self.total_usage += self.last_usage

return results

def generate_json(self):
"""Call the OpenAI API to generate a JSON object."""
raise NotImplementedError
Expand All @@ -305,7 +279,7 @@ async def generate_chat(
system_prompt: Union[str, None],
client: "AsyncOpenAI",
config: OpenAIConfig,
) -> Tuple[np.ndarray, Dict]:
) -> Tuple[np.ndarray, int]:
responses = await client.chat.completions.create(
messages=(
[{"role": "system", "content": system_prompt}] if system_prompt else []
Expand All @@ -321,7 +295,7 @@ async def generate_chat(
[responses.choices[i].message.content for i in range(config.n)]
)

return results, responses.usage.model_dump()
return results, responses.usage.total_tokens


openai = OpenAI
Expand Down

0 comments on commit ab09bb3

Please sign in to comment.