Skip to content

Commit

Permalink
WIP: Working prototyp for multiprocessing-based executor
Browse files Browse the repository at this point in the history
Still needs to be extracted into an executor and probably an adjusted
`UDFRunner`.

- properly drain and close queues
- use a single response queue - this simplifies result handling on the
  receiving side (might become bottleneck in many-core situations)
- perform decoding in the worker processes
- `ZeroMQReceiver`: zero-copy recv, at least for the payload
- some hacks to inject the receiver into the partition
  • Loading branch information
sk1p committed May 17, 2022
1 parent e5ae916 commit e91e601
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 120 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ data
.benchmarks
junit.xml
DEigerClient.py
profiles/
204 changes: 154 additions & 50 deletions prototypes/multip.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,31 @@
import os
from contextlib import contextmanager
from threading import Thread
from queue import Empty
import time
import multiprocessing as mp
from typing import Union
from typing import List, Union
from typing_extensions import Literal
import multiprocessing as mp

import numpy as np

from libertem.udf.base import UDFPartRunner, UDFRunner, UDFParams
from libertem.io.dataset.base.tiling_scheme import TilingScheme
from libertem.udf.sum import SumUDF
from libertem.udf.base import UDFPartRunner, UDFRunner, UDFParams, UDF
from libertem.common import Shape
from libertem.common.executor import Environment
from regex import E

from libertem_live.channel import WorkerPool, WorkerQueues
from libertem_live.detectors.dectris.acquisition import AcquisitionParams, DectrisAcquisition, DetectorConfig, Receiver, TriggerMode
from libertem_live.detectors.dectris.acquisition import (
AcquisitionParams, DectrisAcquisition, DetectorConfig, Receiver,
TriggerMode, RawFrame
)

try:
import prctl
except ImportError:
prctl = None

MSG_TYPES = Union[
Literal['BEGIN_PARTITION'],
Expand All @@ -20,6 +35,20 @@
]


def set_thread_name(name: str):
"""
Set a thread name; mostly useful for using system tools for profiling
Parameters
----------
name : str
The thread name
"""
if prctl is None:
return
prctl.set_name(name)


class OfflineReceiver(Receiver):
"""
Mock Receiver that reads from a numpy array
Expand Down Expand Up @@ -64,7 +93,7 @@ def acquire(self):
try:
self._acq_state = AcquisitionParams(
sequence_id=42,
nimages=128,
nimages=self.data.shape[0],
trigger_mode=self._trigger_mode
)
self.trigger() # <-- this triggers, either via API or via HW trigger
Expand All @@ -73,41 +102,54 @@ def acquire(self):
self._acq_state = None


class DectrisUDFPartRunner(UDFPartRunner):
pass


def decode(payload):
return payload # TODO: move decoding step here


def run_udf_on_frames(partition, frames, udfs, params):
for frame in frames:
pass # TODO
partition._receiver = frames # FIXME # proper interface for this one
runner = UDFPartRunner(udfs)
if True:
result = runner.run_for_partition(
partition=partition,
params=params,
env=Environment(threaded_executor=False, threads_per_worker=1),
)
for frame in frames:
raise RuntimeError("frames should be fully consumed here!")
return result
if False:
for frame in frames:
frame.sum()


def get_frames(partition, request):
"""
Consume all FRAME messages until we get an END_PARTITION message
Consume all FRAME messages from the request queue until we get an
END_PARTITION message (which we also consume)
"""
while True:
with request.get() as msg:
header, payload = msg
header_type = header["type"]
if header_type == "FRAME":
frame_arr = decode(payload)
raw_frame = RawFrame(
data=payload,
encoding=header['encoding'],
dtype=header['dtype'],
shape=header['shape'],
)
frame_arr = raw_frame.decode()
yield frame_arr
elif header_type == "END_PARTITION":
print(f"partition {partition} done")
# print(f"partition {partition} done")
return
else:
raise RuntimeError(f"invalid header type {header}; FRAME or END_PARTITION expected")


def worker(queues: WorkerQueues):
udfs = []
def worker(queues: WorkerQueues, idx: int):
udfs: List[UDF] = []
params = None

set_thread_name(f"worker-{idx}")

while True:
with queues.request.get() as msg:
header, payload = msg
Expand All @@ -124,57 +166,107 @@ def worker(queues: WorkerQueues):
params = header["params"]
continue
elif header_type == "SHUTDOWN":
queues.request.close()
queues.response.close()
break
elif header_type == "WARMUP":
env = Environment(threaded_executor=False, threads_per_worker=2)
with env.enter():
pass
else:
raise RuntimeError(f"unknown message {header}")


def feed_workers(pool: WorkerPool, aq: DectrisAcquisition):
set_thread_name("feed_workers")

print("distributing work")
t0 = time.time()
with aq.acquire():
print("creating and starting receiver")
receiver = aq.get_receiver()
receiver.start()
idx = 0
for partition in aq.get_partitions():
qs = pool.get_worker_queues(idx)
partition._receiver = None
qs.request.put({"type": "START_PARTITION", "partition": partition})
for frame_idx in range(partition.shape.nav.size):
raw_frame = next(receiver)
qs.request.put({
"type": "FRAME",
"shape": raw_frame.shape,
"dtype": raw_frame.dtype,
"encoding": raw_frame.encoding,
}, payload=np.frombuffer(raw_frame.data, dtype=np.uint8))
qs.request.put({"type": "END_PARTITION"})
idx = (idx + 1) % pool.size
t1 = time.time()
print(f"finished feeding workers in {t1 - t0}s")


def run_on_acquisition(aq: DectrisAcquisition):
pool = WorkerPool(processes=2, worker_fn=worker)
pool = WorkerPool(processes=7, worker_fn=worker)
from perf_utils import perf

ts = TilingScheme.make_for_shape(
tileshape=Shape((24, 512, 512), sig_dims=2),
dataset_shape=Shape((128, 128, 512, 512), sig_dims=2),
)

for qs in pool.all_worker_queues():
qs.request.put({
"type": "SET_UDFS_AND_PARAMS",
"udfs": [object()],
"params": UDFParams(corrections=None, roi=None, tiling_scheme=None, kwargs=[{}])
"udfs": [SumUDF()],
"params": UDFParams(corrections=None, roi=None, tiling_scheme=ts, kwargs=[{}])
})
qs.request.put({
"type": "WARMUP",
})

REPEATS = 1

receiver = aq.get_receiver()
receiver.start()
# if True:
with perf("multiprocess-dectris"):
for i in range(2):
msg_thread = Thread(target=feed_workers, args=(pool, aq))
msg_thread.name = "feed_workers"
msg_thread.daemon = True
msg_thread.start()

# with perf("looped"):
if True:
for i in range(REPEATS):
receiver = aq.get_receiver()
receiver.start()
num_partitions = int(aq.shape.nav.size // aq._frames_per_partition)
t0 = time.time()
idx = 0
for partition in aq.get_partitions():
qs = pool.get_worker_queues(idx)
partition._receiver = None
qs.request.put({"type": "START_PARTITION", "partition": partition})
for frame_idx in range(partition.shape.nav.size):
frame = next(receiver)
qs.request.put({"type": "FRAME"}, frame)
qs.request.put({"type": "END_PARTITION"})
idx = (idx + 1) % pool.size
# synchronization:
print("waiting for response...")
with qs.response.get() as response:
print(response)
print("gathering responses...")

num_responses = 0
get_timeout = 0.1
while num_responses < num_partitions:
try:
with pool.response_queue.get(block=True, timeout=get_timeout) as response:
resp_header, payload = response
assert payload is None
assert resp_header['type'] == "RESULT", f"resp_header == {resp_header}"
num_responses += 1
except Empty:
continue
t1 = time.time()
print(t1-t0)
print(f"Max SHM usage: {qs.request._psa._used/1024/1024}MiB")

for qs in pool.all_worker_queues():
qs.request.put({"type": "SHUTDOWN"})
# after a while, the msg_thread has sent all partitions and exits:
msg_thread.join()

pool.close_resp_queue()

pool.join_all()
# ... and we can shut down the workers:
for qs, p in pool.all_workers():
qs.request.put({"type": "SHUTDOWN"})
qs.request.close()
qs.response.close()
p.join()


if __name__ == "__main__":
set_thread_name("main")
print(f"main pid {os.getpid()}")
if False:
dataset_shape = Shape((512, 512, 512), sig_dims=2)
data = np.random.randn(*dataset_shape).astype(np.uint8)
Expand All @@ -198,5 +290,17 @@ def run_on_acquisition(aq: DectrisAcquisition):
trigger_mode="exte",
)
aq.initialize(None)
with aq.acquire():
print(aq.shape, aq.dtype)

if False:
t0 = time.time()
with aq.acquire():
r = aq.get_receiver()
r.start()
for frame in r:
pass
t1 = time.time()
print(t1-t0)

if True:
run_on_acquisition(aq)
Loading

0 comments on commit e91e601

Please sign in to comment.