diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7108725..4cecfe4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,3 +22,9 @@ repos: files: ^frontend/ args: [--config-file=frontend/mypy.ini] additional_dependencies: [types-requests] + - id: mypy + name: mypy-backend + files: ^backend/ + args: [--config-file=backend/mypy.ini] + additional_dependencies: [types-requests] + diff --git a/backend/fastapi_app.py b/backend/fastapi_app.py index ceb4add..496bdfc 100644 --- a/backend/fastapi_app.py +++ b/backend/fastapi_app.py @@ -1,17 +1,18 @@ # command to run: uvicorn backend.fastapi_app:app --reload import os import re -from typing import List +import logging +import sys +from pathlib import Path + from fastapi import FastAPI, HTTPException, UploadFile, Form from llama_index import ServiceContext -from pydantic import BaseModel, Field + +# from pydantic import BaseModel, Field # from llama_index.callbacks import CallbackManager, TokenCountingHandler from requests.exceptions import MissingSchema -import logging -import sys from dotenv import load_dotenv -from pathlib import Path import errno import certifi @@ -21,16 +22,23 @@ AITextDocument, AIPdfDocument, AIHtmlDocument, - CustomLlamaIndexChatEngineWrapper, set_up_text_chatbot, ) - from .script_SQL_querying import ( AIDataBase, - DataChatBotWrapper, set_up_database_chatbot, ) - +from .models import ( + DoubleUploadException, + NoUploadException, + EmptyQuestionException, + TextSummaryModel, + QuestionModel, + QAResponseModel, + TextResponseModel, + MultipleChoiceTest, + ErrorResponse, +) from .helpers import load_aws_secrets # workaround for mac to solve "SSL: CERTIFICATE_VERIFY_FAILED Error" @@ -69,124 +77,60 @@ # Set-up Chat Engine: # - LlamaIndex CondenseQuestionChatEngine with RetrieverQueryEngine for text files # - or querying a database with langchain SQLDatabaseChain and Runnables -app.chat_engine: CustomLlamaIndexChatEngineWrapper | DataChatBotWrapper | None = None -app.callback_manager = None -app.token_counter = None - - -class DoubleUploadException(Exception): - pass - - -class NoUploadException(Exception): - pass - - -class EmptyQuestionException(Exception): - pass - - -class TextSummaryModel(BaseModel): - file_name: str - text_category: str - summary: str - used_tokens: int - - -class QuestionModel(BaseModel): - prompt: str - temperature: float - - -class QAResponseModel(BaseModel): - user_question: str - ai_answer: str - used_tokens: int - - -class TextResponseModel(BaseModel): - message: str - - -class MultipleChoiceQuestion(BaseModel): - """Data Model for a multiple choice question""" - - question: str = Field( - ..., - description="""An interesting and unique question related to the main - subject of the article. - """, - ) - correct_answer: str = Field(..., description="Correct answer to question") - wrong_answer_1: str = Field( - ..., description="a unique wrong answer to the question" - ) - wrong_answer_2: str = Field( - ..., - description="""a unique wrong answer to the question which is different - from wrong_answer_1 and not an empty string - """, - ) - - -class MultipleChoiceTest(BaseModel): - """Data Model for a multiple choice test""" - - questions: List[MultipleChoiceQuestion] = [] - - -class ErrorResponse(BaseModel): - detail: str +app.state.chat_engine = None +app.state.callback_manager = None +app.state.token_counter = None def load_text_chat_engine() -> None: - if not app.chat_engine or app.chat_engine.data_category == "database": + if not app.state.chat_engine or app.state.chat_engine.data_category == "database": logging.debug("setting up text chatbot") logging.debug(f"Debug: {DEBUG_MODE}") ( - app.chat_engine, - app.callback_manager, - app.token_counter, + app.state.chat_engine, + app.state.callback_manager, + app.state.token_counter, ) = set_up_text_chatbot() def load_database_chat_engine() -> None: - if not app.chat_engine or app.chat_engine.data_category != "database": + if not app.state.chat_engine or app.state.chat_engine.data_category != "database": logging.debug("setting up database chatbot") ( - app.chat_engine, - app.callback_manager, # is None in database mode - app.token_counter, + app.state.chat_engine, + app.state.callback_manager, # is None in database mode + app.state.token_counter, ) = set_up_database_chatbot() async def handle_uploadfile( upload_file: UploadFile, -) -> AITextDocument | AIDataBase | None: - file_name = upload_file.filename +) -> AITextDocument | AIDataBase | AIPdfDocument | None: + if not (file_name := upload_file.filename): + return None with open(cfd / data_dir / file_name, "wb") as f: f.write(await upload_file.read()) match upload_file.filename.split(".")[-1]: case "txt": load_text_chat_engine() - return AITextDocument(file_name, LLM_NAME, app.callback_manager) + return AITextDocument(file_name, LLM_NAME, app.state.callback_manager) case "pdf": load_text_chat_engine() - return AIPdfDocument(file_name, LLM_NAME, app.callback_manager) + return AIPdfDocument(file_name, LLM_NAME, app.state.callback_manager) case "sqlite" | "db": uri = f"sqlite:///{app_dir}/{data_dir}/{file_name}" logging.debug(f"uri: {uri} debug {DEBUG_MODE}") load_database_chat_engine() - document: AIDataBase = AIDataBase.from_uri(uri) - return document + return AIDataBase.from_uri(uri) + return None -async def handle_upload_url(upload_url) -> None: +async def handle_upload_url(upload_url: str) -> AITextDocument | AIHtmlDocument: match re.split(r"[./]", upload_url): case [*_, dir, file_name, "txt"] if dir == "data": try: load_text_chat_engine() - return AITextDocument(file_name, LLM_NAME, app.callback_manager) + return AITextDocument(file_name, LLM_NAME, app.state.callback_manager) except OSError: raise FileNotFoundError( errno.ENOENT, @@ -195,7 +139,7 @@ async def handle_upload_url(upload_url) -> None: ) case [http, *_] if "http" in http.lower(): load_text_chat_engine() - return AIHtmlDocument(upload_url, LLM_NAME, app.callback_manager) + return AIHtmlDocument(upload_url, LLM_NAME, app.state.callback_manager) case _: raise MissingSchema @@ -203,16 +147,22 @@ async def handle_upload_url(upload_url) -> None: @app.post("/upload", response_model=TextSummaryModel) async def upload_file( upload_file: UploadFile | None = None, upload_url: str = Form("") -): +) -> TextSummaryModel: message = "" text_category = "" - file_name = "" + file_name: str | None = "" used_tokens = 0 try: if upload_file: if upload_url: raise DoubleUploadException("You can not provide both, file and URL.") - file_name = upload_file.filename + if not (file_name := upload_file.filename): + return TextSummaryModel( + file_name="", + text_category=text_category, + summary=message, + used_tokens=used_tokens, + ) destination_file = Path(cfd / "data" / file_name) destination_file.parent.mkdir(exist_ok=True, parents=True) document = await handle_uploadfile(upload_file) @@ -224,11 +174,11 @@ async def upload_file( raise NoUploadException( "You must provide either a file or URL to upload.", ) - if app.chat_engine and document: - app.chat_engine.add_document(document) + if app.state.chat_engine and document: + app.state.chat_engine.add_document(document) message = document.summary text_category = document.category - used_tokens = app.token_counter.total_llm_token_count + used_tokens = app.state.token_counter.total_llm_token_count except MissingSchema: raise HTTPException( status_code=400, @@ -243,7 +193,7 @@ async def upload_file( status_code=400, detail=f"There was an unexpected OSError on uploading the file:{e}", ) - logging.debug(f"engine_up?: {app.chat_engine is not None}") + logging.debug(f"engine_up?: {app.state.chat_engine is not None}") logging.debug(f"message: {message}") return TextSummaryModel( file_name=file_name, @@ -254,18 +204,18 @@ async def upload_file( @app.post("/qa_text", response_model=QAResponseModel) -async def qa_text(question: QuestionModel): - logging.debug(f"engine_up?: {app.chat_engine is not None}") +async def qa_text(question: QuestionModel) -> QAResponseModel: + logging.debug(f"engine_up?: {app.state.chat_engine is not None}") if not question.prompt: raise EmptyQuestionException( "Your Question is empty, please type a message and resend it." ) - if app.chat_engine: - app.token_counter.reset_counts() - app.chat_engine.update_temp(question.temperature) - response = app.chat_engine.answer_question(question) + if app.state.chat_engine: + app.state.token_counter.reset_counts() + app.state.chat_engine.update_temp(question.temperature) + response = app.state.chat_engine.answer_question(question) ai_answer = str(response) - used_tokens = app.token_counter.total_llm_token_count + used_tokens = app.state.token_counter.total_llm_token_count else: ai_answer = "Sorry, no context loaded. Please upload a file or url." used_tokens = 0 @@ -279,22 +229,22 @@ async def qa_text(question: QuestionModel): @app.get("/clear_storage", response_model=TextResponseModel) async def clear_storage(): - if app.chat_engine: - app.chat_engine.clear_data_storage() + if app.state.chat_engine: + app.state.chat_engine.clear_data_storage() logging.info("chat engine cleared...") if (cfd / "data").exists(): for file in Path(cfd / "data").iterdir(): os.remove(file) - app.chat_engine = None - app.token_counter = None - app.callback_manager = None + app.state.chat_engine = None + app.state.token_counter = None + app.state.callback_manager = None return TextResponseModel(message="Knowledge base succesfully cleared") @app.get("/clear_history", response_model=TextResponseModel) async def clear_history(): - if app.chat_engine: - message = app.chat_engine.clear_chat_history() + if app.state.chat_engine: + message = app.state.chat_engine.clear_chat_history() # logging.debug("chat history cleared...") return TextResponseModel(message=message) return TextResponseModel( @@ -310,13 +260,13 @@ async def clear_history(): }, ) def get_quiz(): - if not app.chat_engine or not app.chat_engine.vector_index.ref_doc_info: + if not app.state.chat_engine or not app.state.chat_engine.vector_index.ref_doc_info: raise HTTPException( status_code=400, detail="No context provided, please provide a url or a text file!", ) - if app.chat_engine.data_category == "database": + if app.state.chat_engine.data_category == "database": raise HTTPException( status_code=400, detail="""A database is loaded, but no valid context for a quiz. @@ -344,7 +294,7 @@ def generate_quiz_from_context(): from llama_index.prompts import PromptTemplate from llama_index.response import Response - vector_index = app.chat_engine.vector_index + vector_index = app.state.chat_engine.vector_index lc_output_parser = PydanticOutputParser(pydantic_object=MultipleChoiceTest) output_parser = LangchainOutputParser(lc_output_parser) diff --git a/backend/models.py b/backend/models.py new file mode 100644 index 0000000..1eef414 --- /dev/null +++ b/backend/models.py @@ -0,0 +1,66 @@ +from pydantic import BaseModel, Field + + +class DoubleUploadException(Exception): + pass + + +class NoUploadException(Exception): + pass + + +class EmptyQuestionException(Exception): + pass + + +class TextSummaryModel(BaseModel): + file_name: str + text_category: str + summary: str + used_tokens: int + + +class QuestionModel(BaseModel): + prompt: str + temperature: float + + +class QAResponseModel(BaseModel): + user_question: str + ai_answer: str + used_tokens: int + + +class TextResponseModel(BaseModel): + message: str + + +class MultipleChoiceQuestion(BaseModel): + """Data Model for a multiple choice question""" + + question: str = Field( + ..., + description="""An interesting and unique question related to the main + subject of the article. + """, + ) + correct_answer: str = Field(..., description="Correct answer to question") + wrong_answer_1: str = Field( + ..., description="a unique wrong answer to the question" + ) + wrong_answer_2: str = Field( + ..., + description="""a unique wrong answer to the question which is different + from wrong_answer_1 and not an empty string + """, + ) + + +class MultipleChoiceTest(BaseModel): + """Data Model for a multiple choice test""" + + questions: list[MultipleChoiceQuestion] = [] + + +class ErrorResponse(BaseModel): + detail: str diff --git a/backend/mypy.ini b/backend/mypy.ini new file mode 100644 index 0000000..2aa6e3b --- /dev/null +++ b/backend/mypy.ini @@ -0,0 +1,14 @@ +[mypy] +python_version = 3.10 +disallow_incomplete_defs = True +no_implicit_optional = True +ignore_missing_imports = True +exclude = (?x)( + backend/outdated/* + | \/*venv\/* + | backend/data/* + | backend/outdated/* + | backend/storage/* + ) + + diff --git a/backend/script_RAG.py b/backend/script_RAG.py index 7b2db24..f528108 100644 --- a/backend/script_RAG.py +++ b/backend/script_RAG.py @@ -1,3 +1,8 @@ +import pathlib +import tiktoken +import logging +import os + from llama_index import ( SimpleWebPageReader, VectorStoreIndex, @@ -9,7 +14,7 @@ get_response_synthesizer, ) from llama_index.readers import BeautifulSoupWebReader - +from llama_index.schema import Document from llama_index.llms import OpenAI from llama_index.node_parser import SimpleNodeParser from llama_index.text_splitter import TokenTextSplitter @@ -24,19 +29,15 @@ from llama_index.chat_engine.condense_question import CondenseQuestionChatEngine from llama_index.callbacks import CallbackManager, TokenCountingHandler from llama_index.memory import ChatMemoryBuffer - from llama_index.vector_stores.types import MetadataInfo, VectorStoreInfo from marvin import ai_model from marvin import settings as marvin_settings from llama_index.bridge.pydantic import BaseModel as LlamaBaseModel from llama_index.bridge.pydantic import Field as LlamaField -import pathlib -import tiktoken -import logging -import os from .document_categories import CATEGORY_LABELS +from .models import QuestionModel marvin_settings.openai.api_key = os.getenv("OPENAI_API_KEY") @@ -54,7 +55,7 @@ def __init__( document_name: str, llm_str: str, callback_manager: CallbackManager | None = None, - ): + ) -> None: self.callback_manager: CallbackManager | None = callback_manager self.document = self._load_document(document_name) self.nodes = self.split_document_and_extract_metadata(llm_str) @@ -64,7 +65,7 @@ def __init__( question about "{text_subject}".' @classmethod - def _load_document(cls, identifier: str): + def _load_document(cls, identifier: str) -> Document: """loads only the data of the specified name identifier: name of the text file as str @@ -129,7 +130,7 @@ class AIMarvinDocument(LlamaBaseModel): class AIPdfDocument(AITextDocument): @classmethod - def _load_document(cls, identifier: str): + def _load_document(cls, identifier: str) -> Document: # loader = PDFReader() # return loader.load_data( # file=pathlib.Path(str(AITextDocument.cfd / identifier)) @@ -141,7 +142,7 @@ def _load_document(cls, identifier: str): class AIHtmlDocument(AITextDocument): @classmethod - def _load_document_simplewebpageReader(cls, identifier: str): + def _load_document_simplewebpageReader(cls, identifier: str) -> Document: """loads the data of a simple static website at a given url identifier: url of the html file as str """ @@ -152,7 +153,7 @@ def _load_document_simplewebpageReader(cls, identifier: str): )[0] @classmethod - def _load_document_BeautifulSoupWebReader(cls, identifier: str): + def _load_document_BeautifulSoupWebReader(cls, identifier: str) -> Document: """loads the data of an html file at a given url identifier: url of the html file as str """ @@ -165,7 +166,7 @@ def _load_document_BeautifulSoupWebReader(cls, identifier: str): return BeautifulSoupWebReader().load_data(urls=[identifier])[0] @classmethod - def _load_document(cls, identifier: str): + def _load_document(cls, identifier: str) -> Document: """loads the data of an html file at a given url identifier: url of the html file as str """ @@ -316,7 +317,7 @@ def update_temp(self, temperature): {"temperature": temperature} ) - def answer_question(self, question: str) -> str: + def answer_question(self, question: QuestionModel) -> str: return self.chat_engine.chat(question.prompt) diff --git a/backend/script_SQL_querying.py b/backend/script_SQL_querying.py index 7c319e5..734ffaf 100644 --- a/backend/script_SQL_querying.py +++ b/backend/script_SQL_querying.py @@ -1,16 +1,24 @@ # https://python.langchain.com/docs/expression_language/cookbook/sql_db +from __future__ import annotations import logging import re import sys +from typing import Any +from operator import itemgetter + from langchain.chat_models import ChatOpenAI from langchain.utilities import SQLDatabase from langchain.schema.output_parser import StrOutputParser -from langchain.schema.runnable import RunnableLambda, RunnableMap, RunnablePassthrough +from langchain.schema.runnable import ( + RunnableLambda, + RunnableMap, + RunnablePassthrough, + RunnableSequence, +) from langchain.prompts import ChatPromptTemplate from langchain.callbacks import get_openai_callback -from operator import itemgetter logging.basicConfig(stream=sys.stdout, level=logging.INFO) logging.getLogger(__name__).addHandler(logging.StreamHandler(stream=sys.stdout)) @@ -48,6 +56,16 @@ def __init__(self, *args, **kwargs): + re.sub(r"/\*((.|\n)*?)\*/", "", self.get_table_info()).strip() ) + @classmethod + def from_uri( + cls, database_uri: str, engine_args: dict | None = None, **kwargs: Any + ) -> AIDataBase: + """Construct a SQLAlchemy engine from URI.""" + from sqlalchemy import create_engine + + _engine_args = engine_args or {} + return cls(create_engine(database_uri, **_engine_args), **kwargs) + def get_schema(self, _): return self.get_table_info() @@ -60,10 +78,10 @@ def ask_a_question(self, question: str, token_callback: CustomTokenCounter) -> s llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo") # logging.debug(self.get_table_info()) with get_openai_callback() as callback: - query_generator = ( + query_generator: RunnableSequence[Any, Any] = ( RunnableMap( { - "schema": RunnableLambda(self.get_schema), + "schema": RunnableLambda(self.get_schema), # type: ignore "question": itemgetter("question"), } ) @@ -106,10 +124,10 @@ def ask_a_question(self, question: str, token_callback: CustomTokenCounter) -> s class DataChatBotWrapper: def __init__(self, callback_manager: CustomTokenCounter): self.data_category: str = "database" - self.token_callback = callback_manager - self.document: AIDataBase = None + self.token_callback: CustomTokenCounter = callback_manager + self.document: AIDataBase | None = None - def add_document(self, document) -> None: + def add_document(self, document: AIDataBase) -> None: self.document = document def clear_chat_history(self) -> str: @@ -118,13 +136,15 @@ def clear_chat_history(self) -> str: def clear_data_storage(self) -> None: del self.document self.document = None - # ToDo delete db file ? - def update_temp(self, temperature) -> None: + def update_temp(self, temperature) -> None: # type: ignore pass def answer_question(self, question: str) -> str: - return self.document.ask_a_question(question, self.token_callback) + if self.document: + return self.document.ask_a_question(question, self.token_callback) + else: + raise AttributeError("no document loaded") def set_up_database_chatbot(): @@ -151,7 +171,8 @@ def set_up_database_chatbot(): chat_engine, callback_manager, token_counter = set_up_database_chatbot() - document: AIDataBase = AIDataBase.from_uri("sqlite:///data/database.sqlite") + document = AIDataBase.from_uri("sqlite:///data/database.sqlite") + print("________") print(document.summary)