From 5c471cd60028cd716927fbd0b74d70175591e83f Mon Sep 17 00:00:00 2001 From: Gregory Oschwald Date: Wed, 29 Jan 2025 11:49:16 -0800 Subject: [PATCH] Add missing types --- maxminddb/reader.py | 6 +- tests/decoder_test.py | 34 +++++------ tests/reader_test.py | 132 +++++++++++++++++++++++------------------- 3 files changed, 92 insertions(+), 80 deletions(-) diff --git a/maxminddb/reader.py b/maxminddb/reader.py index 7feef2c..403139c 100644 --- a/maxminddb/reader.py +++ b/maxminddb/reader.py @@ -16,7 +16,7 @@ import struct from ipaddress import IPv4Address, IPv6Address from os import PathLike -from typing import IO, Any, AnyStr, Dict, List, Optional, Tuple, Union +from typing import IO, Any, AnyStr, Dict, Iterator, List, Optional, Tuple, Union from maxminddb.const import MODE_AUTO, MODE_FD, MODE_FILE, MODE_MEMORY, MODE_MMAP from maxminddb.decoder import Decoder @@ -177,10 +177,10 @@ def get_with_prefix_len( return self._resolve_data_pointer(pointer), prefix_len return None, prefix_len - def __iter__(self): + def __iter__(self) -> Iterator: return self._generate_children(0, 0, 0) - def _generate_children(self, node, depth, ip_acc): + def _generate_children(self, node, depth, ip_acc) -> Iterator: if ip_acc != 0 and node == self._ipv4_start: # Skip nodes aliased to IPv4 return diff --git a/tests/decoder_test.py b/tests/decoder_test.py index de3c999..e965643 100644 --- a/tests/decoder_test.py +++ b/tests/decoder_test.py @@ -7,7 +7,7 @@ class TestDecoder(unittest.TestCase): - def test_arrays(self): + def test_arrays(self) -> None: arrays = { b"\x00\x04": [], b"\x01\x04\x43\x46\x6f\x6f": ["Foo"], @@ -15,14 +15,14 @@ def test_arrays(self): } self.validate_type_decoding("arrays", arrays) - def test_boolean(self): + def test_boolean(self) -> None: booleans = { b"\x00\x07": False, b"\x01\x07": True, } self.validate_type_decoding("booleans", booleans) - def test_double(self): + def test_double(self) -> None: doubles = { b"\x68\x00\x00\x00\x00\x00\x00\x00\x00": 0.0, b"\x68\x3f\xe0\x00\x00\x00\x00\x00\x00": 0.5, @@ -35,7 +35,7 @@ def test_double(self): } self.validate_type_decoding("double", doubles) - def test_float(self): + def test_float(self) -> None: floats = { b"\x04\x08\x00\x00\x00\x00": 0.0, b"\x04\x08\x3f\x80\x00\x00": 1.0, @@ -49,7 +49,7 @@ def test_float(self): } self.validate_type_decoding("float", floats) - def test_int32(self): + def test_int32(self) -> None: int32 = { b"\x00\x01": 0, b"\x04\x01\xff\xff\xff\xff": -1, @@ -66,7 +66,7 @@ def test_int32(self): } self.validate_type_decoding("int32", int32) - def test_map(self): + def test_map(self) -> None: maps = { b"\xe0": {}, b"\xe1\x42\x65\x6e\x43\x46\x6f\x6f": {"en": "Foo"}, @@ -85,7 +85,7 @@ def test_map(self): } self.validate_type_decoding("maps", maps) - def test_pointer(self): + def test_pointer(self) -> None: pointers = { b"\x20\x00": 0, b"\x20\x05": 5, @@ -133,17 +133,17 @@ def test_pointer(self): b"\x5f\x00\x10\x53" + 70000 * b"\x78": "x" * 70000, } - def test_string(self): + def test_string(self) -> None: self.validate_type_decoding("string", self.strings) - def test_byte(self): + def test_byte(self) -> None: b = { bytes([0xC0 ^ k[0]]) + k[1:]: v.encode("utf-8") for k, v in self.strings.items() } self.validate_type_decoding("byte", b) - def test_uint16(self): + def test_uint16(self) -> None: uint16 = { b"\xa0": 0, b"\xa1\xff": 255, @@ -153,7 +153,7 @@ def test_uint16(self): } self.validate_type_decoding("uint16", uint16) - def test_uint32(self): + def test_uint32(self) -> None: uint32 = { b"\xc0": 0, b"\xc1\xff": 255, @@ -165,7 +165,7 @@ def test_uint32(self): } self.validate_type_decoding("uint32", uint32) - def generate_large_uint(self, bits): + def generate_large_uint(self, bits) -> dict: ctrl_byte = b"\x02" if bits == 64 else b"\x03" uints = { b"\x00" + ctrl_byte: 0, @@ -178,17 +178,17 @@ def generate_large_uint(self, bits): uints[input] = expected return uints - def test_uint64(self): + def test_uint64(self) -> None: self.validate_type_decoding("uint64", self.generate_large_uint(64)) - def test_uint128(self): + def test_uint128(self) -> None: self.validate_type_decoding("uint128", self.generate_large_uint(128)) - def validate_type_decoding(self, type, tests): + def validate_type_decoding(self, type, tests) -> None: for input, expected in tests.items(): self.check_decoding(type, input, expected) - def check_decoding(self, type, input, expected, name=None): + def check_decoding(self, type, input, expected, name=None) -> None: name = name or expected db = mmap.mmap(-1, len(input)) db.write(input) @@ -204,7 +204,7 @@ def check_decoding(self, type, input, expected, name=None): else: self.assertEqual(expected, actual, type) - def test_real_pointers(self): + def test_real_pointers(self) -> None: with open("tests/data/test-data/maps-with-pointers.raw", "r+b") as db_file: mm = mmap.mmap(db_file.fileno(), 0) decoder = Decoder(mm, 0) diff --git a/tests/reader_test.py b/tests/reader_test.py index 31bf585..ac06b8f 100644 --- a/tests/reader_test.py +++ b/tests/reader_test.py @@ -6,7 +6,7 @@ import pathlib import threading import unittest -from typing import Type, Union +from typing import Type, Union, cast from unittest import mock import maxminddb @@ -25,9 +25,10 @@ MODE_MMAP, MODE_MMAP_EXT, ) +from maxminddb.reader import Reader -def get_reader_from_file_descriptor(filepath, mode): +def get_reader_from_file_descriptor(filepath, mode) -> Reader: """Patches open_database() for class TestFDReader().""" if mode == MODE_FD: with open(filepath, "rb") as mmdb_fh: @@ -39,7 +40,8 @@ def get_reader_from_file_descriptor(filepath, mode): return maxminddb.open_database(filepath, mode) -class BaseTestReader: +class BaseTestReader(unittest.TestCase): + mode: int readerClass: Union[ Type["maxminddb.extension.Reader"], Type["maxminddb.reader.Reader"], @@ -51,12 +53,12 @@ class BaseTestReader: if os.name != "nt": mp = multiprocessing.get_context("fork") - def ipf(self, ip): + def ipf(self, ip) -> Union[ipaddress.IPv4Address, ipaddress.IPv6Address]: if self.use_ip_objects: return ipaddress.ip_address(ip) return ip - def test_reader(self): + def test_reader(self) -> None: for record_size in [24, 28, 32]: for ip_version in [4, 6]: file_name = ( @@ -76,7 +78,7 @@ def test_reader(self): self._check_ip_v6(reader, file_name) reader.close() - def test_get_with_prefix_len(self): + def test_get_with_prefix_len(self) -> None: decoder_record = { "array": [1, 2, 3], "boolean": True, @@ -174,10 +176,10 @@ def test_get_with_prefix_len(self): for test in tests: with open_database( - "tests/data/test-data/" + test["file_name"], + "tests/data/test-data/" + cast(str, test["file_name"]), self.mode, ) as reader: - (record, prefix_len) = reader.get_with_prefix_len(test["ip"]) + (record, prefix_len) = reader.get_with_prefix_len(cast(str, test["ip"])) self.assertEqual( prefix_len, @@ -188,10 +190,13 @@ def test_get_with_prefix_len(self): self.assertEqual( record, test["expected_record"], - "expected_record for " + test["ip"] + " in " + test["file_name"], + "expected_record for " + + cast(str, test["ip"]) + + " in " + + cast(str, test["file_name"]), ) - def test_iterator(self): + def test_iterator(self) -> None: tests = ( { "database": "ipv4", @@ -239,12 +244,12 @@ def test_iterator(self): networks = [str(n) for (n, _) in reader] self.assertEqual(networks, test["expected"], f) - def test_decoder(self): + def test_decoder(self) -> None: reader = open_database( "tests/data/test-data/MaxMind-DB-test-decoder.mmdb", self.mode, ) - record = reader.get(self.ipf("::1.1.1.0")) + record = cast(dict, reader.get(self.ipf("::1.1.1.0"))) self.assertEqual(record["array"], [1, 2, 3]) self.assertEqual(record["boolean"], True) @@ -267,7 +272,7 @@ def test_decoder(self): self.assertEqual(1329227995784915872903807060280344576, record["uint128"]) reader.close() - def test_metadata_pointers(self): + def test_metadata_pointers(self) -> None: with open_database( "tests/data/test-data/MaxMind-DB-test-metadata-pointers.mmdb", self.mode, @@ -277,7 +282,7 @@ def test_metadata_pointers(self): reader.metadata().database_type, ) - def test_no_ipv4_search_tree(self): + def test_no_ipv4_search_tree(self) -> None: reader = open_database( "tests/data/test-data/MaxMind-DB-no-ipv4-search-tree.mmdb", self.mode, @@ -287,7 +292,7 @@ def test_no_ipv4_search_tree(self): self.assertEqual(reader.get(self.ipf("192.1.1.1")), "::0/64") reader.close() - def test_ipv6_address_in_ipv4_database(self): + def test_ipv6_address_in_ipv4_database(self) -> None: reader = open_database( "tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb", self.mode, @@ -301,16 +306,16 @@ def test_ipv6_address_in_ipv4_database(self): reader.get(self.ipf("2001::")) reader.close() - def test_opening_path(self): + def test_opening_path(self) -> None: with open_database( pathlib.Path("tests/data/test-data/MaxMind-DB-test-decoder.mmdb"), self.mode, ) as reader: self.assertEqual(reader.metadata().database_type, "MaxMind DB Decoder Test") - def test_no_extension_exception(self): + def test_no_extension_exception(self) -> None: real_extension = maxminddb._extension - maxminddb._extension = None + maxminddb._extension = None # type: ignore with self.assertRaisesRegex( ValueError, "MODE_MMAP_EXT requires the maxminddb.extension module to be available", @@ -321,7 +326,7 @@ def test_no_extension_exception(self): ) maxminddb._extension = real_extension - def test_broken_database(self): + def test_broken_database(self) -> None: reader = open_database( "tests/data/test-data/GeoIP2-City-Test-Broken-Double-Format.mmdb", self.mode, @@ -335,7 +340,7 @@ def test_broken_database(self): reader.get(self.ipf("2001:220::")) reader.close() - def test_ip_validation(self): + def test_ip_validation(self) -> None: reader = open_database( "tests/data/test-data/MaxMind-DB-test-decoder.mmdb", self.mode, @@ -347,11 +352,11 @@ def test_ip_validation(self): reader.get("not_ip") reader.close() - def test_missing_database(self): + def test_missing_database(self) -> None: with self.assertRaisesRegex(FileNotFoundError, "No such file or directory"): open_database("file-does-not-exist.mmdb", self.mode) - def test_nondatabase(self): + def test_nondatabase(self) -> None: with self.assertRaisesRegex( InvalidDatabaseError, r"Error opening database file \(README.rst\). " @@ -360,7 +365,7 @@ def test_nondatabase(self): open_database("README.rst", self.mode) # This is from https://github.com/maxmind/MaxMind-DB-Reader-python/issues/58 - def test_database_with_invalid_utf8_key(self): + def test_database_with_invalid_utf8_key(self) -> None: reader = open_database( "tests/data/bad-data/maxminddb-python/bad-unicode-in-map-key.mmdb", self.mode, @@ -368,15 +373,15 @@ def test_database_with_invalid_utf8_key(self): with self.assertRaises(UnicodeDecodeError): reader.get_with_prefix_len("163.254.149.39") - def test_too_many_constructor_args(self): + def test_too_many_constructor_args(self) -> None: with self.assertRaises(TypeError): - self.readerClass("README.md", self.mode, 1) + self.readerClass("README.md", self.mode, 1) # type: ignore - def test_bad_constructor_mode(self): + def test_bad_constructor_mode(self) -> None: with self.assertRaisesRegex(ValueError, r"Unsupported open mode \(100\)"): - self.readerClass("README.md", mode=100) + self.readerClass("README.md", mode=100) # type: ignore - def test_no_constructor_args(self): + def test_no_constructor_args(self) -> None: with self.assertRaisesRegex( TypeError, r" 1 required positional argument|" @@ -384,45 +389,45 @@ def test_no_constructor_args(self): r"takes at least 2 arguments|" r"function missing required argument \'database\' \(pos 1\)", ): - self.readerClass() + self.readerClass() # type: ignore - def test_too_many_get_args(self): + def test_too_many_get_args(self) -> None: reader = open_database( "tests/data/test-data/MaxMind-DB-test-decoder.mmdb", self.mode, ) with self.assertRaises(TypeError): - reader.get(self.ipf("1.1.1.1"), "blah") + reader.get(self.ipf("1.1.1.1"), "blah") # type: ignore reader.close() - def test_no_get_args(self): + def test_no_get_args(self) -> None: reader = open_database( "tests/data/test-data/MaxMind-DB-test-decoder.mmdb", self.mode, ) with self.assertRaises(TypeError): - reader.get() + reader.get() # type: ignore reader.close() - def test_incorrect_get_arg_type(self): + def test_incorrect_get_arg_type(self) -> None: reader = open_database("tests/data/test-data/GeoIP2-City-Test.mmdb", self.mode) with self.assertRaisesRegex( TypeError, "argument 1 must be a string or ipaddress object", ): - reader.get(1) + reader.get(1) # type: ignore reader.close() - def test_metadata_args(self): + def test_metadata_args(self) -> None: reader = open_database( "tests/data/test-data/MaxMind-DB-test-decoder.mmdb", self.mode, ) with self.assertRaises(TypeError): - reader.metadata("blah") + reader.metadata("blah") # type: ignore reader.close() - def test_metadata_unknown_attribute(self): + def test_metadata_unknown_attribute(self) -> None: reader = open_database( "tests/data/test-data/MaxMind-DB-test-decoder.mmdb", self.mode, @@ -432,25 +437,27 @@ def test_metadata_unknown_attribute(self): AttributeError, "'Metadata' object has no attribute 'blah'", ): - metadata.blah + metadata.blah # type: ignore reader.close() - def test_close(self): + def test_close(self) -> None: reader = open_database( "tests/data/test-data/MaxMind-DB-test-decoder.mmdb", self.mode, ) reader.close() - def test_double_close(self): + def test_double_close(self) -> None: reader = open_database( "tests/data/test-data/MaxMind-DB-test-decoder.mmdb", self.mode, ) reader.close() - self.assertIsNone(reader.close(), "Double close does not throw an exception") + self.assertIsNone( + reader.close(), "Double close does not throw an exception" + ) # type: ignore - def test_closed_get(self): + def test_closed_get(self) -> None: if self.mode in [MODE_MEMORY, MODE_FD]: return reader = open_database( @@ -464,13 +471,13 @@ def test_closed_get(self): ): reader.get(self.ipf("1.1.1.1")) - def test_with_statement(self): + def test_with_statement(self) -> None: filename = "tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb" with open_database(filename, self.mode) as reader: self._check_ip_v4(reader, filename) self.assertEqual(reader.closed, True) - def test_with_statement_close(self): + def test_with_statement_close(self) -> None: filename = "tests/data/test-data/MaxMind-DB-test-ipv4-24.mmdb" reader = open_database(filename, self.mode) reader.close() @@ -481,7 +488,7 @@ def test_with_statement_close(self): ), reader: pass - def test_closed(self): + def test_closed(self) -> None: reader = open_database( "tests/data/test-data/MaxMind-DB-test-decoder.mmdb", self.mode, @@ -494,7 +501,7 @@ def test_closed(self): # extension and the pure Python reader. If we do, the pure Python # reader will need to throw an exception or the extension will need # to keep the metadata in memory. - def test_closed_metadata(self): + def test_closed_metadata(self) -> None: reader = open_database( "tests/data/test-data/MaxMind-DB-test-decoder.mmdb", self.mode, @@ -633,15 +640,17 @@ def _check_ip_v6(self, reader, file_name) -> None: self.assertIsNone(reader.get(self.ipf(ip))) -def has_maxminddb_extension(): - return maxminddb.extension and hasattr(maxminddb.extension, "Reader") +def has_maxminddb_extension() -> bool: + return maxminddb.extension and hasattr( + maxminddb.extension, "Reader" + ) # type: ignore @unittest.skipIf( not has_maxminddb_extension() and not os.environ.get("MM_FORCE_EXT_TESTS"), "No C extension module found. Skipping tests", ) -class TestExtensionReader(BaseTestReader, unittest.TestCase): +class TestExtensionReader(BaseTestReader): mode = MODE_MMAP_EXT if has_maxminddb_extension(): @@ -652,7 +661,7 @@ class TestExtensionReader(BaseTestReader, unittest.TestCase): not has_maxminddb_extension() and not os.environ.get("MM_FORCE_EXT_TESTS"), "No C extension module found. Skipping tests", ) -class TestExtensionReaderWithIPObjects(BaseTestReader, unittest.TestCase): +class TestExtensionReaderWithIPObjects(BaseTestReader): mode = MODE_MMAP_EXT use_ip_objects = True @@ -660,7 +669,7 @@ class TestExtensionReaderWithIPObjects(BaseTestReader, unittest.TestCase): readerClass = maxminddb.extension.Reader -class TestAutoReader(BaseTestReader, unittest.TestCase): +class TestAutoReader(BaseTestReader): mode = MODE_AUTO readerClass: Union[ @@ -673,31 +682,31 @@ class TestAutoReader(BaseTestReader, unittest.TestCase): readerClass = maxminddb.reader.Reader -class TestMMAPReader(BaseTestReader, unittest.TestCase): +class TestMMAPReader(BaseTestReader): mode = MODE_MMAP readerClass = maxminddb.reader.Reader # We want one pure Python test to use IP objects, it doesn't # really matter which one. -class TestMMAPReaderWithIPObjects(BaseTestReader, unittest.TestCase): +class TestMMAPReaderWithIPObjects(BaseTestReader): mode = MODE_MMAP use_ip_objects = True readerClass = maxminddb.reader.Reader -class TestFileReader(BaseTestReader, unittest.TestCase): +class TestFileReader(BaseTestReader): mode = MODE_FILE readerClass = maxminddb.reader.Reader -class TestMemoryReader(BaseTestReader, unittest.TestCase): +class TestMemoryReader(BaseTestReader): mode = MODE_MEMORY readerClass = maxminddb.reader.Reader -class TestFDReader(BaseTestReader, unittest.TestCase): - def setUp(self): +class TestFDReader(BaseTestReader): + def setUp(self) -> None: self.open_database_patcher = mock.patch(__name__ + ".open_database") self.addCleanup(self.open_database_patcher.stop) self.open_database = self.open_database_patcher.start() @@ -708,9 +717,12 @@ def setUp(self): class TestOldReader(unittest.TestCase): - def test_old_reader(self): + def test_old_reader(self) -> None: reader = maxminddb.Reader("tests/data/test-data/MaxMind-DB-test-decoder.mmdb") - record = reader.get("::1.1.1.0") + record = cast(dict, reader.get("::1.1.1.0")) self.assertEqual(record["array"], [1, 2, 3]) reader.close() + + +del BaseTestReader