Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Optional API Key Authentication to FastAPI Services #541

Merged
merged 2 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,7 @@ Each server has its own specific configuration options:
| --max-embed-tokens | 8192 | Maximum embedding token size |
| --input-file | ./book.txt | Initial input file |
| --log-level | INFO | Logging level |
| --key | none | Access Key to protect the lightrag service |

#### Ollama Server Options

Expand All @@ -1042,6 +1043,7 @@ Each server has its own specific configuration options:
| --max-embed-tokens | 8192 | Maximum embedding token size |
| --input-file | ./book.txt | Initial input file |
| --log-level | INFO | Logging level |
| --key | none | Access Key to protect the lightrag service |

#### OpenAI Server Options

Expand All @@ -1056,6 +1058,7 @@ Each server has its own specific configuration options:
| --max-embed-tokens | 8192 | Maximum embedding token size |
| --input-dir | ./inputs | Input directory for documents |
| --log-level | INFO | Logging level |
| --key | none | Access Key to protect the lightrag service |

#### OpenAI AZURE Server Options

Expand All @@ -1071,8 +1074,10 @@ Each server has its own specific configuration options:
| --input-dir | ./inputs | Input directory for documents |
| --enable-cache | True | Enable response cache |
| --log-level | INFO | Logging level |
| --key | none | Access Key to protect the lightrag service |


For protecting the server using an authentication key, you can also use an environment variable named `LIGHTRAG_API_KEY`.
### Example Usage

#### LoLLMs RAG Server
Expand All @@ -1083,6 +1088,10 @@ lollms-lightrag-server --model mistral-nemo --port 8080 --working-dir ./custom_r

# Using specific models (ensure they are installed in your LoLLMs instance)
lollms-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024

# Using specific models and an authentication key
lollms-lightrag-server --model mistral-nemo:latest --embedding-model bge-m3 --embedding-dim 1024 --key ky-mykey

```

#### Ollama RAG Server
Expand Down
100 changes: 88 additions & 12 deletions lightrag/api/azure_openai_lightrag_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
import json
from fastapi.responses import StreamingResponse

from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
from fastapi.middleware.cors import CORSMiddleware

from starlette.status import HTTP_403_FORBIDDEN

load_dotenv()

AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
Expand Down Expand Up @@ -93,6 +99,13 @@ def parse_args():
help="Logging level (default: INFO)",
)

parser.add_argument(
"--key",
type=str,
help="API key for authentication. This protects lightrag server against unauthorized access",
default=None,
)

return parser.parse_args()


Expand Down Expand Up @@ -155,6 +168,31 @@ class InsertResponse(BaseModel):
document_count: int


def get_api_key_dependency(api_key: Optional[str]):
if not api_key:
# If no API key is configured, return a dummy dependency that always succeeds
async def no_auth():
return None

return no_auth

# If API key is configured, use proper authentication
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)

async def api_key_auth(api_key_header_value: str | None = Security(api_key_header)):
if not api_key_header_value:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="API Key required"
)
if api_key_header_value != api_key:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Invalid API Key"
)
return api_key_header_value

return api_key_auth


async def get_embedding_dim(embedding_model: str) -> int:
"""Get embedding dimensions for the specified model"""
test_text = ["This is a test sentence."]
Expand All @@ -168,12 +206,32 @@ def create_app(args):
format="%(levelname)s:%(message)s", level=getattr(logging, args.log_level)
)

# Initialize FastAPI app
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key

# Initialize FastAPI
app = FastAPI(
title="LightRAG API",
description="API for querying text using LightRAG with OpenAI integration",
description="API for querying text using LightRAG with separate storage and input directories"
+ "(With authentication)"
if api_key
else "",
version="1.0.0",
openapi_tags=[{"name": "api"}],
)

# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

# Create the optional API key dependency
optional_api_key = get_api_key_dependency(api_key)

# Create working directory if it doesn't exist
Path(args.working_dir).mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -239,7 +297,7 @@ async def startup_event():
except Exception as e:
logging.error(f"Error during startup indexing: {str(e)}")

@app.post("/documents/scan")
@app.post("/documents/scan", dependencies=[Depends(optional_api_key)])
async def scan_for_new_documents():
"""Manually trigger scanning for new documents"""
try:
Expand All @@ -264,7 +322,7 @@ async def scan_for_new_documents():
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/resetcache")
@app.post("/resetcache", dependencies=[Depends(optional_api_key)])
async def reset_cache():
"""Manually reset cache"""
try:
Expand All @@ -276,7 +334,7 @@ async def reset_cache():
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/documents/upload")
@app.post("/documents/upload", dependencies=[Depends(optional_api_key)])
async def upload_to_input_dir(file: UploadFile = File(...)):
"""Upload a file to the input directory"""
try:
Expand Down Expand Up @@ -304,7 +362,9 @@ async def upload_to_input_dir(file: UploadFile = File(...)):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/query", response_model=QueryResponse)
@app.post(
"/query", response_model=QueryResponse, dependencies=[Depends(optional_api_key)]
)
async def query_text(request: QueryRequest):
try:
response = await rag.aquery(
Expand All @@ -319,7 +379,7 @@ async def query_text(request: QueryRequest):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/query/stream")
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
async def query_text_stream(request: QueryRequest):
try:
response = await rag.aquery(
Expand All @@ -345,7 +405,11 @@ async def stream_generator():
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/documents/text", response_model=InsertResponse)
@app.post(
"/documents/text",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_text(request: InsertTextRequest):
try:
await rag.ainsert(request.text)
Expand All @@ -357,7 +421,11 @@ async def insert_text(request: InsertTextRequest):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/documents/file", response_model=InsertResponse)
@app.post(
"/documents/file",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_file(file: UploadFile = File(...), description: str = Form(None)):
try:
content = await file.read()
Expand All @@ -381,7 +449,11 @@ async def insert_file(file: UploadFile = File(...), description: str = Form(None
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.post("/documents/batch", response_model=InsertResponse)
@app.post(
"/documents/batch",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def insert_batch(files: List[UploadFile] = File(...)):
try:
inserted_count = 0
Expand Down Expand Up @@ -411,7 +483,11 @@ async def insert_batch(files: List[UploadFile] = File(...)):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.delete("/documents", response_model=InsertResponse)
@app.delete(
"/documents",
response_model=InsertResponse,
dependencies=[Depends(optional_api_key)],
)
async def clear_documents():
try:
rag.text_chunks = []
Expand All @@ -425,7 +501,7 @@ async def clear_documents():
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
@app.get("/health", dependencies=[Depends(optional_api_key)])
async def get_status():
"""Get current system status"""
return {
Expand Down
Loading
Loading