From 5fc8b3dcbf1fdcfa8faddc6c028628b8668b5449 Mon Sep 17 00:00:00 2001 From: Mira Date: Sat, 23 Mar 2024 08:03:16 +0100 Subject: [PATCH 1/5] update ruff linter version --- .pre-commit-config.yaml | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4841426..06890bb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,17 +1,14 @@ repos: -# taken from: https://black.readthedocs.io/en/stable/integrations/source_version_control.html -# Using this mirror lets us use mypyc-compiled black, which is about 2x faster + +# black formatter - repo: https://github.com/psf/black-pre-commit-mirror rev: 23.9.1 hooks: - id: black - # It is recommended to specify the latest version of Python - # supported by your project here, or alternatively use - # pre-commit's default_language_version, see - # https://pre-commit.com/#top_level-default_language_version language_version: python3.10 -# taken from: https://pypi.org/project/ruff/ + +# ruff linter - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.291 + rev: v0.3.4 hooks: - id: ruff \ No newline at end of file From 275d82956da466d4c2306ff2767dda71fd18edc5 Mon Sep 17 00:00:00 2001 From: Mira Date: Mon, 25 Mar 2024 10:56:30 +0100 Subject: [PATCH 2/5] add mypy static type checking for frontend to pre-commit --- .pre-commit-config.yaml | 12 +++++++++++- frontend/mypy.ini | 10 ++++++++++ frontend/streamlit_app.py | 36 +++++++++++++++++++----------------- frontend/utils/__init__.py | 0 frontend/utils/helpers.py | 22 +++++++++++++++++++--- 5 files changed, 59 insertions(+), 21 deletions(-) create mode 100644 frontend/mypy.ini create mode 100644 frontend/utils/__init__.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 06890bb..7108725 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,4 +11,14 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.3.4 hooks: - - id: ruff \ No newline at end of file + - id: ruff + +# mypy static type checker +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.8.0 + hooks: + - id: mypy + name: mypy-frontend + files: ^frontend/ + args: [--config-file=frontend/mypy.ini] + additional_dependencies: [types-requests] diff --git a/frontend/mypy.ini b/frontend/mypy.ini new file mode 100644 index 0000000..c1ed88c --- /dev/null +++ b/frontend/mypy.ini @@ -0,0 +1,10 @@ +[mypy] +# https://mypy.readthedocs.io/en/stable/config_file.html + +python_version = 3.10 +disallow_incomplete_defs = True +no_implicit_optional = True +install_types = True +exclude = frontend/utils/request_wrapper.py +ignore_missing_imports = True + diff --git a/frontend/streamlit_app.py b/frontend/streamlit_app.py index 7a4fbc0..20daff0 100644 --- a/frontend/streamlit_app.py +++ b/frontend/streamlit_app.py @@ -1,17 +1,22 @@ # run command from root: streamlit run streamlit_app.py +import os +import sys +import logging import pathlib import random -from fastapi import UploadFile +from typing import Callable + +from streamlit.runtime.uploaded_file_manager import UploadedFile + +# from fastapi import UploadFile import streamlit as st from streamlit_option_menu import option_menu from streamlit_extras.add_vertical_space import add_vertical_space from streamlit_extras.stylable_container import stylable_container import requests -import os -import sys -import logging -import certifi + from PIL import Image +import certifi from dotenv import load_dotenv import sentry_sdk @@ -35,7 +40,7 @@ logging.info(f"{API_URL=}") APP_TITLE = "Quaigle" -MAIN_PAGE = {} +PAGE_REGISTRY_DICT: dict[str, Callable[..., None]] = {} cfd = pathlib.Path(__file__).parent logging.basicConfig(stream=sys.stdout, level=logging.INFO) @@ -169,21 +174,18 @@ def display_options_menu(): st.session_state["redirect_page"] = None -def make_get_request(route: str): +def make_get_request(route: str) -> requests.Response: return requests.get(os.path.join(API_URL, route)) def post_data_to_backend( - route: str, url: str = "", uploaded_file: UploadFile | None = None -): + route: str, url: str = "", uploaded_file: UploadedFile | None = None +) -> None: with st.spinner("Waiting for openai API response"): try: if url: data = {"upload_url": url} - files = {"upload_file": ("", None)} - response = requests.post( - os.path.join(API_URL, route), data=data, files=files - ) + response = requests.post(os.path.join(API_URL, route), data=data) elif uploaded_file: files = {"upload_file": (uploaded_file.name, uploaded_file)} data = {"upload_url": ""} @@ -287,7 +289,7 @@ def display_sidemenu(): st.experimental_rerun() -@register_page(MAIN_PAGE) +@register_page(PAGE_REGISTRY_DICT) def questionai(): with st.container(): for message in st.session_state.messages: @@ -335,7 +337,7 @@ def questionai(): ) -@register_page(MAIN_PAGE) +@register_page(PAGE_REGISTRY_DICT) def quizme(): with st.container(): if st.session_state["chat_mode"] == "database": @@ -392,7 +394,7 @@ def quizme(): ) -@register_page(MAIN_PAGE) +@register_page(PAGE_REGISTRY_DICT) def statistics(): import pandas as pd @@ -436,7 +438,7 @@ def main(): display_header() display_sidemenu() # implement the selected page from options menu - MAIN_PAGE[st.session_state.selected_page]() + PAGE_REGISTRY_DICT[st.session_state.selected_page]() if __name__ == "__main__": diff --git a/frontend/utils/__init__.py b/frontend/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/frontend/utils/helpers.py b/frontend/utils/helpers.py index 5e0d95e..e00fafe 100644 --- a/frontend/utils/helpers.py +++ b/frontend/utils/helpers.py @@ -1,7 +1,23 @@ -def register_page(page_dict: dict[str, callable]) -> callable: - """decorator to register page automatically in page dict""" +from typing import Callable, TypeVar + +F = TypeVar("F", bound=Callable[..., None]) + + +def register_page(page_registry_dict: dict[str, F]) -> Callable[[F], F]: + """ + Streamlit page decorator (factory) which takes the page registry dict + as argument and returns the actual decorator, which itself registers + all decorated page functions automatically in page dict. + + Args: + page_registry_dict: the dict where to register the page + + Returns: + Callable: the actual decorator, a callable which takes a callable as + argument and returns it. + """ def inner(func): - page_dict[func.__name__] = func + page_registry_dict[func.__name__] = func return inner From eb997e6632a57e769004c3046b1600a06735422a Mon Sep 17 00:00:00 2001 From: Mira Date: Mon, 25 Mar 2024 21:12:58 +0100 Subject: [PATCH 3/5] add mypy static type checking for backend to pre-commit --- .pre-commit-config.yaml | 6 ++ backend/fastapi_app.py | 186 ++++++++++++--------------------- backend/models.py | 66 ++++++++++++ backend/mypy.ini | 14 +++ backend/script_RAG.py | 27 ++--- backend/script_SQL_querying.py | 43 ++++++-- 6 files changed, 198 insertions(+), 144 deletions(-) create mode 100644 backend/models.py create mode 100644 backend/mypy.ini diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7108725..4cecfe4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,3 +22,9 @@ repos: files: ^frontend/ args: [--config-file=frontend/mypy.ini] additional_dependencies: [types-requests] + - id: mypy + name: mypy-backend + files: ^backend/ + args: [--config-file=backend/mypy.ini] + additional_dependencies: [types-requests] + diff --git a/backend/fastapi_app.py b/backend/fastapi_app.py index ceb4add..5dc47d8 100644 --- a/backend/fastapi_app.py +++ b/backend/fastapi_app.py @@ -1,17 +1,14 @@ # command to run: uvicorn backend.fastapi_app:app --reload import os import re -from typing import List +import logging +import sys +from pathlib import Path + from fastapi import FastAPI, HTTPException, UploadFile, Form from llama_index import ServiceContext -from pydantic import BaseModel, Field - -# from llama_index.callbacks import CallbackManager, TokenCountingHandler from requests.exceptions import MissingSchema -import logging -import sys from dotenv import load_dotenv -from pathlib import Path import errno import certifi @@ -21,16 +18,23 @@ AITextDocument, AIPdfDocument, AIHtmlDocument, - CustomLlamaIndexChatEngineWrapper, set_up_text_chatbot, ) - from .script_SQL_querying import ( AIDataBase, - DataChatBotWrapper, set_up_database_chatbot, ) - +from .models import ( + DoubleUploadException, + NoUploadException, + EmptyQuestionException, + TextSummaryModel, + QuestionModel, + QAResponseModel, + TextResponseModel, + MultipleChoiceTest, + ErrorResponse, +) from .helpers import load_aws_secrets # workaround for mac to solve "SSL: CERTIFICATE_VERIFY_FAILED Error" @@ -69,124 +73,60 @@ # Set-up Chat Engine: # - LlamaIndex CondenseQuestionChatEngine with RetrieverQueryEngine for text files # - or querying a database with langchain SQLDatabaseChain and Runnables -app.chat_engine: CustomLlamaIndexChatEngineWrapper | DataChatBotWrapper | None = None -app.callback_manager = None -app.token_counter = None - - -class DoubleUploadException(Exception): - pass - - -class NoUploadException(Exception): - pass - - -class EmptyQuestionException(Exception): - pass - - -class TextSummaryModel(BaseModel): - file_name: str - text_category: str - summary: str - used_tokens: int - - -class QuestionModel(BaseModel): - prompt: str - temperature: float - - -class QAResponseModel(BaseModel): - user_question: str - ai_answer: str - used_tokens: int - - -class TextResponseModel(BaseModel): - message: str - - -class MultipleChoiceQuestion(BaseModel): - """Data Model for a multiple choice question""" - - question: str = Field( - ..., - description="""An interesting and unique question related to the main - subject of the article. - """, - ) - correct_answer: str = Field(..., description="Correct answer to question") - wrong_answer_1: str = Field( - ..., description="a unique wrong answer to the question" - ) - wrong_answer_2: str = Field( - ..., - description="""a unique wrong answer to the question which is different - from wrong_answer_1 and not an empty string - """, - ) - - -class MultipleChoiceTest(BaseModel): - """Data Model for a multiple choice test""" - - questions: List[MultipleChoiceQuestion] = [] - - -class ErrorResponse(BaseModel): - detail: str +app.state.chat_engine = None +app.state.callback_manager = None +app.state.token_counter = None def load_text_chat_engine() -> None: - if not app.chat_engine or app.chat_engine.data_category == "database": + if not app.state.chat_engine or app.state.chat_engine.data_category == "database": logging.debug("setting up text chatbot") logging.debug(f"Debug: {DEBUG_MODE}") ( - app.chat_engine, - app.callback_manager, - app.token_counter, + app.state.chat_engine, + app.state.callback_manager, + app.state.token_counter, ) = set_up_text_chatbot() def load_database_chat_engine() -> None: - if not app.chat_engine or app.chat_engine.data_category != "database": + if not app.state.chat_engine or app.state.chat_engine.data_category != "database": logging.debug("setting up database chatbot") ( - app.chat_engine, - app.callback_manager, # is None in database mode - app.token_counter, + app.state.chat_engine, + app.state.callback_manager, # is None in database mode + app.state.token_counter, ) = set_up_database_chatbot() async def handle_uploadfile( upload_file: UploadFile, -) -> AITextDocument | AIDataBase | None: - file_name = upload_file.filename +) -> AITextDocument | AIDataBase | AIPdfDocument | None: + if not (file_name := upload_file.filename): + return None with open(cfd / data_dir / file_name, "wb") as f: f.write(await upload_file.read()) match upload_file.filename.split(".")[-1]: case "txt": load_text_chat_engine() - return AITextDocument(file_name, LLM_NAME, app.callback_manager) + return AITextDocument(file_name, LLM_NAME, app.state.callback_manager) case "pdf": load_text_chat_engine() - return AIPdfDocument(file_name, LLM_NAME, app.callback_manager) + return AIPdfDocument(file_name, LLM_NAME, app.state.callback_manager) case "sqlite" | "db": uri = f"sqlite:///{app_dir}/{data_dir}/{file_name}" logging.debug(f"uri: {uri} debug {DEBUG_MODE}") load_database_chat_engine() - document: AIDataBase = AIDataBase.from_uri(uri) - return document + return AIDataBase.from_uri(uri) + return None -async def handle_upload_url(upload_url) -> None: +async def handle_upload_url(upload_url: str) -> AITextDocument | AIHtmlDocument: match re.split(r"[./]", upload_url): case [*_, dir, file_name, "txt"] if dir == "data": try: load_text_chat_engine() - return AITextDocument(file_name, LLM_NAME, app.callback_manager) + return AITextDocument(file_name, LLM_NAME, app.state.callback_manager) except OSError: raise FileNotFoundError( errno.ENOENT, @@ -195,7 +135,7 @@ async def handle_upload_url(upload_url) -> None: ) case [http, *_] if "http" in http.lower(): load_text_chat_engine() - return AIHtmlDocument(upload_url, LLM_NAME, app.callback_manager) + return AIHtmlDocument(upload_url, LLM_NAME, app.state.callback_manager) case _: raise MissingSchema @@ -203,16 +143,22 @@ async def handle_upload_url(upload_url) -> None: @app.post("/upload", response_model=TextSummaryModel) async def upload_file( upload_file: UploadFile | None = None, upload_url: str = Form("") -): +) -> TextSummaryModel: message = "" text_category = "" - file_name = "" + file_name: str | None = "" used_tokens = 0 try: if upload_file: if upload_url: raise DoubleUploadException("You can not provide both, file and URL.") - file_name = upload_file.filename + if not (file_name := upload_file.filename): + return TextSummaryModel( + file_name="", + text_category=text_category, + summary=message, + used_tokens=used_tokens, + ) destination_file = Path(cfd / "data" / file_name) destination_file.parent.mkdir(exist_ok=True, parents=True) document = await handle_uploadfile(upload_file) @@ -224,11 +170,11 @@ async def upload_file( raise NoUploadException( "You must provide either a file or URL to upload.", ) - if app.chat_engine and document: - app.chat_engine.add_document(document) + if app.state.chat_engine and document: + app.state.chat_engine.add_document(document) message = document.summary text_category = document.category - used_tokens = app.token_counter.total_llm_token_count + used_tokens = app.state.token_counter.total_llm_token_count except MissingSchema: raise HTTPException( status_code=400, @@ -243,7 +189,7 @@ async def upload_file( status_code=400, detail=f"There was an unexpected OSError on uploading the file:{e}", ) - logging.debug(f"engine_up?: {app.chat_engine is not None}") + logging.debug(f"engine_up?: {app.state.chat_engine is not None}") logging.debug(f"message: {message}") return TextSummaryModel( file_name=file_name, @@ -254,18 +200,18 @@ async def upload_file( @app.post("/qa_text", response_model=QAResponseModel) -async def qa_text(question: QuestionModel): - logging.debug(f"engine_up?: {app.chat_engine is not None}") +async def qa_text(question: QuestionModel) -> QAResponseModel: + logging.debug(f"engine_up?: {app.state.chat_engine is not None}") if not question.prompt: raise EmptyQuestionException( "Your Question is empty, please type a message and resend it." ) - if app.chat_engine: - app.token_counter.reset_counts() - app.chat_engine.update_temp(question.temperature) - response = app.chat_engine.answer_question(question) + if app.state.chat_engine: + app.state.token_counter.reset_counts() + app.state.chat_engine.update_temp(question.temperature) + response = app.state.chat_engine.answer_question(question) ai_answer = str(response) - used_tokens = app.token_counter.total_llm_token_count + used_tokens = app.state.token_counter.total_llm_token_count else: ai_answer = "Sorry, no context loaded. Please upload a file or url." used_tokens = 0 @@ -279,22 +225,22 @@ async def qa_text(question: QuestionModel): @app.get("/clear_storage", response_model=TextResponseModel) async def clear_storage(): - if app.chat_engine: - app.chat_engine.clear_data_storage() + if app.state.chat_engine: + app.state.chat_engine.clear_data_storage() logging.info("chat engine cleared...") if (cfd / "data").exists(): for file in Path(cfd / "data").iterdir(): os.remove(file) - app.chat_engine = None - app.token_counter = None - app.callback_manager = None + app.state.chat_engine = None + app.state.token_counter = None + app.state.callback_manager = None return TextResponseModel(message="Knowledge base succesfully cleared") @app.get("/clear_history", response_model=TextResponseModel) async def clear_history(): - if app.chat_engine: - message = app.chat_engine.clear_chat_history() + if app.state.chat_engine: + message = app.state.chat_engine.clear_chat_history() # logging.debug("chat history cleared...") return TextResponseModel(message=message) return TextResponseModel( @@ -310,13 +256,13 @@ async def clear_history(): }, ) def get_quiz(): - if not app.chat_engine or not app.chat_engine.vector_index.ref_doc_info: + if not app.state.chat_engine or not app.state.chat_engine.vector_index.ref_doc_info: raise HTTPException( status_code=400, detail="No context provided, please provide a url or a text file!", ) - if app.chat_engine.data_category == "database": + if app.state.chat_engine.data_category == "database": raise HTTPException( status_code=400, detail="""A database is loaded, but no valid context for a quiz. @@ -344,7 +290,7 @@ def generate_quiz_from_context(): from llama_index.prompts import PromptTemplate from llama_index.response import Response - vector_index = app.chat_engine.vector_index + vector_index = app.state.chat_engine.vector_index lc_output_parser = PydanticOutputParser(pydantic_object=MultipleChoiceTest) output_parser = LangchainOutputParser(lc_output_parser) diff --git a/backend/models.py b/backend/models.py new file mode 100644 index 0000000..1eef414 --- /dev/null +++ b/backend/models.py @@ -0,0 +1,66 @@ +from pydantic import BaseModel, Field + + +class DoubleUploadException(Exception): + pass + + +class NoUploadException(Exception): + pass + + +class EmptyQuestionException(Exception): + pass + + +class TextSummaryModel(BaseModel): + file_name: str + text_category: str + summary: str + used_tokens: int + + +class QuestionModel(BaseModel): + prompt: str + temperature: float + + +class QAResponseModel(BaseModel): + user_question: str + ai_answer: str + used_tokens: int + + +class TextResponseModel(BaseModel): + message: str + + +class MultipleChoiceQuestion(BaseModel): + """Data Model for a multiple choice question""" + + question: str = Field( + ..., + description="""An interesting and unique question related to the main + subject of the article. + """, + ) + correct_answer: str = Field(..., description="Correct answer to question") + wrong_answer_1: str = Field( + ..., description="a unique wrong answer to the question" + ) + wrong_answer_2: str = Field( + ..., + description="""a unique wrong answer to the question which is different + from wrong_answer_1 and not an empty string + """, + ) + + +class MultipleChoiceTest(BaseModel): + """Data Model for a multiple choice test""" + + questions: list[MultipleChoiceQuestion] = [] + + +class ErrorResponse(BaseModel): + detail: str diff --git a/backend/mypy.ini b/backend/mypy.ini new file mode 100644 index 0000000..2aa6e3b --- /dev/null +++ b/backend/mypy.ini @@ -0,0 +1,14 @@ +[mypy] +python_version = 3.10 +disallow_incomplete_defs = True +no_implicit_optional = True +ignore_missing_imports = True +exclude = (?x)( + backend/outdated/* + | \/*venv\/* + | backend/data/* + | backend/outdated/* + | backend/storage/* + ) + + diff --git a/backend/script_RAG.py b/backend/script_RAG.py index 7b2db24..f528108 100644 --- a/backend/script_RAG.py +++ b/backend/script_RAG.py @@ -1,3 +1,8 @@ +import pathlib +import tiktoken +import logging +import os + from llama_index import ( SimpleWebPageReader, VectorStoreIndex, @@ -9,7 +14,7 @@ get_response_synthesizer, ) from llama_index.readers import BeautifulSoupWebReader - +from llama_index.schema import Document from llama_index.llms import OpenAI from llama_index.node_parser import SimpleNodeParser from llama_index.text_splitter import TokenTextSplitter @@ -24,19 +29,15 @@ from llama_index.chat_engine.condense_question import CondenseQuestionChatEngine from llama_index.callbacks import CallbackManager, TokenCountingHandler from llama_index.memory import ChatMemoryBuffer - from llama_index.vector_stores.types import MetadataInfo, VectorStoreInfo from marvin import ai_model from marvin import settings as marvin_settings from llama_index.bridge.pydantic import BaseModel as LlamaBaseModel from llama_index.bridge.pydantic import Field as LlamaField -import pathlib -import tiktoken -import logging -import os from .document_categories import CATEGORY_LABELS +from .models import QuestionModel marvin_settings.openai.api_key = os.getenv("OPENAI_API_KEY") @@ -54,7 +55,7 @@ def __init__( document_name: str, llm_str: str, callback_manager: CallbackManager | None = None, - ): + ) -> None: self.callback_manager: CallbackManager | None = callback_manager self.document = self._load_document(document_name) self.nodes = self.split_document_and_extract_metadata(llm_str) @@ -64,7 +65,7 @@ def __init__( question about "{text_subject}".' @classmethod - def _load_document(cls, identifier: str): + def _load_document(cls, identifier: str) -> Document: """loads only the data of the specified name identifier: name of the text file as str @@ -129,7 +130,7 @@ class AIMarvinDocument(LlamaBaseModel): class AIPdfDocument(AITextDocument): @classmethod - def _load_document(cls, identifier: str): + def _load_document(cls, identifier: str) -> Document: # loader = PDFReader() # return loader.load_data( # file=pathlib.Path(str(AITextDocument.cfd / identifier)) @@ -141,7 +142,7 @@ def _load_document(cls, identifier: str): class AIHtmlDocument(AITextDocument): @classmethod - def _load_document_simplewebpageReader(cls, identifier: str): + def _load_document_simplewebpageReader(cls, identifier: str) -> Document: """loads the data of a simple static website at a given url identifier: url of the html file as str """ @@ -152,7 +153,7 @@ def _load_document_simplewebpageReader(cls, identifier: str): )[0] @classmethod - def _load_document_BeautifulSoupWebReader(cls, identifier: str): + def _load_document_BeautifulSoupWebReader(cls, identifier: str) -> Document: """loads the data of an html file at a given url identifier: url of the html file as str """ @@ -165,7 +166,7 @@ def _load_document_BeautifulSoupWebReader(cls, identifier: str): return BeautifulSoupWebReader().load_data(urls=[identifier])[0] @classmethod - def _load_document(cls, identifier: str): + def _load_document(cls, identifier: str) -> Document: """loads the data of an html file at a given url identifier: url of the html file as str """ @@ -316,7 +317,7 @@ def update_temp(self, temperature): {"temperature": temperature} ) - def answer_question(self, question: str) -> str: + def answer_question(self, question: QuestionModel) -> str: return self.chat_engine.chat(question.prompt) diff --git a/backend/script_SQL_querying.py b/backend/script_SQL_querying.py index 7c319e5..734ffaf 100644 --- a/backend/script_SQL_querying.py +++ b/backend/script_SQL_querying.py @@ -1,16 +1,24 @@ # https://python.langchain.com/docs/expression_language/cookbook/sql_db +from __future__ import annotations import logging import re import sys +from typing import Any +from operator import itemgetter + from langchain.chat_models import ChatOpenAI from langchain.utilities import SQLDatabase from langchain.schema.output_parser import StrOutputParser -from langchain.schema.runnable import RunnableLambda, RunnableMap, RunnablePassthrough +from langchain.schema.runnable import ( + RunnableLambda, + RunnableMap, + RunnablePassthrough, + RunnableSequence, +) from langchain.prompts import ChatPromptTemplate from langchain.callbacks import get_openai_callback -from operator import itemgetter logging.basicConfig(stream=sys.stdout, level=logging.INFO) logging.getLogger(__name__).addHandler(logging.StreamHandler(stream=sys.stdout)) @@ -48,6 +56,16 @@ def __init__(self, *args, **kwargs): + re.sub(r"/\*((.|\n)*?)\*/", "", self.get_table_info()).strip() ) + @classmethod + def from_uri( + cls, database_uri: str, engine_args: dict | None = None, **kwargs: Any + ) -> AIDataBase: + """Construct a SQLAlchemy engine from URI.""" + from sqlalchemy import create_engine + + _engine_args = engine_args or {} + return cls(create_engine(database_uri, **_engine_args), **kwargs) + def get_schema(self, _): return self.get_table_info() @@ -60,10 +78,10 @@ def ask_a_question(self, question: str, token_callback: CustomTokenCounter) -> s llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo") # logging.debug(self.get_table_info()) with get_openai_callback() as callback: - query_generator = ( + query_generator: RunnableSequence[Any, Any] = ( RunnableMap( { - "schema": RunnableLambda(self.get_schema), + "schema": RunnableLambda(self.get_schema), # type: ignore "question": itemgetter("question"), } ) @@ -106,10 +124,10 @@ def ask_a_question(self, question: str, token_callback: CustomTokenCounter) -> s class DataChatBotWrapper: def __init__(self, callback_manager: CustomTokenCounter): self.data_category: str = "database" - self.token_callback = callback_manager - self.document: AIDataBase = None + self.token_callback: CustomTokenCounter = callback_manager + self.document: AIDataBase | None = None - def add_document(self, document) -> None: + def add_document(self, document: AIDataBase) -> None: self.document = document def clear_chat_history(self) -> str: @@ -118,13 +136,15 @@ def clear_chat_history(self) -> str: def clear_data_storage(self) -> None: del self.document self.document = None - # ToDo delete db file ? - def update_temp(self, temperature) -> None: + def update_temp(self, temperature) -> None: # type: ignore pass def answer_question(self, question: str) -> str: - return self.document.ask_a_question(question, self.token_callback) + if self.document: + return self.document.ask_a_question(question, self.token_callback) + else: + raise AttributeError("no document loaded") def set_up_database_chatbot(): @@ -151,7 +171,8 @@ def set_up_database_chatbot(): chat_engine, callback_manager, token_counter = set_up_database_chatbot() - document: AIDataBase = AIDataBase.from_uri("sqlite:///data/database.sqlite") + document = AIDataBase.from_uri("sqlite:///data/database.sqlite") + print("________") print(document.summary) From cb60280d1f59a74430c369f37670587890b6d0e0 Mon Sep 17 00:00:00 2001 From: Mira Date: Mon, 25 Mar 2024 21:40:11 +0100 Subject: [PATCH 4/5] fix deprecated typing.Callable function annotation --- frontend/utils/helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/frontend/utils/helpers.py b/frontend/utils/helpers.py index e00fafe..0e68cbc 100644 --- a/frontend/utils/helpers.py +++ b/frontend/utils/helpers.py @@ -1,4 +1,5 @@ -from typing import Callable, TypeVar +from typing import TypeVar +from collections.abc import Callable F = TypeVar("F", bound=Callable[..., None]) From a32a39c5f45c036e5c9a39ab38fbc6c52b7178b5 Mon Sep 17 00:00:00 2001 From: Mira Date: Wed, 27 Mar 2024 09:27:10 +0100 Subject: [PATCH 5/5] update tests and add mypy static type checking for them --- .pre-commit-config.yaml | 5 + backend/fastapi_app.py | 2 +- pytest.ini | 2 +- tests/mypy.ini | 8 ++ tests/test_backend/test_fastapi_app.py | 114 ++++++++---------- tests/test_backend/test_script_RAG.py | 2 +- .../test_backend/test_script_SQL_querying.py | 2 +- 7 files changed, 66 insertions(+), 69 deletions(-) create mode 100644 tests/mypy.ini diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4cecfe4..3b53ee7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,4 +27,9 @@ repos: files: ^backend/ args: [--config-file=backend/mypy.ini] additional_dependencies: [types-requests] + - id: mypy + name: mypy-test + files: ^tests/ + args: [--config-file=tests/mypy.ini] + additional_dependencies: [types-requests] diff --git a/backend/fastapi_app.py b/backend/fastapi_app.py index 5dc47d8..95d4def 100644 --- a/backend/fastapi_app.py +++ b/backend/fastapi_app.py @@ -102,7 +102,7 @@ def load_database_chat_engine() -> None: async def handle_uploadfile( upload_file: UploadFile, ) -> AITextDocument | AIDataBase | AIPdfDocument | None: - if not (file_name := upload_file.filename): + if not (file_name := Path(upload_file.filename).name): return None with open(cfd / data_dir / file_name, "wb") as f: f.write(await upload_file.read()) diff --git a/pytest.ini b/pytest.ini index ba36b78..a3705ad 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,5 @@ [pytest] markers = ai_call: marks tests where any openai API call is used - ai_gpt_35: marks tests where openai API is called for gpt3.5 usage(deselect with '-m "not gpt_35"') + ai_gpt35: marks tests where openai API is called for gpt3.5 usage (deselect with '-m "not gpt_35"') ai_embeddings: marks tests where openai API is called for embedding only \ No newline at end of file diff --git a/tests/mypy.ini b/tests/mypy.ini new file mode 100644 index 0000000..e5b2eff --- /dev/null +++ b/tests/mypy.ini @@ -0,0 +1,8 @@ +[mypy] +python_version = 3.10 +disallow_incomplete_defs = True +no_implicit_optional = True +# install_types = True +exclude = example_upload_files/* +ignore_missing_imports = True + diff --git a/tests/test_backend/test_fastapi_app.py b/tests/test_backend/test_fastapi_app.py index c37078d..0a0ba21 100644 --- a/tests/test_backend/test_fastapi_app.py +++ b/tests/test_backend/test_fastapi_app.py @@ -1,16 +1,15 @@ # python -m tests.test_backend.test_fastapi_app.py import logging import os +import io import pytest - from pathlib import Path -import shutil import sys from fastapi.testclient import TestClient -from backend.fastapi_app import ( - app, +from backend.fastapi_app import app +from backend.models import ( EmptyQuestionException, DoubleUploadException, NoUploadException, @@ -26,6 +25,8 @@ # Todo: put fixtures into conftest.py + + @pytest.fixture def text_file(): file_name = "example.txt" @@ -34,11 +35,11 @@ def text_file(): yield upload_file # clean-up app storage after tests finally: - if app.chat_engine: - app.chat_engine.clear_data_storage() + if app.state.chat_engine: + app.state.chat_engine.clear_data_storage() logging.debug("chat engine cleared...") - app.chat_engine = None - if file := Path(backend_dir / "data" / file_name).is_file(): + app.state.chat_engine = None + if (file := Path(backend_dir / "data" / file_name)).is_file(): os.remove(file) @@ -46,32 +47,27 @@ def text_file(): def url(): yield "https://de.wikipedia.org/wiki/Don’t_repeat_yourself" # clean-up (clear_storage) - if app.chat_engine: - app.chat_engine.clear_data_storage() + if app.state.chat_engine: + app.state.chat_engine.clear_data_storage() logging.debug("chat engine cleared...") - app.chat_engine = None - app.callback_manager = None - app.token_counter = None + app.state.chat_engine = None + app.state.callback_manager = None + app.state.token_counter = None @pytest.fixture -def url_db(): +def db_file(): file_name = "database.sqlite" - app_data_dir = "data" - destination_file = Path(backend_dir / app_data_dir / file_name) - destination_file.parent.mkdir(exist_ok=True, parents=True) - - url_db_copied_to_app = f"sqlite:///{app_data_dir}/{file_name}" - shutil.copy(example_file_dir / file_name, destination_file) - yield url_db_copied_to_app - # clean-up (clear_storage) - if app.chat_engine: - app.chat_engine.clear_data_storage() - logging.debug("chat engine cleared...") - app.chat_engine = None - app.token_counter = None - if destination_file.is_file(): - os.remove(destination_file) + with Path(example_file_dir, file_name).open("rb") as upload_file: + try: + yield upload_file + finally: + if app.state.chat_engine: + app.state.chat_engine.clear_data_storage() + logging.debug("chat engine cleared...") + app.state.chat_engine = None + if (file := Path(backend_dir / "data" / file_name)).is_file(): + os.remove(file) @pytest.mark.ai_call @@ -104,40 +100,30 @@ def test_upload_url_webpage(url): assert data.get("used_tokens", None) is not None -# def test_test_route(): -# #data = {"upload_url": "sqlite:///data/database.sqlite"} -# file_name = "example.txt" -# file = Path(example_file_dir, file_name).open("rb") -# response = client.post( -# "/upload", -# files={"upload_file": ("example.txt", file)} -# ) - -# assert response.status_code == 200 -# assert response.json().get("detail") == "example" - - -def test_upload_url_database(url_db): - data = {"upload_url": url_db} - response = client.post("/upload", data=data, files=None) +def test_upload_database_file(db_file): + db_file_name = Path(db_file.name).name + response = client.post( + "/upload", + data={"upload_url": ""}, + files={"upload_file": (db_file_name, db_file)}, + ) - assert app.chat_engine is not None - assert app.chat_engine.data_category == "database" - # assert app.callback_manager is not None # is None in database mode - assert app.token_counter is not None + assert app.state.chat_engine is not None + assert app.state.chat_engine.data_category == "database" + assert app.state.callback_manager is None # is None in database mode + assert app.state.token_counter is not None assert response.status_code == 200 data = response.json() # test if keys in response and if not None - assert data.get("file_name", None) == url_db - # assert data.get("text_category", None) is not None + assert data.get("file_name", None) == db_file_name assert data.get("text_category") == "database" assert data.get("summary", None) is not None assert len(data.get("summary")) > 13 assert data.get("used_tokens", None) is not None -def test_upload_url_and_file(url: str, text_file): +def test_upload_url_and_file(url: str, text_file: io.BytesIO) -> None: with pytest.raises(DoubleUploadException): client.post( "/upload", @@ -157,20 +143,16 @@ def test_upload_bad_url(): response = client.post("/upload", data={"upload_url": url}, files=None) assert response.status_code == 400 assert response.json() == { - "detail": f"There was a problem with the provided url (MissingSchema): {url}" + "detail": f"""There was a problem with the provided url (MissingSchema): + {url} + """ } -# def test_upload_database(): -# """currently no small ( <200MB) database availabe""" -# response = client.post("/upload") -# assert response.status_code == 200 - - @pytest.mark.ai_call @pytest.mark.ai_gpt35 def test_ask_question_about_given_text(text_file): - """Caution: test takes some time since openai API call required""" + """Caution: openai API call required""" client.post( "/upload", data={"upload_url": ""}, @@ -193,11 +175,13 @@ def test_ask_question_about_given_text(text_file): @pytest.mark.ai_call @pytest.mark.ai_gpt35 -def test_ask_question_about_given_database(url_db): - """Caution: test takes some time since openai API call required""" - data = {"upload_url": url_db} - res1 = client.post("/upload", data=data, files=None) - assert res1.status_code == 200 +def test_ask_question_about_given_database(db_file): + """Caution: openai API call required""" + # data = {"upload_url": url_db} + upload_response = client.post( + "/upload", files={"upload_file": (Path(db_file.name).name, db_file)} + ) + assert upload_response.status_code == 200 response = client.post( "/qa_text", diff --git a/tests/test_backend/test_script_RAG.py b/tests/test_backend/test_script_RAG.py index f2c118d..1f1e086 100644 --- a/tests/test_backend/test_script_RAG.py +++ b/tests/test_backend/test_script_RAG.py @@ -1,4 +1,4 @@ -from script_RAG import ( +from backend.script_RAG import ( set_up_text_chatbot, ) diff --git a/tests/test_backend/test_script_SQL_querying.py b/tests/test_backend/test_script_SQL_querying.py index e26b406..198fba0 100644 --- a/tests/test_backend/test_script_SQL_querying.py +++ b/tests/test_backend/test_script_SQL_querying.py @@ -1,4 +1,4 @@ -from script_SQL_querying import ( +from backend.script_SQL_querying import ( set_up_database_chatbot, )