diff --git a/api/ma1sd-extender/main.py b/api/ma1sd-extender/main.py index b86b9eb..1106ec5 100644 --- a/api/ma1sd-extender/main.py +++ b/api/ma1sd-extender/main.py @@ -76,16 +76,5 @@ async def userDirectory(request: Request, cache: InMemoryCacheBackend = Depends( - display_name: str - Display name of each gathered user - user_id : str - Matrix ID of each gathered user """ - try: - body = await request.json() - searchTerm = body["search_term"] - inCache = await cache.get(searchTerm) - if not inCache: - response = await search.findUsers(request) - await cache.set(searchTerm, response.body) - await cache.expire(searchTerm, 86400) - return response - else: - return JSONResponse(status_code=200, content=loads(inCache)) - except: - return JSONResponse(status_code=400, content={'errcode': 'M_UNKNOWN', 'error': 'search_term` is required field', 'success': False}) + response = await search.findUsers(request, cache) + return response diff --git a/api/ma1sd-extender/search.py b/api/ma1sd-extender/search.py index e2e66a0..a9faf81 100644 --- a/api/ma1sd-extender/search.py +++ b/api/ma1sd-extender/search.py @@ -1,5 +1,6 @@ -from fastapi import Request, Header +from fastapi import Request, Header, Depends from fastapi.responses import JSONResponse +from fastapi_cache.backends.memory import InMemoryCacheBackend from os import getenv from requests import get, post from json import loads @@ -7,7 +8,8 @@ import asyncio -async def findUsers(request: Request): + +async def findUsers(request: Request, cache: InMemoryCacheBackend): # Declare constants MA1SD_EXTENDER_USERNAME, MA1SD_EXTENDER_PASSWORD, MA1SD_EXTENDER_FEDERATED_DOMAINS, MA1SD_EXTENDER_MATRIX_DOMAIN = getenv("MA1SD_EXTENDER_USERNAME", ""), getenv("MA1SD_EXTENDER_PASSWORD", ""), loads(getenv("MA1SD_EXTENDER_FEDERATED_DOMAINS", "[]").replace('\'', '"')), getenv("MA1SD_EXTENDER_MATRIX_DOMAIN", "") MA1SD_EXTENDER_ACCESS_TOKEN = getAccessToken(MA1SD_EXTENDER_MATRIX_DOMAIN, MA1SD_EXTENDER_USERNAME, MA1SD_EXTENDER_PASSWORD) @@ -22,6 +24,7 @@ async def findUsers(request: Request): # Extract the search term from the request body or return 400 client error. parsedBody = "" + searchTerm = "" try: parsedBody = loads(body) except: @@ -34,18 +37,28 @@ async def findUsers(request: Request): if not validAccessToken: return JSONResponse(status_code=400, content={"errcode": "M_UNKNOWN_TOKEN", "error": "Invalid macaroon passed.", "soft_logout": False}) - # Retrieve local results. - localDirectoryListings = list(filter(None, [getDirectorySearch('http', 'localhost:8090', f"Bearer {MA1SD_EXTENDER_ACCESS_TOKEN}", parsedBody)])) + searchTerm = parsedBody["search_term"] + inCache = await cache.get(searchTerm) + + if not inCache: + # Retrieve local results. + localDirectoryListings = list(filter(None, [getDirectorySearch('http', 'localhost:8090', f"Bearer {MA1SD_EXTENDER_ACCESS_TOKEN}", parsedBody)])) + + remoteDirectoryListings = [] - remoteDirectoryListings = [] + # Do recursive directory listing if requested. + if "no_recursion" not in parsedBody.keys(): + parsedBody["no_recursion"] = True + remoteDirectoryListings = list(filter(None, [getDirectorySearch('https', domain, headers["authorization"], parsedBody) for domain in MA1SD_EXTENDER_FEDERATED_DOMAINS])) - # Do recursive directory listing if requested. - if "no_recursion" not in parsedBody.keys(): - parsedBody["no_recursion"] = True - remoteDirectoryListings = list(filter(None, [getDirectorySearch('https', domain, headers["authorization"], parsedBody) for domain in MA1SD_EXTENDER_FEDERATED_DOMAINS])) + response = parseDirectoryListings(list(chain(localDirectoryListings, remoteDirectoryListings))) + await cache.set(searchTerm, response) + await cache.expire(searchTerm, 86400) + else: + response = inCache # Get final response and return it. - return JSONResponse(status_code=200, content=parseDirectoryListings(list(chain(localDirectoryListings, remoteDirectoryListings)))) + return JSONResponse(status_code=200, content=response) def getAccessToken(domain, username, password):