Skip to content

Commit

Permalink
fix:xread blocking (#275)
Browse files Browse the repository at this point in the history
fix #274
  • Loading branch information
cunla authored Jan 19, 2024
1 parent 1abca94 commit a057834
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 14 deletions.
27 changes: 13 additions & 14 deletions fakeredis/commands_mixins/streams_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,19 @@ def xrevrange(self, key, _min, _max, *args):
(count,), _ = extract_args(args, ("+count",))
return self._xrange(key.value, _max, _min, True, count)

def _xread(self, stream_start_id_list: List, count: int, first_pass: bool):
def _xread(self, stream_start_id_list: List[Tuple[bytes, StreamRangeTest]], count: int, blocking: bool,
first_pass: bool):
max_inf = StreamRangeTest.decode(b"+")
res = list()
for item, start_id in stream_start_id_list:
for stream_name, start_id in stream_start_id_list:
item = CommandItem(stream_name, self._db, item=self._db.get(stream_name), default=None)
stream_results = self._xrange(item.value, start_id, max_inf, False, count)
if first_pass and (count is None):
return None
if len(stream_results) > 0:
res.append([item.key, stream_results])
if blocking and count and len(res) == 0:
return None
return res

def _xreadgroup(
Expand All @@ -135,6 +139,8 @@ def _xreadgroup(
@staticmethod
def _parse_start_id(key: CommandItem, s: bytes) -> StreamRangeTest:
if s == b"$":
if key.value is None:
return StreamRangeTest.decode(b"0-0")
return StreamRangeTest.decode(key.value.last_item_key(), exclusive=True)
return StreamRangeTest.decode(s, exclusive=True)

Expand All @@ -146,23 +152,16 @@ def xread(self, *args):
left_args = left_args[1:]
num_streams = int(len(left_args) / 2)

stream_start_id_list = list()
stream_start_id_list: List[Tuple[bytes, StreamRangeTest]] = list() # (name, start_id)
for i in range(num_streams):
item = CommandItem(
left_args[i], self._db, item=self._db.get(left_args[i]), default=None
)
item = CommandItem(left_args[i], self._db, item=self._db.get(left_args[i]), default=None)
start_id = self._parse_start_id(item, left_args[i + num_streams])
stream_start_id_list.append(
(
item,
start_id,
)
)
stream_start_id_list.append((left_args[i], start_id))
if timeout is None:
return self._xread(stream_start_id_list, count, False)
return self._xread(stream_start_id_list, count, blocking=False, first_pass=False)
else:
return self._blocking(
timeout / 1000.0, functools.partial(self._xread, stream_start_id_list, count)
timeout / 1000.0, functools.partial(self._xread, stream_start_id_list, count, True)
)

@command(name="XREADGROUP", fixed=(bytes, bytes, bytes), repeat=(bytes,))
Expand Down
26 changes: 26 additions & 0 deletions test/test_mixins/test_streams_commands.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import threading
import time
from typing import List

Expand Down Expand Up @@ -744,3 +745,28 @@ def test_xclaim(r: redis.Redis):
stream, group, consumer1,
min_idle_time=0, message_ids=(message_id,), justid=True,
) == [message_id, ]


def test_xread_blocking(create_redis):
# thread with xread block 0 should hang
# putting data in the stream should unblock it
event = threading.Event()
event.clear()

def thread_func():
while not event.is_set():
time.sleep(0.1)
r = create_redis(db=1)
r.xadd("stream", {"x": "1"})
time.sleep(0.1)

t = threading.Thread(target=thread_func)
t.start()
r1 = create_redis(db=1)
event.set()
result = r1.xread({"stream": "$"}, block=0, count=1)
event.clear()
t.join()
assert result[0][0] == b"stream"
assert result[0][1][0][1] == {b'x': b'1'}
pass

0 comments on commit a057834

Please sign in to comment.