diff --git a/outlines/models/openai.py b/outlines/models/openai.py index dd1c73b24..16126acb0 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: from openai import AsyncOpenAI + from openai.types.completion_usage import CompletionUsage @dataclass(frozen=True) @@ -281,23 +282,19 @@ async def generate_chat( system_prompt: Union[str, None], client: "AsyncOpenAI", config: OpenAIConfig, -) -> Tuple[np.ndarray, int]: +) -> Tuple[np.ndarray, "CompletionUsage"]: + system_message = ( + [{"role": "system", "content": system_prompt}] if system_prompt else [] + ) + responses = await client.chat.completions.create( - messages=( - [{"role": "system", "content": system_prompt}] if system_prompt else [] - ) - + [{"role": "user", "content": prompt}], + messages=system_message + [{"role": "user", "content": prompt}], **asdict(config), # type: ignore ) - if config.n == 1: - results = np.array([responses.choices[0].message.content]) - else: - results = np.array( - [responses.choices[i].message.content for i in range(config.n)] - ) + results = np.array([responses.choices[i].message.content for i in range(config.n)]) - return results, responses.usage.total_tokens + return results, responses.usage openai = OpenAI diff --git a/pyproject.toml b/pyproject.toml index 87493cb03..48f3e9dd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,7 @@ module = [ "jinja2", "joblib.*", "jsonschema.*", - "openai", + "openai.*", "nest_asyncio", "numpy.*", "perscache.*",