From 00459f506e4193b33f2c75d0557a5fac2a044d96 Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Thu, 7 Mar 2024 18:20:16 +0530 Subject: [PATCH 1/3] add: multi-modal chat support --- pyproject.toml | 1 + src/wandbot/adcopy/adcopy.py | 11 +++-- src/wandbot/api/app.py | 2 +- src/wandbot/api/client.py | 16 ++++--- src/wandbot/api/routers/adcopy.py | 1 + src/wandbot/api/routers/chat.py | 4 ++ src/wandbot/api/routers/content_navigator.py | 16 ++++--- src/wandbot/api/routers/database.py | 4 +- src/wandbot/apps/slack/__main__.py | 6 ++- src/wandbot/apps/slack/config.py | 8 ++-- src/wandbot/apps/slack/handlers/ad_copy.py | 10 ++-- .../apps/slack/handlers/content_navigator.py | 18 ++++--- src/wandbot/apps/slack/handlers/docsbot.py | 1 + .../apps/slack/handlers/youtube_chat.py | 6 ++- src/wandbot/chat/chat.py | 31 +++++++----- src/wandbot/chat/rag.py | 47 ++++++++++++++++++- src/wandbot/chat/schemas.py | 3 +- src/wandbot/database/client.py | 2 +- src/wandbot/evaluation/eval/__main__.py | 18 +++---- src/wandbot/multi_modal/__init__.py | 0 src/wandbot/multi_modal/multi_modal.py | 14 ++++++ src/wandbot/rag/retrieval.py | 1 + src/wandbot/retriever/base.py | 9 ++-- src/wandbot/youtube_chat/chat_utils.py | 6 +-- 24 files changed, 168 insertions(+), 67 deletions(-) create mode 100644 src/wandbot/multi_modal/__init__.py create mode 100644 src/wandbot/multi_modal/multi_modal.py diff --git a/pyproject.toml b/pyproject.toml index d7d7c08..5ea8411 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ simsimd = "^3.7.4" youtube-transcript-api = "^0.6.2" pytube = "^15.0.0" pydub = "^0.25.1" +langchain-google-genai = "^0.0.9" [tool.poetry.dev-dependencies] #fasttext = {git = "https://github.com/cfculhane/fastText"} # FastText doesn't come with pybind11 and we need to use this workaround. diff --git a/src/wandbot/adcopy/adcopy.py b/src/wandbot/adcopy/adcopy.py index 3bc1607..a9c0a7f 100644 --- a/src/wandbot/adcopy/adcopy.py +++ b/src/wandbot/adcopy/adcopy.py @@ -1,6 +1,6 @@ import json -import random import logging +import random from operator import itemgetter from typing import Any, Dict, List @@ -8,6 +8,7 @@ from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import Runnable, RunnableParallel from langchain_openai import ChatOpenAI + from wandbot.chat.chat import Chat from wandbot.chat.schemas import ChatRequest from wandbot.rag.utils import ChatModel @@ -79,7 +80,9 @@ def build_prompt_input_variables( self, query: str, persona: str, action: str ) -> Dict[str, Any]: wandbot_response = self.query_wandbot(query) - additional_context = "\n".join(random.choices(self.contexts[action], k=2)) + additional_context = "\n".join( + random.choices(self.contexts[action], k=2) + ) persona_prompt = ( TECHNICAL_PROMPT if persona == "technical" else EXECUTIVE_PROMPT ) @@ -122,7 +125,9 @@ def _load_chain(self, model: ChatOpenAI) -> Runnable: return chain def __call__(self, query: str, persona: str, action: str) -> str: - logging.info(f"Generating ad copy for {persona} {action} with query: '{query}'") + logging.info( + f"Generating ad copy for {persona} {action} with query: '{query}'" + ) inputs = self.build_inputs_for_ad_formats(query, persona, action) outputs = self.chain.batch(inputs) str_output = "" diff --git a/src/wandbot/api/app.py b/src/wandbot/api/app.py index ec2ad45..d162fa7 100644 --- a/src/wandbot/api/app.py +++ b/src/wandbot/api/app.py @@ -33,9 +33,9 @@ from datetime import datetime, timezone import pandas as pd -import wandb from fastapi import FastAPI +import wandb from wandbot.api.routers import adcopy as adcopy_router from wandbot.api.routers import chat as chat_router from wandbot.api.routers import content_navigator as content_navigator_router diff --git a/src/wandbot/api/client.py b/src/wandbot/api/client.py index 23fa1e9..d60715f 100644 --- a/src/wandbot/api/client.py +++ b/src/wandbot/api/client.py @@ -15,12 +15,13 @@ import aiohttp import requests + +from wandbot.api.routers.adcopy import AdCopyRequest, AdCopyResponse +from wandbot.api.routers.chat import APIQueryRequest, APIQueryResponse from wandbot.api.routers.content_navigator import ( ContentNavigatorRequest, ContentNavigatorResponse, ) -from wandbot.api.routers.adcopy import AdCopyRequest, AdCopyResponse -from wandbot.api.routers.chat import APIQueryRequest, APIQueryResponse from wandbot.api.routers.database import ( APIFeedbackRequest, APIFeedbackResponse, @@ -66,7 +67,9 @@ def __init__(self, url: str): ) self.retrieve_endpoint = urljoin(str(self.url), "retrieve") self.generate_ads_endpoint = urljoin(str(self.url), "generate_ads") - self.generate_content_suggestions_endpoint = urljoin(str(self.url), "generate_content_suggestions") + self.generate_content_suggestions_endpoint = urljoin( + str(self.url), "generate_content_suggestions" + ) def _get_chat_thread( self, request: APIGetChatThreadRequest @@ -640,9 +643,9 @@ async def generate_ads( response = await self._generate_ads(request) return response - + async def _generate_content_suggestions( - self, request: ContentNavigatorRequest + self, request: ContentNavigatorRequest ) -> ContentNavigatorResponse | None: """Call the content navigator API. @@ -664,7 +667,7 @@ async def _generate_content_suggestions( return None async def generate_content_suggestions( - self, user_id: str, query: str + self, user_id: str, query: str ) -> ContentNavigatorResponse: """Generates content suggestions given query. @@ -683,4 +686,3 @@ async def generate_content_suggestions( response = await self._generate_content_suggestions(request) return response - diff --git a/src/wandbot/api/routers/adcopy.py b/src/wandbot/api/routers/adcopy.py index 34cdcef..360edf2 100644 --- a/src/wandbot/api/routers/adcopy.py +++ b/src/wandbot/api/routers/adcopy.py @@ -3,6 +3,7 @@ from fastapi import APIRouter from pydantic import BaseModel from starlette import status + from wandbot.adcopy.adcopy import AdCopyEngine diff --git a/src/wandbot/api/routers/chat.py b/src/wandbot/api/routers/chat.py index 735188d..310ba43 100644 --- a/src/wandbot/api/routers/chat.py +++ b/src/wandbot/api/routers/chat.py @@ -1,3 +1,5 @@ +from typing import Optional + from fastapi import APIRouter from starlette import status @@ -38,12 +40,14 @@ def query( Returns: The APIQueryResponse object containing the result of the query. """ + logger.info(request.images[0][:10]) result = chat( ChatRequest( question=request.question, chat_history=request.chat_history, language=request.language, application=request.application, + images=request.images, ), ) result = APIQueryResponse(**result.model_dump()) diff --git a/src/wandbot/api/routers/content_navigator.py b/src/wandbot/api/routers/content_navigator.py index 590ef43..22df118 100644 --- a/src/wandbot/api/routers/content_navigator.py +++ b/src/wandbot/api/routers/content_navigator.py @@ -1,10 +1,12 @@ -import httpx - +import httpx from fastapi import APIRouter from pydantic import BaseModel from starlette import status -CONTENT_NAVIGATOR_ENDPOINT = "https://wandb-content-navigator.replit.app/get_content" +CONTENT_NAVIGATOR_ENDPOINT = ( + "https://wandb-content-navigator.replit.app/get_content" +) + class ContentNavigatorRequest(BaseModel): """A user query to be used by the content navigator app""" @@ -12,6 +14,7 @@ class ContentNavigatorRequest(BaseModel): user_id: str = None query: str + class ContentNavigatorResponse(BaseModel): """Response from the content navigator app""" @@ -26,7 +29,9 @@ class ContentNavigatorResponse(BaseModel): ) -@router.post("/", response_model=ContentNavigatorResponse, status_code=status.HTTP_200_OK) +@router.post( + "/", response_model=ContentNavigatorResponse, status_code=status.HTTP_200_OK +) async def generate_content_suggestions(request: ContentNavigatorRequest): async with httpx.AsyncClient(timeout=1200.0) as content_client: response = await content_client.post( @@ -34,7 +39,7 @@ async def generate_content_suggestions(request: ContentNavigatorRequest): json={"query": request.query, "user_id": request.user_id}, ) response_data = response.json() - + slack_response = response_data.get("slack_response", "") rejected_slack_response = response_data.get("rejected_slack_response", "") response_items_count = response_data.get("response_items_count", 0) @@ -48,4 +53,3 @@ async def generate_content_suggestions(request: ContentNavigatorRequest): rejected_slack_response=rejected_slack_response, response_items_count=response_items_count, ) - diff --git a/src/wandbot/api/routers/database.py b/src/wandbot/api/routers/database.py index 7174af6..4a8f840 100644 --- a/src/wandbot/api/routers/database.py +++ b/src/wandbot/api/routers/database.py @@ -1,8 +1,8 @@ -import wandb from fastapi import APIRouter from starlette import status from starlette.responses import Response +import wandb from wandbot.database.client import DatabaseClient from wandbot.database.database import engine from wandbot.database.models import Base @@ -13,8 +13,8 @@ FeedbackCreate, QuestionAnswer, QuestionAnswerCreate, - YoutubeAssistantThreadCreate, YoutubeAssistantThread, + YoutubeAssistantThreadCreate, ) from wandbot.utils import get_logger diff --git a/src/wandbot/apps/slack/__main__.py b/src/wandbot/apps/slack/__main__.py index 6c2955e..5d066f2 100644 --- a/src/wandbot/apps/slack/__main__.py +++ b/src/wandbot/apps/slack/__main__.py @@ -15,6 +15,7 @@ from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler from slack_bolt.async_app import AsyncApp + from wandbot.api.client import AsyncAPIClient from wandbot.apps.slack.config import SlackAppEnConfig, SlackAppJaConfig from wandbot.apps.slack.handlers.ad_copy import ( @@ -62,6 +63,7 @@ api_client = AsyncAPIClient(url=config.WANDBOT_API_URL) slack_client = app.client + def get_init_block(user: str) -> List[Dict[str, Any]]: initial_block = [ { @@ -133,6 +135,7 @@ def get_init_block(user: str) -> List[Dict[str, Any]]: ] return initial_block + # -------------------------------------- # Main Wandbot Mention Handler # -------------------------------------- @@ -212,10 +215,11 @@ async def handle_mention(event: dict, say: callable) -> None: ) ) + async def main(): handler = AsyncSocketModeHandler(app, config.SLACK_APP_TOKEN) await handler.start_async() if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/src/wandbot/apps/slack/config.py b/src/wandbot/apps/slack/config.py index 5b903df..92692df 100644 --- a/src/wandbot/apps/slack/config.py +++ b/src/wandbot/apps/slack/config.py @@ -48,12 +48,12 @@ "#support チャンネルにいるwandbチームに質問してください。この答えは役に立ったでしょうか?下のボタンでお知らせ下さい。" ) -JA_ERROR_MESSAGE = "「おっと、問題が発生しました。しばらくしてからもう一度お試しください。」" - -JA_FALLBACK_WARNING_MESSAGE = ( - "**警告: {model}** にフォールバックします。これらの結果は **gpt-4** ほど良くない可能性があります\n\n" +JA_ERROR_MESSAGE = ( + "「おっと、問題が発生しました。しばらくしてからもう一度お試しください。」" ) +JA_FALLBACK_WARNING_MESSAGE = "**警告: {model}** にフォールバックします。これらの結果は **gpt-4** ほど良くない可能性があります\n\n" + class SlackAppEnConfig(BaseSettings): APPLICATION: str = Field("Slack_EN") diff --git a/src/wandbot/apps/slack/handlers/ad_copy.py b/src/wandbot/apps/slack/handlers/ad_copy.py index 111c211..903b4f1 100644 --- a/src/wandbot/apps/slack/handlers/ad_copy.py +++ b/src/wandbot/apps/slack/handlers/ad_copy.py @@ -1,8 +1,9 @@ -import re import logging +import re from typing import Any, Dict, List from slack_sdk.web.async_client import AsyncWebClient + from wandbot.api.client import AsyncAPIClient from wandbot.apps.slack.utils import get_initial_message_from_thread @@ -131,8 +132,11 @@ async def handle_adcopy_action( query = re.sub(r"\<@\w+\>", "", query).strip() logger.info(f"Initial message: {initial_message}") - await say(f"Working on generating ads for '{persona}' focussed on '{action}' \ -for the query: '{query}'...", thread_ts=thread_ts) + await say( + f"Working on generating ads for '{persona}' focussed on '{action}' \ +for the query: '{query}'...", + thread_ts=thread_ts, + ) api_response = await api_client.generate_ads( query=query, action=action, persona=persona diff --git a/src/wandbot/apps/slack/handlers/content_navigator.py b/src/wandbot/apps/slack/handlers/content_navigator.py index 482e234..bef0560 100644 --- a/src/wandbot/apps/slack/handlers/content_navigator.py +++ b/src/wandbot/apps/slack/handlers/content_navigator.py @@ -1,6 +1,7 @@ import logging from slack_sdk.web.async_client import AsyncWebClient + from wandbot.api.client import AsyncAPIClient from wandbot.apps.slack.utils import get_initial_message_from_thread @@ -32,15 +33,20 @@ async def handle_content_navigator_action( if api_response.response_items_count > 0: await say(api_response.slack_response, thread_ts=thread_ts) else: - await say("No content suggestions found. Try rephrasing your query, but note \ + await say( + "No content suggestions found. Try rephrasing your query, but note \ there may also not be any relevant pieces of content for this query. Add '--debug' to \ your query and try again to see a detailed resoning for each suggestion.", - thread_ts=thread_ts) - + thread_ts=thread_ts, + ) + # if debug mode is enabled, send the rejected suggestions as well if len(api_response.rejected_slack_response) > 1: - await say("REJECTED SUGGESTIONS:\n{api_response.rejected_slack_response}", thread_ts=thread_ts) - + await say( + "REJECTED SUGGESTIONS:\n{api_response.rejected_slack_response}", + thread_ts=thread_ts, + ) + def create_content_navigator_handler( slack_client: AsyncWebClient, api_client: AsyncAPIClient @@ -57,4 +63,4 @@ async def executive_signups_handler( logger=logger, ) - return executive_signups_handler \ No newline at end of file + return executive_signups_handler diff --git a/src/wandbot/apps/slack/handlers/docsbot.py b/src/wandbot/apps/slack/handlers/docsbot.py index 3626ad8..7901a46 100644 --- a/src/wandbot/apps/slack/handlers/docsbot.py +++ b/src/wandbot/apps/slack/handlers/docsbot.py @@ -2,6 +2,7 @@ from slack_sdk.web import SlackResponse from slack_sdk.web.async_client import AsyncWebClient + from wandbot.api.client import AsyncAPIClient from wandbot.apps.slack.config import SlackAppEnConfig, SlackAppJaConfig from wandbot.apps.slack.formatter import MrkdwnFormatter diff --git a/src/wandbot/apps/slack/handlers/youtube_chat.py b/src/wandbot/apps/slack/handlers/youtube_chat.py index 0e6b8ed..d47bc88 100644 --- a/src/wandbot/apps/slack/handlers/youtube_chat.py +++ b/src/wandbot/apps/slack/handlers/youtube_chat.py @@ -3,6 +3,7 @@ from pytube.exceptions import RegexMatchError from slack_sdk.web.async_client import AsyncWebClient + from wandbot.apps.slack.utils import get_initial_message_from_thread from wandbot.youtube_chat.video_utils import YoutubeVideoInfo @@ -201,7 +202,10 @@ async def handle_youtube_chat_input( await ack() logger.info(f"Received message: {body}") url = body["actions"][0]["value"] - await say("Working on in it...", thread_ts=body["message"]["thread_ts"],) + await say( + "Working on in it...", + thread_ts=body["message"]["thread_ts"], + ) video_confirmation_block = get_video_confirmation_blocks(url) await say( blocks=video_confirmation_block, diff --git a/src/wandbot/chat/chat.py b/src/wandbot/chat/chat.py index ff0f0d7..a4f72c1 100644 --- a/src/wandbot/chat/chat.py +++ b/src/wandbot/chat/chat.py @@ -24,7 +24,10 @@ print(f"WandBot: {response.answer}") print(f"Time taken: {response.time_taken}") """ -from typing import List + +from typing import List, Optional + +from weave.monitoring import StreamTable import wandb from wandbot.chat.config import ChatConfig @@ -33,7 +36,6 @@ from wandbot.database.schemas import QuestionAnswer from wandbot.retriever import VectorStore from wandbot.utils import Timer, get_logger -from weave.monitoring import StreamTable logger = get_logger(__name__) @@ -65,23 +67,27 @@ def __init__( job_type="chat", ) self.run._label(repo="wandbot") - self.chat_table = StreamTable( - table_name="chat_logs", - project_name=self.config.wandb_project, - entity_name=self.config.wandb_entity, - ) + # self.chat_table = StreamTable( + # table_name="chat_logs", + # project_name=self.config.wandb_project, + # entity_name=self.config.wandb_entity, + # ) self.rag_pipeline = RAGPipeline(vector_store=vector_store) def _get_answer( - self, question: str, chat_history: List[QuestionAnswer] + self, + question: str, + chat_history: List[QuestionAnswer], + images: Optional[List[str]] = None, ) -> RAGPipelineOutput: history = [] for item in chat_history: history.append(("user", item.question)) history.append(("assistant", item.answer)) + # TODO: Add image prompts to history - result = self.rag_pipeline(question, history) + result = self.rag_pipeline(question, history, images=images) return result @@ -95,8 +101,11 @@ def __call__(self, chat_request: ChatRequest) -> ChatResponse: An instance of `ChatResponse` representing the chat response. """ try: + logger.info(chat_request.images[0][:10]) result = self._get_answer( - chat_request.question, chat_request.chat_history or [] + chat_request.question, + chat_request.chat_history or [], + images=chat_request.images, ) result_dict = result.model_dump() @@ -108,7 +117,7 @@ def __call__(self, chat_request: ChatRequest) -> ChatResponse: } result_dict.update({"application": chat_request.application}) self.run.log(usage_stats) - self.chat_table.log(result_dict) + # self.chat_table.log(result_dict) return ChatResponse(**result_dict) except Exception as e: with Timer() as timer: diff --git a/src/wandbot/chat/rag.py b/src/wandbot/chat/rag.py index 10590ba..eb05517 100644 --- a/src/wandbot/chat/rag.py +++ b/src/wandbot/chat/rag.py @@ -1,8 +1,12 @@ import datetime -from typing import List, Tuple +from typing import List, Optional, Tuple from langchain_community.callbacks import get_openai_callback +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.prompts import ChatPromptTemplate +from langchain_openai import ChatOpenAI from pydantic import BaseModel + from wandbot.rag.query_handler import QueryEnhancer from wandbot.rag.response_synthesis import ResponseSynthesizer from wandbot.rag.retrieval import FusionRetrieval @@ -58,12 +62,51 @@ def __init__( ) self.response_synthesizer = ResponseSynthesizer() + def generate_multi_modal_initial_response( + self, question: str, images: List[str] + ) -> str: + model = ChatOpenAI(model="gpt-4-vision-preview", max_tokens=500) + system_message = SystemMessage( + content="""You are a Weights & Biases support expert. + Your goal is to answer the user's question and provide them with the best possible solution. + You are provided with a support ticket and a set of screenshots related to the issue. + Provide a detailed solution to the user query based on the ticket and the screenshots.""" + ) + prompt = [ + { + "type": "text", + "text": question, + } + ] + for img in images: + prompt += [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img}"}, + } + ] + message = HumanMessage(content=prompt) + response = model.invoke([system_message, message]) + logger.info(response.content) + return response.content + def __call__( - self, question: str, chat_history: List[Tuple[str, str]] | None = None + self, + question: str, + chat_history: List[Tuple[str, str]] | None = None, + images: Optional[List[str]] = None, ): if chat_history is None: chat_history = [] + if images is not None: + multi_modal_response = self.generate_multi_modal_initial_response( + question, images + ) + question = ( + f"Is the following statement correct?\n{multi_modal_response}" + ) + with get_openai_callback() as query_enhancer_cb, Timer() as query_enhancer_tb: enhanced_query = self.query_enhancer.chain.invoke( {"query": question, "chat_history": chat_history} diff --git a/src/wandbot/chat/schemas.py b/src/wandbot/chat/schemas.py index f952ded..bce7b6f 100644 --- a/src/wandbot/chat/schemas.py +++ b/src/wandbot/chat/schemas.py @@ -17,7 +17,7 @@ """ from datetime import datetime -from typing import List +from typing import List, Optional from pydantic import BaseModel @@ -46,6 +46,7 @@ class ChatRequest(BaseModel): chat_history: List[QuestionAnswer] | None = None application: str | None = None language: str = "en" + images: Optional[List[str]] = None class ChatResponse(BaseModel): diff --git a/src/wandbot/database/client.py b/src/wandbot/database/client.py index 06813ac..0e527bb 100644 --- a/src/wandbot/database/client.py +++ b/src/wandbot/database/client.py @@ -15,7 +15,7 @@ import json from datetime import datetime, timedelta -from typing import Any, List, Collection +from typing import Any, Collection, List from sqlalchemy.future import create_engine from sqlalchemy.orm import sessionmaker diff --git a/src/wandbot/evaluation/eval/__main__.py b/src/wandbot/evaluation/eval/__main__.py index 3302f59..78de74c 100644 --- a/src/wandbot/evaluation/eval/__main__.py +++ b/src/wandbot/evaluation/eval/__main__.py @@ -95,9 +95,9 @@ def get_answer_correctness(row_str: str) -> str: reference_notes=row["reference_notes"], ) result = parse_answer_eval("answer_correctness", result.dict()) - result[ - "answer_correctness_score_(ragas)" - ] = metrics.answer_correctness.score_single(row) + result["answer_correctness_score_(ragas)"] = ( + metrics.answer_correctness.score_single(row) + ) result = json.dumps(result) return result @@ -113,9 +113,9 @@ def get_answer_relevancy(row_str: str) -> str: reference=row["ground_truths"], ) result = parse_answer_eval("answer_relevancy", result.dict()) - result[ - "answer_relevancy_score_(ragas)" - ] = metrics.answer_relevancy.score_single(row) + result["answer_relevancy_score_(ragas)"] = ( + metrics.answer_relevancy.score_single(row) + ) result = json.dumps(result) return result @@ -132,9 +132,9 @@ def get_answer_faithfulness(row_str: str) -> str: ) result = parse_answer_eval("answer_faithfulness", result.dict()) - result[ - "answer_faithfulness_score_(ragas)" - ] = metrics.faithfulness.score_single(row) + result["answer_faithfulness_score_(ragas)"] = ( + metrics.faithfulness.score_single(row) + ) result = json.dumps(result) return result diff --git a/src/wandbot/multi_modal/__init__.py b/src/wandbot/multi_modal/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/wandbot/multi_modal/multi_modal.py b/src/wandbot/multi_modal/multi_modal.py new file mode 100644 index 0000000..200188f --- /dev/null +++ b/src/wandbot/multi_modal/multi_modal.py @@ -0,0 +1,14 @@ +import base64 + +from langchain.schema.messages import AIMessage, HumanMessage +from langchain_openai import ChatOpenAI + + +class MultiModalQueryEngine: + + def __init__(self, model: str = "gpt-4-vision-preview"): + self.model = ChatOpenAI(model_name=model) + + def encode_image(image_path): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") diff --git a/src/wandbot/rag/retrieval.py b/src/wandbot/rag/retrieval.py index 065c791..1a72ac9 100644 --- a/src/wandbot/rag/retrieval.py +++ b/src/wandbot/rag/retrieval.py @@ -3,6 +3,7 @@ from langchain.retrievers.document_compressors import CohereRerank from langchain_core.documents import Document from langchain_core.runnables import Runnable, RunnablePassthrough + from wandbot.rag.utils import get_web_contexts from wandbot.retriever.base import VectorStore from wandbot.retriever.web_search import YouSearch, YouSearchConfig diff --git a/src/wandbot/retriever/base.py b/src/wandbot/retriever/base.py index 3982ca6..8c08a47 100644 --- a/src/wandbot/retriever/base.py +++ b/src/wandbot/retriever/base.py @@ -1,18 +1,17 @@ from operator import itemgetter from typing import List -import wandb +from chromadb.config import Settings as ChromaSettings from langchain_community.document_transformers import EmbeddingsRedundantFilter from langchain_community.vectorstores.chroma import Chroma from langchain_core.documents import Document from langchain_core.runnables import RunnableLambda, RunnableParallel + +import wandb from wandbot.ingestion.config import VectorStoreConfig from wandbot.retriever.reranking import CohereRerankChain from wandbot.retriever.utils import OpenAIEmbeddingsModel -from chromadb.config import Settings as ChromaSettings - - class VectorStore: embeddings_model: OpenAIEmbeddingsModel = OpenAIEmbeddingsModel( @@ -34,7 +33,7 @@ def __init__( collection_name=collection_name, embedding_function=self.embeddings_model, # type: ignore persist_directory=persist_dir, - client_settings=ChromaSettings(anonymized_telemetry=False) + client_settings=ChromaSettings(anonymized_telemetry=False), ) @classmethod diff --git a/src/wandbot/youtube_chat/chat_utils.py b/src/wandbot/youtube_chat/chat_utils.py index f9ce1b4..b546f6c 100644 --- a/src/wandbot/youtube_chat/chat_utils.py +++ b/src/wandbot/youtube_chat/chat_utils.py @@ -9,14 +9,12 @@ from tenacity import ( retry, retry_if_result, - wait_exponential, stop_after_attempt, + wait_exponential, ) from wandbot.database.client import DatabaseClient -from wandbot.database.schemas import ( - YoutubeAssistantThreadCreate, -) +from wandbot.database.schemas import YoutubeAssistantThreadCreate from wandbot.utils import get_logger logger = get_logger(__name__) From 94983e61eeb873cb24f123e7696c3bb0997f6308 Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Thu, 7 Mar 2024 19:44:58 +0530 Subject: [PATCH 2/3] update: multi-modal chat --- src/wandbot/chat/rag.py | 27 +++++++++++++++------------ src/wandbot/rag/query_handler.py | 11 +++++++++-- src/wandbot/rag/utils.py | 5 +++++ 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/src/wandbot/chat/rag.py b/src/wandbot/chat/rag.py index eb05517..21d6228 100644 --- a/src/wandbot/chat/rag.py +++ b/src/wandbot/chat/rag.py @@ -1,6 +1,7 @@ import datetime from typing import List, Optional, Tuple +from langchain_anthropic import ChatAnthropic from langchain_community.callbacks import get_openai_callback from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.prompts import ChatPromptTemplate @@ -65,12 +66,13 @@ def __init__( def generate_multi_modal_initial_response( self, question: str, images: List[str] ) -> str: - model = ChatOpenAI(model="gpt-4-vision-preview", max_tokens=500) + # model = ChatOpenAI(model="gpt-4-vision-preview", max_tokens=500) + model = ChatAnthropic(model="claude-3-opus-20240229") system_message = SystemMessage( content="""You are a Weights & Biases support expert. - Your goal is to answer the user's question and provide them with the best possible solution. + Your goal is to describe the attached screenshots in the context of the user query. You are provided with a support ticket and a set of screenshots related to the issue. - Provide a detailed solution to the user query based on the ticket and the screenshots.""" + Provide a detailed description of the image in the context of the query such that the ticket can be answered correctly while incorporating the image info.""" ) prompt = [ { @@ -87,7 +89,6 @@ def generate_multi_modal_initial_response( ] message = HumanMessage(content=prompt) response = model.invoke([system_message, message]) - logger.info(response.content) return response.content def __call__( @@ -99,17 +100,19 @@ def __call__( if chat_history is None: chat_history = [] - if images is not None: - multi_modal_response = self.generate_multi_modal_initial_response( - question, images - ) - question = ( - f"Is the following statement correct?\n{multi_modal_response}" - ) + multi_modal_response = ( + self.generate_multi_modal_initial_response(question, images) + if images is not None + else "" + ) with get_openai_callback() as query_enhancer_cb, Timer() as query_enhancer_tb: enhanced_query = self.query_enhancer.chain.invoke( - {"query": question, "chat_history": chat_history} + { + "query": question, + "chat_history": chat_history, + "image_context": multi_modal_response, + } ) with Timer() as retrieval_tb: diff --git a/src/wandbot/rag/query_handler.py b/src/wandbot/rag/query_handler.py index e311fc7..851f5a4 100644 --- a/src/wandbot/rag/query_handler.py +++ b/src/wandbot/rag/query_handler.py @@ -192,7 +192,10 @@ def avoid_query(self) -> bool: ) def parse_output( - self, query: str, chat_history: Optional[List[Tuple[str, str]]] = None + self, + query: str, + chat_history: Optional[List[Tuple[str, str]]] = None, + image_context: Optional[str] = None, ) -> Dict[str, Any]: """Parse the output of the model""" question = clean_question(query) @@ -237,6 +240,7 @@ def parse_output( "avoid_query": self.avoid_query, "chat_history": chat_history, "all_queries": all_queries, + "image_context": image_context, } @@ -299,10 +303,13 @@ def _load_chain(self, model: ChatOpenAI) -> Runnable: query=itemgetter("query"), chat_history=itemgetter("chat_history"), enhanced_query=full_query_enhancer_chain, + image_context=itemgetter("image_context"), ) chain = intermediate_chain | RunnableLambda( lambda x: x["enhanced_query"].parse_output( - x["query"], convert_to_messages(x["chat_history"]) + x["query"], + convert_to_messages(x["chat_history"]), + image_context=x["image_context"], ) ) diff --git a/src/wandbot/rag/utils.py b/src/wandbot/rag/utils.py index f2270a4..264de4c 100644 --- a/src/wandbot/rag/utils.py +++ b/src/wandbot/rag/utils.py @@ -47,6 +47,10 @@ def __set__(self, obj, value): Sub-queries to consider answering: {sub_queries} + +Context from attached images: + +{image_context} """ ) @@ -58,6 +62,7 @@ def create_query_str(enhanced_query, document_prompt=DEFAULT_QUESTION_PROMPT): "intents": enhanced_query["intents"], "sub_queries": "\t" + "\n\t".join(enhanced_query["sub_queries"]).strip(), + "image_context": enhanced_query["image_context"], } doc = Document(page_content=page_content, metadata=metadata) doc = clean_document_content(doc) From 034dd747e5a34bdd1f719b32f33c131c0222099a Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Thu, 7 Mar 2024 19:46:35 +0530 Subject: [PATCH 3/3] update: system prompt for multi-modal chat --- src/wandbot/chat/rag.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/wandbot/chat/rag.py b/src/wandbot/chat/rag.py index 21d6228..02f471a 100644 --- a/src/wandbot/chat/rag.py +++ b/src/wandbot/chat/rag.py @@ -70,9 +70,8 @@ def generate_multi_modal_initial_response( model = ChatAnthropic(model="claude-3-opus-20240229") system_message = SystemMessage( content="""You are a Weights & Biases support expert. - Your goal is to describe the attached screenshots in the context of the user query. - You are provided with a support ticket and a set of screenshots related to the issue. - Provide a detailed description of the image in the context of the query such that the ticket can be answered correctly while incorporating the image info.""" + Your goal is to describe the attached screenshots in the context of the user query. You are provided with a support ticket and screenshots related to the issue. + Provide a detailed description of the image in the context of the query so that the ticket can be answered correctly while incorporating the image info.""" ) prompt = [ {