diff --git a/outlines/models/openai.py b/outlines/models/openai.py index 7d7ac61c8..520581039 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -79,6 +79,8 @@ def __init__( model_name: str, api_key: Optional[str] = None, max_retries: int = 6, + timeout: Optional[float] = None, + system_prompt: Optional[str] = None, config: Optional[OpenAIConfig] = None, ): """Create an `OpenAI` instance. @@ -93,6 +95,10 @@ def __init__( `openai.api_key`. max_retries The maximum number of retries when calls to the API fail. + timeout + Duration after which the request times out. + system_prompt + The content of the system message that precedes the user's prompt. config An instance of `OpenAIConfig`. Can be useful to specify some parameters that cannot be set by calling this class' methods. @@ -120,7 +126,16 @@ def __init__( else: self.config = OpenAIConfig(model=model_name) - self.client = openai.AsyncOpenAI(api_key=api_key, max_retries=max_retries) + self.client = openai.AsyncOpenAI( + api_key=api_key, max_retries=max_retries, timeout=timeout + ) + self.system_prompt = system_prompt + + # We count the total number of prompt and generated tokens as returned + # by the OpenAI API, summed over all the requests performed with this + # model instance. + self.prompt_tokens = 0 + self.completion_tokens = 0 def __call__( self, @@ -158,7 +173,13 @@ def __call__( ) ) if "gpt-" in self.config.model: - return generate_chat(prompt, self.client, config) + response, usage = generate_chat( + prompt, self.system_prompt, self.client, config + ) + self.prompt_tokens += usage["prompt_tokens"] + self.completion_tokens += usage["completion_tokens"] + + return response def generate_choice( self, prompt: str, choices: List[str], max_tokens: Optional[int] = None @@ -210,7 +231,13 @@ def generate_choice( break config = replace(config, logit_bias=mask, max_tokens=max_tokens_left) - response = generate_chat(prompt, self.client, config) + + response, usage = generate_chat( + prompt, self.system_prompt, self.client, config + ) + self.completion_tokens += usage["completion_tokens"] + self.prompt_tokens += usage["prompt_tokens"] + encoded_response = tokenizer.encode(response) if encoded_response in encoded_choices_left: @@ -255,22 +282,46 @@ def __repr__(self): @cache(ignore="client") -@functools.partial(outlines.vectorize, signature="(),(),()->(s)") +@functools.partial(outlines.vectorize, signature="(),(),(),()->(s),()") async def generate_chat( - prompt: str, client: "AsyncOpenAI", config: OpenAIConfig -) -> np.ndarray: + prompt: str, + system_prompt: Union[str, None], + client: "AsyncOpenAI", + config: OpenAIConfig, +) -> Tuple[np.ndarray, Dict]: + """Call OpenAI's Chat Completion API. + + Parameters + ---------- + prompt + The prompt we use to start the generation. Passed to the model + with the "user" role. + system_prompt + The system prompt, passed to the model with the "system" role + before the prompt. + client + The API client + config + An `OpenAIConfig` instance. + + Returns + ------- + A tuple that contains the model's response(s) and usage statistics. + + """ + system_message = ( + [{"role": "system", "content": system_prompt}] if system_prompt else [] + ) + user_message = [{"role": "user", "content": prompt}] + responses = await client.chat.completions.create( - messages=[{"role": "user", "content": prompt}], **asdict(config) # type: ignore + messages=system_message + user_message, + **asdict(config), # type: ignore ) - if config.n == 1: - results = np.array([responses.choices[0].message.content]) - else: - results = np.array( - [responses.choices[i].message.content for i in range(config.n)] - ) + results = np.array([responses.choices[i].message.content for i in range(config.n)]) - return results + return results, responses.usage.model_dump() openai = OpenAI @@ -292,8 +343,8 @@ def find_response_choices_intersection( choices. Say the response is of the form `[1, 2, 3, 4, 5]` and we have the choices - `[[1, 2], [1, 2, 3], [6, 7, 8]` then the function will return `[1, 2]` as the - intersection, and `[1, 2, 3]` as the choice that is left. + `[[1, 2], [1, 2, 3], [6, 7, 8]` then the function will return `[1, 2, 3]` as the + intersection, and `[[]]` as the list of choices left. Parameters ---------- @@ -305,7 +356,8 @@ def find_response_choices_intersection( Returns ------- A tuple that contains the longest intersection between the response and the - different choices, and the choices which start with this intersection. + different choices, and the choices which start with this intersection, with the + intersection removed. """ max_len_prefix = 0 diff --git a/pyproject.toml b/pyproject.toml index 87493cb03..48f3e9dd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,7 @@ module = [ "jinja2", "joblib.*", "jsonschema.*", - "openai", + "openai.*", "nest_asyncio", "numpy.*", "perscache.*",