From cdffb5b0b6aff7283a4cf8fd46e97c67ababa36c Mon Sep 17 00:00:00 2001 From: Ye Cao Date: Mon, 18 Nov 2024 13:43:26 +0800 Subject: [PATCH] Support async put for vineyard client. Signed-off-by: Ye Cao --- python/vineyard/core/client.py | 42 +++++++++++++++++++++-- python/vineyard/core/tests/test_client.py | 39 +++++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/python/vineyard/core/client.py b/python/vineyard/core/client.py index 2bb0b1c2f..cd08edd30 100644 --- a/python/vineyard/core/client.py +++ b/python/vineyard/core/client.py @@ -168,6 +168,7 @@ def __init__( session: int = None, username: str = None, password: str = None, + max_workers: int = 8, config: str = None, ): """Connects to the vineyard IPC socket and RPC socket. @@ -211,6 +212,8 @@ def __init__( is enabled. password: Optional, the required password of vineyardd when authentication is enabled. + max_workers: Optional, the maximum number of threads that can be used to + asynchronously put objects to vineyard. Default is 8. config: Optional, can either be a path to a YAML configuration file or a path to a directory containing the default config file `vineyard-config.yaml`. Also, the environment variable @@ -290,6 +293,9 @@ def __init__( except VineyardException: continue + self._max_workers = max_workers + self._put_thread_pool = None + self._spread = False self._compression = True if self._ipc_client is None and self._rpc_client is None: @@ -347,6 +353,13 @@ def rpc_client(self) -> RPCClient: assert self._rpc_client is not None, "RPC client is not available." return self._rpc_client + @property + def put_thread_pool(self) -> ThreadPoolExecutor: + """Lazy initialization of the thread pool for asynchronous put.""" + if self._put_thread_pool is None: + self._put_thread_pool = ThreadPoolExecutor(max_workers=self._max_workers) + return self._put_thread_pool + def has_ipc_client(self): return self._ipc_client is not None @@ -820,8 +833,7 @@ def get( ): return get(self, object_id, name, resolver, fetch, **kwargs) - @_apply_docstring(put) - def put( + def _put_internal( self, value: Any, builder: Optional[BuilderContext] = None, @@ -858,6 +870,32 @@ def put( self.compression = previous_compression_state return put(self, value, builder, persist, name, **kwargs) + @_apply_docstring(put) + def put( + self, + value: Any, + builder: Optional[BuilderContext] = None, + persist: bool = False, + name: Optional[str] = None, + as_async: bool = False, + **kwargs, + ): + if as_async: + def _default_callback(future): + try: + result = future.result() + print(f"Successfully put object {result}", flush=True) + except Exception as e: + print(f"Failed to put object: {e}", flush=True) + + thread_pool = self.put_thread_pool + result = thread_pool.submit( + self._put_internal, value, builder, persist, name, **kwargs + ) + result.add_done_callback(_default_callback) + return result + return self._put_internal(value, builder, persist, name, **kwargs) + @contextlib.contextmanager def with_compression(self, enabled: bool = True): """Disable compression for the following put operations.""" diff --git a/python/vineyard/core/tests/test_client.py b/python/vineyard/core/tests/test_client.py index ee38eabca..c5a13a9ee 100644 --- a/python/vineyard/core/tests/test_client.py +++ b/python/vineyard/core/tests/test_client.py @@ -19,8 +19,10 @@ import itertools import multiprocessing import random +import time import traceback from concurrent.futures import ThreadPoolExecutor +from threading import Thread import numpy as np @@ -317,3 +319,40 @@ def test_memory_trim(vineyard_client): # there might be some fragmentation overhead assert parse_shared_memory_usage() <= original_memory_usage + 2 * data_kbytes + + +def test_async_put_and_get(vineyard_client): + data = np.ones((100, 100, 16)) + object_nums = 100 + + def producer(vineyard_client): + start_time = time.time() + client = vineyard_client.fork() + for i in range(object_nums): + client.put(data, name="test" + str(i), as_async=True, persist=True) + client.put(data) + end_time = time.time() + print("Producer time: ", end_time - start_time) + + def consumer(vineyard_client): + start_time = time.time() + client = vineyard_client.fork() + for i in range(object_nums): + object_id = client.get_name(name="test" + str(i), wait=True) + client.get(object_id) + end_time = time.time() + print("Consumer time: ", end_time - start_time) + + producer_thread = Thread(target=producer, args=(vineyard_client,)) + consumer_thread = Thread(target=consumer, args=(vineyard_client,)) + + start_time = time.time() + + producer_thread.start() + consumer_thread.start() + + producer_thread.join() + consumer_thread.join() + + end_time = time.time() + print("Total time: ", end_time - start_time)