diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py index 5f57b4b77b8af..8ebb03975795b 100644 --- a/homeassistant/components/http/__init__.py +++ b/homeassistant/components/http/__init__.py @@ -6,22 +6,18 @@ import logging import os import ssl -from typing import Optional, cast +from typing import Any, Optional, cast from aiohttp import web from aiohttp.web_exceptions import HTTPMovedPermanently import voluptuous as vol -from homeassistant.const import ( - EVENT_HOMEASSISTANT_START, - EVENT_HOMEASSISTANT_STOP, - SERVER_PORT, -) +from homeassistant.const import EVENT_HOMEASSISTANT_STOP, SERVER_PORT from homeassistant.core import Event, HomeAssistant from homeassistant.helpers import storage import homeassistant.helpers.config_validation as cv from homeassistant.loader import bind_hass -from homeassistant.setup import ATTR_COMPONENT, EVENT_COMPONENT_LOADED +from homeassistant.setup import async_start_setup, async_when_setup_or_start import homeassistant.util as hass_util from homeassistant.util import ssl as ssl_util @@ -161,36 +157,17 @@ async def async_setup(hass, config): ssl_profile=ssl_profile, ) - startup_listeners = [] - async def stop_server(event: Event) -> None: """Stop the server.""" await server.stop() - async def start_server(event: Event) -> None: + async def start_server(*_: Any) -> None: """Start the server.""" + with async_start_setup(hass, ["http"]): + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server) + await start_http_server_and_save_config(hass, dict(conf), server) - for listener in startup_listeners: - listener() - - hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server) - - await start_http_server_and_save_config(hass, dict(conf), server) - - async def async_wait_frontend_load(event: Event) -> None: - """Wait for the frontend to load.""" - - if event.data[ATTR_COMPONENT] != "frontend": - return - - await start_server(event) - - startup_listeners.append( - hass.bus.async_listen(EVENT_COMPONENT_LOADED, async_wait_frontend_load) - ) - startup_listeners.append( - hass.bus.async_listen(EVENT_HOMEASSISTANT_START, start_server) - ) + async_when_setup_or_start(hass, "frontend", start_server) hass.http = server diff --git a/homeassistant/setup.py b/homeassistant/setup.py index 6af20e21905f5..bead16c1d789e 100644 --- a/homeassistant/setup.py +++ b/homeassistant/setup.py @@ -10,7 +10,11 @@ from homeassistant import config as conf_util, core, loader, requirements from homeassistant.config import async_notify_setup_error -from homeassistant.const import EVENT_COMPONENT_LOADED, PLATFORM_FORMAT +from homeassistant.const import ( + EVENT_COMPONENT_LOADED, + EVENT_HOMEASSISTANT_START, + PLATFORM_FORMAT, +) from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.typing import ConfigType from homeassistant.util import dt as dt_util, ensure_unique_string @@ -379,6 +383,27 @@ def async_when_setup( when_setup_cb: Callable[[core.HomeAssistant, str], Awaitable[None]], ) -> None: """Call a method when a component is setup.""" + _async_when_setup(hass, component, when_setup_cb, False) + + +@core.callback +def async_when_setup_or_start( + hass: core.HomeAssistant, + component: str, + when_setup_cb: Callable[[core.HomeAssistant, str], Awaitable[None]], +) -> None: + """Call a method when a component is setup or state is fired.""" + _async_when_setup(hass, component, when_setup_cb, True) + + +@core.callback +def _async_when_setup( + hass: core.HomeAssistant, + component: str, + when_setup_cb: Callable[[core.HomeAssistant, str], Awaitable[None]], + start_event: bool, +) -> None: + """Call a method when a component is setup or the start event fires.""" async def when_setup() -> None: """Call the callback.""" @@ -387,22 +412,28 @@ async def when_setup() -> None: except Exception: # pylint: disable=broad-except _LOGGER.exception("Error handling when_setup callback for %s", component) - # Running it in a new task so that it always runs after if component in hass.config.components: hass.async_create_task(when_setup()) return - unsub = None + listeners: list[Callable] = [] - async def loaded_event(event: core.Event) -> None: - """Call the callback.""" - if event.data[ATTR_COMPONENT] != component: - return - - unsub() # type: ignore + async def _matched_event(event: core.Event) -> None: + """Call the callback when we matched an event.""" + for listener in listeners: + listener() await when_setup() - unsub = hass.bus.async_listen(EVENT_COMPONENT_LOADED, loaded_event) + async def _loaded_event(event: core.Event) -> None: + """Call the callback if we loaded the expected component.""" + if event.data[ATTR_COMPONENT] == component: + await _matched_event(event) + + listeners.append(hass.bus.async_listen(EVENT_COMPONENT_LOADED, _loaded_event)) + if start_event: + listeners.append( + hass.bus.async_listen(EVENT_HOMEASSISTANT_START, _matched_event) + ) @core.callback diff --git a/tests/test_setup.py b/tests/test_setup.py index 72613722ca1f1..d245c9818363b 100644 --- a/tests/test_setup.py +++ b/tests/test_setup.py @@ -556,6 +556,41 @@ async def mock_callback(hass, component): assert calls == ["test", "test"] +async def test_async_when_setup_or_start_already_loaded(hass): + """Test when setup or start.""" + calls = [] + + async def mock_callback(hass, component): + """Mock callback.""" + calls.append(component) + + setup.async_when_setup_or_start(hass, "test", mock_callback) + await hass.async_block_till_done() + assert calls == [] + + hass.config.components.add("test") + hass.bus.async_fire(EVENT_COMPONENT_LOADED, {"component": "test"}) + await hass.async_block_till_done() + assert calls == ["test"] + + # Event listener should be gone + hass.bus.async_fire(EVENT_COMPONENT_LOADED, {"component": "test"}) + await hass.async_block_till_done() + assert calls == ["test"] + + # Should be called right away + setup.async_when_setup_or_start(hass, "test", mock_callback) + await hass.async_block_till_done() + assert calls == ["test", "test"] + + setup.async_when_setup_or_start(hass, "not_loaded", mock_callback) + await hass.async_block_till_done() + assert calls == ["test", "test"] + hass.bus.async_fire(EVENT_HOMEASSISTANT_START) + await hass.async_block_till_done() + assert calls == ["test", "test", "not_loaded"] + + async def test_setup_import_blows_up(hass): """Test that we handle it correctly when importing integration blows up.""" with patch(