From 3356d52191f57c41ea8ba457aecaddb0051fed82 Mon Sep 17 00:00:00 2001 From: Ivan Herreros Date: Tue, 5 Dec 2023 12:49:31 +0100 Subject: [PATCH] return dict instead of pydantic object in generate chat --- outlines/models/openai.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/outlines/models/openai.py b/outlines/models/openai.py index cb5bddae6..45509a9f4 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -15,7 +15,6 @@ if TYPE_CHECKING: from openai import AsyncOpenAI - from openai.types.completion_usage import CompletionUsage @dataclass(frozen=True) @@ -177,8 +176,8 @@ def __call__( response, usage = generate_chat( prompt, self.system_prompt, self.client, config ) - self.prompt_tokens += usage.prompt_tokens - self.completion_tokens += usage.completion_tokens + self.prompt_tokens += usage["prompt_tokens"] + self.completion_tokens += usage["completion_tokens"] return response @@ -236,8 +235,8 @@ def generate_choice( response, usage = generate_chat( prompt, self.system_prompt, self.client, config ) - self.completion_tokens += usage.completion_tokens - self.prompt_tokens += usage.completion_tokens + self.completion_tokens += usage["completion_tokens"] + self.prompt_tokens += usage["prompt_tokens"] encoded_response = tokenizer.encode(response) @@ -289,7 +288,7 @@ async def generate_chat( system_prompt: Union[str, None], client: "AsyncOpenAI", config: OpenAIConfig, -) -> Tuple[np.ndarray, "CompletionUsage"]: +) -> Tuple[np.ndarray, Dict]: """Call OpenAI's Chat Completion API. Parameters @@ -322,7 +321,7 @@ async def generate_chat( results = np.array([responses.choices[i].message.content for i in range(config.n)]) - return results, responses.usage + return results, responses.usage.model_dump() openai = OpenAI