From 53dfd35fa5cc10543ff93fcdd053663526a7cb6e Mon Sep 17 00:00:00 2001 From: yashbonde Date: Fri, 15 Mar 2024 13:13:30 +0530 Subject: [PATCH 1/2] [chore] remove pycryptodome dependency --- README.md | 13 ++ server/chainfury_server/api/chains.py | 5 +- server/chainfury_server/engine.py | 201 +++++++++++++------------- server/chainfury_server/utils.py | 57 +------- 4 files changed, 125 insertions(+), 151 deletions(-) diff --git a/README.md b/README.md index 6cd6c95..64fe613 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,19 @@ pip install chainfury_server python3 -m chainfury_server ``` +You can also start the server with worker mode for long running chains: +```bash +# in different tab +redis-server + +# in different tab +celery --app chainfury_server.engine.app worker --queues cfs + +# set env variable to enable this +export CFS_ENABLE_CELERY=1 +python3 -m chainfury_server +``` + ### Run Docker Easiest way to run the server is to use docker. You can use the following command to run ChainFury: diff --git a/server/chainfury_server/api/chains.py b/server/chainfury_server/api/chains.py index 822ca8c..16ac631 100644 --- a/server/chainfury_server/api/chains.py +++ b/server/chainfury_server/api/chains.py @@ -245,7 +245,10 @@ def run_chain( if as_task: # when run as a task this will return a task ID that will be submitted - # raise HTTPException(501, detail="Not implemented yet") + if not Env.CFS_ENABLE_CELERY(): + raise HTTPException( + 400, "Celery is not enabled on this server, cannot run as task" + ) result = engine.submit( chatbot=chatbot, prompt=prompt, diff --git a/server/chainfury_server/engine.py b/server/chainfury_server/engine.py index ec5c278..b78af37 100644 --- a/server/chainfury_server/engine.py +++ b/server/chainfury_server/engine.py @@ -13,117 +13,117 @@ from chainfury.utils import SimplerTimes import chainfury_server.database as DB -from chainfury_server.utils import logger +from chainfury_server.utils import logger, Env -from celery import Celery from sqlalchemy.pool import NullPool from sqlalchemy import create_engine +if Env.CFS_ENABLE_CELERY(): + from celery import Celery -app = Celery() + app = Celery() - -@app.task(name="chainfury_server.engine.run_chain") -def run_chain( - chatbot_id: str, - prompt_id: str, - prompt_data: Dict, - store_ir: bool, - store_io: bool, - worker_id: str, -): - start = SimplerTimes.get_now_fp64() - - # create the DB session - sess = DB.get_local_session( - create_engine( - DB.db, - poolclass=NullPool, + @app.task(name="chainfury_server.engine.run_chain") + def run_chain( + chatbot_id: str, + prompt_id: str, + prompt_data: Dict, + store_ir: bool, + store_io: bool, + worker_id: str, + ): + start = SimplerTimes.get_now_fp64() + + # create the DB session + sess = DB.get_local_session( + create_engine( + DB.db, + poolclass=NullPool, + ) ) - ) - db = sess() - - # get the db object - chatbot = db.query(DB.ChatBot).filter(DB.ChatBot.id == chatbot_id).first() # type: ignore - prompt_row: DB.Prompt = db.query(DB.Prompt).filter(DB.Prompt.id == prompt_id).first() # type: ignore - if prompt_row is None: - time.sleep(2) - prompt_row = db.query(DB.Prompt).filter(DB.Prompt.id == prompt_id).first() # type: ignore - if prompt_row is None: - raise RuntimeError(f"Prompt {prompt_id} not found") - - # Create a Fury chain then run the chain while logging all the intermediate steps - dag = T.Dag(**chatbot.dag) # type: ignore - chain = Chain.from_dag(dag, check_server=False) - callback = FuryThoughtsCallback(db, prompt_row.id) - - # print( - # f"starting chain execution: [{prompt_row.meta.get('task_id')=}] [{worker_id=}]" - # ) - iterator = chain.stream( - data=prompt_data, - thoughts_callback=callback, - print_thoughts=False, - ) - mainline_out = "" - last_db = 0 - for ir, done in iterator: - if done: - mainline_out = ir - break - - if store_ir: - # in case of stream, every item is a fundamentally a step - data = { - "outputs": [ - { - "name": k.split("/")[-1], - "data": v, - } - for k, v in ir.items() - ] - } - k = next(iter(ir)).split("/")[0] - db_chainlog = DB.ChainLog( - prompt_id=prompt_row.id, - created_at=SimplerTimes.get_now_datetime(), - node_id=k, - worker_id=worker_id, - message="step", - data=data, - ) # type: ignore - db.add(db_chainlog) + db = sess() - # update the DB every 5 seconds - if time.time() - last_db > 5: - db.commit() - last_db = time.time() + # get the db object + chatbot = db.query(DB.ChatBot).filter(DB.ChatBot.id == chatbot_id).first() # type: ignore + prompt_row: DB.Prompt = db.query(DB.Prompt).filter(DB.Prompt.id == prompt_id).first() # type: ignore + if prompt_row is None: + time.sleep(2) + prompt_row = db.query(DB.Prompt).filter(DB.Prompt.id == prompt_id).first() # type: ignore + if prompt_row is None: + raise RuntimeError(f"Prompt {prompt_id} not found") + + # Create a Fury chain then run the chain while logging all the intermediate steps + dag = T.Dag(**chatbot.dag) # type: ignore + chain = Chain.from_dag(dag, check_server=False) + callback = FuryThoughtsCallback(db, prompt_row.id) + + # print( + # f"starting chain execution: [{prompt_row.meta.get('task_id')=}] [{worker_id=}]" + # ) + iterator = chain.stream( + data=prompt_data, + thoughts_callback=callback, + print_thoughts=False, + ) + mainline_out = "" + last_db = 0 + for ir, done in iterator: + if done: + mainline_out = ir + break - result = T.ChainResult( - result=str(mainline_out), - prompt_id=prompt_row.id, # type: ignore - ) + if store_ir: + # in case of stream, every item is a fundamentally a step + data = { + "outputs": [ + { + "name": k.split("/")[-1], + "data": v, + } + for k, v in ir.items() + ] + } + k = next(iter(ir)).split("/")[0] + db_chainlog = DB.ChainLog( + prompt_id=prompt_row.id, + created_at=SimplerTimes.get_now_datetime(), + node_id=k, + worker_id=worker_id, + message="step", + data=data, + ) # type: ignore + db.add(db_chainlog) + + # update the DB every 5 seconds + if time.time() - last_db > 5: + db.commit() + last_db = time.time() + + result = T.ChainResult( + result=str(mainline_out), + prompt_id=prompt_row.id, # type: ignore + ) - db_chainlog = DB.ChainLog( - prompt_id=prompt_row.id, - created_at=SimplerTimes.get_now_datetime(), - node_id="end", - worker_id=worker_id, - message="completed", - ) # type: ignore - db.add(db_chainlog) + db_chainlog = DB.ChainLog( + prompt_id=prompt_row.id, + created_at=SimplerTimes.get_now_datetime(), + node_id="end", + worker_id=worker_id, + message="completed", + ) # type: ignore + db.add(db_chainlog) - # commit the prompt to DB - if store_io: - prompt_row.response = result.result # type: ignore - prompt_row.time_taken = float(time.time() - start) # type: ignore + # commit the prompt to DB + if store_io: + prompt_row.response = result.result # type: ignore + prompt_row.time_taken = float(time.time() - start) # type: ignore - # update the DB after sleeping a bit - st = time.time() - last_db - if st < 2: - time.sleep(2 - st) # be nice to the db - db.commit() + # update the DB after sleeping a bit + st = time.time() - last_db + if st < 2: + time.sleep(2 - st) # be nice to the db + db.commit() class FuryEngine: @@ -304,6 +304,11 @@ def submit( raise HTTPException( status_code=400, detail="prompt cannot have both new_message and data" ) + if not Env.CFS_ENABLE_CELERY(): + raise HTTPException( + status_code=400, + detail="submit is only available when using celery. Set CFS_ENABLE_CELERY=1 in the environment", + ) try: logger.debug("Adding prompt to database") prompt_row = create_prompt(db, chatbot.id, prompt.new_message if store_io else "", prompt.session_id) # type: ignore diff --git a/server/chainfury_server/utils.py b/server/chainfury_server/utils.py index 22c5832..2e37b0d 100644 --- a/server/chainfury_server/utils.py +++ b/server/chainfury_server/utils.py @@ -1,8 +1,7 @@ # Copyright © 2023- Frello Technology Private Limited import os -from Cryptodome.Cipher import AES -from base64 import b64decode, b64encode +from snowflake import SnowflakeGenerator # WARNING: do not import anything from anywhere here, this is the place where chainfury_server starts. # importing anything can cause the --pre and --post flags to fail when starting server. @@ -38,6 +37,7 @@ class Env: ] CFS_DISABLE_UI = lambda: os.getenv("CFS_DISABLE_UI", "0") == "1" CFS_DISABLE_DOCS = lambda: os.getenv("CFS_DISABLE_DOCS", "0") == "1" + CFS_ENABLE_CELERY = lambda: os.getenv("CFS_ENABLE_CELERY", "0") == "1" def folder(x: str) -> str: @@ -50,57 +50,10 @@ def joinp(x: str, *args) -> str: return os.path.join(x, *args) -class Crypt: - - def __init__(self, salt="SlTKeYOpHygTYkP3"): - self.salt = salt.encode("utf8") - self.enc_dec_method = "utf-8" - - def encrypt(self, str_to_enc, str_key): - try: - aes_obj = AES.new(str_key.encode("utf-8"), AES.MODE_CFB, self.salt) - hx_enc = aes_obj.encrypt(str_to_enc.encode("utf8")) - mret = b64encode(hx_enc).decode(self.enc_dec_method) - return mret - except ValueError as value_error: - if value_error.args[0] == "IV must be 16 bytes long": - raise ValueError("Encryption Error: SALT must be 16 characters long") - elif ( - value_error.args[0] == "AES key must be either 16, 24, or 32 bytes long" - ): - raise ValueError( - "Encryption Error: Encryption key must be either 16, 24, or 32 characters long" - ) - else: - raise ValueError(value_error) - - def decrypt(self, enc_str, str_key): - try: - aes_obj = AES.new(str_key.encode("utf8"), AES.MODE_CFB, self.salt) - str_tmp = b64decode(enc_str.encode(self.enc_dec_method)) - str_dec = aes_obj.decrypt(str_tmp) - mret = str_dec.decode(self.enc_dec_method) - return mret - except ValueError as value_error: - if value_error.args[0] == "IV must be 16 bytes long": - raise ValueError("Decryption Error: SALT must be 16 characters long") - elif ( - value_error.args[0] == "AES key must be either 16, 24, or 32 bytes long" - ): - raise ValueError( - "Decryption Error: Encryption key must be either 16, 24, or 32 characters long" - ) - else: - raise ValueError(value_error) - - -CURRENT_EPOCH_START = 1705905900000 # UTC timezone -"""Start of the current epoch, used for generating snowflake ids""" - -from snowflake import SnowflakeGenerator - - class SFGen: + CURRENT_EPOCH_START = 1705905900000 # UTC timezone + """Start of the current epoch, used for generating snowflake ids""" + def __init__(self, instance, epoch=CURRENT_EPOCH_START): self.gen = SnowflakeGenerator(instance, epoch=epoch) From 42219c5a24914f6c89e431803585daa9e7e23175 Mon Sep 17 00:00:00 2001 From: yashbonde Date: Fri, 15 Mar 2024 13:51:45 +0530 Subject: [PATCH 2/2] [chore] remove qdrant + improve server code --- chainfury/components/qdrant/__init__.py | 353 ------------------------ pyproject.toml | 8 +- server/chainfury_server/api/chains.py | 37 +-- server/chainfury_server/api/user.py | 2 +- 4 files changed, 12 insertions(+), 388 deletions(-) delete mode 100644 chainfury/components/qdrant/__init__.py diff --git a/chainfury/components/qdrant/__init__.py b/chainfury/components/qdrant/__init__.py deleted file mode 100644 index dd29d1f..0000000 --- a/chainfury/components/qdrant/__init__.py +++ /dev/null @@ -1,353 +0,0 @@ -# Copyright © 2023- Frello Technology Private Limited - -from uuid import uuid4 -from functools import lru_cache -from typing import List, Dict, Tuple, Optional, Union - -try: - from qdrant_client import models, QdrantClient - - QDRANT_CLIENT_INSTALLED = True -except ImportError: - QDRANT_CLIENT_INSTALLED = False - -from chainfury import Secret, memory_registry, logger -from chainfury.components.const import Env, ComponentMissingError - -# https://qdrant.tech/documentation/concepts/filtering -# Must : "must" : AND -# Should : "should" : OR -# Must Not: "must_not" : NOT -# Match: = -# Match Any: IN -# Match Except: NOT IN - - -@lru_cache(maxsize=1) -def _get_qdrant_client( - qdrant_url: Secret = Secret(), qdrant_api_key: Secret = Secret() -): - """Create a qdrant client and cache it - - Args: - qdrant_url (Secret, optional): qdrant url or set env var `QDRANT_API_URL`. - qdrant_api_key (Secret, optional): qdrant api key or set env var `QDRANT_API_KEY`. - - Returns: - qdrant_client.QdrantClient: qdrant client - """ - qdrant_url = Secret(Env.QDRANT_API_URL(qdrant_url.value)).value # type: ignore - qdrant_api_key = Secret(Env.QDRANT_API_KEY(qdrant_api_key.value)).value # type: ignore - if not qdrant_url: - raise Exception( - "Qdrant URL is not set. Please pass `qdrant_url` or env var `QDRANT_API_URL=`" - ) - if not qdrant_api_key: - raise Exception( - "Qdrant API KEY is not set. Please pass `qdrant_api_key` or env var `QDRANT_API_KEY=`" - ) - logger.info("Creating Qdrant client") - return QdrantClient(url=qdrant_url, api_key=qdrant_api_key) # type: ignore - - -def qdrant_write( - embeddings: List[List[float]], - collection_name: str, - qdrant_url: Secret = Secret(""), - qdrant_api_key: Secret = Secret(""), - extra_payload: List[Dict[str, str]] = [], - wait: bool = True, - create_if_not_present: bool = True, - distance: str = "cosine", -) -> Tuple[str, Optional[Exception]]: - """ - Write to the Qdrant DB using the Qdrant client. In order to use this, access via the `memory_registry`: - - Example: - >>> from chainfury import memory_registry - >>> mem = memory_registry.get_write("qdrant") - >>> sentence = "C.P. Cavafy is widely considered the most distinguished Greek poet of the 20th century." - >>> out, err = mem( - { - "items": [sentence], - "extra_payload": [ - {"data": sentence}, - ], - "collection_name": "my_test_collection", - "embedding_model": "openai-embedding", - "create_if_not_present": True, - } - ) - >>> if err: - print("TRACE:", out) - else: - print(out) - - Args: - embeddings (List[List[float]]): list of embeddings - collection_name (str): collection name - qdrant_url (Secret, optional): qdrant url or set env var `QDRANT_API_URL`. - qdrant_api_key (Secret, optional): qdrant api key or set env var `QDRANT_API_KEY`. - extra_payload (List[Dict[str, str]], optional): extra payload. Defaults to []. - wait (bool, optional): wait for the response. Defaults to True. - create_if_not_present (bool, optional): create collection if not present. Defaults to True. - distance (str, optional): distance metric. Defaults to "cosine". - - Returns: - Tuple[str, Optional[Exception]]: status and error - """ - # client check - if not QDRANT_CLIENT_INSTALLED: - raise ComponentMissingError( - "Qdrant client is not installed. Please install it with `pip install qdrant-client`" - ) - - # checks - if not (len(embeddings) and len(embeddings[0]) and type(embeddings[0][0]) == float): - raise Exception("Embeddings should be a list of lists of floats") - if extra_payload and len(extra_payload) != len(embeddings): - raise Exception("Length of extra_payload should be equal to embeddings") - - client: QdrantClient = _get_qdrant_client(qdrant_url, qdrant_api_key) # type: ignore - - # next we create points and upsert them into the DB - points = [] - for i, embedding in enumerate(embeddings): - payload = {} - if extra_payload: - payload = extra_payload[i] - points.append(models.PointStruct(id=str(uuid4()), payload=payload, vector=embedding)) # type: ignore - batch = models.Batch( # type: ignore - ids=[point.id for point in points], - vectors=[point.vector for point in points], - payloads=[point.payload for point in points], - ) - - def _insert(): - try: - result = client.upsert( - collection_name=collection_name, - points=batch, - wait=wait, - ) - except Exception as e: - return e.content, e # type: ignore - return result.status.lower(), None - - status, err = _insert() - if err and err.status_code == 404 and create_if_not_present: # type: ignore - collection = client.recreate_collection( - collection_name=collection_name, - vectors_config=models.VectorParams( # type: ignore - size=len(embeddings[0]), - distance=getattr(models.Distance, distance.upper()), # type: ignore - ), - ) - logger.info(f"Created collection {collection}") - status, err = _insert() - return status, err - - -memory_registry.register_write( - component_name="qdrant", - fn=qdrant_write, - outputs={"status": 0}, - vector_key="embeddings", - description="Write to the Qdrant DB using the Qdrant client", -) - - -def qdrant_read( - embeddings: List[List[float]], - collection_name: str, - cutoff_score: float = 0.0, - top: int = 5, - limit: int = 0, - offset: int = 0, - filters: Dict[str, Dict[str, str]] = {}, - qdrant_url: Secret = Secret(""), - qdrant_api_key: Secret = Secret(""), - qdrant_search_hnsw_ef: int = 0, - qdrant_search_exact: bool = False, - batch_search: bool = False, -) -> Tuple[Dict[str, List[Dict[str, Union[float, int]]]], Optional[Exception]]: - """ - Read from the Qdrant DB using the Qdrant client. In order to use this access via the `memory_registry`: - - Example: - >>> from chainfury import memory_registry - >>> mem = memory_registry.get_read("qdrant") - >>> sentence = "Who was the Cafavy?" - >>> out, err = mem( - { - "items": [sentence], - "collection_name": "my_test_collection", - "embedding_model": "openai-embedding" - } - ) - >>> if err: - print("TRACE:", out) - else: - print(out) - - Note: - `batch_search` is not implemented yet. There's some issues from the `qdrant_client` library. - - Args: - embeddings (List[List[float]]): list of embeddings - collection_name (str): collection name - cutoff_score (float, optional): cutoff score. Defaults to 0.0. - limit (int, optional): limit. Defaults to 3. - offset (int, optional): offset. Defaults to 0. - qdrant_url (Secret, optional): qdrant url or set env var `QDRANT_API_URL`. - qdrant_api_key (Secret, optional): qdrant api key or set env var `QDRANT_API_KEY`. - qdrant_search_hnsw_ef (int, optional): qdrant search beam size, the larger the beam size the more accurate the search, - if not set uses default value. - qdrant_search_exact (bool, optional): qdrant search exact. Defaults to False. - batch_search (bool, optional): batch search. Defaults to False. - - Returns: - Tuple[List[Dict[str, Union[float, int]]], Optional[Exception]]: list of results and error - """ - # client check - if not QDRANT_CLIENT_INSTALLED: - raise ComponentMissingError( - "Qdrant client is not installed. Please install it with `pip install qdrant-client`" - ) - - # checks - if not (len(embeddings) and len(embeddings[0]) and type(embeddings[0][0]) == float): - raise Exception("Embeddings should be a list of lists of floats") - if batch_search: - raise NotImplementedError("Batch search is not implemented yet") - if not batch_search and len(embeddings) > 1: - raise Exception( - "Batch search is not enabled, but multiple embeddings are passed" - ) - if not top and not limit: - raise Exception("Either top or limit should be set") - - client: QdrantClient = _get_qdrant_client(qdrant_url, qdrant_api_key) # type: ignore - - search_params = models.SearchParams() # type: ignore - if qdrant_search_hnsw_ef: - search_params.hnsw_ef = qdrant_search_hnsw_ef - if qdrant_search_exact: - search_params.exact = qdrant_search_exact - - if batch_search: - # this is not implemented, this fails when we try to pass a list of vectors - search_queries = [ - models.SearchRequest( - vector=x, limit=limit, offset=offset, params=search_params - ) - for x in embeddings - ] - out = client.search_batch( - collection_name=collection_name, - requests=search_queries, - ) - res = [[_x.dict(skip_defaults=False) for _x in x] for x in out] # type: ignore - - query_filter = None - if filters: - query_filter = models.Filter(**filters) # type: ignore - - out = client.search( - collection_name=collection_name, - query_vector=embeddings[0], - query_filter=query_filter, - limit=max(limit, top), - offset=offset, - search_params=search_params, - ) - out = [x for x in out if x.score > cutoff_score] - res = [_x.dict(skip_defaults=False) for _x in out] # type: ignore - return {"data": res}, None - - -memory_registry.register_read( - component_name="qdrant", - fn=qdrant_read, - outputs={"items": 0}, - vector_key="embeddings", - description="Function to read from the Qdrant DB using the Qdrant client", -) - - -# helper functions - - -def recreate_collection(collection_name: str, embedding_dim: int) -> bool: - """ - Deletes and recreates a collection - - Note: - This will delete all the data in the collection, use with caution - - Args: - collection_name (str): collection name - embedding_dim (int): embedding dimension - - Returns: - bool: success - """ - client: QdrantClient = _get_qdrant_client() # type: ignore - return client.recreate_collection( - collection_name=collection_name, - vectors_config=models.VectorParams( # type: ignore - size=embedding_dim, - distance=models.Distance.COSINE, # type: ignore - ), - optimizers_config=models.OptimizersConfigDiff( # type: ignore - indexing_threshold=0, - ), - ) - - -def enable_indexing(collection_name: str, indexing_threshold: int = 20000) -> bool: - """ - Enable indexing for a collection, use this in conjunction with `disable_indexing`. Read more - `here `_. - - Example: - >>> from chainfury.components.qdrant import enable_indexing, disable_indexing, qdrant_write - >>> disable_indexing("my_collection") - >>> qdrant_write([[1, 2, 3] for _ in range(100)], "my_collection") - >>> enable_indexing("my_collection") - - - Args: - collection_name (str): collection name - indexing_threshold (int, optional): indexing threshold. Defaults to 20000. - - Returns: - bool: success - """ - client: QdrantClient = _get_qdrant_client() # type: ignore - return client.update_collection( - collection_name=collection_name, - optimizer_config=models.OptimizersConfigDiff( # type: ignore - indexing_threshold=indexing_threshold, - ), - ) - - -def disable_indexing(collection_name: str): - """ - Disable indexing for a collection, use this in conjunction with `enable_indexing`. Read more - `here `_. - - Args: - collection_name (str): collection name - - Returns: - bool: success - """ - client: QdrantClient = _get_qdrant_client() # type: ignore - return client.update_collection( - collection_name=collection_name, - optimizer_config=models.OptimizersConfigDiff( # type: ignore - indexing_threshold=0, - ), - ) diff --git a/pyproject.toml b/pyproject.toml index 05bedbe..2e44725 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,15 +18,11 @@ requests = "^2.31.0" python-dotenv = "1.0.0" urllib3 = ">=1.26.18" tabulate = "0.9.0" -"cryptography" = ">=41.0.6" -stability-sdk = { version = "0.8.3", optional = true } -qdrant-client = { version = "1.5.4", optional = true } +cryptography = ">=41.0.6" boto3 = { version = "1.29.6", optional = true } [tool.poetry.extras] -all = ["stability-sdk", "qdrant-client", "boto3"] -stability = ["stability-sdk"] -qdrant = ["qdrant-client"] +all = ["boto3"] [tool.poetry.group.dev.dependencies] sphinx = "7.2.5" diff --git a/server/chainfury_server/api/chains.py b/server/chainfury_server/api/chains.py index 16ac631..7575ab1 100644 --- a/server/chainfury_server/api/chains.py +++ b/server/chainfury_server/api/chains.py @@ -16,8 +16,6 @@ def create_chain( - req: Request, - resp: Response, token: Annotated[str, Header()], chatbot_data: T.ApiCreateChainRequest, db: Session = Depends(DB.fastapi_db_session), @@ -27,8 +25,7 @@ def create_chain( # validate chatbot if not chatbot_data.name: - resp.status_code = 400 - return T.ApiResponse(message="Name not specified") + raise HTTPException(status_code=400, detail="Name not specified") if chatbot_data.dag: for n in chatbot_data.dag.nodes: if len(n.id) > Env.CFS_MAXLEN_CF_NODE(): @@ -55,8 +52,6 @@ def create_chain( def get_chain( - req: Request, - resp: Response, token: Annotated[str, Header()], id: str, tag_id: str = "", @@ -75,16 +70,13 @@ def get_chain( filters.append(DB.ChatBot.tag_id == tag_id) chatbot: DB.ChatBot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore if not chatbot: - resp.status_code = 404 - return T.ApiResponse(message="ChatBot not found") + raise HTTPException(status_code=404, detail="Chain not found") # return return chatbot.to_ApiChain() def update_chain( - req: Request, - resp: Response, token: Annotated[str, Header()], id: str, chatbot_data: T.ApiChain, @@ -96,14 +88,14 @@ def update_chain( # validate chatbot update if not len(chatbot_data.update_keys): - resp.status_code = 400 - return T.ApiResponse(message="No keys to update") + raise HTTPException(status_code=400, detail="No keys to update") unq_keys = set(chatbot_data.update_keys) valid_keys = {"name", "description", "dag"} if not unq_keys.issubset(valid_keys): - resp.status_code = 400 - return T.ApiResponse(message=f"Invalid keys {unq_keys.difference(valid_keys)}") + raise HTTPException( + status_code=400, detail=f"Invalid keys {unq_keys.difference(valid_keys)}" + ) # DB Call filters = [ @@ -115,8 +107,7 @@ def update_chain( filters.append(DB.ChatBot.tag_id == tag_id) chatbot: DB.ChatBot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore if not chatbot: - resp.status_code = 404 - return T.ApiResponse(message="ChatBot not found") + raise HTTPException(status_code=404, detail="Chain not found") for field in unq_keys: if field == "name": @@ -133,8 +124,6 @@ def update_chain( def delete_chain( - req: Request, - resp: Response, token: Annotated[str, Header()], id: str, tag_id: str = "", @@ -153,8 +142,7 @@ def delete_chain( filters.append(DB.ChatBot.tag_id == tag_id) chatbot: DB.ChatBot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore if not chatbot: - resp.status_code = 404 - return T.ApiResponse(message="ChatBot not found") + raise HTTPException(status_code=404, detail="Chain not found") chatbot.deleted_at = datetime.now() db.commit() @@ -163,8 +151,6 @@ def delete_chain( def list_chains( - req: Request, - resp: Response, token: Annotated[str, Header()], skip: int = 0, limit: int = 10, @@ -190,8 +176,6 @@ def list_chains( def run_chain( - req: Request, - resp: Response, id: str, token: Annotated[str, Header()], prompt: T.ApiPromptBody, @@ -234,8 +218,7 @@ def run_chain( ] chatbot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore if not chatbot: - resp.status_code = 404 - return T.ApiResponse(message="ChatBot not found") + raise HTTPException(status_code=404, detail="Chain not found") # call the engine engine = FuryEngine() @@ -292,8 +275,6 @@ def _get_streaming_response(result): def get_chain_metrics( - req: Request, - resp: Response, id: str, token: Annotated[str, Header()], db: Session = Depends(DB.fastapi_db_session), diff --git a/server/chainfury_server/api/user.py b/server/chainfury_server/api/user.py index 633124c..77aacc8 100644 --- a/server/chainfury_server/api/user.py +++ b/server/chainfury_server/api/user.py @@ -85,8 +85,8 @@ def change_password( def create_secret( - token: Annotated[str, Header()], inputs: T.ApiToken, + token: Annotated[str, Header()], db: Session = Depends(DB.fastapi_db_session), ) -> T.ApiResponse: # validate user