Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add Type Annotations #32

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[mypy]
warn_unused_configs=True
warn_unused_ignores=True
warn_return_any=True
warn_unreachable=True
warn_redundant_casts=True

check_untyped_defs=True
disallow_untyped_calls=True
disallow_untyped_defs=True
disallow_incomplete_defs=True
86 changes: 59 additions & 27 deletions pcapng/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@

import io
import itertools
from typing import Any, List, Tuple, Type

from pcapng import strictness as strictness
from pcapng.constants import link_types
from pcapng.flags import FlagField
from pcapng.structs import (
IntField,
ListField,
Expand Down Expand Up @@ -42,7 +44,7 @@ class Block(object):
# These are in addition to the above two properties
"magic_number",
"_decoded",
]
] # type: List[str]

def __init__(self, **kwargs):
if "raw" in kwargs:
Expand All @@ -66,13 +68,15 @@ def __init__(self, **kwargs):
self._decoded[key] = default

def __eq__(self, other):
# type: (Any) -> bool
if self.__class__ != other.__class__:
return False
keys = [x[0] for x in self.schema]
# Use `getattr()` so eg. @property calls are used
return [getattr(self, k) for k in keys] == [getattr(other, k) for k in keys]

def _write(self, outstream):
# type: (io.BytesIO) -> None
"""Writes this block into the given output stream"""
encoded_block = io.BytesIO()
self._encode(encoded_block)
Expand All @@ -89,6 +93,7 @@ def _write(self, outstream):
write_int(block_length, outstream, 32)

def _encode(self, outstream):
# type: (io.BytesIO) -> None
"""Encodes the fields of this block into raw data"""
for name, field, default in self.schema:
field.encode(
Expand Down Expand Up @@ -125,6 +130,7 @@ def __setattr__(self, name, value):
self._decoded[name] = value

def __repr__(self):
# type: () -> str
args = []
for item in self.schema:
name = item[0]
Expand All @@ -140,17 +146,19 @@ def __repr__(self):
class SectionMemberBlock(Block):
"""Block which must be a member of a section"""

__slots__ = ["section"]
__slots__ = ["section"] # type: List[str]

def __init__(self, section, **kwargs):
# type: (SectionMemberBlock, str) -> None
super(SectionMemberBlock, self).__init__(**kwargs)
self.section = section


def register_block(block):
# type: (Any) -> Block
"""Handy decorator to register a new known block type"""
KNOWN_BLOCKS[block.magic_number] = block
return block
return block # type: ignore


@register_block
Expand All @@ -170,7 +178,8 @@ class SectionHeader(Block):
"_interfaces_id",
"interfaces",
"interface_stats",
]
] # type: List[str]

schema = [
("version_major", IntField(16, False), 1),
("version_minor", IntField(16, False), 0),
Expand All @@ -196,6 +205,7 @@ def __init__(self, endianness="<", **kwargs):
super(SectionHeader, self).__init__(endianness=endianness, **kwargs)

def _encode(self, outstream):
# type: (io.BytesIO) -> None
write_int(0x1A2B3C4D, outstream, 32, endianness=self.endianness)
super(SectionHeader, self)._encode(outstream)

Expand All @@ -213,31 +223,37 @@ def new_member(self, cls, **kwargs):
return blk

def register_interface(self, interface):
# type: (Block) -> None
"""Helper method to register an interface within this section"""
assert isinstance(interface, InterfaceDescription)
interface_id = next(self._interfaces_id)
interface.interface_id = interface_id
self.interfaces[interface_id] = interface

def add_interface_stats(self, interface_stats):
# type: (Block) -> None
"""Helper method to register interface stats within this section"""
assert isinstance(interface_stats, InterfaceStatistics)
self.interface_stats[interface_stats.interface_id] = interface_stats

@property
def version(self):
# type: () -> Tuple[IntField, IntField]
return (self.version_major, self.version_minor)

@property
def length(self):
# type: () -> IntField
return self.section_length

# Block.decode() assumes all blocks have sections -- technically true...
@property
def section(self):
# type: () -> SectionHeader
return self

def __repr__(self):
# type: () -> str
return (
"<{name} version={version} endianness={endianness} "
"length={length} options={options}>"
Expand All @@ -260,7 +276,7 @@ class InterfaceDescription(SectionMemberBlock):
"""

magic_number = 0x00000001
__slots__ = ["interface_id"]
__slots__ = ["interface_id"] # type: List[str]
schema = [
("link_type", IntField(16, False), 0), # todo: enc/decode
("reserved", IntField(16, False), 0),
Expand Down Expand Up @@ -291,18 +307,19 @@ class InterfaceDescription(SectionMemberBlock):

@property # todo: cache this property
def timestamp_resolution(self):
# ------------------------------------------------------------
# Resolution of timestamps. If the Most Significant Bit is
# equal to zero, the remaining bits indicates the resolution
# of the timestamp as as a negative power of 10 (e.g. 6 means
# microsecond resolution, timestamps are the number of
# microseconds since 1/1/1970). If the Most Significant Bit is
# equal to one, the remaining bits indicates the resolution as
# as negative power of 2 (e.g. 10 means 1/1024 of second). If
# this option is not present, a resolution of 10^-6 is assumed
# (i.e. timestamps have the same resolution of the standard
# 'libpcap' timestamps).
# ------------------------------------------------------------
# type: () -> float
"""
Resolution of timestamps. If the Most Significant Bit is
equal to zero, the remaining bits indicates the resolution
of the timestamp as as a negative power of 10 (e.g. 6 means
microsecond resolution, timestamps are the number of
microseconds since 1/1/1970). If the Most Significant Bit is
equal to one, the remaining bits indicates the resolution as
as negative power of 2 (e.g. 10 means 1/1024 of second). If
this option is not present, a resolution of 10^-6 is assumed
(i.e. timestamps have the same resolution of the standard
'libpcap' timestamps).
"""

if "if_tsresol" in self.options:
return unpack_timestamp_resolution(self.options["if_tsresol"])
Expand All @@ -311,11 +328,13 @@ def timestamp_resolution(self):

@property
def statistics(self):
# type: () -> object
# todo: ensure we always have an interface id -> how??
return self.section.interface_stats.get(self.interface_id)

@property
def link_type_description(self):
# type: () -> str
try:
return link_types.LINKTYPE_DESCRIPTIONS[self.link_type]
except KeyError:
Expand All @@ -328,17 +347,19 @@ class BlockWithTimestampMixin(object):
of blocks that provide one.
"""

__slots__ = []
__slots__ = [] # type: List[str]

@property
def timestamp(self):
# type: () -> float
# First, get the accuracy from the ts_resol option
return (
(self.timestamp_high << 32) + self.timestamp_low
) * self.timestamp_resolution

@property
def timestamp_resolution(self):
# type: () -> float
return self.interface.timestamp_resolution

# todo: add some property returning a datetime() with timezone..
Expand All @@ -350,15 +371,17 @@ class BlockWithInterfaceMixin(object):
This includes all packet blocks as well as InterfaceStatistics.
"""

__slots__ = []
__slots__ = [] # type: List[str]

@property
def interface(self):
# type: () -> FlagField
# We need to get the correct interface from the section
# by looking up the interface_id
return self.section.interfaces[self.interface_id]

def _encode(self, outstream):
# type: (io.BytesIO) -> None
if len(self.section.interfaces) < 1:
strictness.problem(
"writing {cls} for section with no interfaces".format(
Expand All @@ -384,11 +407,12 @@ class BasePacketBlock(SectionMemberBlock, BlockWithInterfaceMixin):
the current length of the packet data.
"""

__slots__ = []
__slots__ = [] # type: List[str]
readonly_fields = set(("captured_len",))

@property
def captured_len(self):
# type: () -> int
return len(self.packet_data)

# Helper function. If the user hasn't explicitly set an original packet
Expand All @@ -397,6 +421,7 @@ def captured_len(self):
# the captured data length.
@property
def packet_len(self):
# type: () -> int
plen = self.__getattr__("packet_len") or 0 # this call prevents recursion
return plen or len(self.packet_data)

Expand All @@ -411,7 +436,7 @@ class EnhancedPacket(BasePacketBlock, BlockWithTimestampMixin):
"""

magic_number = 0x00000006
__slots__ = []
__slots__ = [] # type: List[str]
schema = [
("interface_id", IntField(32, False), 0),
("timestamp_high", IntField(32, False), 0),
Expand Down Expand Up @@ -443,7 +468,7 @@ class SimplePacket(BasePacketBlock):
"""

magic_number = 0x00000003
__slots__ = []
__slots__ = [] # type: List[str]
schema = [
# packet_len is NOT the captured length
("packet_len", IntField(32, False), 0),
Expand All @@ -469,6 +494,7 @@ def __init__(self, section, **kwargs):

@property
def interface_id(self):
# type: () -> int
"""
"The Simple Packet Block does not contain the Interface ID field.
Therefore, it MUST be assumed that all the Simple Packet Blocks have
Expand All @@ -479,6 +505,7 @@ def interface_id(self):

@property
def captured_len(self):
# type: () -> int
"""
"...the SnapLen value MUST be used to determine the size of the Packet
Data field length."
Expand All @@ -490,6 +517,7 @@ def captured_len(self):
return min(snap_len, self.packet_len)

def _encode(self, outstream):
# type: (io.BytesIO) -> None
fld_size = IntField(32, False)
fld_data = RawBytes(0)
if len(self.section.interfaces) > 1:
Expand Down Expand Up @@ -532,7 +560,7 @@ class ObsoletePacket(BasePacketBlock, BlockWithTimestampMixin):
"""

magic_number = 0x00000002
__slots__ = []
__slots__ = [] # type: List[str]
schema = [
("interface_id", IntField(16, False), 0),
("drops_count", IntField(16, False), 0),
Expand All @@ -555,6 +583,7 @@ class ObsoletePacket(BasePacketBlock, BlockWithTimestampMixin):
]

def enhanced(self):
# type: () -> SectionMemberBlock
"""Return an EnhancedPacket with this block's attributes."""
opts_dict = dict(self.options)
opts_dict["epb_dropcount"] = self.drops_count
Expand All @@ -576,6 +605,7 @@ def enhanced(self):
# Do this check in _write() instead of _encode() to ensure the block gets written
# with the correct magic number.
def _write(self, outstream):
# type: (io.BytesIO) -> None
strictness.problem("Packet Block is obsolete and must not be used")
if strictness.should_fix():
self.enhanced()._write(outstream)
Expand All @@ -597,7 +627,7 @@ class NameResolution(SectionMemberBlock):
"""

magic_number = 0x00000004
__slots__ = []
__slots__ = [] # type: List[str]
schema = [
("records", ListField(NameResolutionRecordField()), []),
(
Expand All @@ -611,7 +641,7 @@ class NameResolution(SectionMemberBlock):
),
None,
),
]
] # type: List[Tuple[str, FlagField, object]]


@register_block
Expand All @@ -627,7 +657,7 @@ class InterfaceStatistics(
"""

magic_number = 0x00000005
__slots__ = []
__slots__ = [] # type: List[str]
schema = [
("interface_id", IntField(32, False), 0),
("timestamp_high", IntField(32, False), 0),
Expand Down Expand Up @@ -658,11 +688,13 @@ class UnknownBlock(Block):
processing.
"""

__slots__ = ["block_type", "data"]
__slots__ = ["block_type", "data"] # type: List[str]

def __init__(self, block_type, data):
# type: (int, bytes) -> None
self.block_type = block_type
self.data = data

def __repr__(self):
# type: () -> str
return "UnknownBlock(0x{0:08X}, {1!r})".format(self.block_type, self.data)
Loading