From 8a1430f2e7576788876c2722970afa02edbfeaae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 4 Dec 2023 15:59:50 +0100 Subject: [PATCH] Return `CompletionUsage` from API --- outlines/models/openai.py | 21 +++++++++------------ pyproject.toml | 2 +- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/outlines/models/openai.py b/outlines/models/openai.py index dd1c73b24..16126acb0 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: from openai import AsyncOpenAI + from openai.types.completion_usage import CompletionUsage @dataclass(frozen=True) @@ -281,23 +282,19 @@ async def generate_chat( system_prompt: Union[str, None], client: "AsyncOpenAI", config: OpenAIConfig, -) -> Tuple[np.ndarray, int]: +) -> Tuple[np.ndarray, "CompletionUsage"]: + system_message = ( + [{"role": "system", "content": system_prompt}] if system_prompt else [] + ) + responses = await client.chat.completions.create( - messages=( - [{"role": "system", "content": system_prompt}] if system_prompt else [] - ) - + [{"role": "user", "content": prompt}], + messages=system_message + [{"role": "user", "content": prompt}], **asdict(config), # type: ignore ) - if config.n == 1: - results = np.array([responses.choices[0].message.content]) - else: - results = np.array( - [responses.choices[i].message.content for i in range(config.n)] - ) + results = np.array([responses.choices[i].message.content for i in range(config.n)]) - return results, responses.usage.total_tokens + return results, responses.usage openai = OpenAI diff --git a/pyproject.toml b/pyproject.toml index 87493cb03..48f3e9dd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,7 @@ module = [ "jinja2", "joblib.*", "jsonschema.*", - "openai", + "openai.*", "nest_asyncio", "numpy.*", "perscache.*",