Skip to content

Commit

Permalink
Add missing types
Browse files Browse the repository at this point in the history
  • Loading branch information
oschwald committed Jan 29, 2025
1 parent 147f44e commit 5c471cd
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 80 deletions.
6 changes: 3 additions & 3 deletions maxminddb/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import struct
from ipaddress import IPv4Address, IPv6Address
from os import PathLike
from typing import IO, Any, AnyStr, Dict, List, Optional, Tuple, Union
from typing import IO, Any, AnyStr, Dict, Iterator, List, Optional, Tuple, Union

from maxminddb.const import MODE_AUTO, MODE_FD, MODE_FILE, MODE_MEMORY, MODE_MMAP
from maxminddb.decoder import Decoder
Expand Down Expand Up @@ -177,10 +177,10 @@ def get_with_prefix_len(
return self._resolve_data_pointer(pointer), prefix_len
return None, prefix_len

def __iter__(self):
def __iter__(self) -> Iterator:
return self._generate_children(0, 0, 0)

def _generate_children(self, node, depth, ip_acc):
def _generate_children(self, node, depth, ip_acc) -> Iterator:
if ip_acc != 0 and node == self._ipv4_start:
# Skip nodes aliased to IPv4
return
Expand Down
34 changes: 17 additions & 17 deletions tests/decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@


class TestDecoder(unittest.TestCase):
def test_arrays(self):
def test_arrays(self) -> None:
arrays = {
b"\x00\x04": [],
b"\x01\x04\x43\x46\x6f\x6f": ["Foo"],
b"\x02\x04\x43\x46\x6f\x6f\x43\xe4\xba\xba": ["Foo", "人"],
}
self.validate_type_decoding("arrays", arrays)

def test_boolean(self):
def test_boolean(self) -> None:
booleans = {
b"\x00\x07": False,
b"\x01\x07": True,
}
self.validate_type_decoding("booleans", booleans)

def test_double(self):
def test_double(self) -> None:
doubles = {
b"\x68\x00\x00\x00\x00\x00\x00\x00\x00": 0.0,
b"\x68\x3f\xe0\x00\x00\x00\x00\x00\x00": 0.5,
Expand All @@ -35,7 +35,7 @@ def test_double(self):
}
self.validate_type_decoding("double", doubles)

def test_float(self):
def test_float(self) -> None:
floats = {
b"\x04\x08\x00\x00\x00\x00": 0.0,
b"\x04\x08\x3f\x80\x00\x00": 1.0,
Expand All @@ -49,7 +49,7 @@ def test_float(self):
}
self.validate_type_decoding("float", floats)

def test_int32(self):
def test_int32(self) -> None:
int32 = {
b"\x00\x01": 0,
b"\x04\x01\xff\xff\xff\xff": -1,
Expand All @@ -66,7 +66,7 @@ def test_int32(self):
}
self.validate_type_decoding("int32", int32)

def test_map(self):
def test_map(self) -> None:
maps = {
b"\xe0": {},
b"\xe1\x42\x65\x6e\x43\x46\x6f\x6f": {"en": "Foo"},
Expand All @@ -85,7 +85,7 @@ def test_map(self):
}
self.validate_type_decoding("maps", maps)

def test_pointer(self):
def test_pointer(self) -> None:
pointers = {
b"\x20\x00": 0,
b"\x20\x05": 5,
Expand Down Expand Up @@ -133,17 +133,17 @@ def test_pointer(self):
b"\x5f\x00\x10\x53" + 70000 * b"\x78": "x" * 70000,
}

def test_string(self):
def test_string(self) -> None:
self.validate_type_decoding("string", self.strings)

def test_byte(self):
def test_byte(self) -> None:
b = {
bytes([0xC0 ^ k[0]]) + k[1:]: v.encode("utf-8")
for k, v in self.strings.items()
}
self.validate_type_decoding("byte", b)

def test_uint16(self):
def test_uint16(self) -> None:
uint16 = {
b"\xa0": 0,
b"\xa1\xff": 255,
Expand All @@ -153,7 +153,7 @@ def test_uint16(self):
}
self.validate_type_decoding("uint16", uint16)

def test_uint32(self):
def test_uint32(self) -> None:
uint32 = {
b"\xc0": 0,
b"\xc1\xff": 255,
Expand All @@ -165,7 +165,7 @@ def test_uint32(self):
}
self.validate_type_decoding("uint32", uint32)

def generate_large_uint(self, bits):
def generate_large_uint(self, bits) -> dict:
ctrl_byte = b"\x02" if bits == 64 else b"\x03"
uints = {
b"\x00" + ctrl_byte: 0,
Expand All @@ -178,17 +178,17 @@ def generate_large_uint(self, bits):
uints[input] = expected
return uints

def test_uint64(self):
def test_uint64(self) -> None:
self.validate_type_decoding("uint64", self.generate_large_uint(64))

def test_uint128(self):
def test_uint128(self) -> None:
self.validate_type_decoding("uint128", self.generate_large_uint(128))

def validate_type_decoding(self, type, tests):
def validate_type_decoding(self, type, tests) -> None:
for input, expected in tests.items():
self.check_decoding(type, input, expected)

def check_decoding(self, type, input, expected, name=None):
def check_decoding(self, type, input, expected, name=None) -> None:
name = name or expected
db = mmap.mmap(-1, len(input))
db.write(input)
Expand All @@ -204,7 +204,7 @@ def check_decoding(self, type, input, expected, name=None):
else:
self.assertEqual(expected, actual, type)

def test_real_pointers(self):
def test_real_pointers(self) -> None:
with open("tests/data/test-data/maps-with-pointers.raw", "r+b") as db_file:
mm = mmap.mmap(db_file.fileno(), 0)
decoder = Decoder(mm, 0)
Expand Down
Loading

0 comments on commit 5c471cd

Please sign in to comment.