Skip to content

Commit

Permalink
Improve output format for ChatGPT and Claude 2
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyannn committed Dec 5, 2023
1 parent a2b78ee commit 53fa684
Showing 1 changed file with 54 additions and 26 deletions.
80 changes: 54 additions & 26 deletions de_wiki_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
For each question, we retrieve the pieces of context and rescore them using Pulze API.
We demonstrate what the answers look with and without using the context.
"""

import json
import logging
import os
import random
from json import JSONDecodeError

import dotenv
import openai
Expand All @@ -35,7 +36,7 @@
EMBEDDINGS_PATH = f"data/de-wiki-multilingual-e5-large-top-{EMBEDDINGS_HOW_MANY_K}k"

CONTEXT_CHOICES = 20
MODEL = "pulze"
MODEL = "anthropic/claude-2"
MODEL_CONTEXT_LENGTH = 8192

MAX_ANSWER_TOKENS = min(4096, MODEL_CONTEXT_LENGTH)
Expand Down Expand Up @@ -80,34 +81,44 @@ def build_context(context_chunks):
def context_rescoring_prompt(query, context_chunks):
"""Prepare a rescoring prompt for context chunks"""
return f"""
You are part of a text retrieval engine for German language. Your goal is to check whether
the context, retrieved from the vector database, is helpful when answering the
query asked.
Human:
You are part of a text retrieval engine for German language. Your goal is to check whether the context, retrieved from the vector database, is helpful when answering the query asked.
The query: {query}
Context pieces, taken from Wikipedia articles, that you need to check:
{build_context(context_chunks)}
Provide the list of ids of context pieces that help answer the question posed,
separated by space. Do not give any other output. Example: 7682345 23876423 324123
"""
Provide the list of ids of context pieces that help answer the question posed, in the JSON format. Do not give any other output. Example output:
[76, 23, 32344123]
Please output your answer within <answer></answer> tags.
Assistant: <answer>"""


def question_prompt(query, context_string=None):
"""Prepare a question prompt that optionally includes a context"""
return f"""
Human:
You are a question-answer engine who takes great care to provide the most accurate answer.
Answer the following question in German to the best of your ability: {query}
Aim at several paragraphs that show clear and reasoned thinking.
""" + (
""
if not context_string
else """
else f"""
The following context pieces, taken from recent Wikipedia articles, might be helpful in the answer:
"""
+ context_string
)
{context_string}
""" ) + """
Please output your answer within <answer></answer> tags.
Assistant: <answer>"""


def run_loop(client, data, embeddings, question):
Expand All @@ -117,7 +128,7 @@ def run_loop(client, data, embeddings, question):
except KeyError:
encoding = tiktoken.encoding_for_model("gpt-4")

def complete(prompt):
def complete(prompt, output_json: bool=False):
return (
client.chat.completions.create(
messages=[
Expand All @@ -127,15 +138,18 @@ def complete(prompt):
}
],
model=MODEL,
response_format=("json_object" if output_json else "text"),
max_tokens=MAX_ANSWER_TOKENS,
)
.choices[0]
.message.content
)
).removesuffix("</answer>")

def format_chunk(chunk_id):
return f"""{chunk_id} [{data[chunk_id]["title"]}] {data[chunk_id]["text"]}"""

decode_json = json.JSONDecoder(strict=True)

while question:
logging.info("Answering '%s'", question)

Expand All @@ -145,7 +159,7 @@ def format_chunk(chunk_id):

while True:
rescoring_prompt = context_rescoring_prompt(
question, (data[row_id] for row_id, _ in ids_scores)
question, (data[row_id] for row_id, _ in ids_scores),
)
prompt_length = len(encoding.encode(rescoring_prompt))
logging.debug(rescoring_prompt)
Expand All @@ -154,21 +168,35 @@ def format_chunk(chunk_id):
ids_scores = ids_scores[: len(ids_scores) // 2]

try:
completion = complete(rescoring_prompt)
completion = complete(rescoring_prompt, output_json=True)
except openai.BadRequestError as e:
logging.error("API wasn't happy: %s", e)
else:
try:
# While ChatGPT correctly returned only the ids of accepted chunks,
# other models may add text before or after the chunk id list.
accepted_id_string = next(
s
for s in completion.split("\n")
if s and all(all(ch.isdigit() for ch in sub) for sub in s.split())
)
if completion[0] == '[' or completion[0].isdigit():
accepted_id_string = completion
else:
# While ChatGPT correctly returns only the ids of accepted chunks in JSON format,
# other models may add text before the chunk id list.
accepted_id_string = next(
s
for s in completion.split("\n")
if s and all(all(ch.isdigit() or ch in "[]," for ch in sub) for sub in s.split())
)

try:
returned_ids = json.loads(accepted_id_string)
assert isinstance(returned_ids, list) and all(isinstance(i, int) for i in returned_ids)
except (AssertionError, json.JSONDecodeError):
returned_ids = [int(s) for s in accepted_id_string.split()]

assert isinstance(returned_ids, list) and all(isinstance(i, int) for i in returned_ids)

if invented_ids := set(returned_ids) - {row_id for row_id, _ in ids_scores}:
logging.info(f"The model invented following context IDs: {invented_ids}")

print("---- Accepted ----")

accepted_ids = [int(s) for s in accepted_id_string.split()]
accepted_ids = [row_id for row_id in returned_ids if row_id not in invented_ids]
for cid in accepted_ids:
print(format_chunk(cid))

Expand All @@ -185,7 +213,7 @@ def format_chunk(chunk_id):
print("---- With context ----")
print(complete(question_prompt(question, context)))

except ValueError:
except (ValueError, AssertionError, StopIteration):
logging.warning(
"Received a response to '%s' that I cannot parse: '%s'",
rescoring_prompt,
Expand Down

0 comments on commit 53fa684

Please sign in to comment.