Skip to content

Commit

Permalink
refactored run_scan method; changed test to use regex for key ids; fa…
Browse files Browse the repository at this point in the history
…ll back to TCP if a UDP query raises a truncated flag exception
  • Loading branch information
fabian-hk committed Apr 11, 2020
1 parent 9ad2895 commit c4d9ad1
Show file tree
Hide file tree
Showing 16 changed files with 360 additions and 240 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ found records divided into secure and insecure records
from dnssec_scanner import DNSSECScanner, DNSSECScannerResult

scanner = DNSSECScanner("www.ietf.org")
res = scanner.run_scan() # type: DNSSECScannerResult
res = scanner.run() # type: DNSSECScannerResult
print(res)
```

Expand Down
77 changes: 47 additions & 30 deletions dnssec_scanner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
import datetime

from dnssec_scanner.validation import (
validate_zone,
validate_zone_keys,
validate_rrset,
validate_ds,
)
from dnssec_scanner import nsec
from dnssec_scanner.utils import DNSSECScannerResult, Zone, State
from dnssec_scanner.utils import DNSSECScannerResult, Zone, State, SoaState
from dnssec_scanner import utils
from dnssec_scanner.messages import Message, Msg

Expand All @@ -33,7 +33,7 @@ def __init__(self, domain: str):
self.domain = domain
self.root_zone = self.initialize_root_zone()

def run_scan(self) -> DNSSECScannerResult:
def run(self) -> DNSSECScannerResult:
resolver = dns.resolver.Resolver()
resolver.nameservers = self.RESOLVER_IPS

Expand All @@ -44,15 +44,39 @@ def run_scan(self) -> DNSSECScannerResult:
return result

def scan_zone(
self, zone: Zone, result: DNSSECScannerResult, resolver: dns.resolver.Resolver,
self, zone: Zone, result: DNSSECScannerResult, resolver: dns.resolver.Resolver,
) -> DNSSECScannerResult:
log.info(f"Entering {zone.name} zone")

self.get_dnskey(zone)

validate_zone_keys(zone, result)

success = self.search_soa(zone, result, resolver)

if success == SoaState.FOUND:
return result
elif success == SoaState.FOUND_CNAME:
return self.scan_zone(self.root_zone, result, resolver)

next_zone = self.get_ns(zone, resolver)

success = self.get_ds(zone, next_zone, result)

if success:
validate_ds(zone, result)

return self.scan_zone(next_zone, result, resolver)

def get_dnskey(self, zone: Zone):
response = utils.dns_query(zone.name, zone.ip, dns.rdatatype.DNSKEY)

zone.DNSKEY = utils.get_rr_by_type(response.answer, dns.rdatatype.DNSKEY)
zone.DNSKEY_RRSIG = utils.get_rr_by_type(response.answer, dns.rdatatype.RRSIG)

def search_soa(
self, zone: Zone, result: DNSSECScannerResult, resolver: dns.resolver.Resolver
) -> SoaState:
response = utils.dns_query(self.domain, zone.ip, dns.rdatatype.SOA)

rrsets = response.answer + response.authority
Expand All @@ -63,40 +87,42 @@ def scan_zone(
# Domain name does not exist. Validate with NSEC the integrity of the none-existence.
result.note = "Domain name does not exist"
zone.RR = rrsets
validate_zone(zone, result)
nsec.proof_none_existence(zone, result, False)
return result
return SoaState.FOUND
elif utils.get_rr_by_type(rrsets, dns.rdatatype.SOA):
# We are in the zone for the domain name.
validate_zone(zone, result)

# We are in the zone for the domain name
rr_types = self.find_records(zone)
zone.RR = self.get_records(zone, rr_types)

validate_rrset(zone, result, True)
return result
return SoaState.FOUND
elif utils.get_rrs_by_type(rrsets, dns.rdatatype.CNAME):
# We have found a CNAME RR set so we have to start from the top again
validate_zone(zone, result)

zone.RR = rrsets
validate_rrset(zone, result) # validate CNAME entry

self.domain = str(
utils.get_rr_by_type(rrsets, dns.rdatatype.CNAME).items[0].target
)
result.domain = self.domain
return self.scan_zone(self.root_zone, result, resolver)
return SoaState.FOUND_CNAME

return SoaState.NOT_FOUND

def get_ns(self, zone: Zone, resolver: dns.resolver.Resolver) -> Zone:
response = utils.dns_query(self.domain, zone.ip, dns.rdatatype.NS)

ns = utils.get_rr_by_type(response.authority, dns.rdatatype.NS)
next_zone_name = str(ns.name)
response = utils.get_rr_by_type(response.authority, dns.rdatatype.NS)
next_zone_name = str(response.name)
zone.child_name = next_zone_name

validate_zone(zone, result)
next_zone_domain = response.items[0].to_text()
next_zone_ip = resolver.query(next_zone_domain, "A").rrset.items[0].address

return Zone(next_zone_name, next_zone_ip, next_zone_domain, zone)

response = utils.dns_query(next_zone_name, zone.ip, dns.rdatatype.DS)
def get_ds(self, zone: Zone, next_zone: Zone, result: DNSSECScannerResult) -> bool:
response = utils.dns_query(next_zone.name, zone.ip, dns.rdatatype.DS)

zone.RR = response.answer
if not utils.get_rr_by_type(zone.RR, dns.rdatatype.DS):
Expand All @@ -106,15 +132,9 @@ def scan_zone(
msg.set_not_found(Msg.NOT_FOUND)
result.errors.append(str(msg))
result.change_state(False)
else:
validate_ds(zone, result)

next_zone_domain = ns.items[0].to_text()
next_zone_ip = resolver.query(next_zone_domain, "A").rrset.items[0].address

next_zone = Zone(next_zone_name, next_zone_ip, next_zone_domain, zone)
return False

return self.scan_zone(next_zone, result, resolver)
return True

def find_records(self, zone: Zone) -> Set[int]:
# define a default list of records in case ANY does not return anything
Expand All @@ -133,8 +153,7 @@ def find_records(self, zone: Zone) -> Set[int]:
]

# ask with ANY for all existing records
request = dns.message.make_query(self.domain, dns.rdatatype.ANY, payload=16384)
response = dns.query.tcp(request, zone.ip)
response = utils.dns_query(self.domain, zone.ip, dns.rdatatype.ANY)

for rr in response.answer:
if rr.rdtype != dns.rdatatype.DNSKEY and rr.rdtype != dns.rdatatype.RRSIG:
Expand All @@ -144,9 +163,7 @@ def find_records(self, zone: Zone) -> Set[int]:
rr_types = set(rr_types)
return rr_types

def get_records(
self, zone: Zone, rrs: Set[int]
) -> List[dns.rrset.RRset]:
def get_records(self, zone: Zone, rrs: Set[int]) -> List[dns.rrset.RRset]:
output = []
for rr in rrs:
response = utils.dns_query(self.domain, zone.ip, rr)
Expand Down
2 changes: 1 addition & 1 deletion dnssec_scanner/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def main():
raise ValueError("You have to enter a valid domain name.")

scanner = DNSSECScanner(domain)
result = scanner.run_scan()
result = scanner.run()
print(result)


Expand Down
1 change: 1 addition & 0 deletions dnssec_scanner/nsec/nsec3.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def nsec3_proof_of_none_existence(
nsec3param = utils.get_rr_by_type(zone.RR, dns.rdatatype.NSEC3PARAM)

# TODO if there is only one NSEC3 record check if the owner hash is the hash of the QNAME
# TODO check if the Opt-Out flag is set

# search for closest enclosure
status, closest_encloser, next_closer_name = find_closest_encloser(
Expand Down
32 changes: 27 additions & 5 deletions dnssec_scanner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from .messages import Message

logging.basicConfig(level=logging.INFO)
log = logging.getLogger("dnssec_scanner")


Expand Down Expand Up @@ -110,29 +109,52 @@ def __init__(self, name: str, ip: str, domain: str, parent: Optional[Zone]):
self.DNSKEY_RRSIG: Optional[dns.rrset.RRset] = None
self.trusted_DS: List[dns.rrset.RRset] = []
self.untrusted_DS: List[dns.rrset.RRset] = []
self.RR: Optional[dns.rrset.RRset] = None
self.RR: Optional[List[dns.rrset.RRset]] = None
self.child_name: str = ""

def __str__(self):
return f"{self.name} @{self.ip}"


class SoaState(Enum):
FOUND = 0
FOUND_CNAME = 1
NOT_FOUND = 2


def dns_query(
domain: str, ip: str, type: int, tries: Optional[int] = 0
) -> dns.message.Message:
try:
request = dns.message.make_query(domain, type, want_dnssec=True, payload=16384)
request = dns.message.make_query(domain, type, want_dnssec=True, payload=32768)
return dns.query.udp(request, ip, timeout=1)
except dns.exception.Timeout as e:
log.warning("Query timeout")
log.debug("Query timeout")
if tries < 5:
return dns_query(domain, ip, type, tries + 1)
else:
raise e
except dns.message.Truncated as e:
log.debug("Truncated flag was set - trying again with TCP")
return dns_query_tcp(domain, ip, type)


def dns_query_tcp(
domain: str, ip: str, type: int, tries: Optional[int] = 0
) -> dns.message.Message:
try:
request = dns.message.make_query(domain, type, want_dnssec=True, payload=32768)
return dns.query.tcp(request, ip, timeout=4)
except dns.exception.Timeout as e:
log.debug("Query timeout")
if tries < 5:
return dns_query(domain, ip, type, tries + 1)
else:
raise e


def get_rr_by_type(
items: List[dns.rrset.RRset], rdtype: dns.rdatatype
items: List[dns.rrset.RRset], rdtype: dns.rdatatype
) -> Optional[dns.rrset.RRset]:
for item in items:
if item.rdtype == rdtype:
Expand Down
17 changes: 9 additions & 8 deletions dnssec_scanner/validation.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
from __future__ import annotations
from typing import List, Tuple, Optional
import logging

import dns

from . import utils
from .utils import DNSSECScannerResult, Zone, Key
from .messages import Message, Validator, Msg, Types

logging.basicConfig(level=logging.INFO)
log = logging.getLogger("dnssec_scanner")


def validate_zone(zone: Zone, result: DNSSECScannerResult):
def validate_zone_keys(zone: Zone, result: DNSSECScannerResult):
if zone.DNSKEY and zone.DNSKEY_RRSIG:
trusted_ksks, untrusted_ksks = validate_ksks(zone, result)
validate_zsks(zone, trusted_ksks, untrusted_ksks, result)
Expand Down Expand Up @@ -151,7 +147,9 @@ def validate_zsks(
zone.DNSKEY, sig, {dns.name.from_text(zone.name): [ksk]},
)
except dns.dnssec.ValidationFailure as e:
msg.add_warning(validator, key_id, f"{Msg.VALIDATION_FAILURE} ({e})", zone.DNSKEY)
msg.add_warning(
validator, key_id, f"{Msg.VALIDATION_FAILURE} ({e})", zone.DNSKEY
)
else:
msg.set_success(validator, key_id)
msg.validated = success
Expand All @@ -172,7 +170,10 @@ def validate_zsks(
)
except dns.dnssec.ValidationFailure as e:
msg.add_warning(
Validator.ZSK, key_id, f"{Msg.VALIDATION_FAILURE} ({e})", zone.DNSKEY
Validator.ZSK,
key_id,
f"{Msg.VALIDATION_FAILURE} ({e})",
zone.DNSKEY,
)
else:
msg.set_success(Validator.ZSK, key_id)
Expand Down Expand Up @@ -259,7 +260,7 @@ def validate_rrset(
s = result.compute_message(msg)
result.change_state(s)
res &= s
if save and msg:
if save and msg and result.state == utils.State.SECURE:
result.secure_rrsets.append(rr)
note.append(dns.rdatatype.to_text(rr.rdtype))
elif save:
Expand Down
8 changes: 3 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/fabian-hk/dnssec_scanner.git",
packages=setuptools.find_packages(),
packages=setuptools.find_packages(exclude=["tests"]),
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
],
python_requires=">=3.7",
install_requires=[
"dnspython @ git+https://github.com/fabian-hk/dnspython.git@feature/nsec3-hash#egg=dnspython",
"dnspython",
"tabulate",
"dataclasses",
"pycryptodome>=3.4",
Expand All @@ -29,7 +29,5 @@
"python-dateutil",
"ecdsa",
],
entry_points={
"console_scripts": ["dnssec-scanner=dnssec_scanner.cli:main"]
},
entry_points={"console_scripts": ["dnssec-scanner=dnssec_scanner.cli:main"]},
)
38 changes: 21 additions & 17 deletions tests/test_google_com.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,38 @@
import unittest
import logging

import dns.rdatatype

from tests.utils.custom_test_case import CustomTestCase as CTC
from dnssec_scanner import DNSSECScanner, State
from dnssec_scanner.messages import Validator, Msg, Types
from tests.messages_testing import TestMessage
from tests.utils.messages_testing import TestMessage

log = logging.getLogger("dnssec_scanner")
log.setLevel(logging.WARNING)


class GoogleCom(unittest.TestCase):
class GoogleCom(CTC):
"""
Last checked on 09.04.2020
Last checked on 11.04.2020
"""

# fmt: off
LOGS = [
str(TestMessage(".", "", Types.KSK, "20326", Msg.VALIDATED, Validator.DS, "20326")),
str(TestMessage(".", "", dns.rdatatype.DNSKEY, "20326,33853,48903", Msg.VALIDATED, Validator.KSK, "20326")),
str(TestMessage(".", "com.", dns.rdatatype.DS, "30909", Msg.VALIDATED, Validator.ZSK, "48903")),
str(TestMessage("com.", "", Types.KSK, "30909", Msg.VALIDATED, Validator.DS, "30909")),
str(TestMessage("com.", "", dns.rdatatype.DNSKEY, "30909,56311", Msg.VALIDATED, Validator.KSK, "30909")),
str(TestMessage(".", "", Types.KSK, CTC.SINGLE_PATTERN, Msg.VALIDATED, Validator.DS, CTC.SINGLE_PATTERN)),
str(TestMessage(".", "", dns.rdatatype.DNSKEY, CTC.MULTI_PATTERN, Msg.VALIDATED, Validator.KSK,
CTC.SINGLE_PATTERN)),
str(TestMessage(".", "com.", dns.rdatatype.DS, CTC.MULTI_PATTERN, Msg.VALIDATED, Validator.ZSK,
CTC.SINGLE_PATTERN)),
str(TestMessage("com.", "", Types.KSK, CTC.SINGLE_PATTERN, Msg.VALIDATED, Validator.DS, CTC.SINGLE_PATTERN)),
str(TestMessage("com.", "", dns.rdatatype.DNSKEY, CTC.MULTI_PATTERN, Msg.VALIDATED, Validator.KSK,
CTC.SINGLE_PATTERN)),
str(TestMessage("com.", "CK0POJMG874LJREF7EFN8430QVIT8BSM.com.", dns.rdatatype.NSEC3, "", Msg.VALIDATED,
Validator.ZSK, "56311")),
str(TestMessage("com.", "com.", dns.rdatatype.SOA, "", Msg.VALIDATED, Validator.ZSK, "56311")),
Validator.ZSK, CTC.SINGLE_PATTERN)),
str(TestMessage("com.", "com.", dns.rdatatype.SOA, "", Msg.VALIDATED, Validator.ZSK, CTC.SINGLE_PATTERN)),
str(TestMessage("com.", "S84BDVKNH5AGDSI7F5J0O3NPRHU0G7JQ.com.", dns.rdatatype.NSEC3, "", Msg.VALIDATED,
Validator.ZSK, "56311")),
str(TestMessage("com.", "com.", dns.rdatatype.NSEC3PARAM, "", Msg.VALIDATED, Validator.ZSK, "56311")),
Validator.ZSK, CTC.SINGLE_PATTERN)),
str(TestMessage("com.", "com.", dns.rdatatype.NSEC3PARAM, "", Msg.VALIDATED, Validator.ZSK,
CTC.SINGLE_PATTERN)),
"com. zone: Found closest encloser com.",
"com. zone: Found NSEC3 that covers the next closer name google.com.",
"com. zone: Successfully proved that google.com. does not support DNSSEC",
Expand All @@ -53,10 +57,10 @@ class GoogleCom(unittest.TestCase):

def test_dnssec(self):
scanner = DNSSECScanner("google.com")
result = scanner.run_scan()
result = scanner.run()

self.assertCountEqual(self.LOGS, result.logs)
self.assertCountEqual(self.WARNIGNS, result.warnings)
self.assertCountEqual(self.ERRORS, result.errors)
self.assert_list(self.LOGS, result.logs)
self.assert_list(self.WARNIGNS, result.warnings)
self.assert_list(self.ERRORS, result.errors)

self.assertEqual(State.INSECURE, result.state)
Loading

0 comments on commit c4d9ad1

Please sign in to comment.