Skip to content

Commit

Permalink
Miscellaneous updates for different models
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyannn committed Nov 29, 2023
1 parent 976d970 commit 5d28bb8
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions de_wiki_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,25 @@

INITIAL_QUESTIONS = [
"How many wives can a man have in Germany?",
"What are the parties in current German parliament?",
"Who is in the current German government?",
"Wer ist ein Schöffe bzw eine Schöffin?",
"Was waren die deutsch-französischen Beziehungen im 19. Jhd?",
"Why was the Berlin wall built?",
"Wer ist das aktuelle Staatsoberhaupt in Deutschland?",
]

DATASET_SOURCE = "Cohere/wikipedia-22-12"
DATASET_PATH = "data/de-wiki-22-12-cohere-by-views"

# much better than the default one for German text
EMBEDDINGS_MODEL = "intfloat/multilingual-e5-large"
EMBEDDINGS_HOW_MANY_K = 1500 # the total size of the dataset is 15m embeddings
EMBEDDINGS_HOW_MANY_K = 1500 # note the total size of the dataset is 15m embeddings
EMBEDDINGS_PATH = f"data/de-wiki-multilingual-e5-large-top-{EMBEDDINGS_HOW_MANY_K}k"

CONTEXT_CHOICES = 20
OPENAI_MODEL = "pulze"
OPENAI_MODEL_CONTEXT_LENGTH = 8191
MODEL = "pulze"
MODEL_CONTEXT_LENGTH = 8192

MAX_ANSWER_TOKENS = min(4096, MODEL_CONTEXT_LENGTH)


def load_data_embeddings():
Expand Down Expand Up @@ -112,7 +113,7 @@ def question_prompt(query, context_string=None):
def run_loop(client, data, embeddings, question):
"""Run an interactive loop to test the context retrieval"""
try:
encoding = tiktoken.encoding_for_model(OPENAI_MODEL)
encoding = tiktoken.encoding_for_model(MODEL)
except KeyError:
encoding = tiktoken.encoding_for_model("gpt-4")

Expand All @@ -125,8 +126,8 @@ def complete(prompt):
"content": prompt,
}
],
model=OPENAI_MODEL,
max_tokens=8192,
model=MODEL,
max_tokens=MAX_ANSWER_TOKENS,
)
.choices[0]
.message.content
Expand All @@ -148,7 +149,7 @@ def format_chunk(chunk_id):
)
prompt_length = len(encoding.encode(rescoring_prompt))
logging.debug(rescoring_prompt)
if prompt_length <= OPENAI_MODEL_CONTEXT_LENGTH:
if prompt_length <= MODEL_CONTEXT_LENGTH:
break
ids_scores = ids_scores[: len(ids_scores) // 2]

Expand All @@ -161,7 +162,7 @@ def format_chunk(chunk_id):
# While ChatGPT correctly returned only the ids of accepted chunks,
# other models may add text before or after the chunk id list.
accepted_id_string = next(
s for s in completion.split() if s and s[0].isdigit()
s for s in completion.split("\n") if s and s[0].isdigit()
)
print("---- Accepted ----")

Expand Down

0 comments on commit 5d28bb8

Please sign in to comment.