diff --git a/de_wiki_context.py b/de_wiki_context.py index 0defe91..cccd6fe 100755 --- a/de_wiki_context.py +++ b/de_wiki_context.py @@ -1,30 +1,47 @@ #!/usr/bin/env python +"""Generate context using German Wikipedia articles. + +We take top 10% paragraphs (ranked by article views), embed them using a large +multilingual model, and cache the embeddings. For each question we retrieve the pieces +of context and rescore them using Pulze API. We demonstrate what the answers look with +and without using the context. + """ import logging import os +import random +import dotenv import openai import tiktoken -import dotenv - from datasets import load_dataset, load_from_disk -from txtai import Embeddings from openai import OpenAI +from txtai import Embeddings + +INITIAL_QUESTIONS = [ + "How many wives can a man have in Germany?", + "What are the parties in current German parliament?", + "Who is in the current German government?", + "Wer ist ein Schöffe bzw eine Schöffin?", + "Was waren die deutsch-französischen Beziehungen im 19. Jhd?", + "Why was the Berlin wall built?", +] DATASET_SOURCE = "Cohere/wikipedia-22-12" DATASET_PATH = "data/de-wiki-22-12-cohere-by-views" # much better than the default one for German text EMBEDDINGS_MODEL = "intfloat/multilingual-e5-large" -EMBEDDINGS_HOW_MANY_K = 1500 # total size of the dataset is 15M embeddings +EMBEDDINGS_HOW_MANY_K = 1500 # total size of the dataset is 15M embeddings EMBEDDINGS_PATH = f"data/de-wiki-multilingual-e5-large-top-{EMBEDDINGS_HOW_MANY_K}k" CONTEXT_CHOICES = 20 -OPENAI_MODEL = "gpt-3.5-turbo" +OPENAI_MODEL = "pulze" OPENAI_MODEL_CONTEXT_LENGTH = 8191 def load_data_embeddings(): + """Load and cache the dataset and its embeddings.""" try: data = load_from_disk(DATASET_PATH, keep_in_memory=True) logging.info(f"Loaded data of shape {data.shape} from {DATASET_PATH}") @@ -53,12 +70,14 @@ def load_data_embeddings(): def build_context(context_chunks): + """Prepare a context string out of the suggested content chunks""" return "\n".join( f"""{c["id"]} (from '{c["title"]}'): {c["text"]}""" for c in context_chunks ) def context_rescoring_prompt(query, context_chunks): + """Prepare a rescoring prompt for context chunks""" return f""" You are part of a text retrieval engine for German language. Your goal is to check whether the context, retrieved from the vector database, is helpful when answering the @@ -75,32 +94,42 @@ def context_rescoring_prompt(query, context_chunks): def question_prompt(query, context_string=None): + """Prepare a question prompt that optionally includes a context""" return f""" You are a question-answer engine who takes great care to provide the most accurate answer. Answer the following question in German to the best of your ability: {query} Aim at several paragraphs that show clear and reasoned thinking. - """ + ("" if not context_string else """ + """ + ( + "" + if not context_string + else """ The following context pieces, taken from recent Wikipedia articles, might be helpful in the answer: -""" + context_string) +""" + + context_string + ) def run_loop(client, data, embeddings, question): + """Run an interactive loop to test the context retrieval""" try: encoding = tiktoken.encoding_for_model(OPENAI_MODEL) except KeyError: - encoding = tiktoken.encoding_for_model('gpt-4') - - def complete(prompt): - return client.chat.completions.create( - messages=[ - { - "role": "user", - "content": prompt, - } - ], - model=OPENAI_MODEL, - ).choices[0].message.content + encoding = tiktoken.encoding_for_model("gpt-4") + def complete(prompt): + return ( + client.chat.completions.create( + messages=[ + { + "role": "user", + "content": prompt, + } + ], + model=OPENAI_MODEL, + ) + .choices[0] + .message.content + ) def format_chunck(chunk_id): return f"""{chunk_id} [{data[chunk_id]["title"]}] {data[chunk_id]["text"]}""" @@ -138,7 +167,6 @@ def format_chunck(chunk_id): for cid in rejected_ids: print(format_chunck(cid)) - # print("---- Context: ----") context = build_context(data[cid] for cid in accepted_ids) print("---- Without context ----") @@ -148,28 +176,19 @@ def format_chunck(chunk_id): print(complete(question_prompt(question, context))) except ValueError: - logging.warning("Received a response that I cannot parse: %s", completion) - + logging.warning( + "Received a response that I cannot parse: %s", completion + ) question = input("Question: ") if __name__ == "__main__": - env = dotenv.dotenv_values() + logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO) + env = dotenv.dotenv_values() client_ = OpenAI(api_key=env["PULZE_API_KEY"], base_url="https://api.pulze.ai/v1") - - logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO) data_, embeddings_ = load_data_embeddings() - run_loop( - client_, - data_, - embeddings_, - question= - # "How many wives can a man have in Germany?" - # "What are the parties in current German parliament?" - # "Who is in the current German government?" - # "Wer ist ein Schöffe bzw eine Schöffin?" - "Was waren die deutsch-französischen Beziehungen im 19. Jhd?", - # "Why was the Berlin wall built?" - ) + + initial_question = random.choice(INITIAL_QUESTIONS) + run_loop(client_, data_, embeddings_, initial_question) diff --git a/example_embeddings.py b/example_embeddings.py deleted file mode 100644 index f537a7d..0000000 --- a/example_embeddings.py +++ /dev/null @@ -1,40 +0,0 @@ -from txtai.embeddings import Embeddings - -# Create embeddings model, backed by sentence-transformers & transformers -embeddings = Embeddings(path="sentence-transformers/nli-mpnet-base-v2") - -data = [ - "US tops 5 million confirmed virus cases", - "Canada's last fully intact ice shelf has suddenly collapsed, " - + "forming a Manhattan-sized iceberg", - "Beijing mobilises invasion craft along coast as Taiwan tensions escalate", - "The National Park Service warns against sacrificing slower friends " - + "in a bear attack", - "Maine man wins $1M from $25 lottery ticket", - "Make huge profits without work, earn up to $100,000 a day", -] - -# Index the list of text -embeddings.index(data) - -print(f"{'Query':20} Best Match") -print("-" * 50) - -# Run an embeddings search for each query -for query in ( - "feel good story", - "climate change", - "public health story", - "war", - "wildlife", - "asia", - "lucky", - "dishonest junk", - "exemplary journalism", -): - # Extract uid of first result - # search result format: (uid, score) - uid = embeddings.search(query, 1)[0][0] - - # Print text - print(f"{query:20} {data[uid]}")