diff --git a/outlines/models/openai.py b/outlines/models/openai.py index 771b6db71..21fece2aa 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -17,20 +17,6 @@ from openai import AsyncOpenAI -@dataclass(frozen=True) -class OpenAIUsage: - prompt_tokens: Optional[int] = 0 - completion_tokens: Optional[int] = 0 - total_tokens: Optional[int] = 0 - - def __add__(self, other): - return OpenAIUsage( - prompt_tokens=self.prompt_tokens + other.prompt_tokens, - completion_tokens=self.completion_tokens + other.completion_tokens, - total_tokens=self.total_tokens + other.total_tokens, - ) - - @dataclass(frozen=True) class OpenAIConfig: """Represents the parameters of the OpenAI API. @@ -142,9 +128,8 @@ def __init__( api_key=api_key, max_retries=max_retries, timeout=timeout ) self.system_prompt = role - - self.total_usage = OpenAIUsage() - self.last_usage: Union[OpenAIUsage, None] = None + self.total_tokens = 0 + self.last_tokens: int def __call__( self, @@ -182,7 +167,11 @@ def __call__( ) ) if "gpt-" in self.config.model: - return self.generate_chat(prompt, config) + response, self.last_tokens = generate_chat( + prompt, self.system_prompt, self.client, config + ) + self.total_tokens += self.last_tokens + return response def generate_choice( self, prompt: str, choices: List[str], max_tokens: Optional[int] = None @@ -234,7 +223,12 @@ def generate_choice( break config = replace(config, logit_bias=mask, max_tokens=max_tokens_left) - response = self.generate_chat(prompt, config) + + response, self.last_tokens = generate_chat( + prompt, self.system_prompt, self.client, config + ) + self.total_tokens += self.last_tokens + encoded_response = tokenizer.encode(response) if encoded_response in encoded_choices_left: @@ -267,26 +261,6 @@ def generate_choice( return choice - def generate_chat( - self, prompt: Union[str, List[str]], config: OpenAIConfig - ) -> np.ndarray: - """Call the async function to generate a chat response and keeps track of usage data. - - Parameters - ---------- - prompt - A string used to prompt the model as user message - config - An instance of `OpenAIConfig`. - - """ - results, usage = generate_chat(prompt, self.system_prompt, self.client, config) - - self.last_usage = OpenAIUsage(**usage) - self.total_usage += self.last_usage - - return results - def generate_json(self): """Call the OpenAI API to generate a JSON object.""" raise NotImplementedError @@ -305,7 +279,7 @@ async def generate_chat( system_prompt: Union[str, None], client: "AsyncOpenAI", config: OpenAIConfig, -) -> Tuple[np.ndarray, Dict]: +) -> Tuple[np.ndarray, int]: responses = await client.chat.completions.create( messages=( [{"role": "system", "content": system_prompt}] if system_prompt else [] @@ -321,7 +295,7 @@ async def generate_chat( [responses.choices[i].message.content for i in range(config.n)] ) - return results, responses.usage.model_dump() + return results, responses.usage.total_tokens openai = OpenAI