Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
cunla committed Jan 8, 2024
1 parent dc21155 commit 07b2c54
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 83 deletions.
1 change: 0 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ a method returning the same connection.
```python
import django_rq


# RQ
# Configuration to pretend there is a Redis service available.
# Set up the connection before RQ Django reads the settings.
Expand Down
4 changes: 2 additions & 2 deletions fakeredis/aioredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import sys
from typing import Union, Optional
from typing import Union, Optional, Any

from ._server import FakeBaseConnectionMixin

Expand Down Expand Up @@ -31,7 +31,7 @@ def _decode_error(self, error):
parser = DefaultParser(1)
return parser.parse_error(error.value)

def put_response(self, msg) -> None:
def put_response(self, msg:Any) -> None:
if not self.responses:
return
self.responses.put_nowait(msg)
Expand Down
12 changes: 6 additions & 6 deletions fakeredis/commands_mixins/bitmap_mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Tuple
from typing import Tuple, Any

from fakeredis import _msgs as msgs
from fakeredis._commands import (
Expand All @@ -9,7 +9,7 @@
BitOffset,
BitValue,
fix_range_string,
fix_range,
fix_range, CommandItem,
)
from fakeredis._helpers import SimpleError, casematch

Expand All @@ -18,7 +18,7 @@ class BitfieldEncoding:
signed: bool
size: int

def __init__(self, encoding):
def __init__(self, encoding: bytes) -> None:
match = re.match(br'^([ui])(\d+)$', encoding)
if match is None:
raise SimpleError(msgs.INVALID_BITFIELD_TYPE)
Expand All @@ -33,16 +33,16 @@ def __init__(self, encoding):
class BitmapCommandsMixin:
# TODO: bitfield, bitfield_ro, bitpos

def __init(self, *args, **kwargs):
def __init(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.version: Tuple[int]

@staticmethod
def _bytes_as_bin_string(value):
def _bytes_as_bin_string(value: bytes) -> str:
return "".join([bin(i).lstrip("0b").rjust(8, "0") for i in value])

@command((Key(bytes), Int), (bytes,))
def bitpos(self, key, bit, *args):
def bitpos(self, key: CommandItem, bit: int, *args: bytes) -> int:
if bit != 0 and bit != 1:
raise SimpleError(msgs.BIT_ARG_MUST_BE_ZERO_OR_ONE)
if len(args) > 3:
Expand Down
10 changes: 5 additions & 5 deletions fakeredis/commands_mixins/connection_mixin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, List, Union

from fakeredis import _msgs as msgs
from fakeredis._commands import command, DbIndex
Expand All @@ -19,11 +19,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._server: Any

@command((bytes,))
def echo(self, message):
def echo(self, message: bytes) -> bytes:
return message

@command((), (bytes,))
def ping(self, *args):
def ping(self, *args: bytes) -> Union[List[bytes], bytes, SimpleString]:
if len(args) > 1:
msg = msgs.WRONG_ARGS_MSG6.format("ping")
raise SimpleError(msg)
Expand All @@ -33,7 +33,7 @@ def ping(self, *args):
return args[0] if args else PONG

@command((DbIndex,))
def select(self, index):
def select(self, index: DbIndex) -> SimpleString:
self._db = self._server.dbs[index]
self._db_num = index
self._db_num = index # type: ignore
return OK
75 changes: 35 additions & 40 deletions fakeredis/commands_mixins/hash_mixin.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
import itertools
import math
import random
from typing import Callable
from typing import Callable, List, Tuple, Any, Optional

from fakeredis import _msgs as msgs
from fakeredis._commands import command, Key, Hash, Int, Float
from fakeredis._helpers import SimpleError, OK, casematch
from fakeredis._commands import command, Key, Hash, Int, Float, CommandItem
from fakeredis._helpers import SimpleError, OK, casematch, SimpleString


class HashCommandsMixin:
_encodeint: Callable[[int, ], bytes]
_encodefloat: Callable[[float, bool], bytes]
_scan: Callable[[CommandItem, int, bytes, bytes], Tuple[int, List[bytes]]]

def _hset(self, key: CommandItem, *args: bytes) -> int:
h = key.value
keys_count = len(h.keys())
h.update(
dict(zip(*[iter(args)] * 2)) # type: ignore
) # https://stackoverflow.com/a/12739974/1056460
created = len(h.keys()) - keys_count

key.updated()
return created

@command((Key(Hash), bytes), (bytes,))
def hdel(self, key, *fields):
def hdel(self, key: CommandItem, *fields: bytes) -> int:
h = key.value
rem = 0
for field in fields:
Expand All @@ -24,20 +36,20 @@ def hdel(self, key, *fields):
return rem

@command((Key(Hash), bytes))
def hexists(self, key, field):
def hexists(self, key: CommandItem, field: bytes) -> int:
return int(field in key.value)

@command((Key(Hash), bytes))
def hget(self, key, field):
def hget(self, key: CommandItem, field: bytes) -> Any:
return key.value.get(field)

@command((Key(Hash),))
def hgetall(self, key):
def hgetall(self, key: CommandItem) -> List[bytes]:
return list(itertools.chain(*key.value.items()))

@command(fixed=(Key(Hash), bytes, bytes))
def hincrby(self, key, field, amount):
amount = Int.decode(amount)
def hincrby(self, key: CommandItem, field: bytes, amount_bytes: bytes) -> int:
amount = Int.decode(amount_bytes)
field_value = Int.decode(
key.value.get(field, b"0"), decode_error=msgs.INVALID_HASH_MSG
)
Expand All @@ -47,7 +59,7 @@ def hincrby(self, key, field, amount):
return c

@command((Key(Hash), bytes, bytes))
def hincrbyfloat(self, key, field, amount):
def hincrbyfloat(self, key: CommandItem, field: bytes, amount: bytes) -> bytes:
c = Float.decode(key.value.get(field, b"0")) + Float.decode(amount)
if not math.isfinite(c):
raise SimpleError(msgs.NONFINITE_MSG)
Expand All @@ -57,30 +69,24 @@ def hincrbyfloat(self, key, field, amount):
return encoded

@command((Key(Hash),))
def hkeys(self, key):
def hkeys(self, key: CommandItem) -> List[bytes]:
return list(key.value.keys())

@command((Key(Hash),))
def hlen(self, key):
def hlen(self, key: CommandItem) -> int:
return len(key.value)

@command((Key(Hash), bytes), (bytes,))
def hmget(self, key, *fields):
def hmget(self, key: CommandItem, *fields: bytes) -> List[bytes]:
return [key.value.get(field) for field in fields]

@command((Key(Hash), bytes, bytes), (bytes, bytes))
def hmset(self, key, *args):
def hmset(self, key: CommandItem, *args: bytes) -> SimpleString:
self.hset(key, *args)
return OK

@command(
(
Key(Hash),
Int,
),
(bytes, bytes),
)
def hscan(self, key, cursor, *args):
@command((Key(Hash), Int), (bytes, bytes))
def hscan(self, key: CommandItem, cursor: int, *args: bytes) -> List[Any]:
cursor, keys = self._scan(key.value, cursor, *args)
items = []
for k in keys:
Expand All @@ -89,33 +95,25 @@ def hscan(self, key, cursor, *args):
return [cursor, items]

@command((Key(Hash), bytes, bytes), (bytes, bytes))
def hset(self, key, *args):
h = key.value
keys_count = len(h.keys())
h.update(
dict(zip(*[iter(args)] * 2))
) # https://stackoverflow.com/a/12739974/1056460
created = len(h.keys()) - keys_count

key.updated()
return created
def hset(self, key: CommandItem, *args: bytes) -> int:
return self._hset(key, *args)

@command((Key(Hash), bytes, bytes))
def hsetnx(self, key, field, value):
def hsetnx(self, key: CommandItem, field: bytes, value: bytes) -> int:
if field in key.value:
return 0
return self.hset(key, field, value)
return self._hset(key, field, value)

@command((Key(Hash), bytes))
def hstrlen(self, key, field):
def hstrlen(self, key: CommandItem, field: bytes) -> int:
return len(key.value.get(field, b""))

@command((Key(Hash),))
def hvals(self, key):
def hvals(self, key: CommandItem) -> List[bytes]:
return list(key.value.values())

@command(name="HRANDFIELD", fixed=(Key(Hash),), repeat=(bytes,))
def hrandfield(self, key, *args):
def hrandfield(self, key: CommandItem, *args: bytes) -> Optional[List[bytes]]:
if len(args) > 2:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
if key.value is None or len(key.value) == 0:
Expand All @@ -135,6 +133,3 @@ def hrandfield(self, key, *args):
else:
res = [t[0] for t in res]
return res

def _scan(self, keys, cursor, *args):
raise NotImplementedError # Implemented in BaseFakeSocket
45 changes: 23 additions & 22 deletions fakeredis/commands_mixins/pubsub_mixin.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from typing import Tuple, Any, Dict, Callable
from typing import Tuple, Any, Dict, Callable, List, Iterable

from fakeredis import _msgs as msgs
from fakeredis._commands import command
from fakeredis._helpers import NoResponse, compile_pattern, SimpleError


class PubSubCommandsMixin:
put_response: Callable[[Any], None]

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(PubSubCommandsMixin, self).__init__(*args, **kwargs)
self._pubsub = 0 # Count of subscriptions
self._server: Any
self.version: Tuple[int]
self.put_response: Callable

def _subscribe(self, channels, subscribers, mtype):
def _subscribe(self, channels: Iterable[bytes], subscribers: Dict[bytes, Any], mtype: bytes) -> NoResponse:
for channel in channels:
subs = subscribers[channel]
if self not in subs:
Expand All @@ -24,7 +24,8 @@ def _subscribe(self, channels, subscribers, mtype):
self.put_response(msg)
return NoResponse()

def _unsubscribe(self, channels, subscribers, mtype):
def _unsubscribe(self, channels: Iterable[bytes], subscribers: Dict[bytes, Any],
mtype: bytes) -> NoResponse:
if not channels:
channels = []
for channel, subs in subscribers.items():
Expand All @@ -41,36 +42,36 @@ def _unsubscribe(self, channels, subscribers, mtype):
self.put_response(msg)
return NoResponse()

def _numsub(self, subscribers: Dict[bytes, Any], *channels):
def _numsub(self, subscribers: Dict[bytes, Any], *channels: bytes) -> List[Any]:
tuples_list = [(ch, len(subscribers.get(ch, []))) for ch in channels]
return [item for sublist in tuples_list for item in sublist]

@command((bytes,), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
def psubscribe(self, *patterns):
def psubscribe(self, *patterns: bytes) -> NoResponse:
return self._subscribe(patterns, self._server.psubscribers, b"psubscribe")

@command((bytes,), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
def subscribe(self, *channels):
def subscribe(self, *channels: bytes) -> NoResponse:
return self._subscribe(channels, self._server.subscribers, b"subscribe")

@command((bytes,), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
def ssubscribe(self, *channels):
def ssubscribe(self, *channels: bytes) -> NoResponse:
return self._subscribe(channels, self._server.ssubscribers, b"ssubscribe")

@command((), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
def punsubscribe(self, *patterns):
def punsubscribe(self, *patterns: bytes) -> NoResponse:
return self._unsubscribe(patterns, self._server.psubscribers, b"punsubscribe")

@command((), (bytes,), flags=msgs.FLAG_NO_SCRIPT)
def unsubscribe(self, *channels):
def unsubscribe(self, *channels: bytes) -> NoResponse:
return self._unsubscribe(channels, self._server.subscribers, b"unsubscribe")

@command(fixed=(), repeat=(bytes,), flags=msgs.FLAG_NO_SCRIPT)
def sunsubscribe(self, *channels):
def sunsubscribe(self, *channels: bytes) -> NoResponse:
return self._unsubscribe(channels, self._server.ssubscribers, b"sunsubscribe")

@command((bytes, bytes))
def publish(self, channel, message):
def publish(self, channel: bytes, message: bytes) -> int:
receivers = 0
msg = [b"message", channel, message]
subs = self._server.subscribers.get(channel, set())
Expand All @@ -87,7 +88,7 @@ def publish(self, channel, message):
return receivers

@command((bytes, bytes))
def spublish(self, channel, message):
def spublish(self, channel: bytes, message: bytes) -> int:
receivers = 0
msg = [b"smessage", channel, message]
subs = self._server.ssubscribers.get(channel, set())
Expand All @@ -104,38 +105,38 @@ def spublish(self, channel, message):
return receivers

@command(name="PUBSUB NUMPAT", fixed=(), repeat=())
def pubsub_numpat(self, *_):
def pubsub_numpat(self, *_: Any) -> int:
return len(self._server.psubscribers)

def _channels(self, subscribers_dict: Dict[bytes, Any], *patterns):
def _channels(self, subscribers_dict: Dict[bytes, Any], *patterns: bytes) -> List[bytes]:
channels = list(subscribers_dict.keys())
if len(patterns) > 0:
regex = compile_pattern(patterns[0])
channels = [ch for ch in channels if regex.match(ch)]
return channels

@command(name="PUBSUB CHANNELS", fixed=(), repeat=(bytes,))
def pubsub_channels(self, *args):
def pubsub_channels(self, *args: bytes) -> List[bytes]:
return self._channels(self._server.subscribers, *args)

@command(name="PUBSUB SHARDCHANNELS", fixed=(), repeat=(bytes,))
def pubsub_shardchannels(self, *args):
def pubsub_shardchannels(self, *args: bytes) -> List[bytes]:
return self._channels(self._server.ssubscribers, *args)

@command(name="PUBSUB NUMSUB", fixed=(), repeat=(bytes,))
def pubsub_numsub(self, *args):
def pubsub_numsub(self, *args: bytes) -> List[Any]:
return self._numsub(self._server.subscribers, *args)

@command(name="PUBSUB SHARDNUMSUB", fixed=(), repeat=(bytes,))
def pubsub_shardnumsub(self, *args):
def pubsub_shardnumsub(self, *args: bytes) -> List[Any]:
return self._numsub(self._server.ssubscribers, *args)

@command(name="PUBSUB", fixed=())
def pubsub(self, *args):
def pubsub(self, *args: Any) -> None:
raise SimpleError(msgs.WRONG_ARGS_MSG6.format("pubsub"))

@command(name="PUBSUB HELP", fixed=())
def pubsub_help(self, *args):
def pubsub_help(self, *args: Any) -> List[bytes]:
if self.version >= (7,):
help_strings = [
"PUBSUB <subcommand> [<arg> [value] [opt] ...]. Subcommands are:",
Expand Down
Loading

0 comments on commit 07b2c54

Please sign in to comment.