diff --git a/husky_directory/app.py b/husky_directory/app.py index 61e2d03..126dd43 100755 --- a/husky_directory/app.py +++ b/husky_directory/app.py @@ -24,6 +24,7 @@ ) from husky_directory.blueprints.search import SearchBlueprint from husky_directory.models.search import SearchDirectoryInput +from husky_directory.services.object_store import ObjectStoreInjectorModule from husky_directory.util import MetricsClient from .app_config import ( ApplicationConfig, @@ -45,6 +46,7 @@ def get_app_injector_modules() -> List[Type[Module]]: return [ ApplicationConfigInjectorModule, IdentityProviderModule, + ObjectStoreInjectorModule, ] @@ -133,6 +135,7 @@ def provide_app( app_blueprint: AppBlueprint, saml_blueprint: SAMLBlueprint, mock_saml_blueprint: MockSAMLBlueprint, + redis: Redis, ) -> Flask: # First we have to do some logging configuration, before the # app instance is created. @@ -160,7 +163,7 @@ def provide_app( # our dependencies appropriate for each request. FlaskInjector(app=app, injector=injector) FlaskJSONLogger(app) - self._configure_app_session(app, app_settings) + self._configure_app_session(app, app_settings, redis) self._configure_prometheus(app, app_settings, injector) attach_app_error_handlers(app) self.register_jinja_extensions(app) @@ -208,7 +211,9 @@ def verify_credentials(username: str, password: str): injector.binder.bind(MetricsClient, metrics, scope=singleton) @staticmethod - def _configure_app_session(app: Flask, app_settings: ApplicationConfig) -> NoReturn: + def _configure_app_session( + app: Flask, app_settings: ApplicationConfig, redis: Redis + ) -> NoReturn: # There is something wrong with the flask_session implementation that # is supposed to translate flask config values into redis settings; # also, it doesn't support authorization (what?!) so we have to @@ -220,20 +225,27 @@ def _configure_app_session(app: Flask, app_settings: ApplicationConfig) -> NoRet if app.config["SESSION_TYPE"] == "redis": redis_settings = app_settings.redis_settings app.logger.info( - f"Setting up redis cache with settings: {redis_settings.flask_config_values}" + f"Setting up redis session cache with settings: {redis_settings.flask_config_values}" ) app.session_interface = RedisSessionInterface( - redis=Redis( - host=redis_settings.host, - port=redis_settings.port, - username=redis_settings.namespace, - password=redis_settings.password.get_secret_value(), - ), + redis, key_prefix=redis_settings.flask_config_values["SESSION_KEY_PREFIX"], ) else: Session(app) + @provider + def provide_redis(self, app_settings: ApplicationConfig) -> Redis: + redis_settings = app_settings.redis_settings + if not redis_settings.password: + return None + return Redis( + host=redis_settings.host, + port=redis_settings.port, + username=redis_settings.namespace, + password=redis_settings.password.get_secret_value(), + ) + def create_app(injector: Optional[Injector] = None) -> Flask: injector = injector or create_app_injector() diff --git a/husky_directory/app_config.py b/husky_directory/app_config.py index e76f7e4..82c1051 100644 --- a/husky_directory/app_config.py +++ b/husky_directory/app_config.py @@ -4,10 +4,9 @@ import string from datetime import datetime from enum import Enum -from typing import Any, Dict, Optional, Type, TypeVar, Union, cast +from typing import Any, Dict, Optional, TypeVar, cast -import yaml -from injector import Module, inject, provider, singleton +from injector import Module, provider, singleton from pydantic import BaseSettings, Field, SecretStr, validator logger = logging.getLogger("app_config") @@ -102,6 +101,9 @@ class RedisSettings(FlaskConfigurationSettings): port: str = Field("6379", env="REDIS_PORT") namespace: str = Field(None, env="REDIS_NAMESPACE") password: SecretStr = Field(None, env="REDIS_PASSWORD") + default_cache_expire_seconds: Optional[int] = Field( + None, env="REDIS_CACHE_DEFAULT_EXPIRE_SECONDS" + ) @property def flask_config_values(self) -> Dict[str, Any]: @@ -177,6 +179,59 @@ class ApplicationSecrets(BaseSettings): prometheus_password: Optional[SecretStr] = Field(None, env="PROMETHEUS_PASSWORD") +class CacheExpirationSettings(BaseSettings): + class Config: + """ + Any of these settings can be tweaked by updating + the environment variables in the gcp-k8 helm release + for this app using this prefix (or in any environment + running this app). + + e.g.: QUERY_CACHE_IN_PROGRESS_STATUS_EXPIRATION=60 + """ + + env_prefix = "QUERY_CACHE_" + + # We should never expect a query to take more than + # 5 minutes to complete. If a query has taken /that/ long, + # we simply delete the lock. This can mean that if a + # query does take longer than 5 minutes, + # the next request for it will be allowed + # to proceed in its own process. + in_progress_status_expiration: int = 300 # Five minutes + + # We do not want to cache completed queries for very long, because + # we want updates to user profiles to be reflected in + # near real-time. But, we want the value to persist + # long enough that if the user mashes the 'search' button + # while their browser is rendering thousands of + # entries as HTML, it'll still be there, and the results will + # seem to come faster to the user the second time around. + # Therefore, the value is 7. + completed_status_expiration: int = 7 + + # The error status expiration lets us check for and/or + # alarm on issues directly, without relying on logs, if the + # issue happens during the query process (which is the + # most likely place for an issue to occur). Remembering that + # if the query is re-attempted, its status will revert back to + # 'in progress', there is not really a point to keeping these for + # super long. + error_status_expiration: int = 3600 # 1 hour + + # The error message expiration lasts longer so that event + # responders can access the error messages that were logged + # via the cache, instead of poring through logs, if necessary, + # for easier investigation. However, this is only a minor + # convenience, as the JSON logging will already contain a + # lot of information. + + # To find errors in a redis cache, you can use the command: + # 'keys *:status:message' + # to get a list of all relevant keys in the shared cache. + error_message_expiration: int = 3600 * 24 # 24 hours + + @singleton class ApplicationConfig(FlaskConfigurationSettings): """ @@ -208,8 +263,9 @@ class ApplicationConfig(FlaskConfigurationSettings): pws_settings: PWSSettings = PWSSettings() auth_settings: AuthSettings = AuthSettings() session_settings: SessionSettings = SessionSettings() - redis_settings: Optional[RedisSettings] + redis_settings: RedisSettings = RedisSettings() metrics_settings: MetricsSettings = MetricsSettings() + cache_expiration_settings: CacheExpirationSettings = CacheExpirationSettings() secrets: ApplicationSecrets = ApplicationSecrets() @validator("redis_settings") @@ -256,78 +312,3 @@ def provide_application_config(self) -> ApplicationConfig: SettingsType = TypeVar( "SettingsType", bound=BaseSettings ) # Used to type hint the return value of load_settings below - - -@singleton -class YAMLSettingsLoader: - """ - Complex configuration is hard to express as environment variables; so, for everything else, there's YAML. - YAML files loaded this way expect stage-based configuration. - - Here is an example of a simple YAML file: - # foo.yml - - base: &base - foo: bar - baz: boop - - development: &development - <<: *base # Development uses all values from base - - eval: &eval # Eval uses all settings from development, but overrides the 'baz' setting. - <<: *development - baz: snap - - special: # Here is a special one-off stage that doesn't use anyone else's values - foo: blah - baz: also blah - - prod: - <<: *eval - foo: AH! - - The above configuration could be modeled and loaded: - - class FooSettings: - foo: str - baz: str - - settings = loader.load_settings('foo', output_type=FooSettings) - settings.foo # 'bar' - - settings = loader.load_settings('foo') - settings['foo'] # 'bar' - """ - - @inject - def __init__(self, app_config: ApplicationConfig): - self.app_config = app_config - - @property - def settings_dir(self) -> str: - return self.app_config.settings_dir - - def load_settings( - self, - settings_name: str, - output_type: Union[Type[SettingsType], Type[Dict]] = Dict, - ) -> Union[Dict, SettingsType]: - """ - Given a configuration name, looks up the setting file from ApplicationConfig.settings_dir, - and loads the stage declared by ApplicationConfig.stage - - If no output type is provided, the results will be in dict form. - """ - filename = os.path.join(self.settings_dir, f"{settings_name}.yml") - stage = self.app_config.stage - with open(filename) as f: - try: - settings = yaml.load(f, yaml.SafeLoader)[stage] - except KeyError as e: - raise KeyError( - f"{filename} has no configuration for stage '{stage}': {str(e)}" - ) - - if output_type is Dict: - return settings - return output_type.parse_obj(settings) diff --git a/husky_directory/services/object_store.py b/husky_directory/services/object_store.py new file mode 100644 index 0000000..df4c0fc --- /dev/null +++ b/husky_directory/services/object_store.py @@ -0,0 +1,122 @@ +import json +import time +from abc import ABC, abstractmethod +from copy import copy +from enum import Enum +from typing import Any, Optional + +from flask_injector import request +from injector import Module, provider +from pydantic import BaseModel +from redis import Redis + +from husky_directory.app_config import ApplicationConfig +from husky_directory.util import AppLoggerMixIn + + +class ObjectStorageInterface(AppLoggerMixIn, ABC): + """ + Basic interface that does nothing but declare + abstractions. + + It also provides a utility method that can convert + anything* into a string. + + *if the thing you want to convert can't be converted, + add a case in the normalize_object_data implementation below. + """ + + @abstractmethod + def get(self, key: str) -> Optional[str]: + return None + + @abstractmethod + def put(self, key: str, obj: Any, expire_after_seconds: Optional[int] = None): + pass + + @staticmethod + def normalize_object_data(obj: Any) -> str: + if isinstance(obj, BaseModel): + return obj.json() + if isinstance(obj, Enum): + obj = obj.value + if not isinstance(obj, str): + return json.dumps(obj) + return obj + + +class InMemoryObjectStorage(ObjectStorageInterface): + """ + Used when testing locally using flask itself, + cannot be shared between processes. This is a very + basic implementation which checks for key expiration + on every `put`. + """ + + def __init__(self): + self.__store__ = {} + self.__key_expirations__ = {} + + def validate_key_expiration(self, key: str): + expiration = self.__key_expirations__.get(key) + now = time.time() + if expiration: + max_elapsed = expiration["max"] + if not max_elapsed: + return + elapsed = now - expiration["stored"] + if elapsed > max_elapsed: + del self.__key_expirations__[key] + if key in self.__store__: + del self.__store__[key] + + def expire_keys(self): + for key in copy(self.__key_expirations__): + self.validate_key_expiration(key) + + def get(self, key: str) -> Optional[str]: + self.validate_key_expiration(key) + return self.__store__.get(key) + + def put(self, key: str, obj: Any, expire_after_seconds: Optional[int] = None): + self.expire_keys() + self.__store__[key] = self.normalize_object_data(obj) + now = time.time() + self.__key_expirations__[key] = {"stored": now, "max": expire_after_seconds} + return key + + +class RedisObjectStorage(ObjectStorageInterface): + def __init__(self, redis: Redis, config: ApplicationConfig): + self.redis = redis + self.prefix = f"{config.redis_settings.namespace}:obj" + + def normalize_key(self, key: str) -> str: + """Normalizes the key using the configured namespace.""" + if not key.startswith(self.prefix): + key = f"{self.prefix}:{key}" + return key + + def put(self, key: str, obj: Any, expire_after_seconds: Optional[int] = None): + key = self.normalize_key(key) + self.redis.set(key, self.normalize_object_data(obj), ex=expire_after_seconds) + return key + + def get(self, key: str) -> Optional[str]: + val = self.redis.get(self.normalize_key(key)) + if val: + if isinstance(val, bytes): + return val.decode("UTF-8") + return val + return None + + +class ObjectStoreInjectorModule(Module): + @request + @provider + def provide_object_store( + self, redis: Redis, config: ApplicationConfig + ) -> ObjectStorageInterface: + if config.redis_settings.host: + return RedisObjectStorage(redis, config) + return InMemoryObjectStorage() diff --git a/husky_directory/services/query_synchronizer.py b/husky_directory/services/query_synchronizer.py new file mode 100644 index 0000000..3899c60 --- /dev/null +++ b/husky_directory/services/query_synchronizer.py @@ -0,0 +1,117 @@ +import hashlib +import time +from contextlib import contextmanager +from enum import Enum + +from pydantic import BaseModel + +from husky_directory.app_config import CacheExpirationSettings +from husky_directory.services.object_store import ObjectStorageInterface + + +class QueryStatus(Enum): + in_progress = "in_progress" + completed = "completed" + not_found = "not_found" + error = "error" + + +class QuerySynchronizer: + """ + This service class provides functionality to lock queries to a given + processor (app worker), and other processes to subscribe to those results. + When users give us a query that takes longer than they expect, they often + will interrupt and retry the query. + + While this does not speed up the query in question, it does mean that + subsequent retries while the initial request is still in process will + cost essentially zero extra compute resources. + """ + + def __init__( + self, object_store: ObjectStorageInterface, config: CacheExpirationSettings + ): + self.cache = object_store + self.config = config + + def get_status(self, query_id: str) -> QueryStatus: + return QueryStatus( + self.cache.get(f"{query_id}:status") or QueryStatus.not_found + ) + + @contextmanager + def lock( + self, + query_id: str, + ): + """ + This is a `with` context that creates a status lock for the + given id, which is updated upon completion. Any attached + processes waiting for the query to complete can then + parse and return the results. + + If an error occurs in the calling code, the error message will + be stored in the cache for traceability. + + use: + sync = QuerySynchronizer(object_store, config) + with sync.lock('foo'): + result = do_processing_work() + sync.cache.put('foo', result) + """ + status_key = f"{query_id}:status" + self.cache.put( + status_key, + QueryStatus.in_progress.value, + expire_after_seconds=self.config.in_progress_status_expiration, + ) + try: + yield + self.cache.put( + status_key, + QueryStatus.completed.value, + expire_after_seconds=self.config.completed_status_expiration, + ) + except Exception as e: + self.cache.put( + status_key, + QueryStatus.error.value, + expire_after_seconds=self.config.error_status_expiration, + ) + self.cache.put( + f"{status_key}:message", + str(e), + expire_after_seconds=self.config.error_message_expiration, + ) + raise + + def attach(self, query_id: str) -> bool: + """ + Returns True iff the query was found and now has results waiting, + False otherwise. If the status is found to be in progress already, + it will ping for a new status every second. + + Use: + sync = QuerySynchronizer(...) + if sync.attach('foo'): + return ResultModel.parse_raw(sync.cache.get('foo')) + """ + query_status = self.get_status(query_id) + while query_status == QueryStatus.in_progress: + time.sleep(1) + query_status = self.get_status(query_id) + + return query_status == QueryStatus.completed + + @staticmethod + def get_model_digest(query_model: BaseModel) -> str: + """ + Creates a deterministic query id for a given input, + which allows requests on many servers to share a + query process. + """ + return hashlib.md5( + query_model.json( + exclude_unset=True, exclude_none=True, by_alias=True + ).encode("UTF-8") + ).hexdigest() diff --git a/husky_directory/services/search.py b/husky_directory/services/search.py index ace28a3..9ebcffe 100644 --- a/husky_directory/services/search.py +++ b/husky_directory/services/search.py @@ -5,6 +5,7 @@ from flask_injector import request from injector import inject +from husky_directory.app_config import ApplicationConfig from husky_directory.models.enum import AffiliationState from husky_directory.models.pws import ( ListPersonsInput, @@ -18,9 +19,12 @@ SearchDirectoryInput, SearchDirectoryOutput, ) +from husky_directory.models.transforms import ResultBucket from husky_directory.services.auth import AuthService +from husky_directory.services.object_store import ObjectStorageInterface from husky_directory.services.pws import PersonWebServiceClient from husky_directory.services.query_generator import SearchQueryGenerator +from husky_directory.services.query_synchronizer import QuerySynchronizer from husky_directory.services.reducer import NameSearchResultReducer from husky_directory.services.translator import ( ListPersonsOutputTranslator, @@ -38,12 +42,21 @@ def __init__( pws_translator: ListPersonsOutputTranslator, auth_service: AuthService, reducer: NameSearchResultReducer, + object_store: ObjectStorageInterface, + config: ApplicationConfig, ): self._pws = pws self.query_generator = query_generator self.pws_translator = pws_translator self.auth_service = auth_service self.reducer = reducer + self.object_store = object_store + self.cache_expiration_seconds = ( + config.redis_settings.default_cache_expire_seconds + ) + self.query_sync = QuerySynchronizer( + object_store, config.cache_expiration_settings + ) def get_listing(self, href: str) -> Person: return self.pws_translator.translate_person( @@ -90,8 +103,41 @@ def search_directory_experimental( num_user_search_tokens=len(request_input.name.split()), ) query = " ".join(f"*{token}*" for token in request_input.name.split()) - results = {} + results = self._process_query( + query, + request_input, + statistics, + ) + + statistics.num_duplicates_found = self.reducer.duplicate_hit_count + timer.context["statistics"] = statistics.dict(by_alias=True) + timer.stop(emit_log=True) + + return SearchDirectoryOutput( + scenarios=[ + DirectoryQueryScenarioOutput( + description=b.description, + populations=self.pws_translator.translate_bucket(b), + ) + for b in results.values() + ] + ) + + def _process_query( + self, + query: str, + request_input: SearchDirectoryInput, + statistics: ListPersonsRequestStatistics, + ) -> Dict[str, ResultBucket]: + """ + Factored out the meat of 'search_directory_experimental' for + a little better modularization. + Returns a dictionary of k:v, where k is the + queried population, and v is a ResultBucket instance containing + the scenario results. + """ + results = {} for population in request_input.requested_populations: pws_output: ListPersonsOutput = self._pws.list_persons( ListPersonsInput( @@ -118,21 +164,9 @@ def search_directory_experimental( results = self.reducer.reduce_output( pws_output, request_input.name, results ) - statistics.aggregate(pws_output.request_statistics) - statistics.num_duplicates_found = self.reducer.duplicate_hit_count - timer.context["statistics"] = statistics.dict(by_alias=True) - timer.stop(emit_log=True) - - return SearchDirectoryOutput( - scenarios=[ - DirectoryQueryScenarioOutput( - description=b.description, - populations=self.pws_translator.translate_bucket(b), - ) - for b in results.values() - ] - ) + statistics.aggregate(pws_output.request_statistics) + return results def search_directory_classic( self, request_input: SearchDirectoryInput @@ -152,7 +186,6 @@ def search_directory_classic( statistics = ListPersonsRequestStatistics() scenarios: List[DirectoryQueryScenarioOutput] = [] scenario_description_indexes: Dict[str, int] = {} - for generated in self.query_generator.generate(request_input): self.logger.debug( f"Querying: {generated.description} with " @@ -195,22 +228,50 @@ def search_directory_classic( scenarios.append(scenario_output) scenario_description_indexes[generated.description] = len(scenarios) - 1 - timer.context["statistics"] = statistics.dict(by_alias=True) - timer.stop(emit_log=True) + timer.context["statistics"] = statistics.dict(by_alias=True) + timer.stop(emit_log=True) return SearchDirectoryOutput(scenarios=scenarios) def search_directory( self, request_input: SearchDirectoryInput, ) -> SearchDirectoryOutput: - """The main interface for this service. Submits a query to PWS, filters and translates the output, - and returns a DirectoryQueryScenarioOutput.""" - - if ( - # Only name search is implemented in experimental mode right now. - request_input.name - # Wildcard searches are already accounted for in "classic" mode. - and "*" not in request_input.name - ): - return self.search_directory_experimental(request_input) - return self.search_directory_classic(request_input) + """ + The main interface for this service. + First, checks to see if this query was already submitted in another + request and currently running. If so, it will wait until that query completes, and + return the results. + + Otherwise, creates a lock for this query and + submits the query to PWS. The results are sorted and filtered + to ensure user privacy, data accuracy, and "scenario" categorization of results. + + The results are then cached for the configured amount of time + and returns a DirectoryQueryScenarioOutput instance. + """ + query_id = self.query_sync.get_model_digest(request_input) + query_key = f"query:{query_id}" + + if self.query_sync.attach(query_key): + return SearchDirectoryOutput.parse_raw(self.object_store.get(query_key)) + + with self.query_sync.lock(query_key): + if ( + # Only name search is implemented in experimental mode right now. + request_input.name + # Wildcard searches are already accounted for in "classic" mode. + and "*" not in request_input.name + ): + result = self.search_directory_experimental(request_input) + else: + result = self.search_directory_classic(request_input) + # Keep this `put` inside the `with` context, so that the + # status is not updated until the information + # has been written to memory, otherwise we run the risk of + # trying to pull the object off the cache before it's done + # being written. + self.object_store.put( + query_key, result, expire_after_seconds=self.cache_expiration_seconds + ) + + return result diff --git a/pyproject.toml b/pyproject.toml index 6442663..3d905ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "uw-husky-directory" -version = "2.1.10" +version = "2.2.0" description = "An updated version of the UW Directory" authors = ["Thomas Thorogood "] license = "MIT" diff --git a/tests/blueprints/test_search_blueprint.py b/tests/blueprints/test_search_blueprint.py index 7aa0e45..4149e9c 100644 --- a/tests/blueprints/test_search_blueprint.py +++ b/tests/blueprints/test_search_blueprint.py @@ -4,11 +4,10 @@ from typing import cast from unittest import mock -from flask import Response -from flask.testing import FlaskClient - import pytest from bs4 import BeautifulSoup +from flask import Response +from flask.testing import FlaskClient from inflection import titleize from werkzeug import exceptions from werkzeug.local import LocalProxy diff --git a/tests/services/test_object_store.py b/tests/services/test_object_store.py new file mode 100644 index 0000000..e70acd7 --- /dev/null +++ b/tests/services/test_object_store.py @@ -0,0 +1,69 @@ +import time +from typing import cast + +import pytest +from pydantic import create_model +from redis import Redis + +from husky_directory.app_config import ApplicationConfig +from husky_directory.services.object_store import ( + InMemoryObjectStorage, + ObjectStorageInterface, + RedisObjectStorage, +) +from husky_directory.services.query_synchronizer import QueryStatus + + +@pytest.mark.parametrize( + "redis_host, expected_type", + [ + (None, InMemoryObjectStorage), + ("localhost", RedisObjectStorage), + ], +) +def test_object_store_interface(injector, redis_host, expected_type): + settings = injector.get(ApplicationConfig) + settings.redis_settings.host = redis_host + assert type(injector.get(ObjectStorageInterface)) == expected_type + + +@pytest.mark.parametrize( + "obj, expected", + [ + (True, "true"), + (False, "false"), + (QueryStatus.completed, "completed"), + ({"foo": "bar"}, '{"foo": "bar"}'), + (create_model("FooModel", foo=(str, "bar"))(), '{"foo": "bar"}'), + ], +) +def test_normalize_object_data(obj, expected): + assert ObjectStorageInterface.normalize_object_data(obj) == expected + + +def test_local_interface(): + store = InMemoryObjectStorage() + store.put("foo", True, expire_after_seconds=1) + assert store.get("foo") == "true" + time.sleep(1.1) + assert not store.get("foo") + + +class MockRedis(InMemoryObjectStorage): + def set(self, key, val, ex=None): + val = self.normalize_object_data(val) + self.put(key, val, expire_after_seconds=ex) + + +def test_redis_interface(injector): + mock_redis_ = MockRedis() + cfg = injector.get(ApplicationConfig) + cfg.redis_settings.namespace = "uw-directory" + + store = RedisObjectStorage( + cast(Redis, mock_redis_), + cfg, + ) + store.put("hello", True, expire_after_seconds=None) + assert "uw-directory:obj:hello" in mock_redis_.__store__ + assert store.get("hello") == "true" diff --git a/tests/services/test_query_synchronizer.py b/tests/services/test_query_synchronizer.py new file mode 100644 index 0000000..ced3647 --- /dev/null +++ b/tests/services/test_query_synchronizer.py @@ -0,0 +1,53 @@ +import time + +import pytest + +from husky_directory.models.search import SearchDirectoryInput +from husky_directory.services.object_store import InMemoryObjectStorage +from husky_directory.services.query_synchronizer import QueryStatus, QuerySynchronizer + + +class TestQuerySynchronizer: + @pytest.fixture(autouse=True) + def initialize(self, injector, app_config): + self.config = app_config + self.cache = InMemoryObjectStorage() + self.request = SearchDirectoryInput(name="foo") + self.sync = QuerySynchronizer(self.cache, app_config.cache_expiration_settings) + self.query_id = self.sync.get_model_digest(self.request) + + def test_query_sync(self): + with self.sync.lock(self.query_id): + assert self.sync.get_status(self.query_id) == QueryStatus.in_progress + + assert self.sync.get_status(self.query_id) == QueryStatus.completed + + def test_query_sync_error(self): + with pytest.raises(RuntimeError): + with self.sync.lock(self.query_id): + raise RuntimeError("oh dear!") + + assert self.sync.get_status(self.query_id) == QueryStatus.error + assert self.cache.get(f"{self.query_id}:status:message") == "oh dear!" + + def test_attach_in_progress(self): + self.cache.put( + f"{self.query_id}:status", QueryStatus.in_progress, expire_after_seconds=2 + ) + before = time.time() + assert self.sync.attach(self.query_id) is False + # Make sure that we waited for the query status to change, + # which took +- 2 seconds seconds. + assert time.time() - before > 1 + + def test_attach_completed(self): + self.cache.put( + f"{self.query_id}:status", + QueryStatus.completed, + expire_after_seconds=2, + ) + before = time.time() + assert self.sync.attach(self.query_id) is True + # Make sure that we did not sleep, because the + # query was already completed. + assert time.time() - before < 1 diff --git a/tests/services/test_reducer.py b/tests/services/test_reducer.py index b698371..4d29bb8 100644 --- a/tests/services/test_reducer.py +++ b/tests/services/test_reducer.py @@ -2,8 +2,8 @@ from husky_directory.models.pws import ListPersonsOutput, NamedIdentity from husky_directory.services.reducer import ( - NameSearchResultReducer, NameQueryResultAnalyzer, + NameSearchResultReducer, ) diff --git a/tests/test_app.py b/tests/test_app.py index 7b9e281..1b8b256 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -2,10 +2,11 @@ from unittest import mock import pytest +from flask_session import RedisSessionInterface from pydantic import SecretStr -from husky_directory.app import create_app_injector -from husky_directory.app_config import ApplicationConfig +from husky_directory.app import create_app, create_app_injector +from husky_directory.app_config import ApplicationConfig, SessionType from husky_directory.services.pws import PersonWebServiceClient from husky_directory.services.search import DirectorySearchService @@ -83,6 +84,17 @@ def test_internal_server_error(client, injector, mock_injected): assert response.status_code == 500 +def test_session_interface_configuration(): + injector = create_app_injector() + app_config = injector.get(ApplicationConfig) + app_config.session_settings.session_type = SessionType.redis + app_config.redis_settings.host = "localhost" + app_config.redis_settings.password = SecretStr("s3kr1t") + assert isinstance( + getattr(create_app(injector), "session_interface", None), RedisSessionInterface + ) + + @pytest.mark.parametrize("auth_required", (True, False)) def test_prometheus_configuration( app_config: ApplicationConfig, client, auth_required: bool diff --git a/tests/test_app_config.py b/tests/test_app_config.py index b39f8b1..05e38ab 100644 --- a/tests/test_app_config.py +++ b/tests/test_app_config.py @@ -8,7 +8,6 @@ ApplicationConfig, ApplicationSecrets, RedisSettings, - YAMLSettingsLoader, ) @@ -35,6 +34,7 @@ def test_app_config_has_fields(self): "secrets", "metrics_settings", "show_experimental", + "cache_expiration_settings", } assert set(self.app_config.dict().keys()) == expected_field_names @@ -55,22 +55,6 @@ def test_app_secrets_has_fields(self): assert set(self.app_secrets.dict().keys()) == expected_field_names -class TestYAMLSettingsLoader: - @pytest.fixture(autouse=True) - def configure_base(self, injector: Injector, test_root_path: str): - self.loader = injector.get(YAMLSettingsLoader) - self.loader.app_config.settings_dir = os.path.join(test_root_path, "data") - self.loader.app_config.stage = "development" - - def test_yaml_settings_loader(self): - assert self.loader.load_settings("testconfig")["foo"] == "bar" - - def test_yaml_settings_loader_fails_wrong_stage(self): - self.loader.app_config.stage = "error" - with pytest.raises(KeyError): - self.test_yaml_settings_loader() - - class TestRedisSettings: def test_flask_config_values(self): settings = RedisSettings(host="redis", namespace="directory")