Skip to content

Commit

Permalink
Refactor away a separate LLM class
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyannn committed Dec 10, 2023
1 parent d3bb0e4 commit b71e650
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 135 deletions.
246 changes: 111 additions & 135 deletions de_wiki_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
import logging
import os
import random
import re

import dotenv
import openai
import tiktoken
from datasets import load_dataset, load_from_disk
from openai import OpenAI
from txtai import Embeddings

from llm import LLM

INITIAL_QUESTIONS = [
"How many wives can a man have in Germany?",
"Wer ist ein Schöffe bzw eine Schöffin?",
Expand All @@ -38,7 +38,6 @@
CONTEXT_CHOICES = 20
MODEL = "pulze"
MODEL_CONTEXT_LENGTH = 8192
MODEL_CLAUDE_FIX = "claude" in MODEL or "pulze" in MODEL

MAX_ANSWER_TOKENS = min(4096, MODEL_CONTEXT_LENGTH)

Expand Down Expand Up @@ -79,29 +78,9 @@ def build_context(context_chunks):
)


def claude_prompt_fix(prompt):
"""This seems to give better results for Anthropic models"""
return (
prompt
if not MODEL_CLAUDE_FIX
else f"""
Human:
{prompt}
Please output your answer within <answer></answer> tags.
Assistant: <answer>"""
)


def context_rescoring_prompt(query, context_chunks):
"""Prepare a rescoring prompt for context chunks"""
return claude_prompt_fix(
f"""
return f"""
You are part of a text retrieval engine for German language. Your goal is to check whether the context, retrieved from the vector database, is helpful when answering the query asked.
The query: {query}
Expand All @@ -111,7 +90,6 @@ def context_rescoring_prompt(query, context_chunks):
Provide the list of ids of context pieces that help answer the question posed, in the JSON format. Do not give any other output. Do not add any ticks or other symbols around JSON. Example output:
[76, 23, 32344123]"""
)


def question_prompt(query, context_string=None):
Expand All @@ -126,137 +104,133 @@ def question_prompt(query, context_string=None):
"""
)

return claude_prompt_fix(
f"""You are a question-answer engine who takes great care to provide the most accurate answer.
return f"""You are a question-answer engine who takes great care to provide the most accurate answer.
Answer the following question in German to the best of your ability: {query}
Aim at several paragraphs that show clear and reasoned thinking.
{context_query}
"""
)


def run_loop(client, data, embeddings, question):
"""Run an interactive loop to test the context retrieval"""
def get_context_ids(
llm: LLM,
question: str,
data: dict[int:str],
embeddings: Embeddings,
) -> (list[int], list[int]):
"""
:param llm: The language model abstraction used for completion.
:param question: The question for which we want to find the context.
:param data: Chunks of context within which we look for context.
:param embeddings: The embeddings of data.
:return: A tuple containing suggested context IDs and IDs rejected when scoring.
This method searches for context IDs within the provided embeddings based on the given question.
It then performs a rescore with the language model.
If any invented (hallucinated) IDs are found, they are logged.
Finally, the method returns the accepted and rejected IDs as a tuple or a (None, None) pair
"""
ids_scores = embeddings.search(question, limit=CONTEXT_CHOICES)
for row_id, score in ids_scores:
logging.debug(score, data[row_id])

while True:
rescoring_prompt = context_rescoring_prompt(
question,
(data[row_id] for row_id, _ in ids_scores),
)
prompt_length = len(llm.encoding.encode(rescoring_prompt))
logging.debug(rescoring_prompt)
if prompt_length <= MODEL_CONTEXT_LENGTH:
break
ids_scores = ids_scores[: len(ids_scores) // 2]

try:
encoding = tiktoken.encoding_for_model(MODEL)
except KeyError:
encoding = tiktoken.encoding_for_model("gpt-4")

def complete(prompt, output_json: bool = False):
response_content = (
client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt,
}
],
model=MODEL,
# This parameter is not supported by Pulze
response_format=("json_object" if output_json else "text"),
max_tokens=MAX_ANSWER_TOKENS,
completion = llm.answer(rescoring_prompt, output_json=True)
except openai.BadRequestError as e:
logging.error("API wasn't happy: %s", e)
else:
try:
if completion[0] == "[" or completion[0].isdigit():
accepted_id_string = completion
else:
# While ChatGPT mostly correctly returns only the ids in JSON format,
# some other models may add text before and after the chunk id list.
accepted_id_string = next(
s
for s in completion.split("\n")
if s
and all(
all(ch.isdigit() or ch in "[]," for ch in sub)
for sub in s.split()
)
)

if "], [" in accepted_id_string:
# Another output format bug with Claude
accepted_id_string = accepted_id_string.replace("], [", ", ")

try:
returned_ids = json.loads(accepted_id_string)
assert isinstance(returned_ids, list) and all(
isinstance(i, int) for i in returned_ids
)
except (AssertionError, json.JSONDecodeError):
returned_ids = [int(s) for s in accepted_id_string.split()]

assert isinstance(returned_ids, list) and all(
isinstance(i, int) for i in returned_ids
)
.choices[0]
.message.content
)

# Sometimes we get "bla bla bla <answer>good stuff</answer> bla bla bla"
# Sometimes we get "bla bla bla: good stuff</answer>"
if "<answer>" not in response_content:
return response_content.removesuffix("</answer>")
return re.search(r"<answer>(.*?)</answer>", response_content).group(1)
if invented_ids := set(returned_ids) - {row_id for row_id, _ in ids_scores}:
logging.info(
f"The model invented following context IDs: {invented_ids}"
)

def format_chunk(chunk_id):
return f"""{chunk_id} [{data[chunk_id]["title"]}] {data[chunk_id]["text"]}"""
accepted_ids = [
row_id for row_id in returned_ids if row_id not in invented_ids
]

while question:
logging.info("Answering '%s'", question)
rejected_ids = set(cid for cid, _ in ids_scores) - set(accepted_ids)

ids_scores = embeddings.search(question, limit=CONTEXT_CHOICES)
for row_id, score in ids_scores:
logging.debug(score, data[row_id])
return accepted_ids, rejected_ids

while True:
rescoring_prompt = context_rescoring_prompt(
question,
(data[row_id] for row_id, _ in ids_scores),
except (ValueError, AssertionError, StopIteration):
logging.warning(
"Received a response to '%s' that I cannot parse: '%s'",
rescoring_prompt,
completion,
)
prompt_length = len(encoding.encode(rescoring_prompt))
logging.debug(rescoring_prompt)
if prompt_length <= MODEL_CONTEXT_LENGTH:
break
ids_scores = ids_scores[: len(ids_scores) // 2]

try:
completion = complete(rescoring_prompt, output_json=True)
except openai.BadRequestError as e:
logging.error("API wasn't happy: %s", e)
else:
try:
if completion[0] == "[" or completion[0].isdigit():
accepted_id_string = completion
else:
# While ChatGPT mostly correctly returns only the ids in JSON format,
# some other models may add text before and after the chunk id list.
accepted_id_string = next(
s
for s in completion.split("\n")
if s
and all(
all(ch.isdigit() or ch in "[]," for ch in sub)
for sub in s.split()
)
)

if "], [" in accepted_id_string:
# Another output format bug with Claude
accepted_id_string = accepted_id_string.replace("], [", ", ")

try:
returned_ids = json.loads(accepted_id_string)
assert isinstance(returned_ids, list) and all(
isinstance(i, int) for i in returned_ids
)
except (AssertionError, json.JSONDecodeError):
returned_ids = [int(s) for s in accepted_id_string.split()]
return [], []

assert isinstance(returned_ids, list) and all(
isinstance(i, int) for i in returned_ids
)

if invented_ids := set(returned_ids) - {
row_id for row_id, _ in ids_scores
}:
logging.info(
f"The model invented following context IDs: {invented_ids}"
)
def run_loop(llm: LLM, data, embeddings, question):
"""Run an interactive loop to test the context retrieval"""

def format_chunk(chunk_id):
return f"""{chunk_id} [{data[chunk_id]["title"]}] {data[chunk_id]["text"]}"""

while question:
logging.info("Answering '%s'", question)

print("---- Accepted ----")
accepted_ids = [
row_id for row_id in returned_ids if row_id not in invented_ids
]
for cid in accepted_ids:
print(format_chunk(cid))
context_ids, rejected_ids = get_context_ids(llm, question, data, embeddings)

print("---- Rejected ----")
rejected_ids = set(cid for cid, _ in ids_scores) - set(accepted_ids)
for cid in rejected_ids:
print(format_chunk(cid))
if context_ids:
print("---- Accepted ----")
for cid in context_ids:
print(format_chunk(cid))

context = build_context(data[cid] for cid in accepted_ids)
print("---- Rejected ----")
for cid in rejected_ids:
print(format_chunk(cid))

print("---- Without context ----")
print(complete(question_prompt(question)))
context = build_context(data[cid] for cid in context_ids)

print("---- With context ----")
print(complete(question_prompt(question, context)))
print("---- Without context ----")
print(llm.answer(question_prompt(question)))

except (ValueError, AssertionError, StopIteration):
logging.warning(
"Received a response to '%s' that I cannot parse: '%s'",
rescoring_prompt,
completion,
)
print("---- With context ----")
print(llm.answer(question_prompt(question, context)))

question = input("---- Question: ")

Expand All @@ -265,8 +239,10 @@ def format_chunk(chunk_id):
logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)

env = dotenv.dotenv_values()
client_ = OpenAI(api_key=env["PULZE_API_KEY"], base_url="https://api.pulze.ai/v1")
client = OpenAI(api_key=env["PULZE_API_KEY"], base_url="https://api.pulze.ai/v1")
llm_ = LLM(client, MODEL, MAX_ANSWER_TOKENS)

data_, embeddings_ = load_data_embeddings()

initial_question = random.choice(INITIAL_QUESTIONS)
run_loop(client_, data_, embeddings_, initial_question)
run_loop(llm_, data_, embeddings_, initial_question)
63 changes: 63 additions & 0 deletions llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import re
import tiktoken


class LLM:
def __init__(self, client, model_name, max_answer_tokens):
self.client = client
self.model_name = model_name
self.max_answer_tokens = max_answer_tokens
self.use_claude_fix = "claude" in model_name or "pulze" in model_name

try:
self.encoding = tiktoken.encoding_for_model(self.model_name)
except KeyError:
self.encoding = tiktoken.encoding_for_model("gpt-4")

def claude_prompt_fix(self, prompt):
"""This seems to give better results for Anthropic models"""
return (
prompt
if not self.use_claude_fix
else f"""
Human:
{prompt}
Please output your answer within <answer></answer> tags.
Assistant: <answer>"""
)

def answer(self, prompt, output_json: bool = False):
"""Ask LLM and parse the answer.
:param prompt: The prompt for generating the answer.
:param output_json: A boolean indicating whether the response should be returned as JSON. Default is False.
:return: The generated answer.
"""
response_content = (
self.client.chat.completions.create(
messages=[
{
"role": "user",
"content": self.claude_prompt_fix(prompt),
}
],
model=self.model_name,
# This parameter is not supported by Pulze
response_format=("json_object" if output_json else "text"),
max_tokens=self.max_answer_tokens,
)
.choices[0]
.message.content
)

# Sometimes we get "bla bla bla <answer>good stuff</answer> bla bla bla"
# Sometimes we get "bla bla bla: good stuff</answer>"
if "<answer>" not in response_content:
return response_content.removesuffix("</answer>")
return re.search(r"<answer>(.*?)</answer>", response_content, ).group(1)

0 comments on commit b71e650

Please sign in to comment.