Skip to content

Commit

Permalink
Make the script nicer
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyannn committed Nov 27, 2023
1 parent 1698118 commit 1108b35
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 77 deletions.
93 changes: 56 additions & 37 deletions de_wiki_context.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,47 @@
#!/usr/bin/env python
"""Generate context using German Wikipedia articles.
We take top 10% paragraphs (ranked by article views), embed them using a large
multilingual model, and cache the embeddings. For each question we retrieve the pieces
of context and rescore them using Pulze API. We demonstrate what the answers look with
and without using the context.
"""

import logging
import os
import random

import dotenv
import openai
import tiktoken
import dotenv

from datasets import load_dataset, load_from_disk
from txtai import Embeddings
from openai import OpenAI
from txtai import Embeddings

INITIAL_QUESTIONS = [
"How many wives can a man have in Germany?",
"What are the parties in current German parliament?",
"Who is in the current German government?",
"Wer ist ein Schöffe bzw eine Schöffin?",
"Was waren die deutsch-französischen Beziehungen im 19. Jhd?",
"Why was the Berlin wall built?",
]

DATASET_SOURCE = "Cohere/wikipedia-22-12"
DATASET_PATH = "data/de-wiki-22-12-cohere-by-views"

# much better than the default one for German text
EMBEDDINGS_MODEL = "intfloat/multilingual-e5-large"
EMBEDDINGS_HOW_MANY_K = 1500 # total size of the dataset is 15M embeddings
EMBEDDINGS_HOW_MANY_K = 1500 # total size of the dataset is 15M embeddings
EMBEDDINGS_PATH = f"data/de-wiki-multilingual-e5-large-top-{EMBEDDINGS_HOW_MANY_K}k"

CONTEXT_CHOICES = 20
OPENAI_MODEL = "gpt-3.5-turbo"
OPENAI_MODEL = "pulze"
OPENAI_MODEL_CONTEXT_LENGTH = 8191


def load_data_embeddings():
"""Load and cache the dataset and its embeddings."""
try:
data = load_from_disk(DATASET_PATH, keep_in_memory=True)
logging.info(f"Loaded data of shape {data.shape} from {DATASET_PATH}")
Expand Down Expand Up @@ -53,12 +70,14 @@ def load_data_embeddings():


def build_context(context_chunks):
"""Prepare a context string out of the suggested content chunks"""
return "\n".join(
f"""{c["id"]} (from '{c["title"]}'): {c["text"]}""" for c in context_chunks
)


def context_rescoring_prompt(query, context_chunks):
"""Prepare a rescoring prompt for context chunks"""
return f"""
You are part of a text retrieval engine for German language. Your goal is to check whether
the context, retrieved from the vector database, is helpful when answering the
Expand All @@ -75,32 +94,42 @@ def context_rescoring_prompt(query, context_chunks):


def question_prompt(query, context_string=None):
"""Prepare a question prompt that optionally includes a context"""
return f"""
You are a question-answer engine who takes great care to provide the most accurate answer.
Answer the following question in German to the best of your ability: {query}
Aim at several paragraphs that show clear and reasoned thinking.
""" + ("" if not context_string else """
""" + (
""
if not context_string
else """
The following context pieces, taken from recent Wikipedia articles, might be helpful in the answer:
""" + context_string)
"""
+ context_string
)


def run_loop(client, data, embeddings, question):
"""Run an interactive loop to test the context retrieval"""
try:
encoding = tiktoken.encoding_for_model(OPENAI_MODEL)
except KeyError:
encoding = tiktoken.encoding_for_model('gpt-4')

def complete(prompt):
return client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt,
}
],
model=OPENAI_MODEL,
).choices[0].message.content
encoding = tiktoken.encoding_for_model("gpt-4")

def complete(prompt):
return (
client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt,
}
],
model=OPENAI_MODEL,
)
.choices[0]
.message.content
)

def format_chunck(chunk_id):
return f"""{chunk_id} [{data[chunk_id]["title"]}] {data[chunk_id]["text"]}"""
Expand Down Expand Up @@ -138,7 +167,6 @@ def format_chunck(chunk_id):
for cid in rejected_ids:
print(format_chunck(cid))

# print("---- Context: ----")
context = build_context(data[cid] for cid in accepted_ids)

print("---- Without context ----")
Expand All @@ -148,28 +176,19 @@ def format_chunck(chunk_id):
print(complete(question_prompt(question, context)))

except ValueError:
logging.warning("Received a response that I cannot parse: %s", completion)

logging.warning(
"Received a response that I cannot parse: %s", completion
)

question = input("Question: ")


if __name__ == "__main__":
env = dotenv.dotenv_values()
logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)

env = dotenv.dotenv_values()
client_ = OpenAI(api_key=env["PULZE_API_KEY"], base_url="https://api.pulze.ai/v1")

logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)
data_, embeddings_ = load_data_embeddings()
run_loop(
client_,
data_,
embeddings_,
question=
# "How many wives can a man have in Germany?"
# "What are the parties in current German parliament?"
# "Who is in the current German government?"
# "Wer ist ein Schöffe bzw eine Schöffin?"
"Was waren die deutsch-französischen Beziehungen im 19. Jhd?",
# "Why was the Berlin wall built?"
)

initial_question = random.choice(INITIAL_QUESTIONS)
run_loop(client_, data_, embeddings_, initial_question)
40 changes: 0 additions & 40 deletions example_embeddings.py

This file was deleted.

0 comments on commit 1108b35

Please sign in to comment.