Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Optimistic" is_in generation for openai. #378

Merged
merged 1 commit into from
Nov 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 97 additions & 19 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Integration with OpenAI's API."""
import functools
import os
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
from collections import deque
from itertools import zip_longest
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -85,6 +87,39 @@ async def generate_base(

return results

def longest_common_prefix(tokens1: List[int], tokens2: List[int]) -> List[int]:
i = 0
while i < len(tokens1) and i < len(tokens2) and tokens1[i] == tokens2[i]:
i += 1
return tokens1[:i]

def get_choices_with_longest_common_prefix(
response: List[int], is_in: List[List[int]]
) -> Tuple[List[int], List[List[int]]]:
max_len_prefix = 0
is_in_left = []
prefix = []
for i in range(len(is_in)):
len_prefix = len(longest_common_prefix(response, is_in[i]))

if len_prefix > max_len_prefix:
max_len_prefix = len_prefix
is_in_left = [is_in[i][len_prefix:]]
prefix = is_in[i][:len_prefix]

elif len_prefix == max_len_prefix:
is_in_left.append(is_in[i][len_prefix:])

return prefix, is_in_left

def build_optimistic_mask(transposed: deque[Set]) -> Dict:
# build the biggest mask possible, adding tokens left to right
to_mask: Set[int] = set()
while len(transposed) > 0 and len(to_mask | transposed[0]) <= 300:
to_mask = to_mask | transposed.popleft()

return {token: 100 for token in to_mask}

@functools.partial(outlines.vectorize, signature="(),(m),()->(s)")
async def generate_choice(
prompt: str,
Expand All @@ -95,12 +130,11 @@ async def generate_choice(

.. warning::

This function will call the API once for every token generated.
Worst case, this function may call the API as many times as tokens are in the response.

We tokenize every choice, iterate over the token lists, create a mask
with the current tokens and generate one token. We progressively
eliminate the choices that don't start with the currently decoded
sequence.
With the optimistic approach, we activate all tokens that could form all answers. If the solution returned
does not match any of the answers, we the call the API again only with the tokens that can be accepted as
next-token. In average, this approach returns a solution consuming less calls to the API.

"""
try:
Expand All @@ -111,36 +145,80 @@ async def generate_choice(
)

tokenizer = tiktoken.encoding_for_model(model_name)
encoded: List[List[int]] = [tokenizer.encode(word) for word in is_in]

decoded_samples = []
for _ in range(samples):
is_in_left = is_in.copy()
decoded: List[str] = []
for i in range(max([len(word) for word in encoded])):
mask = {}
for word, tokenized_word in zip(is_in, encoded):
if not word.startswith("".join(decoded)):
continue
try:
mask[tokenized_word[i]] = 100
except IndexError:
pass

greedy = False # we try to generate the full response at each iteration
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might as well remove the greedy case altogether

Copy link
Contributor Author

@HerrIvan HerrIvan Nov 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say you have two tokenized choices [T1, TX], [T2, TX]. You would unmask the three tokens and request 2 tokens from the API. If the system is very biased towards TX, it may answer [TX, TX], and you have gained nothing. Repeating the same call you would enter in a loop. Thus, it makes sense to run a greedy step next with just T1 and T2 unmasked, such that you are guaranteed to move one token forward.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this makes sense.


while len(is_in_left) > 0:
encoded: List[List[int]] = [
tokenizer.encode(word) for word in is_in_left
]

max_tokens_left = max([len(tokens) for tokens in encoded])
transposed: deque[Set] = deque(
[
{item for item in subset if item is not None}
for subset in zip_longest(*encoded)
]
)

if not greedy:
mask = build_optimistic_mask(transposed)
else:
mask = {}
for token in transposed.popleft(): # build greedy mask
mask[token] = 100

if len(mask) == 0:
break

response = await call_api(
model_name,
format_prompt(prompt),
1,
max_tokens_left if not greedy else 1,
temperature,
[],
mask,
1,
)
decoded.append(extract_choice(response["choices"][0]))

prompt = prompt + "".join(decoded)
current_resp = extract_choice(response["choices"][0])

if current_resp in is_in_left:
decoded.append(current_resp)
break
else:
# map response to tokens
tokenized_resp = tokenizer.encode(current_resp)
(
tokenized_resp,
encoded,
) = get_choices_with_longest_common_prefix(
tokenized_resp, encoded
)

if len(tokenized_resp) == 0:
greedy = True # next iteration will be "greedy"
continue
else:
decoded.append("".join(tokenizer.decode(tokenized_resp)))

# map back to words
is_in_left = [
"".join(tokenizer.decode(tokens)) for tokens in encoded
]

if len(is_in_left) == 1: # only one choice left
decoded.append(is_in_left[0])
break

greedy = False # after each success, stay with (or switch to) "optimistic" approach

prompt = prompt + "".join(decoded)

decoded_samples.append("".join(decoded))

Expand Down