Skip to content

Commit

Permalink
Use Langfuse for observability
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyannn committed Jan 9, 2024
1 parent d888f74 commit 0873bd5
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 29 deletions.
99 changes: 73 additions & 26 deletions de_wiki_context.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import logging
import os
import random
from datetime import datetime

import dotenv
import openai
from langfuse import Langfuse
from langfuse.openai import openai
from datasets import load_dataset, load_from_disk
from openai import OpenAI
from txtai import Embeddings
Expand All @@ -37,15 +39,15 @@

CONTEXT_CHOICES = 20

MODEL = "pulze"
MODEL = "gpt-4-1106-preview"
MODEL_CONTEXT_LENGTH = 8192
MAX_ANSWER_TOKENS = min(4096, MODEL_CONTEXT_LENGTH)


class Corpus:
def __init__(
self,
data: dict[int:dict[str:str]],
data: dict[int : dict[str:str]],
embeddings: Embeddings,
):
self.data = data
Expand Down Expand Up @@ -128,57 +130,87 @@ def get_context_ids(
question: str,
corpus: Corpus,
llm: LLM,
trace,
) -> (list[int], list[int]):
"""
:param question: The question for which we want to find the context.
:param corpus: Corpus within which we look for context.
:param llm: The language model abstraction used for completion.
:param trace: Langfuse trace object for observation purposes
:return: A tuple containing suggested context IDs and IDs rejected when scoring.
This method searches for context IDs within the provided embeddings based on the given question.
It then performs a rescore with the language model.
If any invented (hallucinated) IDs are found, they are logged.
Finally, the method returns the accepted and rejected IDs as a tuple or a (None, None) pair
"""
span = trace.span(
name="embedding-search",
metadata={"database": "corpus"},
input={"query": question},
)

ids_scores = corpus.embeddings.search(question, limit=CONTEXT_CHOICES)
span.end(output=ids_scores)

for row_id, score in ids_scores:
logging.debug(score, corpus.data[row_id])

while True:
rescoring_prompt = context_rescoring_prompt(
question,
(corpus.data[row_id] for row_id, _ in ids_scores),
)
rescoring_context = [corpus.data[row_id] for row_id, _ in ids_scores]
rescoring_prompt = context_rescoring_prompt(question, rescoring_context)
prompt_length = len(llm.encoding.encode(rescoring_prompt))
logging.debug(rescoring_prompt)
if prompt_length <= MODEL_CONTEXT_LENGTH:
break
ids_scores = ids_scores[: len(ids_scores) // 2]

try:
completion = llm.answer(rescoring_prompt, output_json=True)
# creates generation
generation = trace.generation(
name="context-rescoring",
model=MODEL,
# model_parameters={"maxTokens": "1000", "temperature": "0.9"},
)

completion = llm.answer(
rescoring_prompt,
output_json=True,
name="de-wiki-context",
metadata={"question": question, "rescoring_context": rescoring_context},
)

generation.end(
output=completion,
)
except openai.BadRequestError as e:
logging.error("API wasn't happy: %s", e)
else:
try:
# While ChatGPT mostly correctly returns only the ids in JSON format,
# some other models may add text before and after the chunk id list.
accepted_id_string = next(
s
for s in completion.split("\n")
if s
and all(
all(ch.isdigit() or ch in "[]," for ch in sub)
for sub in s.split()
try:
_ = json.loads(completion)
accepted_id_string = completion
except json.JSONDecodeError:
# While ChatGPT mostly correctly returns only the ids in JSON format,
# some other models may add text before and after the chunk id list.
accepted_id_string = next(
s
for s in completion.split("\n")
if s
and all(
all(ch.isdigit() or ch in "[]," for ch in sub)
for sub in s.split()
)
)
)

if "], [" in accepted_id_string:
# Another output format bug with Claude
accepted_id_string = accepted_id_string.replace("], [", ", ")
if "], [" in accepted_id_string:
# Another output format bug with Claude
accepted_id_string = accepted_id_string.replace("], [", ", ")

try:
returned_ids = json.loads(accepted_id_string)
while isinstance(returned_ids, dict):
returned_ids = list(returned_ids.values())[0]
assert isinstance(returned_ids, list) and all(
isinstance(i, int) for i in returned_ids
)
Expand Down Expand Up @@ -215,10 +247,20 @@ def get_context_ids(
def run_loop(llm: LLM, corpus: Corpus, question: str):
"""Run an interactive loop to test the context retrieval"""

langfuse = Langfuse()
langfuse.auth_check()
session_id = datetime.now().strftime("%Y%m%d-%H%M")

while question:
trace = langfuse.trace(
name="de-wiki-context",
input={"question": question},
session_id=session_id,
)
logging.info("Answering '%s'", question)
logging.info("Monitor trace in Langfuse: %s", trace.get_trace_url())

context_ids, rejected_ids = get_context_ids(question, corpus, llm)
context_ids, rejected_ids = get_context_ids(question, corpus, llm, trace)

if context_ids:
print("---- Accepted ----")
Expand All @@ -232,19 +274,24 @@ def run_loop(llm: LLM, corpus: Corpus, question: str):
context = build_context(corpus.data[cid] for cid in context_ids)

print("---- Without context ----")
print(llm.answer(question_prompt(question)))
print(llm.answer(question_prompt(question), name="context-off"))

print("---- With context ----")
print(llm.answer(question_prompt(question, context)))
print(llm.answer(question_prompt(question, context), name="context-on"))

question = input("---- Question: ")

langfuse.flush()


if __name__ == "__main__":
logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO)
dotenv.load_dotenv()

env = dotenv.dotenv_values()
client = OpenAI(api_key=env["PULZE_API_KEY"], base_url="https://api.pulze.ai/v1")
if pulze_key := os.environ.get("PULZE_KEY"):
client = OpenAI(api_key=pulze_key, base_url="https://api.pulze.ai/v1")
else:
client = OpenAI()

initial_question = random.choice(INITIAL_QUESTIONS)
run_loop(LLM(client, MODEL, MAX_ANSWER_TOKENS), load_corpus(), initial_question)
Empty file modified gen_context.py
100755 → 100644
Empty file.
9 changes: 6 additions & 3 deletions llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re

from openai import OpenAI

import tiktoken

ANSWER_REGEX = re.compile(r"<answer>(.*?)</answer>", flags=re.DOTALL)
Expand All @@ -22,7 +24,7 @@ class LLM:
- answer(self, prompt, output_json=False): Generates an answer based on the prompt.
"""

def __init__(self, client, model_name, max_answer_tokens):
def __init__(self, client: OpenAI, model_name, max_answer_tokens):
self.client = client
self.model_name = model_name
self.max_answer_tokens = max_answer_tokens
Expand Down Expand Up @@ -51,7 +53,7 @@ def claude_prompt_fix(self, prompt):
Assistant: <answer>"""
)

def answer(self, prompt, output_json: bool = False):
def answer(self, prompt, output_json: bool = False, **kwargs):
"""Ask LLM and parse the answer.
:param prompt: The prompt for generating the answer.
Expand All @@ -68,8 +70,9 @@ def answer(self, prompt, output_json: bool = False):
],
model=self.model_name,
# This parameter is not supported by Pulze
response_format=("json_object" if output_json else "text"),
response_format={"type": "json_object" if output_json else "text"},
max_tokens=self.max_answer_tokens,
**kwargs,
)
.choices[0]
.message.content
Expand Down

0 comments on commit 0873bd5

Please sign in to comment.