Skip to content

Commit

Permalink
return dict instead of pydantic object in generate chat
Browse files Browse the repository at this point in the history
  • Loading branch information
HerrIvan committed Dec 5, 2023
1 parent 4fe5fa9 commit 3356d52
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

if TYPE_CHECKING:
from openai import AsyncOpenAI
from openai.types.completion_usage import CompletionUsage


@dataclass(frozen=True)
Expand Down Expand Up @@ -177,8 +176,8 @@ def __call__(
response, usage = generate_chat(
prompt, self.system_prompt, self.client, config
)
self.prompt_tokens += usage.prompt_tokens
self.completion_tokens += usage.completion_tokens
self.prompt_tokens += usage["prompt_tokens"]
self.completion_tokens += usage["completion_tokens"]

return response

Expand Down Expand Up @@ -236,8 +235,8 @@ def generate_choice(
response, usage = generate_chat(
prompt, self.system_prompt, self.client, config
)
self.completion_tokens += usage.completion_tokens
self.prompt_tokens += usage.completion_tokens
self.completion_tokens += usage["completion_tokens"]
self.prompt_tokens += usage["prompt_tokens"]

encoded_response = tokenizer.encode(response)

Expand Down Expand Up @@ -289,7 +288,7 @@ async def generate_chat(
system_prompt: Union[str, None],
client: "AsyncOpenAI",
config: OpenAIConfig,
) -> Tuple[np.ndarray, "CompletionUsage"]:
) -> Tuple[np.ndarray, Dict]:
"""Call OpenAI's Chat Completion API.
Parameters
Expand Down Expand Up @@ -322,7 +321,7 @@ async def generate_chat(

results = np.array([responses.choices[i].message.content for i in range(config.n)])

return results, responses.usage
return results, responses.usage.model_dump()


openai = OpenAI
Expand Down

0 comments on commit 3356d52

Please sign in to comment.