From 518a8a726a800a7fbad354fa60cd0fc21f66d335 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Sat, 4 Jan 2025 02:21:37 +0100 Subject: [PATCH 1/2] Added servers protection using an API key to restrict access to only authenticated entities. --- README.md | 5 ++ lightrag/api/azure_openai_lightrag_server.py | 83 +++++++++++++++--- lightrag/api/lollms_lightrag_server.py | 79 ++++++++++++++--- lightrag/api/ollama_lightrag_server.py | 79 ++++++++++++++--- lightrag/api/openai_lightrag_server.py | 91 +++++++++++++++++--- 5 files changed, 292 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index fbd0fdf6..28991389 100644 --- a/README.md +++ b/README.md @@ -1025,6 +1025,7 @@ Each server has its own specific configuration options: | --max-embed-tokens | 8192 | Maximum embedding token size | | --input-file | ./book.txt | Initial input file | | --log-level | INFO | Logging level | +| --key | none | Access Key to protect the lightrag service | #### Ollama Server Options @@ -1042,6 +1043,7 @@ Each server has its own specific configuration options: | --max-embed-tokens | 8192 | Maximum embedding token size | | --input-file | ./book.txt | Initial input file | | --log-level | INFO | Logging level | +| --key | none | Access Key to protect the lightrag service | #### OpenAI Server Options @@ -1056,6 +1058,7 @@ Each server has its own specific configuration options: | --max-embed-tokens | 8192 | Maximum embedding token size | | --input-dir | ./inputs | Input directory for documents | | --log-level | INFO | Logging level | +| --key | none | Access Key to protect the lightrag service | #### OpenAI AZURE Server Options @@ -1071,8 +1074,10 @@ Each server has its own specific configuration options: | --input-dir | ./inputs | Input directory for documents | | --enable-cache | True | Enable response cache | | --log-level | INFO | Logging level | +| --key | none | Access Key to protect the lightrag service | +For protecting the server using an authentication key, you can also use an environment variable named `LIGHTRAG_API_KEY`. ### Example Usage #### LoLLMs RAG Server diff --git a/lightrag/api/azure_openai_lightrag_server.py b/lightrag/api/azure_openai_lightrag_server.py index 6b20caac..a145d6d6 100644 --- a/lightrag/api/azure_openai_lightrag_server.py +++ b/lightrag/api/azure_openai_lightrag_server.py @@ -20,6 +20,19 @@ import inspect import json from fastapi.responses import StreamingResponse +from fastapi import FastAPI, HTTPException +import os +from typing import Optional + +from fastapi import FastAPI, Depends, HTTPException, Security +from fastapi.security import APIKeyHeader +import os +import argparse +from typing import Optional +from fastapi.middleware.cors import CORSMiddleware + +from starlette.status import HTTP_403_FORBIDDEN +from fastapi import HTTPException load_dotenv() @@ -93,6 +106,9 @@ def parse_args(): help="Logging level (default: INFO)", ) + parser.add_argument('--key', type=str, help='API key for authentication. This protects lightrag server against unauthorized access', default=None) + + return parser.parse_args() @@ -154,6 +170,31 @@ class InsertResponse(BaseModel): message: str document_count: int +def get_api_key_dependency(api_key: Optional[str]): + if not api_key: + # If no API key is configured, return a dummy dependency that always succeeds + async def no_auth(): + return None + return no_auth + + # If API key is configured, use proper authentication + api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): + if not api_key_header_value: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail="API Key required" + ) + if api_key_header_value != api_key: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail="Invalid API Key" + ) + return api_key_header_value + + return api_key_auth + async def get_embedding_dim(embedding_model: str) -> int: """Get embedding dimensions for the specified model""" @@ -168,12 +209,30 @@ def create_app(args): format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) ) - # Initialize FastAPI app + + # Check if API key is provided either through env var or args + api_key = os.getenv("LIGHTRAG_API_KEY") or args.key + + # Initialize FastAPI app = FastAPI( title="LightRAG API", - description="API for querying text using LightRAG with OpenAI integration", + description="API for querying text using LightRAG with separate storage and input directories"+"(With authentication)" if api_key else "", + version="1.0.0", + openapi_tags=[{"name": "api"}] + ) + + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], ) + # Create the optional API key dependency + optional_api_key = get_api_key_dependency(api_key) + # Create working directory if it doesn't exist Path(args.working_dir).mkdir(parents=True, exist_ok=True) @@ -239,7 +298,7 @@ async def startup_event(): except Exception as e: logging.error(f"Error during startup indexing: {str(e)}") - @app.post("/documents/scan") + @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) async def scan_for_new_documents(): """Manually trigger scanning for new documents""" try: @@ -264,7 +323,7 @@ async def scan_for_new_documents(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/resetcache") + @app.post("/resetcache", dependencies=[Depends(optional_api_key)]) async def reset_cache(): """Manually reset cache""" try: @@ -276,7 +335,7 @@ async def reset_cache(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/upload") + @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) async def upload_to_input_dir(file: UploadFile = File(...)): """Upload a file to the input directory""" try: @@ -304,7 +363,7 @@ async def upload_to_input_dir(file: UploadFile = File(...)): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/query", response_model=QueryResponse) + @app.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]) async def query_text(request: QueryRequest): try: response = await rag.aquery( @@ -319,7 +378,7 @@ async def query_text(request: QueryRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/query/stream") + @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) async def query_text_stream(request: QueryRequest): try: response = await rag.aquery( @@ -345,7 +404,7 @@ async def stream_generator(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/text", response_model=InsertResponse) + @app.post("/documents/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) async def insert_text(request: InsertTextRequest): try: await rag.ainsert(request.text) @@ -357,7 +416,7 @@ async def insert_text(request: InsertTextRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/file", response_model=InsertResponse) + @app.post("/documents/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) async def insert_file(file: UploadFile = File(...), description: str = Form(None)): try: content = await file.read() @@ -381,7 +440,7 @@ async def insert_file(file: UploadFile = File(...), description: str = Form(None except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/batch", response_model=InsertResponse) + @app.post("/documents/batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) async def insert_batch(files: List[UploadFile] = File(...)): try: inserted_count = 0 @@ -411,7 +470,7 @@ async def insert_batch(files: List[UploadFile] = File(...)): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.delete("/documents", response_model=InsertResponse) + @app.delete("/documents", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) async def clear_documents(): try: rag.text_chunks = [] @@ -425,7 +484,7 @@ async def clear_documents(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.get("/health") + @app.get("/health", dependencies=[Depends(optional_api_key)]) async def get_status(): """Get current system status""" return { diff --git a/lightrag/api/lollms_lightrag_server.py b/lightrag/api/lollms_lightrag_server.py index 1ce7b259..5b5fabd5 100644 --- a/lightrag/api/lollms_lightrag_server.py +++ b/lightrag/api/lollms_lightrag_server.py @@ -11,7 +11,19 @@ import shutil import aiofiles from ascii_colors import trace_exception +from fastapi import FastAPI, HTTPException +import os +from typing import Optional +from fastapi import FastAPI, Depends, HTTPException, Security +from fastapi.security import APIKeyHeader +import os +import argparse +from typing import Optional +from fastapi.middleware.cors import CORSMiddleware + +from starlette.status import HTTP_403_FORBIDDEN +from fastapi import HTTPException def parse_args(): parser = argparse.ArgumentParser( @@ -86,6 +98,9 @@ def parse_args(): help="Logging level (default: INFO)", ) + parser.add_argument('--key', type=str, help='API key for authentication. This protects lightrag server against unauthorized access', default=None) + + return parser.parse_args() @@ -147,6 +162,31 @@ class InsertResponse(BaseModel): message: str document_count: int +def get_api_key_dependency(api_key: Optional[str]): + if not api_key: + # If no API key is configured, return a dummy dependency that always succeeds + async def no_auth(): + return None + return no_auth + + # If API key is configured, use proper authentication + api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): + if not api_key_header_value: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail="API Key required" + ) + if api_key_header_value != api_key: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail="Invalid API Key" + ) + return api_key_header_value + + return api_key_auth + def create_app(args): # Setup logging @@ -154,11 +194,28 @@ def create_app(args): format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) ) - # Initialize FastAPI app + # Check if API key is provided either through env var or args + api_key = os.getenv("LIGHTRAG_API_KEY") or args.key + + # Initialize FastAPI app = FastAPI( title="LightRAG API", - description="API for querying text using LightRAG with separate storage and input directories", + description="API for querying text using LightRAG with separate storage and input directories"+"(With authentication)" if api_key else "", + version="1.0.0", + openapi_tags=[{"name": "api"}] ) + + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Create the optional API key dependency + optional_api_key = get_api_key_dependency(api_key) # Create working directory if it doesn't exist Path(args.working_dir).mkdir(parents=True, exist_ok=True) @@ -209,7 +266,7 @@ async def startup_event(): except Exception as e: logging.error(f"Error during startup indexing: {str(e)}") - @app.post("/documents/scan") + @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) async def scan_for_new_documents(): """Manually trigger scanning for new documents""" try: @@ -234,7 +291,7 @@ async def scan_for_new_documents(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/upload") + @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) async def upload_to_input_dir(file: UploadFile = File(...)): """Upload a file to the input directory""" try: @@ -262,7 +319,7 @@ async def upload_to_input_dir(file: UploadFile = File(...)): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/query", response_model=QueryResponse) + @app.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]) async def query_text(request: QueryRequest): try: response = await rag.aquery( @@ -284,7 +341,7 @@ async def query_text(request: QueryRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/query/stream") + @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) async def query_text_stream(request: QueryRequest): try: response = rag.query( @@ -304,7 +361,7 @@ async def stream_generator(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/text", response_model=InsertResponse) + @app.post("/documents/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) async def insert_text(request: InsertTextRequest): try: rag.insert(request.text) @@ -316,7 +373,7 @@ async def insert_text(request: InsertTextRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/file", response_model=InsertResponse) + @app.post("/documents/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) async def insert_file(file: UploadFile = File(...), description: str = Form(None)): try: content = await file.read() @@ -340,7 +397,7 @@ async def insert_file(file: UploadFile = File(...), description: str = Form(None except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/batch", response_model=InsertResponse) + @app.post("/documents/batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) async def insert_batch(files: List[UploadFile] = File(...)): try: inserted_count = 0 @@ -370,7 +427,7 @@ async def insert_batch(files: List[UploadFile] = File(...)): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.delete("/documents", response_model=InsertResponse) + @app.delete("/documents", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) async def clear_documents(): try: rag.text_chunks = [] @@ -384,7 +441,7 @@ async def clear_documents(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.get("/health") + @app.get("/health", dependencies=[Depends(optional_api_key)]) async def get_status(): """Get current system status""" return { diff --git a/lightrag/api/ollama_lightrag_server.py b/lightrag/api/ollama_lightrag_server.py index 40f617c6..bb9d6e15 100644 --- a/lightrag/api/ollama_lightrag_server.py +++ b/lightrag/api/ollama_lightrag_server.py @@ -11,6 +11,19 @@ import shutil import aiofiles from ascii_colors import trace_exception +from fastapi import FastAPI, HTTPException +import os +from typing import Optional + +from fastapi import FastAPI, Depends, HTTPException, Security +from fastapi.security import APIKeyHeader +import os +import argparse +from typing import Optional +from fastapi.middleware.cors import CORSMiddleware + +from starlette.status import HTTP_403_FORBIDDEN +from fastapi import HTTPException def parse_args(): @@ -85,6 +98,7 @@ def parse_args(): choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Logging level (default: INFO)", ) + parser.add_argument('--key', type=str, help='API key for authentication. This protects lightrag server against unauthorized access', default=None) return parser.parse_args() @@ -147,6 +161,31 @@ class InsertResponse(BaseModel): message: str document_count: int +def get_api_key_dependency(api_key: Optional[str]): + if not api_key: + # If no API key is configured, return a dummy dependency that always succeeds + async def no_auth(): + return None + return no_auth + + # If API key is configured, use proper authentication + api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): + if not api_key_header_value: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail="API Key required" + ) + if api_key_header_value != api_key: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail="Invalid API Key" + ) + return api_key_header_value + + return api_key_auth + def create_app(args): # Setup logging @@ -154,12 +193,30 @@ def create_app(args): format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) ) - # Initialize FastAPI app + # Check if API key is provided either through env var or args + api_key = os.getenv("LIGHTRAG_API_KEY") or args.key + + # Initialize FastAPI app = FastAPI( title="LightRAG API", - description="API for querying text using LightRAG with separate storage and input directories", + description="API for querying text using LightRAG with separate storage and input directories"+"(With authentication)" if api_key else "", + version="1.0.0", + openapi_tags=[{"name": "api"}] + ) + + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], ) + # Create the optional API key dependency + optional_api_key = get_api_key_dependency(api_key) + + # Create working directory if it doesn't exist Path(args.working_dir).mkdir(parents=True, exist_ok=True) @@ -209,7 +266,7 @@ async def startup_event(): except Exception as e: logging.error(f"Error during startup indexing: {str(e)}") - @app.post("/documents/scan") + @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) async def scan_for_new_documents(): """Manually trigger scanning for new documents""" try: @@ -234,7 +291,7 @@ async def scan_for_new_documents(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/upload") + @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) async def upload_to_input_dir(file: UploadFile = File(...)): """Upload a file to the input directory""" try: @@ -262,7 +319,7 @@ async def upload_to_input_dir(file: UploadFile = File(...)): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/query", response_model=QueryResponse) + @app.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]) async def query_text(request: QueryRequest): try: response = await rag.aquery( @@ -284,7 +341,7 @@ async def query_text(request: QueryRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/query/stream") + @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) async def query_text_stream(request: QueryRequest): try: response = rag.query( @@ -304,7 +361,7 @@ async def stream_generator(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/text", response_model=InsertResponse) + @app.post("/documents/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) async def insert_text(request: InsertTextRequest): try: await rag.ainsert(request.text) @@ -316,7 +373,7 @@ async def insert_text(request: InsertTextRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/file", response_model=InsertResponse) + @app.post("/documents/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) async def insert_file(file: UploadFile = File(...), description: str = Form(None)): try: content = await file.read() @@ -340,7 +397,7 @@ async def insert_file(file: UploadFile = File(...), description: str = Form(None except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/batch", response_model=InsertResponse) + @app.post("/documents/batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) async def insert_batch(files: List[UploadFile] = File(...)): try: inserted_count = 0 @@ -370,7 +427,7 @@ async def insert_batch(files: List[UploadFile] = File(...)): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.delete("/documents", response_model=InsertResponse) + @app.delete("/documents", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) async def clear_documents(): try: rag.text_chunks = [] @@ -384,7 +441,7 @@ async def clear_documents(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.get("/health") + @app.get("/health", dependencies=[Depends(optional_api_key)]) async def get_status(): """Get current system status""" return { diff --git a/lightrag/api/openai_lightrag_server.py b/lightrag/api/openai_lightrag_server.py index 119d6900..d47b25f2 100644 --- a/lightrag/api/openai_lightrag_server.py +++ b/lightrag/api/openai_lightrag_server.py @@ -14,6 +14,20 @@ from ascii_colors import trace_exception import nest_asyncio +from fastapi import FastAPI, HTTPException +import os +from typing import Optional + +from fastapi import FastAPI, Depends, HTTPException, Security +from fastapi.security import APIKeyHeader +import os +import argparse +from typing import Optional +from fastapi.middleware.cors import CORSMiddleware + +from starlette.status import HTTP_403_FORBIDDEN +from fastapi import HTTPException + # Apply nest_asyncio to solve event loop issues nest_asyncio.apply() @@ -75,6 +89,9 @@ def parse_args(): help="Logging level (default: INFO)", ) + parser.add_argument('--key', type=str, help='API key for authentication. This protects lightrag server against unauthorized access', default=None) + + return parser.parse_args() @@ -136,6 +153,31 @@ class InsertResponse(BaseModel): message: str document_count: int +def get_api_key_dependency(api_key: Optional[str]): + if not api_key: + # If no API key is configured, return a dummy dependency that always succeeds + async def no_auth(): + return None + return no_auth + + # If API key is configured, use proper authentication + api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): + if not api_key_header_value: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail="API Key required" + ) + if api_key_header_value != api_key: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail="Invalid API Key" + ) + return api_key_header_value + + return api_key_auth + async def get_embedding_dim(embedding_model: str) -> int: """Get embedding dimensions for the specified model""" @@ -150,10 +192,37 @@ def create_app(args): format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) ) - # Initialize FastAPI app + + # Check if API key is provided either through env var or args + api_key = os.getenv("LIGHTRAG_API_KEY") or args.key + + # Initialize FastAPI app = FastAPI( title="LightRAG API", - description="API for querying text using LightRAG with OpenAI integration", + description="API for querying text using LightRAG with separate storage and input directories"+"(With authentication)" if api_key else "", + version="1.0.0", + openapi_tags=[{"name": "api"}] + ) + + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Create the optional API key dependency + optional_api_key = get_api_key_dependency(api_key) + + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], ) # Create working directory if it doesn't exist @@ -213,7 +282,7 @@ async def startup_event(): except Exception as e: logging.error(f"Error during startup indexing: {str(e)}") - @app.post("/documents/scan") + @app.post("/documents/scan", dependencies=[Depends(optional_api_key)]) async def scan_for_new_documents(): """Manually trigger scanning for new documents""" try: @@ -238,7 +307,7 @@ async def scan_for_new_documents(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/upload") + @app.post("/documents/upload", dependencies=[Depends(optional_api_key)]) async def upload_to_input_dir(file: UploadFile = File(...)): """Upload a file to the input directory""" try: @@ -266,7 +335,7 @@ async def upload_to_input_dir(file: UploadFile = File(...)): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/query", response_model=QueryResponse) + @app.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]) async def query_text(request: QueryRequest): try: response = await rag.aquery( @@ -288,7 +357,7 @@ async def query_text(request: QueryRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/query/stream") + @app.post("/query/stream", dependencies=[Depends(optional_api_key)]) async def query_text_stream(request: QueryRequest): try: response = rag.query( @@ -308,7 +377,7 @@ async def stream_generator(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/text", response_model=InsertResponse) + @app.post("/documents/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) async def insert_text(request: InsertTextRequest): try: rag.insert(request.text) @@ -320,7 +389,7 @@ async def insert_text(request: InsertTextRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/file", response_model=InsertResponse) + @app.post("/documents/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) async def insert_file(file: UploadFile = File(...), description: str = Form(None)): try: content = await file.read() @@ -344,7 +413,7 @@ async def insert_file(file: UploadFile = File(...), description: str = Form(None except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/batch", response_model=InsertResponse) + @app.post("/documents/batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) async def insert_batch(files: List[UploadFile] = File(...)): try: inserted_count = 0 @@ -374,7 +443,7 @@ async def insert_batch(files: List[UploadFile] = File(...)): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.delete("/documents", response_model=InsertResponse) + @app.delete("/documents", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) async def clear_documents(): try: rag.text_chunks = [] @@ -388,7 +457,7 @@ async def clear_documents(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.get("/health") + @app.get("/health", dependencies=[Depends(optional_api_key)]) async def get_status(): """Get current system status""" return { From b15c398889999f38d795f6a708a755ff66ac34b8 Mon Sep 17 00:00:00 2001 From: Saifeddine ALOUI Date: Sat, 4 Jan 2025 02:23:39 +0100 Subject: [PATCH 2/2] applyed linting --- README.md | 6 +- lightrag/api/azure_openai_lightrag_server.py | 71 ++++++++++++-------- lightrag/api/lollms_lightrag_server.py | 70 ++++++++++++------- lightrag/api/ollama_lightrag_server.py | 69 ++++++++++++------- lightrag/api/openai_lightrag_server.py | 70 ++++++++++++------- 5 files changed, 182 insertions(+), 104 deletions(-) diff --git a/README.md b/README.md index 28991389..54a84323 100644 --- a/README.md +++ b/README.md @@ -1077,7 +1077,7 @@ Each server has its own specific configuration options: | --key | none | Access Key to protect the lightrag service | -For protecting the server using an authentication key, you can also use an environment variable named `LIGHTRAG_API_KEY`. +For protecting the server using an authentication key, you can also use an environment variable named `LIGHTRAG_API_KEY`. ### Example Usage #### LoLLMs RAG Server @@ -1088,6 +1088,10 @@ lollms-lightrag-server --model mistral-nemo --port 8080 --working-dir ./custom_r # Using specific models (ensure they are installed in your LoLLMs instance) lollms-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024 + +# Using specific models and an authentication key +lollms-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024 --key ky-mykey + ``` #### Ollama RAG Server diff --git a/lightrag/api/azure_openai_lightrag_server.py b/lightrag/api/azure_openai_lightrag_server.py index a145d6d6..abe3f738 100644 --- a/lightrag/api/azure_openai_lightrag_server.py +++ b/lightrag/api/azure_openai_lightrag_server.py @@ -20,19 +20,12 @@ import inspect import json from fastapi.responses import StreamingResponse -from fastapi import FastAPI, HTTPException -import os -from typing import Optional -from fastapi import FastAPI, Depends, HTTPException, Security +from fastapi import Depends, Security from fastapi.security import APIKeyHeader -import os -import argparse -from typing import Optional from fastapi.middleware.cors import CORSMiddleware from starlette.status import HTTP_403_FORBIDDEN -from fastapi import HTTPException load_dotenv() @@ -106,8 +99,12 @@ def parse_args(): help="Logging level (default: INFO)", ) - parser.add_argument('--key', type=str, help='API key for authentication. This protects lightrag server against unauthorized access', default=None) - + parser.add_argument( + "--key", + type=str, + help="API key for authentication. This protects lightrag server against unauthorized access", + default=None, + ) return parser.parse_args() @@ -170,29 +167,29 @@ class InsertResponse(BaseModel): message: str document_count: int + def get_api_key_dependency(api_key: Optional[str]): if not api_key: # If no API key is configured, return a dummy dependency that always succeeds async def no_auth(): return None + return no_auth - + # If API key is configured, use proper authentication api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - + async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): if not api_key_header_value: raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail="API Key required" + status_code=HTTP_403_FORBIDDEN, detail="API Key required" ) if api_key_header_value != api_key: raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail="Invalid API Key" + status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" ) return api_key_header_value - + return api_key_auth @@ -209,18 +206,20 @@ def create_app(args): format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) ) - # Check if API key is provided either through env var or args api_key = os.getenv("LIGHTRAG_API_KEY") or args.key - + # Initialize FastAPI app = FastAPI( title="LightRAG API", - description="API for querying text using LightRAG with separate storage and input directories"+"(With authentication)" if api_key else "", + description="API for querying text using LightRAG with separate storage and input directories" + + "(With authentication)" + if api_key + else "", version="1.0.0", - openapi_tags=[{"name": "api"}] + openapi_tags=[{"name": "api"}], ) - + # Add CORS middleware app.add_middleware( CORSMiddleware, @@ -363,7 +362,9 @@ async def upload_to_input_dir(file: UploadFile = File(...)): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]) + @app.post( + "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] + ) async def query_text(request: QueryRequest): try: response = await rag.aquery( @@ -404,7 +405,11 @@ async def stream_generator(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @app.post( + "/documents/text", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) async def insert_text(request: InsertTextRequest): try: await rag.ainsert(request.text) @@ -416,7 +421,11 @@ async def insert_text(request: InsertTextRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @app.post( + "/documents/file", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) async def insert_file(file: UploadFile = File(...), description: str = Form(None)): try: content = await file.read() @@ -440,7 +449,11 @@ async def insert_file(file: UploadFile = File(...), description: str = Form(None except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @app.post( + "/documents/batch", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) async def insert_batch(files: List[UploadFile] = File(...)): try: inserted_count = 0 @@ -470,7 +483,11 @@ async def insert_batch(files: List[UploadFile] = File(...)): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.delete("/documents", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @app.delete( + "/documents", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) async def clear_documents(): try: rag.text_chunks = [] diff --git a/lightrag/api/lollms_lightrag_server.py b/lightrag/api/lollms_lightrag_server.py index 5b5fabd5..8a2804a0 100644 --- a/lightrag/api/lollms_lightrag_server.py +++ b/lightrag/api/lollms_lightrag_server.py @@ -11,19 +11,14 @@ import shutil import aiofiles from ascii_colors import trace_exception -from fastapi import FastAPI, HTTPException import os -from typing import Optional -from fastapi import FastAPI, Depends, HTTPException, Security +from fastapi import Depends, Security from fastapi.security import APIKeyHeader -import os -import argparse -from typing import Optional from fastapi.middleware.cors import CORSMiddleware from starlette.status import HTTP_403_FORBIDDEN -from fastapi import HTTPException + def parse_args(): parser = argparse.ArgumentParser( @@ -98,8 +93,12 @@ def parse_args(): help="Logging level (default: INFO)", ) - parser.add_argument('--key', type=str, help='API key for authentication. This protects lightrag server against unauthorized access', default=None) - + parser.add_argument( + "--key", + type=str, + help="API key for authentication. This protects lightrag server against unauthorized access", + default=None, + ) return parser.parse_args() @@ -162,29 +161,29 @@ class InsertResponse(BaseModel): message: str document_count: int + def get_api_key_dependency(api_key: Optional[str]): if not api_key: # If no API key is configured, return a dummy dependency that always succeeds async def no_auth(): return None + return no_auth - + # If API key is configured, use proper authentication api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - + async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): if not api_key_header_value: raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail="API Key required" + status_code=HTTP_403_FORBIDDEN, detail="API Key required" ) if api_key_header_value != api_key: raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail="Invalid API Key" + status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" ) return api_key_header_value - + return api_key_auth @@ -196,15 +195,18 @@ def create_app(args): # Check if API key is provided either through env var or args api_key = os.getenv("LIGHTRAG_API_KEY") or args.key - + # Initialize FastAPI app = FastAPI( title="LightRAG API", - description="API for querying text using LightRAG with separate storage and input directories"+"(With authentication)" if api_key else "", + description="API for querying text using LightRAG with separate storage and input directories" + + "(With authentication)" + if api_key + else "", version="1.0.0", - openapi_tags=[{"name": "api"}] + openapi_tags=[{"name": "api"}], ) - + # Add CORS middleware app.add_middleware( CORSMiddleware, @@ -319,7 +321,9 @@ async def upload_to_input_dir(file: UploadFile = File(...)): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]) + @app.post( + "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] + ) async def query_text(request: QueryRequest): try: response = await rag.aquery( @@ -361,7 +365,11 @@ async def stream_generator(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @app.post( + "/documents/text", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) async def insert_text(request: InsertTextRequest): try: rag.insert(request.text) @@ -373,7 +381,11 @@ async def insert_text(request: InsertTextRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @app.post( + "/documents/file", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) async def insert_file(file: UploadFile = File(...), description: str = Form(None)): try: content = await file.read() @@ -397,7 +409,11 @@ async def insert_file(file: UploadFile = File(...), description: str = Form(None except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @app.post( + "/documents/batch", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) async def insert_batch(files: List[UploadFile] = File(...)): try: inserted_count = 0 @@ -427,7 +443,11 @@ async def insert_batch(files: List[UploadFile] = File(...)): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.delete("/documents", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @app.delete( + "/documents", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) async def clear_documents(): try: rag.text_chunks = [] diff --git a/lightrag/api/ollama_lightrag_server.py b/lightrag/api/ollama_lightrag_server.py index bb9d6e15..b3140aba 100644 --- a/lightrag/api/ollama_lightrag_server.py +++ b/lightrag/api/ollama_lightrag_server.py @@ -11,19 +11,13 @@ import shutil import aiofiles from ascii_colors import trace_exception -from fastapi import FastAPI, HTTPException import os -from typing import Optional -from fastapi import FastAPI, Depends, HTTPException, Security +from fastapi import Depends, Security from fastapi.security import APIKeyHeader -import os -import argparse -from typing import Optional from fastapi.middleware.cors import CORSMiddleware from starlette.status import HTTP_403_FORBIDDEN -from fastapi import HTTPException def parse_args(): @@ -98,7 +92,12 @@ def parse_args(): choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Logging level (default: INFO)", ) - parser.add_argument('--key', type=str, help='API key for authentication. This protects lightrag server against unauthorized access', default=None) + parser.add_argument( + "--key", + type=str, + help="API key for authentication. This protects lightrag server against unauthorized access", + default=None, + ) return parser.parse_args() @@ -161,29 +160,29 @@ class InsertResponse(BaseModel): message: str document_count: int + def get_api_key_dependency(api_key: Optional[str]): if not api_key: # If no API key is configured, return a dummy dependency that always succeeds async def no_auth(): return None + return no_auth - + # If API key is configured, use proper authentication api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - + async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): if not api_key_header_value: raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail="API Key required" + status_code=HTTP_403_FORBIDDEN, detail="API Key required" ) if api_key_header_value != api_key: raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail="Invalid API Key" + status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" ) return api_key_header_value - + return api_key_auth @@ -195,15 +194,18 @@ def create_app(args): # Check if API key is provided either through env var or args api_key = os.getenv("LIGHTRAG_API_KEY") or args.key - + # Initialize FastAPI app = FastAPI( title="LightRAG API", - description="API for querying text using LightRAG with separate storage and input directories"+"(With authentication)" if api_key else "", + description="API for querying text using LightRAG with separate storage and input directories" + + "(With authentication)" + if api_key + else "", version="1.0.0", - openapi_tags=[{"name": "api"}] + openapi_tags=[{"name": "api"}], ) - + # Add CORS middleware app.add_middleware( CORSMiddleware, @@ -216,7 +218,6 @@ def create_app(args): # Create the optional API key dependency optional_api_key = get_api_key_dependency(api_key) - # Create working directory if it doesn't exist Path(args.working_dir).mkdir(parents=True, exist_ok=True) @@ -319,7 +320,9 @@ async def upload_to_input_dir(file: UploadFile = File(...)): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]) + @app.post( + "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] + ) async def query_text(request: QueryRequest): try: response = await rag.aquery( @@ -361,7 +364,11 @@ async def stream_generator(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @app.post( + "/documents/text", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) async def insert_text(request: InsertTextRequest): try: await rag.ainsert(request.text) @@ -373,7 +380,11 @@ async def insert_text(request: InsertTextRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @app.post( + "/documents/file", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) async def insert_file(file: UploadFile = File(...), description: str = Form(None)): try: content = await file.read() @@ -397,7 +408,11 @@ async def insert_file(file: UploadFile = File(...), description: str = Form(None except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @app.post( + "/documents/batch", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) async def insert_batch(files: List[UploadFile] = File(...)): try: inserted_count = 0 @@ -427,7 +442,11 @@ async def insert_batch(files: List[UploadFile] = File(...)): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.delete("/documents", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @app.delete( + "/documents", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) async def clear_documents(): try: rag.text_chunks = [] diff --git a/lightrag/api/openai_lightrag_server.py b/lightrag/api/openai_lightrag_server.py index d47b25f2..349c09da 100644 --- a/lightrag/api/openai_lightrag_server.py +++ b/lightrag/api/openai_lightrag_server.py @@ -14,19 +14,13 @@ from ascii_colors import trace_exception import nest_asyncio -from fastapi import FastAPI, HTTPException import os -from typing import Optional -from fastapi import FastAPI, Depends, HTTPException, Security +from fastapi import Depends, Security from fastapi.security import APIKeyHeader -import os -import argparse -from typing import Optional from fastapi.middleware.cors import CORSMiddleware from starlette.status import HTTP_403_FORBIDDEN -from fastapi import HTTPException # Apply nest_asyncio to solve event loop issues nest_asyncio.apply() @@ -89,8 +83,12 @@ def parse_args(): help="Logging level (default: INFO)", ) - parser.add_argument('--key', type=str, help='API key for authentication. This protects lightrag server against unauthorized access', default=None) - + parser.add_argument( + "--key", + type=str, + help="API key for authentication. This protects lightrag server against unauthorized access", + default=None, + ) return parser.parse_args() @@ -153,29 +151,29 @@ class InsertResponse(BaseModel): message: str document_count: int + def get_api_key_dependency(api_key: Optional[str]): if not api_key: # If no API key is configured, return a dummy dependency that always succeeds async def no_auth(): return None + return no_auth - + # If API key is configured, use proper authentication api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) - + async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)): if not api_key_header_value: raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail="API Key required" + status_code=HTTP_403_FORBIDDEN, detail="API Key required" ) if api_key_header_value != api_key: raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail="Invalid API Key" + status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key" ) return api_key_header_value - + return api_key_auth @@ -192,18 +190,20 @@ def create_app(args): format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level) ) - # Check if API key is provided either through env var or args api_key = os.getenv("LIGHTRAG_API_KEY") or args.key - + # Initialize FastAPI app = FastAPI( title="LightRAG API", - description="API for querying text using LightRAG with separate storage and input directories"+"(With authentication)" if api_key else "", + description="API for querying text using LightRAG with separate storage and input directories" + + "(With authentication)" + if api_key + else "", version="1.0.0", - openapi_tags=[{"name": "api"}] + openapi_tags=[{"name": "api"}], ) - + # Add CORS middleware app.add_middleware( CORSMiddleware, @@ -335,7 +335,9 @@ async def upload_to_input_dir(file: UploadFile = File(...)): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]) + @app.post( + "/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)] + ) async def query_text(request: QueryRequest): try: response = await rag.aquery( @@ -377,7 +379,11 @@ async def stream_generator(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/text", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @app.post( + "/documents/text", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) async def insert_text(request: InsertTextRequest): try: rag.insert(request.text) @@ -389,7 +395,11 @@ async def insert_text(request: InsertTextRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/file", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @app.post( + "/documents/file", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) async def insert_file(file: UploadFile = File(...), description: str = Form(None)): try: content = await file.read() @@ -413,7 +423,11 @@ async def insert_file(file: UploadFile = File(...), description: str = Form(None except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.post("/documents/batch", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @app.post( + "/documents/batch", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) async def insert_batch(files: List[UploadFile] = File(...)): try: inserted_count = 0 @@ -443,7 +457,11 @@ async def insert_batch(files: List[UploadFile] = File(...)): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @app.delete("/documents", response_model=InsertResponse, dependencies=[Depends(optional_api_key)]) + @app.delete( + "/documents", + response_model=InsertResponse, + dependencies=[Depends(optional_api_key)], + ) async def clear_documents(): try: rag.text_chunks = []