diff --git a/distributed/core.py b/distributed/core.py index c4446826715..02d1d6fefa3 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -7,7 +7,6 @@ import sys import traceback import types -import uuid import warnings import weakref from collections import defaultdict @@ -145,7 +144,6 @@ class Server: default_ip: ClassVar[str] = "" default_port: ClassVar[int] = 0 - id: str blocked_handlers: list[str] handlers: dict[str, Callable] stream_handlers: dict[str, Callable] @@ -174,20 +172,15 @@ def __init__( timeout=None, ): self.handlers = { - "identity": self.identity, "echo": self.echo, + "identity": self.identity, "connection_stream": self.handle_stream, } self.handlers.update(handlers) - if blocked_handlers is None: - blocked_handlers = dask.config.get( - "distributed.%s.blocked-handlers" % type(self).__name__.lower(), [] - ) - self.blocked_handlers = blocked_handlers + self.blocked_handlers = blocked_handlers or {} self.stream_handlers = {} self.stream_handlers.update(stream_handlers or {}) - self.id = type(self).__name__ + "-" + str(uuid.uuid4()) self._address = None self._listen_address = None self._port = None @@ -350,7 +343,7 @@ def port(self): return self._port def identity(self) -> dict[str, str]: - return {"type": type(self).__name__, "id": self.id} + return {"type": type(self).__name__, "id": str(id(self))} def echo(self, data=None): return data diff --git a/distributed/event.py b/distributed/event.py index 145abbf3857..9cad4f0d59a 100644 --- a/distributed/event.py +++ b/distributed/event.py @@ -50,7 +50,7 @@ def __init__(self, scheduler): # we can remove the event self._waiter_count = defaultdict(int) - self.scheduler.handlers.update( + self.scheduler.server.handlers.update( { "event_wait": self.event_wait, "event_set": self.event_set, diff --git a/distributed/http/scheduler/json.py b/distributed/http/scheduler/json.py index 932734f56a7..086cfbbbaae 100644 --- a/distributed/http/scheduler/json.py +++ b/distributed/http/scheduler/json.py @@ -55,7 +55,7 @@ def get(self): class IdentityJSON(RequestHandler): def get(self): - self.write(self.server.identity()) + self.write(self.identity()) class IndexJSON(RequestHandler): diff --git a/distributed/lock.py b/distributed/lock.py index 99ec34cd6f7..42d326dd413 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -27,7 +27,7 @@ def __init__(self, scheduler): self.events = defaultdict(deque) self.ids = dict() - self.scheduler.handlers.update( + self.scheduler.server.handlers.update( {"lock_acquire": self.acquire, "lock_release": self.release} ) diff --git a/distributed/nanny.py b/distributed/nanny.py index ef0d9d913fc..7b0a3ea0089 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -334,7 +334,7 @@ async def start_unsafe(self): security=self.security, ) try: - await self.listen( + await self.server.listen( start_address, **self.security.get_listen_args("worker") ) except OSError as e: diff --git a/distributed/node.py b/distributed/node.py index 7075565598f..5a3d052c901 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -546,12 +546,20 @@ def __init__( _handlers = { "dump_state": self._to_dict, + "identity": self.identity, } if handlers: _handlers.update(handlers) + import uuid + self.id = type(self).__name__ + "-" + str(uuid.uuid4()) + + if blocked_handlers is None: + blocked_handlers = dask.config.get( + "distributed.%s.blocked-handlers" % type(self).__name__.lower(), [] + ) self.server = Server( - handlers=handlers, + handlers=_handlers, blocked_handlers=blocked_handlers, stream_handlers=stream_handlers, connection_limit=connection_limit, @@ -566,6 +574,17 @@ def __init__( needs_workdir=needs_workdir, ) + def identity(self) -> dict[str, str]: + return {"type": type(self).__name__, "id": self.id} + + @property + def port(self): + return self.server.port + + @property + def listen_address(self): + return self.server.address + @property def address(self): return self.server.address @@ -574,10 +593,6 @@ def address(self): def address_safe(self): return self.server.address_safe - @property - def id(self): - return self.server.id - async def start_unsafe(self): await self.server await super().start_unsafe() @@ -592,6 +607,7 @@ async def close(self, reason: str | None = None) -> None: # Close network connections and background tasks await self.server.close() await Node.close(self, reason=reason) + self.status = Status.closed finally: self._event_finished.set() @@ -605,7 +621,7 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict[str, Any]: Client.dump_cluster_state distributed.utils.recursive_to_dict """ - info: dict[str, Any] = self.server.identity() + info: dict[str, Any] = self.identity() extra = { "address": self.server.address, "status": self.status.name, diff --git a/distributed/pubsub.py b/distributed/pubsub.py index de5ca9f401d..aee993354b2 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -122,7 +122,7 @@ class PubSubWorkerExtension: def __init__(self, worker): self.worker = worker - self.worker.stream_handlers.update( + self.worker.server.stream_handlers.update( { "pubsub-add-subscriber": self.add_subscriber, "pubsub-remove-subscriber": self.remove_subscriber, diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 551a70db1e0..68bc72d2347 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -50,6 +50,7 @@ valmap, ) from tornado.ioloop import IOLoop +from typing_extensions import Self import dask from dask.core import get_deps, validate_key @@ -4027,7 +4028,7 @@ def get_worker_service_addr( else: return ws.host, port - async def start_unsafe(self): + async def start_unsafe(self) -> Self: """Clear out old state and restart all running coroutines""" await super().start_unsafe() @@ -4042,7 +4043,7 @@ async def start_unsafe(self): handshake_overrides={"pickle-protocol": 4, "compression": None}, **self.security.get_listen_args("scheduler"), ) - self.ip = get_address_host(self.listen_address) + self.ip = get_address_host(self.server.listen_address) listen_ip = self.ip if listen_ip == "0.0.0.0": @@ -4054,7 +4055,7 @@ async def start_unsafe(self): # Services listen on all addresses self.start_services(listen_ip) - for listener in self.listeners: + for listener in self.server.listeners: logger.info(" Scheduler at: %25s", listener.contact_address) for name, server in self.services.items(): if name == "dashboard": @@ -4089,8 +4090,10 @@ def del_scheduler_file(): if self.jupyter: # Allow insecure communications from local users if self.server.address.startswith("tls://"): - await self.listen("tcp://localhost:0") - os.environ["DASK_SCHEDULER_ADDRESS"] = self.listeners[-1].contact_address + await self.server.listen("tcp://localhost:0") + os.environ["DASK_SCHEDULER_ADDRESS"] = self.server.listeners[ + -1 + ].contact_address await asyncio.gather( *[plugin.start(self) for plugin in list(self.plugins.values())] diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 39ae24c2dc2..2865bda3528 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -16,7 +16,6 @@ from distributed.batched import BatchedSend from distributed.comm.core import CommClosedError -from distributed.comm.registry import backends from distributed.comm.tcp import TCPBackend, TCPListener from distributed.core import ( AsyncTaskGroup, @@ -1287,15 +1286,6 @@ class TCPAsyncListenerBackend(TCPBackend): _listener_class = AsyncStopTCPListener -@gen_test() -async def test_async_listener_stop(monkeypatch): - monkeypatch.setitem(backends, "tcp", TCPAsyncListenerBackend()) - with pytest.warns(DeprecationWarning): - async with Server({}) as s: - await s.listen(0) - assert s.listeners - - @gen_test() async def test_messages_are_ordered_bsend(): ledger = [] diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 74d049a35d4..1be66db2541 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -77,7 +77,6 @@ async def test_nanny_process_failure(c, s): assert not os.path.exists(second_dir) assert not os.path.exists(first_dir) assert first_dir != n.worker_dir - s.stop() @gen_cluster(nthreads=[]) @@ -201,10 +200,9 @@ def func(dask_worker): @gen_test() async def test_scheduler_file(): with tmpfile() as fn: - s = await Scheduler(scheduler_file=fn, dashboard_address=":0") - async with Nanny(scheduler_file=fn) as n: - assert set(s.workers) == {n.worker_address} - s.stop() + async with Scheduler(scheduler_file=fn, dashboard_address=":0") as s: + async with Nanny(scheduler_file=fn) as n: + assert set(s.workers) == {n.worker_address} @pytest.mark.xfail( diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index b12cedc8452..8d4590e799a 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -824,8 +824,8 @@ async def test_retire_workers_concurrently(c, s, w1, w2): async def test_server_listens_to_other_ops(s, a, b): async with rpc(s.address) as r: ident = await r.identity() - assert ident["type"] == "Scheduler" - assert ident["id"].lower().startswith("scheduler") + assert ident["type"] == "Scheduler", ident["type"] + assert ident["id"].lower().startswith("scheduler"), ident["id"] @gen_cluster(client=True) @@ -928,7 +928,7 @@ def func(scheduler): nthreads=[], config={"distributed.scheduler.blocked-handlers": ["test-handler"]} ) async def test_scheduler_init_pulls_blocked_handlers_from_config(s): - assert s.blocked_handlers == ["test-handler"] + assert s.server.blocked_handlers == ["test-handler"] @gen_cluster() @@ -1326,7 +1326,7 @@ async def test_broadcast_nanny(s, a, b): @gen_cluster(config={"distributed.comm.timeouts.connect": "200ms"}) async def test_broadcast_on_error(s, a, b): - a.stop() + a.server.stop() with pytest.raises(OSError): await s.broadcast(msg={"op": "ping"}, on_error="raise") @@ -2007,7 +2007,7 @@ async def test_profile_metadata_timeout(c, s, a, b): def raise_timeout(*args, **kwargs): raise TimeoutError - b.handlers["profile_metadata"] = raise_timeout + b.server.handlers["profile_metadata"] = raise_timeout futures = c.map(slowinc, range(10), delay=0.05, workers=a.address) await wait(futures) @@ -2071,7 +2071,7 @@ async def test_statistical_profiling_failure(c, s, a, b): def raise_timeout(*args, **kwargs): raise TimeoutError - b.handlers["profile"] = raise_timeout + b.server.handlers["profile"] = raise_timeout await wait(futures) profile = await s.get_profile() @@ -3050,7 +3050,7 @@ async def connect(self, *args, **kwargs): async def test_gather_failing_cnn_recover(c, s, a, b): x = await c.scatter({"x": 1}, workers=a.address) rpc = await FlakyConnectionPool(failing_connections=1) - with mock.patch.object(s, "rpc", rpc), dask.config.set( + with mock.patch.object(s.server, "rpc", rpc), dask.config.set( {"distributed.comm.retry.count": 1} ), captured_handler( logging.getLogger("distributed").handlers[0] @@ -3068,7 +3068,7 @@ async def test_gather_failing_cnn_recover(c, s, a, b): async def test_gather_failing_cnn_error(c, s, a, b): x = await c.scatter({"x": 1}, workers=a.address) rpc = await FlakyConnectionPool(failing_connections=10) - with mock.patch.object(s, "rpc", rpc): + with mock.patch.object(s.server, "rpc", rpc): res = await s.gather(keys=["x"]) assert res["status"] == "error" assert list(res["keys"]) == ["x"] @@ -3101,7 +3101,7 @@ async def test_gather_bad_worker(c, s, a, direct): """ x = c.submit(inc, 1, key="x") c.rpc = await FlakyConnectionPool(failing_connections=3) - s.rpc = await FlakyConnectionPool(failing_connections=1) + s.server.rpc = await FlakyConnectionPool(failing_connections=1) with captured_logger("distributed.scheduler") as sched_logger: with captured_logger("distributed.client") as client_logger: @@ -3116,12 +3116,12 @@ async def test_gather_bad_worker(c, s, a, direct): # 3. try direct=True again; fail # 4. fall back to direct=False again; success assert c.rpc.cnn_count == 2 - assert s.rpc.cnn_count == 2 + assert s.server.rpc.cnn_count == 2 else: # 1. try direct=False; fail # 2. try again direct=False; success assert c.rpc.cnn_count == 0 - assert s.rpc.cnn_count == 2 + assert s.server.rpc.cnn_count == 2 @gen_cluster(client=True) @@ -3152,8 +3152,8 @@ async def test_multiple_listeners(dashboard_link_template, expected_dashboard_li async with Scheduler( dashboard_address=":0", protocol=["inproc", "tcp"] ) as s: - async with Worker(s.listeners[0].contact_address) as a: - async with Worker(s.listeners[1].contact_address) as b: + async with Worker(s.server.listeners[0].contact_address) as a: + async with Worker(s.server.listeners[1].contact_address) as b: assert a.address.startswith("inproc") assert a.scheduler.address.startswith("inproc") assert b.address.startswith("tcp") @@ -4602,7 +4602,7 @@ class BrokenGatherDep(Worker): async def gather_dep(self, worker, *args, **kwargs): w = workers.pop(worker, None) if w is not None and workers: - w.listener.stop() + w.server.listener.stop() s.stream_comms[worker].abort() return await super().gather_dep(worker, *args, **kwargs) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 94e18af460c..b2a09ef02bf 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -856,7 +856,6 @@ async def end_worker(w): await asyncio.gather(*(end_worker(w) for w in workers)) await s.close() # wait until scheduler stops completely - s.stop() check_invalid_worker_transitions(s) check_invalid_task_states(s) check_worker_fail_hard(s) diff --git a/distributed/worker.py b/distributed/worker.py index 08b0ba581e5..cb77c2c0510 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -43,6 +43,7 @@ from tlz import keymap, pluck from tornado.ioloop import IOLoop +from typing_extensions import Self import dask from dask.core import istask @@ -1363,7 +1364,7 @@ def get_monitor_info(self, recent: bool = False, start: int = 0) -> dict[str, An # Lifecycle # ############# - async def start_unsafe(self): + async def start_unsafe(self) -> Self: await super().start_unsafe() enable_gc_diagnosis() @@ -1384,7 +1385,7 @@ async def start_unsafe(self): get_address_host(self.scheduler.address) ) try: - await self.listen(start_address, **kwargs) + await self.server.listen(start_address, **kwargs) except OSError as e: if len(ports) > 1 and e.errno == errno.EADDRINUSE: continue @@ -1431,9 +1432,13 @@ async def start_unsafe(self): self.start_services(self.ip) try: - listening_address = "%s%s:%d" % (self.listener.prefix, self.ip, self.port) + listening_address = "%s%s:%d" % ( + self.server.listener.prefix, + self.ip, + self.server.port, + ) except Exception: - listening_address = f"{self.listener.prefix}{self.ip}" + listening_address = f"{self.server.listener.prefix}{self.ip}" logger.info(" Start worker at: %26s", self.server.address) logger.info(" Listening to: %26s", listening_address)