Skip to content

Commit

Permalink
Return CompletionUsage from API
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Dec 4, 2023
1 parent f94b6ce commit 8a1430f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
21 changes: 9 additions & 12 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

if TYPE_CHECKING:
from openai import AsyncOpenAI
from openai.types.completion_usage import CompletionUsage


@dataclass(frozen=True)
Expand Down Expand Up @@ -281,23 +282,19 @@ async def generate_chat(
system_prompt: Union[str, None],
client: "AsyncOpenAI",
config: OpenAIConfig,
) -> Tuple[np.ndarray, int]:
) -> Tuple[np.ndarray, "CompletionUsage"]:
system_message = (
[{"role": "system", "content": system_prompt}] if system_prompt else []
)

responses = await client.chat.completions.create(
messages=(
[{"role": "system", "content": system_prompt}] if system_prompt else []
)
+ [{"role": "user", "content": prompt}],
messages=system_message + [{"role": "user", "content": prompt}],
**asdict(config), # type: ignore
)

if config.n == 1:
results = np.array([responses.choices[0].message.content])
else:
results = np.array(
[responses.choices[i].message.content for i in range(config.n)]
)
results = np.array([responses.choices[i].message.content for i in range(config.n)])

return results, responses.usage.total_tokens
return results, responses.usage


openai = OpenAI
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ module = [
"jinja2",
"joblib.*",
"jsonschema.*",
"openai",
"openai.*",
"nest_asyncio",
"numpy.*",
"perscache.*",
Expand Down

0 comments on commit 8a1430f

Please sign in to comment.