From e91e60130443570b2e01cfece6e7f5293ab0511c Mon Sep 17 00:00:00 2001 From: Alexander Clausen Date: Tue, 17 May 2022 14:18:01 +0200 Subject: [PATCH] WIP: Working prototyp for multiprocessing-based executor Still needs to be extracted into an executor and probably an adjusted `UDFRunner`. - properly drain and close queues - use a single response queue - this simplifies result handling on the receiving side (might become bottleneck in many-core situations) - perform decoding in the worker processes - `ZeroMQReceiver`: zero-copy recv, at least for the payload - some hacks to inject the receiver into the partition --- .gitignore | 1 + prototypes/multip.py | 204 +++++++++++++----- src/libertem_live/channel.py | 103 +++++---- .../detectors/dectris/acquisition.py | 100 ++++++--- src/libertem_live/detectors/dectris/sim.py | 4 +- 5 files changed, 292 insertions(+), 120 deletions(-) diff --git a/.gitignore b/.gitignore index dc5e2a9b..ae8e925b 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,4 @@ data .benchmarks junit.xml DEigerClient.py +profiles/ diff --git a/prototypes/multip.py b/prototypes/multip.py index 21cf77ab..9a00d069 100644 --- a/prototypes/multip.py +++ b/prototypes/multip.py @@ -1,16 +1,31 @@ +import os from contextlib import contextmanager +from threading import Thread +from queue import Empty import time -import multiprocessing as mp -from typing import Union +from typing import List, Union from typing_extensions import Literal +import multiprocessing as mp import numpy as np -from libertem.udf.base import UDFPartRunner, UDFRunner, UDFParams +from libertem.io.dataset.base.tiling_scheme import TilingScheme +from libertem.udf.sum import SumUDF +from libertem.udf.base import UDFPartRunner, UDFRunner, UDFParams, UDF from libertem.common import Shape +from libertem.common.executor import Environment +from regex import E from libertem_live.channel import WorkerPool, WorkerQueues -from libertem_live.detectors.dectris.acquisition import AcquisitionParams, DectrisAcquisition, DetectorConfig, Receiver, TriggerMode +from libertem_live.detectors.dectris.acquisition import ( + AcquisitionParams, DectrisAcquisition, DetectorConfig, Receiver, + TriggerMode, RawFrame +) + +try: + import prctl +except ImportError: + prctl = None MSG_TYPES = Union[ Literal['BEGIN_PARTITION'], @@ -20,6 +35,20 @@ ] +def set_thread_name(name: str): + """ + Set a thread name; mostly useful for using system tools for profiling + + Parameters + ---------- + name : str + The thread name + """ + if prctl is None: + return + prctl.set_name(name) + + class OfflineReceiver(Receiver): """ Mock Receiver that reads from a numpy array @@ -64,7 +93,7 @@ def acquire(self): try: self._acq_state = AcquisitionParams( sequence_id=42, - nimages=128, + nimages=self.data.shape[0], trigger_mode=self._trigger_mode ) self.trigger() # <-- this triggers, either via API or via HW trigger @@ -73,41 +102,54 @@ def acquire(self): self._acq_state = None -class DectrisUDFPartRunner(UDFPartRunner): - pass - - -def decode(payload): - return payload # TODO: move decoding step here - - def run_udf_on_frames(partition, frames, udfs, params): - for frame in frames: - pass # TODO + partition._receiver = frames # FIXME # proper interface for this one + runner = UDFPartRunner(udfs) + if True: + result = runner.run_for_partition( + partition=partition, + params=params, + env=Environment(threaded_executor=False, threads_per_worker=1), + ) + for frame in frames: + raise RuntimeError("frames should be fully consumed here!") + return result + if False: + for frame in frames: + frame.sum() def get_frames(partition, request): """ - Consume all FRAME messages until we get an END_PARTITION message + Consume all FRAME messages from the request queue until we get an + END_PARTITION message (which we also consume) """ while True: with request.get() as msg: header, payload = msg header_type = header["type"] if header_type == "FRAME": - frame_arr = decode(payload) + raw_frame = RawFrame( + data=payload, + encoding=header['encoding'], + dtype=header['dtype'], + shape=header['shape'], + ) + frame_arr = raw_frame.decode() yield frame_arr elif header_type == "END_PARTITION": - print(f"partition {partition} done") + # print(f"partition {partition} done") return else: raise RuntimeError(f"invalid header type {header}; FRAME or END_PARTITION expected") -def worker(queues: WorkerQueues): - udfs = [] +def worker(queues: WorkerQueues, idx: int): + udfs: List[UDF] = [] params = None + set_thread_name(f"worker-{idx}") + while True: with queues.request.get() as msg: header, payload = msg @@ -124,57 +166,107 @@ def worker(queues: WorkerQueues): params = header["params"] continue elif header_type == "SHUTDOWN": + queues.request.close() + queues.response.close() break + elif header_type == "WARMUP": + env = Environment(threaded_executor=False, threads_per_worker=2) + with env.enter(): + pass else: raise RuntimeError(f"unknown message {header}") +def feed_workers(pool: WorkerPool, aq: DectrisAcquisition): + set_thread_name("feed_workers") + + print("distributing work") + t0 = time.time() + with aq.acquire(): + print("creating and starting receiver") + receiver = aq.get_receiver() + receiver.start() + idx = 0 + for partition in aq.get_partitions(): + qs = pool.get_worker_queues(idx) + partition._receiver = None + qs.request.put({"type": "START_PARTITION", "partition": partition}) + for frame_idx in range(partition.shape.nav.size): + raw_frame = next(receiver) + qs.request.put({ + "type": "FRAME", + "shape": raw_frame.shape, + "dtype": raw_frame.dtype, + "encoding": raw_frame.encoding, + }, payload=np.frombuffer(raw_frame.data, dtype=np.uint8)) + qs.request.put({"type": "END_PARTITION"}) + idx = (idx + 1) % pool.size + t1 = time.time() + print(f"finished feeding workers in {t1 - t0}s") + + def run_on_acquisition(aq: DectrisAcquisition): - pool = WorkerPool(processes=2, worker_fn=worker) + pool = WorkerPool(processes=7, worker_fn=worker) from perf_utils import perf + ts = TilingScheme.make_for_shape( + tileshape=Shape((24, 512, 512), sig_dims=2), + dataset_shape=Shape((128, 128, 512, 512), sig_dims=2), + ) + for qs in pool.all_worker_queues(): qs.request.put({ "type": "SET_UDFS_AND_PARAMS", - "udfs": [object()], - "params": UDFParams(corrections=None, roi=None, tiling_scheme=None, kwargs=[{}]) + "udfs": [SumUDF()], + "params": UDFParams(corrections=None, roi=None, tiling_scheme=ts, kwargs=[{}]) + }) + qs.request.put({ + "type": "WARMUP", }) - REPEATS = 1 - - receiver = aq.get_receiver() - receiver.start() + # if True: + with perf("multiprocess-dectris"): + for i in range(2): + msg_thread = Thread(target=feed_workers, args=(pool, aq)) + msg_thread.name = "feed_workers" + msg_thread.daemon = True + msg_thread.start() - # with perf("looped"): - if True: - for i in range(REPEATS): - receiver = aq.get_receiver() - receiver.start() + num_partitions = int(aq.shape.nav.size // aq._frames_per_partition) t0 = time.time() - idx = 0 - for partition in aq.get_partitions(): - qs = pool.get_worker_queues(idx) - partition._receiver = None - qs.request.put({"type": "START_PARTITION", "partition": partition}) - for frame_idx in range(partition.shape.nav.size): - frame = next(receiver) - qs.request.put({"type": "FRAME"}, frame) - qs.request.put({"type": "END_PARTITION"}) - idx = (idx + 1) % pool.size - # synchronization: - print("waiting for response...") - with qs.response.get() as response: - print(response) + print("gathering responses...") + + num_responses = 0 + get_timeout = 0.1 + while num_responses < num_partitions: + try: + with pool.response_queue.get(block=True, timeout=get_timeout) as response: + resp_header, payload = response + assert payload is None + assert resp_header['type'] == "RESULT", f"resp_header == {resp_header}" + num_responses += 1 + except Empty: + continue t1 = time.time() print(t1-t0) + print(f"Max SHM usage: {qs.request._psa._used/1024/1024}MiB") - for qs in pool.all_worker_queues(): - qs.request.put({"type": "SHUTDOWN"}) + # after a while, the msg_thread has sent all partitions and exits: + msg_thread.join() + + pool.close_resp_queue() - pool.join_all() + # ... and we can shut down the workers: + for qs, p in pool.all_workers(): + qs.request.put({"type": "SHUTDOWN"}) + qs.request.close() + qs.response.close() + p.join() if __name__ == "__main__": + set_thread_name("main") + print(f"main pid {os.getpid()}") if False: dataset_shape = Shape((512, 512, 512), sig_dims=2) data = np.random.randn(*dataset_shape).astype(np.uint8) @@ -198,5 +290,17 @@ def run_on_acquisition(aq: DectrisAcquisition): trigger_mode="exte", ) aq.initialize(None) - with aq.acquire(): + print(aq.shape, aq.dtype) + + if False: + t0 = time.time() + with aq.acquire(): + r = aq.get_receiver() + r.start() + for frame in r: + pass + t1 = time.time() + print(t1-t0) + + if True: run_on_acquisition(aq) diff --git a/src/libertem_live/channel.py b/src/libertem_live/channel.py index 32f025f2..5abf73db 100644 --- a/src/libertem_live/channel.py +++ b/src/libertem_live/channel.py @@ -5,7 +5,7 @@ from multiprocessing import shared_memory from multiprocessing.managers import SharedMemoryManager from queue import Empty -from typing import Callable, NamedTuple, Optional, Tuple +from typing import Callable, List, NamedTuple, Optional, Tuple import numpy as np @@ -16,6 +16,23 @@ class PoolAllocation(NamedTuple): full_size: int # full size of the allocation, in bytes (including padding) req_size: int # requested allocation size + def resize(self, new_req_size) -> "PoolAllocation": + assert new_req_size <= self.full_size + return PoolAllocation( + shm_name=self.shm_name, + handle=self.handle, + full_size=self.full_size, + req_size=new_req_size, + ) + + +def drain_queue(q: mp.Queue): + while True: + try: + q.get_nowait() + except Empty: + break + class PoolShmClient: def __init__(self): @@ -92,6 +109,7 @@ def __init__(self): self.release_q = mp.Queue() self._psa = None self._psc = None + self._closed = False def put(self, header, payload: Optional[memoryview] = None): """ @@ -114,23 +132,26 @@ def _copy_to_shm(self, src_buffer: memoryview) -> str: """ if self._psa is None: # FIXME: config item size, pool size - self._psa = PoolShmAllocator(item_size=512*512*4*2, size_num_items=4096) + self._psa = PoolShmAllocator(item_size=512*512*4*2, size_num_items=128*128) size = src_buffer.nbytes try: alloc_handle = self.release_q.get_nowait() + alloc_handle = alloc_handle.resize(size) except Empty: - alloc_handle = self._psa.allocate(src_buffer.nbytes) + alloc_handle = self._psa.allocate(size) payload_shm = self._psa.get(alloc_handle) - payload_arr = np.frombuffer(src_buffer, dtype=np.uint8) - payload_arr_shm = np.frombuffer(payload_shm[:size], dtype=np.uint8) - payload_arr_shm[:] = payload_arr + assert payload_shm.nbytes == size, f"{payload_shm.nbytes} != {size}" + src_arr = np.frombuffer(src_buffer, dtype=np.uint8) + arr_shm = np.frombuffer(payload_shm, dtype=np.uint8) + assert arr_shm.size == size, f"{arr_shm.size} != {size}" + arr_shm[:] = src_arr return alloc_handle def _get_named_shm(self, name: str) -> shared_memory.SharedMemory: return shared_memory.SharedMemory(name=name, create=False) @contextlib.contextmanager - def get(self, timeout: Optional[float] = None): + def get(self, block: bool = True, timeout: Optional[float] = None): """ Receive a message. Memory of the payload will be cleaned up after the context manager scope, so don't keep references outside of it! @@ -142,30 +163,35 @@ def get(self, timeout: Optional[float] = None): """ if self._psc is None: self._psc = PoolShmClient() - while True: - try: - header, typ, payload_handle = self.q.get(timeout=timeout) - except Empty: - continue - try: - if payload_handle is not None: - payload_buf = self._psc.get(payload_handle) - payload_memview = memoryview(payload_buf) - else: - payload_buf = None - payload_memview = None - if typ == "bytes": - yield (pickle.loads(header), payload_memview) - finally: - if payload_memview is not None: - payload_memview.release() - if payload_buf is not None: - self.release_q.put(payload_handle) - break + header, typ, payload_handle = self.q.get(block=block, timeout=timeout) + try: + if payload_handle is not None: + payload_buf = self._psc.get(payload_handle) + payload_memview = memoryview(payload_buf) + else: + payload_buf = None + payload_memview = None + if typ == "bytes": + yield (pickle.loads(header), payload_memview) + finally: + if payload_memview is not None: + payload_memview.release() + if payload_handle is not None: + self.release_q.put(payload_handle) def empty(self): return self.q.empty() + def close(self): + if not self._closed: + drain_queue(self.q) + self.q.close() + self.q.join_thread() + drain_queue(self.release_q) + self.release_q.close() + self.release_q.join_thread() + self._closed = True + class ChannelManager: def __init__(self): @@ -184,11 +210,16 @@ class WorkerQueues(NamedTuple): class WorkerPool: def __init__(self, processes: int, worker_fn: Callable): self._cm = ChannelManager() - self._workers: Tuple[WorkerQueues, mp.Process] = [] + self._workers: List[Tuple[WorkerQueues, mp.Process]] = [] self._worker_fn = worker_fn self._processes = processes + self._response_q = ShmQueue() self._start_workers() + @property + def response_queue(self): + return self._response_q + @property def size(self): return self._processes @@ -196,7 +227,7 @@ def size(self): def _start_workers(self): for i in range(self._processes): queues = self._make_worker_queues() - p = mp.Process(target=self._worker_fn, args=(queues,)) + p = mp.Process(target=self._worker_fn, args=(queues, i)) p.start() self._workers.append((queues, p)) @@ -204,21 +235,21 @@ def all_worker_queues(self): for (qs, _) in self._workers: yield qs + def all_workers(self): + return self._workers + def join_all(self): for (_, p) in self._workers: p.join() + def close_resp_queue(self): + self._response_q.close() + def get_worker_queues(self, idx) -> WorkerQueues: return self._workers[idx][0] - def poll_all(self) -> Optional[WorkerQueues]: - # FIXME: is there a way to just poll all queues together? - for (qs, p) in self._workers: - if not qs.response.empty(): - return qs - def _make_worker_queues(self): return WorkerQueues( request=ShmQueue(), - response=ShmQueue(), + response=self._response_q, ) diff --git a/src/libertem_live/detectors/dectris/acquisition.py b/src/libertem_live/detectors/dectris/acquisition.py index 40216ad9..a325526c 100644 --- a/src/libertem_live/detectors/dectris/acquisition.py +++ b/src/libertem_live/detectors/dectris/acquisition.py @@ -40,6 +40,22 @@ class DetectorConfig(NamedTuple): bit_depth: int +class RawFrame: + def __init__(self, data, encoding, dtype, shape): + self.data = data + self.dtype = dtype + self.shape = shape + self.encoding = encoding + + def decode(self): + return decode( + data=self.data, + encoding=self.encoding, + shape=self.shape, + dtype=self.dtype + ) + + class Receiver(Protocol): def __iter__(self): return self @@ -47,7 +63,7 @@ def __iter__(self): def start(self): pass - def __next__(self) -> np.ndarray: + def __next__(self) -> RawFrame: raise NotImplementedError(f"{self.__class__.__name__}.__next__ is not implemented") @@ -64,6 +80,33 @@ def get_darray(darray) -> np.ndarray: return np.frombuffer(data, dtype=dtype).reshape(shape) +def dtype_from_frame_header(header): + return np.dtype(header['type']).newbyteorder(header['encoding'][-1]) + + +def shape_from_frame_header(header): + return tuple(reversed(header['shape'])) + + +def decode(data, encoding, shape, dtype): + size = prod(shape) * dtype.itemsize + if encoding in ('bs32-lz4<', 'bs16-lz4<', 'bs8-lz4<'): + decompressed = bitshuffle.decompress_lz4( + np.frombuffer(data[12:], dtype=np.uint8), + shape=shape, + dtype=dtype, + block_size=0 + ) + elif encoding == 'lz4<': + decompressed = lz4.block.decompress(data, uncompressed_size=size) + decompressed = np.frombuffer(decompressed, dtype=dtype).reshape(shape) + elif encoding == '<': + decompressed = np.frombuffer(data, dtype=dtype).reshape(shape) + else: + raise RuntimeError(f'Unsupported encoding {encoding}') + return decompressed + + class ZeroMQReceiver(Receiver): def __init__(self, socket: zmq.Socket, params: Optional[AcquisitionParams]): self._socket = socket @@ -83,61 +126,50 @@ def recv(self): res = 0 while not res: res = self._socket.poll(100) - return self._socket.recv() + msg = self._socket.recv(copy=False) + return msg.buffer def receive_acquisition_header(self): while True: data = self.recv() try: - header_header = json.loads(data) + header_header = json.loads(bytes(data)) except (json.JSONDecodeError, UnicodeDecodeError): continue if ('header_detail' in header_header and header_header['series'] == self._params.sequence_id): break - header = json.loads(self.recv()) + header = json.loads(bytes(self.recv())) return header_header, header def receive_frame(self, frame_id): assert self._running, "need to be running to receive frames!" - header_header = json.loads(self.recv()) + header_header = json.loads(bytes(self.recv())) assert header_header['series'] == self._params.sequence_id assert header_header['frame'] == frame_id - header = json.loads(self.recv()) - shape = tuple(reversed(header['shape'])) - dtype = np.dtype(header['type']).newbyteorder(header['encoding'][-1]) - size = prod(shape) * dtype.itemsize + header = json.loads(bytes(self.recv())) data = self.recv() - if header['encoding'] in ('bs32-lz4<', 'bs16-lz4<', 'bs8-lz4<'): - decompressed = bitshuffle.decompress_lz4( - np.frombuffer(data[12:], dtype=np.uint8), - shape=shape, - dtype=dtype, - block_size=0 - ) - elif header['encoding'] == 'lz4<': - decompressed = lz4.block.decompress(data, uncompressed_size=size) - decompressed = np.frombuffer(decompressed, dtype=dtype).reshape(shape) - elif header['encoding'] == '<': - decompressed = np.frombuffer(data, dtype=dtype).reshape(shape) - else: - raise RuntimeError(f'Unsupported encoding {header["encoding"]}') - - footer = json.loads(self.recv()) - return header_header, header, decompressed, footer + footer = json.loads(bytes(self.recv())) + return header_header, header, data, footer def receive_acquisition_footer(self): footer = self.recv() return footer - def __next__(self) -> np.ndarray: + def __next__(self) -> RawFrame: + assert self._params is not None if self._frame_id >= self._params.nimages: self.receive_acquisition_footer() self._running = False raise StopIteration() - f_header_header, f_header, decompressed, f_footer = self.receive_frame(self._frame_id) + f_header_header, f_header, data, f_footer = self.receive_frame(self._frame_id) self._frame_id += 1 - return decompressed + return RawFrame( + data=data, + encoding=f_header['encoding'], + dtype=dtype_from_frame_header(f_header), + shape=shape_from_frame_header(f_header), + ) class DectrisAcquisition(AcquisitionMixin, DataSet): @@ -147,9 +179,9 @@ def __init__( api_port: int, data_host: str, data_port: int, + nav_shape: Optional[Tuple[int, ...]], trigger_mode: TriggerMode, trigger=lambda aq: None, - nav_shape: Optional[Tuple[int, ...]] = None, frames_per_partition: int = 128, enable_corrections: bool = False, name_pattern: Optional[str] = None, @@ -317,7 +349,8 @@ def get_partitions(self): slices = BasePartition.make_slices(self.shape, num_partitions) - receiver = self.get_receiver() + # receiver = self.get_receiver() + receiver = None for part_slice, start, stop in slices: yield DectrisLivePartition( @@ -332,7 +365,7 @@ def get_partitions(self): class DectrisLivePartition(Partition): def __init__( self, start_idx, end_idx, partition_slice, - meta, receiver: Receiver, + meta, receiver: Optional[Receiver] = None, ): super().__init__(meta=meta, partition_slice=partition_slice, io_backend=None, decoder=None) self._start_idx = start_idx @@ -366,7 +399,8 @@ def _get_tiles_fullframe(self, tiling_scheme: TilingScheme, dest_dtype="float32" buf = np.zeros((depth,) + tiling_scheme[0].shape, dtype=dest_dtype) buf_idx = 0 tile_start = self._start_idx - self._receiver.start() + # self._receiver.start() + assert self._receiver is not None while to_read > 0: # 1) put frame into tile buffer (including dtype conversion if needed) assert buf_idx < depth,\ diff --git a/src/libertem_live/detectors/dectris/sim.py b/src/libertem_live/detectors/dectris/sim.py index 71df5924..463178a0 100644 --- a/src/libertem_live/detectors/dectris/sim.py +++ b/src/libertem_live/detectors/dectris/sim.py @@ -55,7 +55,9 @@ def send_line(index, zmq_socket, mm): if self.is_stopped(): raise StopException("Server is stopped") res = zmq_socket.poll(100, flags=zmq.POLLOUT) - zmq_socket.send(data) + # XXX quite bizarrely, for this kind of data stream, using + # `send(..., copy=True)` is faster than `send(..., copy=False)`. + zmq_socket.send(data, copy=True) index += len(data) return index