From f163a9f7f033ef5a1a11bb11f605488b7b8a61f8 Mon Sep 17 00:00:00 2001 From: Sean Whalen Date: Sat, 26 Oct 2024 16:58:55 -0400 Subject: [PATCH] Reformat code using black --- .github/workflows/python-tests.yaml | 3 +- checkdmarc/__init__.py | 188 +++-- checkdmarc/_cli.py | 146 ++-- checkdmarc/bimi.py | 271 +++---- checkdmarc/dmarc.py | 1008 +++++++++++++++------------ checkdmarc/dnssec.py | 80 ++- checkdmarc/mta_sts.py | 201 +++--- checkdmarc/smtp.py | 156 +++-- checkdmarc/smtp_tls_reporting.py | 182 ++--- checkdmarc/spf.py | 381 +++++----- checkdmarc/utils.py | 131 ++-- docs/source/conf.py | 29 +- requirements.txt | 1 + tests.py | 238 ++++--- 14 files changed, 1719 insertions(+), 1296 deletions(-) diff --git a/.github/workflows/python-tests.yaml b/.github/workflows/python-tests.yaml index 831b0b1..c78d4de 100644 --- a/.github/workflows/python-tests.yaml +++ b/.github/workflows/python-tests.yaml @@ -30,8 +30,7 @@ jobs: make html - name: Check code style run: | - flake8 checkdmarc - flake8 tests.py + black --check . - name: Run unit tests run: | coverage run tests.py diff --git a/checkdmarc/__init__.py b/checkdmarc/__init__.py index 7526b4a..791cd7b 100644 --- a/checkdmarc/__init__.py +++ b/checkdmarc/__init__.py @@ -41,16 +41,19 @@ __version__ = checkdmarc._constants.__version__ -def check_domains(domains: list[str], parked: bool = False, - approved_nameservers: list[str] = None, - approved_mx_hostnames: bool = None, - skip_tls: bool = False, - bimi_selector: str = None, - include_tag_descriptions: bool = False, - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0, - wait: float = 0.0) -> Union[OrderedDict, list[OrderedDict]]: +def check_domains( + domains: list[str], + parked: bool = False, + approved_nameservers: list[str] = None, + approved_mx_hostnames: bool = None, + skip_tls: bool = False, + bimi_selector: str = None, + include_tag_descriptions: bool = False, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, + wait: float = 0.0, +) -> Union[OrderedDict, list[OrderedDict]]: """ Check the given domains for SPF and DMARC records, parse them, and return them @@ -82,9 +85,11 @@ def check_domains(domains: list[str], parked: bool = False, - ``dmarc`` - A ``valid`` flag, plus the output of :func:`checkdmarc.dmarc.parse_dmarc_record` or an ``error`` """ - domains = sorted(list(set( - map(lambda d: d.rstrip(".\r\n").strip().lower().split(",")[0], - domains)))) + domains = sorted( + list( + set(map(lambda d: d.rstrip(".\r\n").strip().lower().split(",")[0], domains)) + ) + ) not_domains = [] for domain in domains: if "." not in domain: @@ -99,27 +104,31 @@ def check_domains(domains: list[str], parked: bool = False, logging.debug(f"Checking: {domain}") domain_results = OrderedDict( - [("domain", domain), ("base_domain", get_base_domain(domain)), - ("dnssec", None), ("ns", []), ("mx", [])]) + [ + ("domain", domain), + ("base_domain", get_base_domain(domain)), + ("dnssec", None), + ("ns", []), + ("mx", []), + ] + ) domain_results["dnssec"] = test_dnssec( - domain, - nameservers=nameservers, - timeout=timeout - ) + domain, nameservers=nameservers, timeout=timeout + ) domain_results["ns"] = check_ns( domain, approved_nameservers=approved_nameservers, nameservers=nameservers, - resolver=resolver, timeout=timeout - ) + resolver=resolver, + timeout=timeout, + ) mta_sts_mx_patterns = None - domain_results["mta_sts"] = check_mta_sts(domain, - nameservers=nameservers, - resolver=resolver, - timeout=timeout) + domain_results["mta_sts"] = check_mta_sts( + domain, nameservers=nameservers, resolver=resolver, timeout=timeout + ) if domain_results["mta_sts"]["valid"]: mta_sts_mx_patterns = domain_results["mta_sts"]["policy"]["mx"] domain_results["mx"] = check_mx( @@ -129,16 +138,16 @@ def check_domains(domains: list[str], parked: bool = False, skip_tls=skip_tls, nameservers=nameservers, resolver=resolver, - timeout=timeout - ) + timeout=timeout, + ) domain_results["spf"] = check_spf( domain, parked=parked, nameservers=nameservers, resolver=resolver, - timeout=timeout - ) + timeout=timeout, + ) domain_results["dmarc"] = check_dmarc( domain, @@ -146,14 +155,11 @@ def check_domains(domains: list[str], parked: bool = False, include_dmarc_tag_descriptions=include_tag_descriptions, nameservers=nameservers, resolver=resolver, - timeout=timeout - ) + timeout=timeout, + ) domain_results["smtp_tls_reporting"] = check_smtp_tls_reporting( - domain, - nameservers=nameservers, - resolver=resolver, - timeout=timeout + domain, nameservers=nameservers, resolver=resolver, timeout=timeout ) if bimi_selector is not None: @@ -163,7 +169,8 @@ def check_domains(domains: list[str], parked: bool = False, include_tag_descriptions=include_tag_descriptions, nameservers=nameservers, resolver=resolver, - timeout=timeout) + timeout=timeout, + ) results.append(domain_results) if wait > 0.0: @@ -175,11 +182,13 @@ def check_domains(domains: list[str], parked: bool = False, return results -def check_ns(domain: str, - approved_nameservers: list[str] = None, - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> OrderedDict: +def check_ns( + domain: str, + approved_nameservers: list[str] = None, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> OrderedDict: """ Returns a dictionary of nameservers and warnings or a dictionary with an empty list and an error. @@ -207,11 +216,12 @@ def check_ns(domain: str, ns_results = get_nameservers( domain, approved_nameservers=approved_nameservers, - nameservers=nameservers, resolver=resolver, - timeout=timeout) + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) except DNSException as error: - ns_results = OrderedDict([("hostnames", []), - ("error", error.__str__())]) + ns_results = OrderedDict([("hostnames", []), ("error", error.__str__())]) return ns_results @@ -278,8 +288,9 @@ def results_to_csv_rows(results: Union[dict, list[dict]]) -> list[dict]: row["bimi_l"] = _bimi["tags"]["l"]["value"] if "a" in _bimi["tags"]: row["bimi_a"] = _bimi["tags"]["a"]["value"] - row["mx"] = "|".join(list( - map(lambda r: f"{r['preference']}, {r['hostname']}", mx["hosts"]))) + row["mx"] = "|".join( + list(map(lambda r: f"{r['preference']}, {r['hostname']}", mx["hosts"])) + ) tls = None try: tls_results = list(map(lambda r: f"{r['starttls']}", mx["hosts"])) @@ -296,8 +307,7 @@ def results_to_csv_rows(results: Union[dict, list[dict]]) -> list[dict]: starttls = None try: - starttls_results = list( - map(lambda r: f"{r['starttls']}", mx["hosts"])) + starttls_results = list(map(lambda r: f"{r['starttls']}", mx["hosts"])) for starttls_result in starttls_results: starttls = starttls_result if starttls_result is False: @@ -335,15 +345,15 @@ def results_to_csv_rows(results: Union[dict, list[dict]]) -> list[dict]: row["dmarc_sp"] = _dmarc["tags"]["sp"]["value"] if "rua" in _dmarc["tags"]: addresses = _dmarc["tags"]["rua"]["value"] - addresses = list(map(lambda u: "{}:{}".format( - u["scheme"], - u["address"]), addresses)) + addresses = list( + map(lambda u: "{}:{}".format(u["scheme"], u["address"]), addresses) + ) row["dmarc_rua"] = "|".join(addresses) if "ruf" in _dmarc["tags"]: addresses = _dmarc["tags"]["ruf"]["value"] - addresses = list(map(lambda u: "{}:{}".format( - u["scheme"], - u["address"]), addresses)) + addresses = list( + map(lambda u: "{}:{}".format(u["scheme"], u["address"]), addresses) + ) row["dmarc_ruf"] = "|".join(addresses) row["dmarc_warnings"] = "|".join(_dmarc["warnings"]) if "error" in _smtp_tls_reporting: @@ -351,11 +361,10 @@ def results_to_csv_rows(results: Union[dict, list[dict]]) -> list[dict]: row["smtp_tls_reporting_error"] = _smtp_tls_reporting["error"] else: row["smtp_tls_reporting_valid"] = True - row["smtp_tls_reporting_rua"] = "|".join(_smtp_tls_reporting[ - "tags"]["rua"][ - "value"]) - row["smtp_tls_reporting_warnings"] = _smtp_tls_reporting[ - "warnings"] + row["smtp_tls_reporting_rua"] = "|".join( + _smtp_tls_reporting["tags"]["rua"]["value"] + ) + row["smtp_tls_reporting_warnings"] = _smtp_tls_reporting["warnings"] rows.append(row) return rows @@ -370,18 +379,48 @@ def results_to_csv(results: dict) -> str: Returns: str: A CSV of results """ - fields = ["domain", "base_domain", "dnssec", "spf_valid", "dmarc_valid", - "dmarc_adkim", "dmarc_aspf", - "dmarc_fo", "dmarc_p", "dmarc_pct", "dmarc_rf", "dmarc_ri", - "dmarc_rua", "dmarc_ruf", "dmarc_sp", - "tls", "starttls", "spf_record", "dmarc_record", - "dmarc_record_location", "mx", "mx_error", "mx_warnings", - "mta_sts_id", "mta_sts_mode", "mta_sts_max_age", - "smtp_tls_reporting_valid", "smtp_tls_reporting_rua", - "mta_sts_mx", "mta_sts_error", "mta_sts_warnings", "spf_error", - "spf_warnings", "dmarc_error", "dmarc_warnings", - "ns", "ns_error", "ns_warnings", - "smtp_tls_reporting_error", "smtp_tls_reporting_warnings"] + fields = [ + "domain", + "base_domain", + "dnssec", + "spf_valid", + "dmarc_valid", + "dmarc_adkim", + "dmarc_aspf", + "dmarc_fo", + "dmarc_p", + "dmarc_pct", + "dmarc_rf", + "dmarc_ri", + "dmarc_rua", + "dmarc_ruf", + "dmarc_sp", + "tls", + "starttls", + "spf_record", + "dmarc_record", + "dmarc_record_location", + "mx", + "mx_error", + "mx_warnings", + "mta_sts_id", + "mta_sts_mode", + "mta_sts_max_age", + "smtp_tls_reporting_valid", + "smtp_tls_reporting_rua", + "mta_sts_mx", + "mta_sts_error", + "mta_sts_warnings", + "spf_error", + "spf_warnings", + "dmarc_error", + "dmarc_warnings", + "ns", + "ns_error", + "ns_warnings", + "smtp_tls_reporting_error", + "smtp_tls_reporting_warnings", + ] output = StringIO(newline="\n") writer = DictWriter(output, fieldnames=fields) writer.writeheader() @@ -400,6 +439,7 @@ def output_to_file(path: str, content: str): path (str): A file path content (str): JSON or CSV text """ - with open(path, "w", newline="\n", encoding="utf-8", - errors="ignore") as output_file: + with open( + path, "w", newline="\n", encoding="utf-8", errors="ignore" + ) as output_file: output_file.write(content) diff --git a/checkdmarc/_cli.py b/checkdmarc/_cli.py index 492cf88..4389071 100644 --- a/checkdmarc/_cli.py +++ b/checkdmarc/_cli.py @@ -9,8 +9,13 @@ import logging -from checkdmarc import (__version__, check_domains, results_to_json, - results_to_csv, output_to_file) +from checkdmarc import ( + __version__, + check_domains, + results_to_json, + results_to_csv, + output_to_file, +) """Copyright 2019-2023 Sean Whalen @@ -30,45 +35,68 @@ def _main(): """Called when the module in executed""" arg_parser = ArgumentParser(description=__doc__) - arg_parser.add_argument("domain", nargs="+", - help="one or more domains, or a single path to a " - "file containing a list of domains") - arg_parser.add_argument("-p", "--parked", help="indicate that the " - "domains are parked", - action="store_true", default=False) - arg_parser.add_argument("--ns", nargs="+", - help="approved nameserver substrings") - arg_parser.add_argument("--mx", nargs="+", - help="approved MX hostname substrings") - arg_parser.add_argument("-d", "--descriptions", action="store_true", - help="include descriptions of tags in " - "the JSON output") - arg_parser.add_argument("-f", "--format", default="json", - help="specify JSON or CSV screen output format") - arg_parser.add_argument("-o", "--output", nargs="+", - help="one or more file paths to output to " - "(must end in .json or .csv) " - "(silences screen output)") - arg_parser.add_argument("-n", "--nameserver", nargs="+", - help="nameservers to query") - arg_parser.add_argument("-t", "--timeout", - help="number of seconds to wait for an answer " - "from DNS (default 2.0)", - type=float, - default=2.0) - arg_parser.add_argument("-b", "--bimi-selector", - default="default", - help="the BIMI selector to use") - arg_parser.add_argument("-v", "--version", action="version", - version=__version__) - arg_parser.add_argument("-w", "--wait", type=float, - help="number of seconds to wait between " - "checking domains (default 0.0)", - default=0.0), - arg_parser.add_argument("--skip-tls", action="store_true", - help="skip TLS/SSL testing") - arg_parser.add_argument("--debug", action="store_true", - help="enable debugging output") + arg_parser.add_argument( + "domain", + nargs="+", + help="one or more domains, or a single path to a " + "file containing a list of domains", + ) + arg_parser.add_argument( + "-p", + "--parked", + help="indicate that the " "domains are parked", + action="store_true", + default=False, + ) + arg_parser.add_argument("--ns", nargs="+", help="approved nameserver substrings") + arg_parser.add_argument("--mx", nargs="+", help="approved MX hostname substrings") + arg_parser.add_argument( + "-d", + "--descriptions", + action="store_true", + help="include descriptions of tags in " "the JSON output", + ) + arg_parser.add_argument( + "-f", + "--format", + default="json", + help="specify JSON or CSV screen output format", + ) + arg_parser.add_argument( + "-o", + "--output", + nargs="+", + help="one or more file paths to output to " + "(must end in .json or .csv) " + "(silences screen output)", + ) + arg_parser.add_argument( + "-n", "--nameserver", nargs="+", help="nameservers to query" + ) + arg_parser.add_argument( + "-t", + "--timeout", + help="number of seconds to wait for an answer " "from DNS (default 2.0)", + type=float, + default=2.0, + ) + arg_parser.add_argument( + "-b", "--bimi-selector", default="default", help="the BIMI selector to use" + ) + arg_parser.add_argument("-v", "--version", action="version", version=__version__) + arg_parser.add_argument( + "-w", + "--wait", + type=float, + help="number of seconds to wait between " "checking domains (default 0.0)", + default=0.0, + ), + arg_parser.add_argument( + "--skip-tls", action="store_true", help="skip TLS/SSL testing" + ) + arg_parser.add_argument( + "--debug", action="store_true", help="enable debugging output" + ) args = arg_parser.parse_args() @@ -81,9 +109,16 @@ def _main(): domains = args.domain if len(domains) == 1 and os.path.exists(domains[0]): with open(domains[0]) as domains_file: - domains = sorted(list(set( - map(lambda d: d.rstrip(".\r\n").strip().lower().split(",")[0], - domains_file.readlines())))) + domains = sorted( + list( + set( + map( + lambda d: d.rstrip(".\r\n").strip().lower().split(",")[0], + domains_file.readlines(), + ) + ) + ) + ) not_domains = [] for domain in domains: if "." not in domain: @@ -91,14 +126,18 @@ def _main(): for domain in not_domains: domains.remove(domain) - results = check_domains(domains, skip_tls=args.skip_tls, - parked=args.parked, - approved_nameservers=args.ns, - approved_mx_hostnames=args.mx, - include_tag_descriptions=args.descriptions, - nameservers=args.nameserver, timeout=args.timeout, - bimi_selector=args.bimi_selector, - wait=args.wait) + results = check_domains( + domains, + skip_tls=args.skip_tls, + parked=args.parked, + approved_nameservers=args.ns, + approved_mx_hostnames=args.mx, + include_tag_descriptions=args.descriptions, + nameservers=args.nameserver, + timeout=args.timeout, + bimi_selector=args.bimi_selector, + wait=args.wait, + ) if args.output is None: if args.format.lower() == "json": @@ -112,8 +151,7 @@ def _main(): csv_path = path.lower().endswith(".csv") if not json_path and not csv_path: - logging.error( - f"Output path {path} must end in .json or .csv") + logging.error(f"Output path {path} must end in .json or .csv") else: if path.lower().endswith(".json"): output_to_file(path, results_to_json(results)) diff --git a/checkdmarc/bimi.py b/checkdmarc/bimi.py index 9128380..d5f9df4 100644 --- a/checkdmarc/bimi.py +++ b/checkdmarc/bimi.py @@ -11,15 +11,10 @@ import dns import requests import xmltodict -from pyleri import (Grammar, - Regex, - Sequence, - List - ) +from pyleri import Grammar, Regex, Sequence, List from checkdmarc._constants import SYNTAX_ERROR_MARKER, USER_AGENT -from checkdmarc.utils import (WSP_REGEX, HTTPS_REGEX, query_dns, - get_base_domain) +from checkdmarc.utils import WSP_REGEX, HTTPS_REGEX, query_dns, get_base_domain """Copyright 2019-2023 Sean Whalen @@ -35,9 +30,9 @@ See the License for the specific language governing permissions and limitations under the License.""" -BIMI_VERSION_REGEX_STRING = fr"v{WSP_REGEX}*={WSP_REGEX}*BIMI1{WSP_REGEX}*;" +BIMI_VERSION_REGEX_STRING = rf"v{WSP_REGEX}*={WSP_REGEX}*BIMI1{WSP_REGEX}*;" BIMI_TAG_VALUE_REGEX_STRING = ( - fr"([a-z]{{1,2}}){WSP_REGEX}*={WSP_REGEX}*(bimi1|{HTTPS_REGEX})?" + rf"([a-z]{{1,2}}){WSP_REGEX}*={WSP_REGEX}*(bimi1|{HTTPS_REGEX})?" ) BIMI_TAG_VALUE_REGEX = re.compile(BIMI_TAG_VALUE_REGEX_STRING, re.IGNORECASE) @@ -48,11 +43,12 @@ class _BIMIWarning(Exception): class BIMIError(Exception): """Raised when a fatal BIMI error occurs""" + def __init__(self, msg: str, data: dict = None): """ - Args: - msg (str): The error message - data (dict): A dictionary of data to include in the results + Args: + msg (str): The error message + data (dict): A dictionary of data to include in the results """ self.data = data Exception.__init__(self, msg) @@ -60,6 +56,7 @@ def __init__(self, msg: str, data: dict = None): class BIMIRecordNotFound(BIMIError): """Raised when a BIMI record could not be found""" + def __init__(self, error): if isinstance(error, dns.exception.Timeout): error.kwargs["timeout"] = round(error.kwargs["timeout"], 1) @@ -87,9 +84,9 @@ class UnrelatedTXTRecordFoundAtBIMI(BIMIError): class SPFRecordFoundWhereBIMIRecordShouldBe(UnrelatedTXTRecordFoundAtBIMI): """Raised when an SPF record is found where a BIMI record should be; - most likely, the ``selector_bimi`` subdomain - record does not actually exist, and the request for ``TXT`` records was - redirected to the base domain""" + most likely, the ``selector_bimi`` subdomain + record does not actually exist, and the request for ``TXT`` records was + redirected to the base domain""" class BIMIRecordInWrongLocation(BIMIError): @@ -102,58 +99,66 @@ class MultipleBIMIRecords(BIMIError): class _BIMIGrammar(Grammar): """Defines Pyleri grammar for BIMI records""" + version_tag = Regex(BIMI_VERSION_REGEX_STRING) tag_value = Regex(BIMI_TAG_VALUE_REGEX_STRING, re.IGNORECASE) START = Sequence( - version_tag, List(tag_value, - delimiter=Regex(f"{WSP_REGEX}*;{WSP_REGEX}*"), - opt=True)) + version_tag, + List(tag_value, delimiter=Regex(f"{WSP_REGEX}*;{WSP_REGEX}*"), opt=True), + ) bimi_tags = OrderedDict( - v=OrderedDict(name="Version", - required=True, - description='Identifies the record ' - 'retrieved as a BIMI ' - 'record. It MUST have the ' - 'value of "BIMI1". The ' - 'value of this tag MUST ' - 'match precisely; if it ' - 'does not or it is absent, ' - 'the entire retrieved ' - 'record MUST be ignored. ' - 'It MUST be the first ' - 'tag in the list.'), - a=OrderedDict(name="Authority Evidence Location", - required=False, - default="", - description='If present, this tag MUST have an empty value ' - 'or its value MUST be a single URI. An empty ' - 'value for the tag is interpreted to mean the ' - 'Domain Owner does not wish to publish or does ' - 'not have authority evidence to disclose. The ' - 'URI, if present, MUST contain a fully ' - 'qualified domain name (FQDN) and MUST specify ' - 'HTTPS as the URI scheme ("https"). The URI ' - 'SHOULD specify the location of a publicly ' - 'retrievable BIMI Evidence Document.' - ), - l=OrderedDict(name="Location", - required=False, - default="", - description='The value of this tag is either empty ' - 'indicating declination to publish, or a single ' - 'URI representing the location of a Brand ' - 'Indicator file. The only supported transport ' - 'is HTTPS.' - ) + v=OrderedDict( + name="Version", + required=True, + description="Identifies the record " + "retrieved as a BIMI " + "record. It MUST have the " + 'value of "BIMI1". The ' + "value of this tag MUST " + "match precisely; if it " + "does not or it is absent, " + "the entire retrieved " + "record MUST be ignored. " + "It MUST be the first " + "tag in the list.", + ), + a=OrderedDict( + name="Authority Evidence Location", + required=False, + default="", + description="If present, this tag MUST have an empty value " + "or its value MUST be a single URI. An empty " + "value for the tag is interpreted to mean the " + "Domain Owner does not wish to publish or does " + "not have authority evidence to disclose. The " + "URI, if present, MUST contain a fully " + "qualified domain name (FQDN) and MUST specify " + 'HTTPS as the URI scheme ("https"). The URI ' + "SHOULD specify the location of a publicly " + "retrievable BIMI Evidence Document.", + ), + l=OrderedDict( + name="Location", + required=False, + default="", + description="The value of this tag is either empty " + "indicating declination to publish, or a single " + "URI representing the location of a Brand " + "Indicator file. The only supported transport " + "is HTTPS.", + ), ) -def _query_bimi_record(domain: str, selector: str = "default", - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0): +def _query_bimi_record( + domain: str, + selector: str = "default", + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +): """ Queries DNS for a BIMI record @@ -176,8 +181,9 @@ def _query_bimi_record(domain: str, selector: str = "default", unrelated_records = [] try: - records = query_dns(target, "TXT", nameservers=nameservers, - resolver=resolver, timeout=timeout) + records = query_dns( + target, "TXT", nameservers=nameservers, resolver=resolver, timeout=timeout + ) for record in records: if record.startswith(txt_prefix): bimi_record_count += 1 @@ -185,32 +191,35 @@ def _query_bimi_record(domain: str, selector: str = "default", unrelated_records.append(record) if bimi_record_count > 1: - raise MultipleBIMIRecords( - "Multiple BMI records are not permitted") + raise MultipleBIMIRecords("Multiple BMI records are not permitted") if len(unrelated_records) > 0: ur_str = "\n\n".join(unrelated_records) raise UnrelatedTXTRecordFoundAtBIMI( "Unrelated TXT records were discovered. These should be " "removed, as some receivers may not expect to find " "unrelated TXT records " - f"at {target}\n\n{ur_str}") + f"at {target}\n\n{ur_str}" + ) bimi_record = records[0] except dns.resolver.NoAnswer: try: - records = query_dns(domain, "TXT", - nameservers=nameservers, resolver=resolver, - timeout=timeout) + records = query_dns( + domain, + "TXT", + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) for record in records: if record.startswith(txt_prefix): raise BIMIRecordInWrongLocation( - "The BIMI record must be located at " - f"{target}, not {domain}") + "The BIMI record must be located at " f"{target}, not {domain}" + ) except dns.resolver.NoAnswer: pass except dns.resolver.NXDOMAIN: - raise BIMIRecordNotFound( - f"The domain {domain} does not exist") + raise BIMIRecordNotFound(f"The domain {domain} does not exist") except Exception as error: BIMIRecordNotFound(error) @@ -222,10 +231,13 @@ def _query_bimi_record(domain: str, selector: str = "default", return bimi_record -def query_bimi_record(domain: str, selector: str = "default", - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> OrderedDict: +def query_bimi_record( + domain: str, + selector: str = "default", + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> OrderedDict: """ Queries DNS for a BIMI record @@ -253,41 +265,46 @@ def query_bimi_record(domain: str, selector: str = "default", warnings = [] base_domain = get_base_domain(domain) location = domain.lower() - record = _query_bimi_record(domain, selector=selector, - nameservers=nameservers, resolver=resolver, - timeout=timeout) + record = _query_bimi_record( + domain, + selector=selector, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) try: - root_records = query_dns(domain, "TXT", - nameservers=nameservers, resolver=resolver, - timeout=timeout) + root_records = query_dns( + domain, "TXT", nameservers=nameservers, resolver=resolver, timeout=timeout + ) for root_record in root_records: if root_record.startswith("v=BIMI1"): - warnings.append(f"BIMI record at root of {domain} " - "has no effect") + warnings.append(f"BIMI record at root of {domain} " "has no effect") except dns.resolver.NXDOMAIN: - raise BIMIRecordNotFound( - f"The domain {domain} does not exist") + raise BIMIRecordNotFound(f"The domain {domain} does not exist") except dns.exception.DNSException: pass if record is None and domain != base_domain and selector != "default": - record = _query_bimi_record(base_domain, - nameservers=nameservers, resolver=resolver, - timeout=timeout) + record = _query_bimi_record( + base_domain, nameservers=nameservers, resolver=resolver, timeout=timeout + ) location = base_domain if record is None: raise BIMIRecordNotFound( f"A BIMI record does not exist at the {selector} selector for " - f"this domain or its base domain") + f"this domain or its base domain" + ) - return OrderedDict([("record", record), ("location", location), - ("warnings", warnings)]) + return OrderedDict( + [("record", record), ("location", location), ("warnings", warnings)] + ) def parse_bimi_record( - record: str, - include_tag_descriptions: bool = False, - syntax_error_marker: str = SYNTAX_ERROR_MARKER) -> OrderedDict: + record: str, + include_tag_descriptions: bool = False, + syntax_error_marker: str = SYNTAX_ERROR_MARKER, +) -> OrderedDict: """ Parses a BIMI record @@ -323,11 +340,13 @@ def parse_bimi_record( logging.debug("Parsing the BIMI record") session = requests.Session() session.headers = {"User-Agent": USER_AGENT} - spf_in_dmarc_error_msg = "Found a SPF record where a BIMI record " \ - "should be; most likely, the _bimi " \ - "subdomain record does not actually exist, " \ - "and the request for TXT records was " \ - "redirected to the base domain" + spf_in_dmarc_error_msg = ( + "Found a SPF record where a BIMI record " + "should be; most likely, the _bimi " + "subdomain record does not actually exist, " + "and the request for TXT records was " + "redirected to the base domain" + ) warnings = [] record = record.strip('"') if record.lower().startswith("v=spf1"): @@ -336,14 +355,20 @@ def parse_bimi_record( parsed_record = bimi_syntax_checker.parse(record) if not parsed_record.is_valid: expecting = list( - map(lambda x: str(x).strip('"'), list(parsed_record.expecting))) - marked_record = (record[:parsed_record.pos] + syntax_error_marker + - record[parsed_record.pos:]) + map(lambda x: str(x).strip('"'), list(parsed_record.expecting)) + ) + marked_record = ( + record[: parsed_record.pos] + + syntax_error_marker + + record[parsed_record.pos :] + ) expecting = " or ".join(expecting) - raise BIMISyntaxError(f"Error: Expected {expecting} at position " - f"{parsed_record.pos} " - f"(marked with {syntax_error_marker}) in: " - f"{marked_record}") + raise BIMISyntaxError( + f"Error: Expected {expecting} at position " + f"{parsed_record.pos} " + f"(marked with {syntax_error_marker}) in: " + f"{marked_record}" + ) pairs = BIMI_TAG_VALUE_REGEX.findall(record) tags = OrderedDict() @@ -363,8 +388,7 @@ def parse_bimi_record( response.raise_for_status() raw_xml = response.text except Exception as e: - warnings.append(f"Unable to download " - f"{tag_value} - {str(e)}") + warnings.append(f"Unable to download " f"{tag_value} - {str(e)}") try: if isinstance(raw_xml, bytes): raw_xml = raw_xml.decode(errors="ignore") @@ -386,7 +410,9 @@ def parse_bimi_record( if base_profile != "tiny-ps": warnings.append(f"The SVG base profile must be tiny-ps") if width != height: - warnings.append("The SVG dimensions must be square, not {width}x{height}") + warnings.append( + "The SVG dimensions must be square, not {width}x{height}" + ) if getsizeof(raw_xml) > 32000: warnings.append("The SVG file exceeds to maximum size of 32 KB") except Exception as e: @@ -396,17 +422,22 @@ def parse_bimi_record( response = session.get(tag_value) response.raise_for_status() except Exception as e: - warnings.append(f"Unable to download Authority Evidence at " - f"{tag_value} - {str(e)}") + warnings.append( + f"Unable to download Authority Evidence at " + f"{tag_value} - {str(e)}" + ) return OrderedDict(tags=tags, warnings=warnings) -def check_bimi(domain: str, selector: str = "default", - include_tag_descriptions: bool = False, - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> OrderedDict: +def check_bimi( + domain: str, + selector: str = "default", + include_tag_descriptions: bool = False, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> OrderedDict: """ Returns a dictionary with a parsed BIMI record or an error. @@ -444,13 +475,15 @@ def check_bimi(domain: str, selector: str = "default", bimi_query = query_bimi_record( domain, selector=selector, - nameservers=nameservers, resolver=resolver, - timeout=timeout) + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) bimi_results["selector"] = selector bimi_results["record"] = bimi_query["record"] parsed_bimi = parse_bimi_record( - bimi_results["record"], - include_tag_descriptions=include_tag_descriptions) + bimi_results["record"], include_tag_descriptions=include_tag_descriptions + ) bimi_results["tags"] = parsed_bimi["tags"] bimi_results["warnings"] = parsed_bimi["warnings"] except BIMIError as error: diff --git a/checkdmarc/dmarc.py b/checkdmarc/dmarc.py index 6788395..ea6761f 100644 --- a/checkdmarc/dmarc.py +++ b/checkdmarc/dmarc.py @@ -9,15 +9,21 @@ from typing import Union import dns -from pyleri import (Grammar, - Regex, - Sequence, - List, - ) +from pyleri import ( + Grammar, + Regex, + Sequence, + List, +) -from checkdmarc.utils import (WSP_REGEX, query_dns, get_base_domain, - MAILTO_REGEX, DNSException) +from checkdmarc.utils import ( + WSP_REGEX, + query_dns, + get_base_domain, + MAILTO_REGEX, + DNSException, +) from checkdmarc.utils import get_mx_records from checkdmarc._constants import SYNTAX_ERROR_MARKER @@ -35,13 +41,12 @@ See the License for the specific language governing permissions and limitations under the License.""" -DMARC_VERSION_REGEX_STRING = fr"v{WSP_REGEX}*={WSP_REGEX}*DMARC1{WSP_REGEX}*;" +DMARC_VERSION_REGEX_STRING = rf"v{WSP_REGEX}*={WSP_REGEX}*DMARC1{WSP_REGEX}*;" DMARC_TAG_VALUE_REGEX_STRING = ( - fr"([a-z]{{1,5}}){WSP_REGEX}*={WSP_REGEX}*([\w.:@/+!,_\- ]+)" + rf"([a-z]{{1,5}}){WSP_REGEX}*={WSP_REGEX}*([\w.:@/+!,_\- ]+)" ) -DMARC_TAG_VALUE_REGEX = re.compile(DMARC_TAG_VALUE_REGEX_STRING, - re.IGNORECASE) +DMARC_TAG_VALUE_REGEX = re.compile(DMARC_TAG_VALUE_REGEX_STRING, re.IGNORECASE) class _DMARCWarning(Exception): @@ -100,9 +105,9 @@ class UnrelatedTXTRecordFoundAtDMARC(DMARCError): class SPFRecordFoundWhereDMARCRecordShouldBe(UnrelatedTXTRecordFoundAtDMARC): """Raised when an SPF record is found where a DMARC record should be; - most likely, the ``_dmarc`` subdomain - record does not actually exist, and the request for ``TXT`` records was - redirected to the base domain""" + most likely, the ``_dmarc`` subdomain + record does not actually exist, and the request for ``TXT`` records was + redirected to the base domain""" class DMARCRecordInWrongLocation(DMARCError): @@ -111,274 +116,290 @@ class DMARCRecordInWrongLocation(DMARCError): class DMARCReportEmailAddressMissingMXRecords(_DMARCWarning): """Raised when an email address in a DMARC report URI is missing MX - records""" + records""" class UnverifiedDMARCURIDestination(_DMARCWarning): """Raised when the destination of a DMARC report URI does not indicate - that it accepts reports for the domain""" + that it accepts reports for the domain""" class MultipleDMARCRecords(DMARCError): """Raised when multiple DMARC records are found, in violation of - RFC 7486, section 6.6.3""" + RFC 7486, section 6.6.3""" class _DMARCGrammar(Grammar): """Defines Pyleri grammar for DMARC records""" + version_tag = Regex(DMARC_VERSION_REGEX_STRING, re.IGNORECASE) tag_value = Regex(DMARC_TAG_VALUE_REGEX_STRING, re.IGNORECASE) START = Sequence( version_tag, - List( - tag_value, - delimiter=Regex(f"{WSP_REGEX}*;{WSP_REGEX}*"), - opt=True)) - - -dmarc_tags = OrderedDict(adkim=OrderedDict(name="DKIM Alignment Mode", - required=False, - default="r", - description='In relaxed mode, ' - 'the Organizational ' - 'Domains of both the ' - 'DKIM-authenticated ' - 'signing domain (taken ' - 'from the value of the ' - '"d=" tag in the ' - 'signature) and that ' - 'of the RFC 5322 ' - 'From domain ' - 'must be equal if the ' - 'identifiers are to be ' - 'considered aligned.'), - aspf=OrderedDict(name="SPF alignment mode", - required=False, - default="r", - description='In relaxed mode, ' - 'the SPF-authenticated ' - 'domain and RFC5322 ' - 'From domain must have ' - 'the same ' - 'Organizational Domain. ' - 'In strict mode, only ' - 'an exact DNS domain ' - 'match is considered to ' - 'produce Identifier ' - 'Alignment.'), - fo=OrderedDict(name="Failure Reporting Options", - required=False, - default="0", - description='Provides requested ' - 'options for generation ' - 'of failure reports. ' - 'Report generators MAY ' - 'choose to adhere to the ' - 'requested options. ' - 'This tag\'s content ' - 'MUST be ignored if ' - 'a "ruf" tag (below) is ' - 'not also specified. ' - 'The value of this tag is ' - 'a colon-separated list ' - 'of characters that ' - 'indicate failure ' - 'reporting options.', - values={ - "0": 'Generate a DMARC failure ' - 'report if all underlying ' - 'authentication mechanisms ' - 'fail to produce an aligned ' - '"pass" result.', - "1": 'Generate a DMARC failure ' - 'report if any underlying ' - 'authentication mechanism ' - 'produced something other ' - 'than an aligned ' - '"pass" result.', - "d": 'Generate a DKIM failure ' - 'report if the message had ' - 'a signature that failed ' - 'evaluation, regardless of ' - 'its alignment. DKIM-' - 'specific reporting is ' - 'described in AFRF-DKIM.', - "s": 'Generate an SPF failure ' - 'report if the message ' - 'failed SPF evaluation, ' - 'regardless of its alignment.' - ' SPF-specific reporting is ' - 'described in AFRF-SPF' - } - ), - p=OrderedDict(name="Requested Mail Receiver Policy", - reqired=True, - description='Specifies the policy to ' - 'be enacted by the ' - 'Receiver at the ' - 'request of the ' - 'Domain Owner. The ' - 'policy applies to ' - 'the domain and to its ' - 'subdomains, unless ' - 'subdomain policy ' - 'is explicitly described ' - 'using the "sp" tag.', - values={ - "none": 'The Domain Owner requests ' - 'no specific action be ' - 'taken regarding delivery ' - 'of messages.', - "quarantine": 'The Domain Owner ' - 'wishes to have ' - 'email that fails ' - 'the DMARC mechanism ' - 'check be treated by ' - 'Mail Receivers as ' - 'suspicious. ' - 'Depending on the ' - 'capabilities of the ' - 'MailReceiver, ' - 'this can mean ' - '"place into spam ' - 'folder", ' - '"scrutinize ' - 'with additional ' - 'intensity", and/or ' - '"flag as ' - 'suspicious".', - "reject": 'The Domain Owner wishes ' - 'for Mail Receivers to ' - 'reject ' - 'email that fails the ' - 'DMARC mechanism check. ' - 'Rejection SHOULD ' - 'occur during the SMTP ' - 'transaction.' - } - ), - pct=OrderedDict(name="Percentage", - required=False, - default=100, - description='Integer percentage of ' - 'messages from the ' - 'Domain Owner\'s ' - 'mail stream to which ' - 'the DMARC policy is to ' - 'be applied. ' - 'However, this ' - 'MUST NOT be applied to ' - 'the DMARC-generated ' - 'reports, all of which ' - 'must be sent and ' - 'received unhindered. ' - 'The purpose of the ' - '"pct" tag is to allow ' - 'Domain Owners to enact ' - 'a slow rollout of ' - 'enforcement of the ' - 'DMARC mechanism.' - ), - rf=OrderedDict(name="Report Format", - required=False, - default="afrf", - description='A list separated by ' - 'colons of one or more ' - 'report formats as ' - 'requested by the ' - 'Domain Owner to be ' - 'used when a message ' - 'fails both SPF and DKIM ' - 'tests to report details ' - 'of the individual ' - 'failure. Only "afrf" ' - '(the auth-failure report ' - 'type) is currently ' - 'supported in the ' - 'DMARC standard.', - values={ - "afrf": ' "Authentication Failure ' - 'Reporting Using the ' - 'Abuse Reporting Format", ' - 'RFC 6591, April 2012,' - '' - } - ), - ri=OrderedDict(name="Report Interval", - required=False, - default=86400, - description='Indicates a request to ' - 'Receivers to generate ' - 'aggregate reports ' - 'separated by no more ' - 'than the requested ' - 'number of seconds. ' - 'DMARC implementations ' - 'MUST be able to provide ' - 'daily reports and ' - 'SHOULD be able to ' - 'provide hourly reports ' - 'when requested. ' - 'However, anything other ' - 'than a daily report is ' - 'understood to ' - 'be accommodated on a ' - 'best-effort basis.' - ), - rua=OrderedDict(name="Aggregate Feedback Addresses", - required=False, - description=' A comma-separated list ' - 'of DMARC URIs to which ' - 'aggregate feedback ' - 'is to be sent.' - ), - ruf=OrderedDict(name="Forensic Feedback Addresses", - required=False, - description=' A comma-separated list ' - 'of DMARC URIs to which ' - 'forensic feedback ' - 'is to be sent.' - ), - sp=OrderedDict(name="Subdomain Policy", - required=False, - description='Indicates the policy to ' - 'be enacted by the ' - 'Receiver at the request ' - 'of the Domain Owner. ' - 'It applies only to ' - 'subdomains of the ' - 'domain queried, and not ' - 'to the domain itself. ' - 'Its syntax is identical ' - 'to that of the "p" tag ' - 'defined above. If ' - 'absent, the policy ' - 'specified by the "p" ' - 'tag MUST be applied ' - 'for subdomains.' - ), - v=OrderedDict(name="Version", - reqired=True, - description='Identifies the record ' - 'retrieved as a DMARC ' - 'record. It MUST have the ' - 'value of "DMARC1". The ' - 'value of this tag MUST ' - 'match precisely; if it ' - 'does not or it is absent, ' - 'the entire retrieved ' - 'record MUST be ignored. ' - 'It MUST be the first ' - 'tag in the list.') - ) - - -def _query_dmarc_record(domain: str, nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0, - ignore_unrelated_records: bool = False - ) -> Union[str, None]: + List(tag_value, delimiter=Regex(f"{WSP_REGEX}*;{WSP_REGEX}*"), opt=True), + ) + + +dmarc_tags = OrderedDict( + adkim=OrderedDict( + name="DKIM Alignment Mode", + required=False, + default="r", + description="In relaxed mode, " + "the Organizational " + "Domains of both the " + "DKIM-authenticated " + "signing domain (taken " + "from the value of the " + '"d=" tag in the ' + "signature) and that " + "of the RFC 5322 " + "From domain " + "must be equal if the " + "identifiers are to be " + "considered aligned.", + ), + aspf=OrderedDict( + name="SPF alignment mode", + required=False, + default="r", + description="In relaxed mode, " + "the SPF-authenticated " + "domain and RFC5322 " + "From domain must have " + "the same " + "Organizational Domain. " + "In strict mode, only " + "an exact DNS domain " + "match is considered to " + "produce Identifier " + "Alignment.", + ), + fo=OrderedDict( + name="Failure Reporting Options", + required=False, + default="0", + description="Provides requested " + "options for generation " + "of failure reports. " + "Report generators MAY " + "choose to adhere to the " + "requested options. " + "This tag's content " + "MUST be ignored if " + 'a "ruf" tag (below) is ' + "not also specified. " + "The value of this tag is " + "a colon-separated list " + "of characters that " + "indicate failure " + "reporting options.", + values={ + "0": "Generate a DMARC failure " + "report if all underlying " + "authentication mechanisms " + "fail to produce an aligned " + '"pass" result.', + "1": "Generate a DMARC failure " + "report if any underlying " + "authentication mechanism " + "produced something other " + "than an aligned " + '"pass" result.', + "d": "Generate a DKIM failure " + "report if the message had " + "a signature that failed " + "evaluation, regardless of " + "its alignment. DKIM-" + "specific reporting is " + "described in AFRF-DKIM.", + "s": "Generate an SPF failure " + "report if the message " + "failed SPF evaluation, " + "regardless of its alignment." + " SPF-specific reporting is " + "described in AFRF-SPF", + }, + ), + p=OrderedDict( + name="Requested Mail Receiver Policy", + reqired=True, + description="Specifies the policy to " + "be enacted by the " + "Receiver at the " + "request of the " + "Domain Owner. The " + "policy applies to " + "the domain and to its " + "subdomains, unless " + "subdomain policy " + "is explicitly described " + 'using the "sp" tag.', + values={ + "none": "The Domain Owner requests " + "no specific action be " + "taken regarding delivery " + "of messages.", + "quarantine": "The Domain Owner " + "wishes to have " + "email that fails " + "the DMARC mechanism " + "check be treated by " + "Mail Receivers as " + "suspicious. " + "Depending on the " + "capabilities of the " + "MailReceiver, " + "this can mean " + '"place into spam ' + 'folder", ' + '"scrutinize ' + "with additional " + 'intensity", and/or ' + '"flag as ' + 'suspicious".', + "reject": "The Domain Owner wishes " + "for Mail Receivers to " + "reject " + "email that fails the " + "DMARC mechanism check. " + "Rejection SHOULD " + "occur during the SMTP " + "transaction.", + }, + ), + pct=OrderedDict( + name="Percentage", + required=False, + default=100, + description="Integer percentage of " + "messages from the " + "Domain Owner's " + "mail stream to which " + "the DMARC policy is to " + "be applied. " + "However, this " + "MUST NOT be applied to " + "the DMARC-generated " + "reports, all of which " + "must be sent and " + "received unhindered. " + "The purpose of the " + '"pct" tag is to allow ' + "Domain Owners to enact " + "a slow rollout of " + "enforcement of the " + "DMARC mechanism.", + ), + rf=OrderedDict( + name="Report Format", + required=False, + default="afrf", + description="A list separated by " + "colons of one or more " + "report formats as " + "requested by the " + "Domain Owner to be " + "used when a message " + "fails both SPF and DKIM " + "tests to report details " + "of the individual " + 'failure. Only "afrf" ' + "(the auth-failure report " + "type) is currently " + "supported in the " + "DMARC standard.", + values={ + "afrf": ' "Authentication Failure ' + "Reporting Using the " + 'Abuse Reporting Format", ' + "RFC 6591, April 2012," + "" + }, + ), + ri=OrderedDict( + name="Report Interval", + required=False, + default=86400, + description="Indicates a request to " + "Receivers to generate " + "aggregate reports " + "separated by no more " + "than the requested " + "number of seconds. " + "DMARC implementations " + "MUST be able to provide " + "daily reports and " + "SHOULD be able to " + "provide hourly reports " + "when requested. " + "However, anything other " + "than a daily report is " + "understood to " + "be accommodated on a " + "best-effort basis.", + ), + rua=OrderedDict( + name="Aggregate Feedback Addresses", + required=False, + description=" A comma-separated list " + "of DMARC URIs to which " + "aggregate feedback " + "is to be sent.", + ), + ruf=OrderedDict( + name="Forensic Feedback Addresses", + required=False, + description=" A comma-separated list " + "of DMARC URIs to which " + "forensic feedback " + "is to be sent.", + ), + sp=OrderedDict( + name="Subdomain Policy", + required=False, + description="Indicates the policy to " + "be enacted by the " + "Receiver at the request " + "of the Domain Owner. " + "It applies only to " + "subdomains of the " + "domain queried, and not " + "to the domain itself. " + "Its syntax is identical " + 'to that of the "p" tag ' + "defined above. If " + "absent, the policy " + 'specified by the "p" ' + "tag MUST be applied " + "for subdomains.", + ), + v=OrderedDict( + name="Version", + reqired=True, + description="Identifies the record " + "retrieved as a DMARC " + "record. It MUST have the " + 'value of "DMARC1". The ' + "value of this tag MUST " + "match precisely; if it " + "does not or it is absent, " + "the entire retrieved " + "record MUST be ignored. " + "It MUST be the first " + "tag in the list.", + ), +) + + +def _query_dmarc_record( + domain: str, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, + ignore_unrelated_records: bool = False, +) -> Union[str, None]: """ Queries DNS for a DMARC record @@ -400,8 +421,9 @@ def _query_dmarc_record(domain: str, nameservers: list[str] = None, unrelated_records = [] try: - records = query_dns(target, "TXT", nameservers=nameservers, - resolver=resolver, timeout=timeout) + records = query_dns( + target, "TXT", nameservers=nameservers, resolver=resolver, timeout=timeout + ) for record in records: if record.startswith(txt_prefix): dmarc_records.append(record) @@ -417,7 +439,8 @@ def _query_dmarc_record(domain: str, nameservers: list[str] = None, if len(dmarc_records) > 1: raise MultipleDMARCRecords( "Multiple DMARC policy records are not permitted - " - "https://tools.ietf.org/html/rfc7489#section-6.6.3") + "https://tools.ietf.org/html/rfc7489#section-6.6.3" + ) if len(unrelated_records) > 0: if not ignore_unrelated_records: ur_str = "\n\n".join(unrelated_records) @@ -425,25 +448,29 @@ def _query_dmarc_record(domain: str, nameservers: list[str] = None, "Unrelated TXT records were discovered. These should be " "removed, as some receivers may not expect to find " f"unrelated TXT records at {target}\n\n{ur_str}", - data={"target": target}) + data={"target": target}, + ) if len(dmarc_records) == 1: dmarc_record = dmarc_records[0] except dns.resolver.NoAnswer: try: - records = query_dns(domain, "TXT", - nameservers=nameservers, resolver=resolver, - timeout=timeout) + records = query_dns( + domain, + "TXT", + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) for record in records: if record.startswith(txt_prefix): raise DMARCRecordInWrongLocation( - "The DMARC record must be located at " - f"{target}, not {domain}") + "The DMARC record must be located at " f"{target}, not {domain}" + ) except dns.resolver.NoAnswer: pass except dns.resolver.NXDOMAIN: - raise DMARCRecordNotFound( - f"The domain {0} does not exist".format(domain)) + raise DMARCRecordNotFound(f"The domain {0} does not exist".format(domain)) except Exception as error: raise DMARCRecordNotFound(error) @@ -461,10 +488,13 @@ def _query_dmarc_record(domain: str, nameservers: list[str] = None, return dmarc_record -def query_dmarc_record(domain: str, nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0, - ignore_unrelated_records: bool = False) -> OrderedDict: +def query_dmarc_record( + domain: str, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, + ignore_unrelated_records: bool = False, +) -> OrderedDict: """ Queries DNS for a DMARC record @@ -496,9 +526,12 @@ def query_dmarc_record(domain: str, nameservers: list[str] = None, try: record = _query_dmarc_record( - domain, nameservers=nameservers, - resolver=resolver, timeout=timeout, - ignore_unrelated_records=ignore_unrelated_records) + domain, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ignore_unrelated_records=ignore_unrelated_records, + ) except DMARCRecordNotFound: # Skip this exception as we want to query the base domain. If we fail # at that, at the end of this function we will raise another @@ -506,16 +539,14 @@ def query_dmarc_record(domain: str, nameservers: list[str] = None, record = None try: - root_records = query_dns(domain, "TXT", - nameservers=nameservers, resolver=resolver, - timeout=timeout) + root_records = query_dns( + domain, "TXT", nameservers=nameservers, resolver=resolver, timeout=timeout + ) for root_record in root_records: if root_record.startswith("v=DMARC1"): - warnings.append(f"DMARC record at root of {domain} " - "has no effect") + warnings.append(f"DMARC record at root of {domain} " "has no effect") except dns.resolver.NXDOMAIN: - raise DMARCRecordNotFound( - f"The domain {domain} does not exist") + raise DMARCRecordNotFound(f"The domain {domain} does not exist") except dns.exception.DNSException: pass @@ -525,19 +556,22 @@ def query_dmarc_record(domain: str, nameservers: list[str] = None, nameservers=nameservers, resolver=resolver, timeout=timeout, - ignore_unrelated_records=ignore_unrelated_records) + ignore_unrelated_records=ignore_unrelated_records, + ) location = base_domain if record is None: raise DMARCRecordNotFound( - "A DMARC record does not exist for this domain or its base domain") + "A DMARC record does not exist for this domain or its base domain" + ) - return OrderedDict([("record", record), ("location", location), - ("warnings", warnings)]) + return OrderedDict( + [("record", record), ("location", location), ("warnings", warnings)] + ) def get_dmarc_tag_description( - tag: str, - value: Union[str, list[str]] = None) -> OrderedDict: + tag: str, value: Union[str, list[str]] = None +) -> OrderedDict: """ Get the name, default value, and description for a DMARC tag, amd/or a description for a tag value @@ -571,7 +605,8 @@ def get_dmarc_tag_description( description = new_description return OrderedDict( - [("name", name), ("default", default), ("description", description)]) + [("name", name), ("default", default), ("description", description)] + ) def parse_dmarc_report_uri(uri: str) -> OrderedDict: @@ -598,7 +633,8 @@ def parse_dmarc_report_uri(uri: str) -> OrderedDict: if len(mailto_matches) != 1: raise InvalidDMARCReportURI( ( - f"{uri} is not a valid DMARC report URI" + ( + f"{uri} is not a valid DMARC report URI" + + ( "" if uri.startswith("mailto:") else ( @@ -615,16 +651,18 @@ def parse_dmarc_report_uri(uri: str) -> OrderedDict: if size_limit == "": size_limit = None - return OrderedDict([("scheme", scheme), ("address", email_address), - ("size_limit", size_limit)]) + return OrderedDict( + [("scheme", scheme), ("address", email_address), ("size_limit", size_limit)] + ) def check_wildcard_dmarc_report_authorization( - domain: str, - nameservers: list[str] = None, - ignore_unrelated_records: bool = False, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> bool: + domain: str, + nameservers: list[str] = None, + ignore_unrelated_records: bool = False, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> bool: """ Checks for a wildcard DMARC report authorization record, e.g.: @@ -648,9 +686,13 @@ def check_wildcard_dmarc_report_authorization( dmarc_record_count = 0 unrelated_records = [] try: - records = query_dns(wildcard_target, "TXT", - nameservers=nameservers, resolver=resolver, - timeout=timeout) + records = query_dns( + wildcard_target, + "TXT", + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) for record in records: if record.startswith("v=DMARC1"): @@ -665,7 +707,8 @@ def check_wildcard_dmarc_report_authorization( "These should be removed, as some " "receivers may not expect to find unrelated TXT records " f"at {wildcard_target}\n\n{ur_str}", - data={"target": wildcard_target}) + data={"target": wildcard_target}, + ) if dmarc_record_count < 1: return False @@ -675,53 +718,62 @@ def check_wildcard_dmarc_report_authorization( return True -def verify_dmarc_report_destination(source_domain: str, - destination_domain: str, - nameservers: list[str] = None, - ignore_unrelated_records: bool = False, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> None: +def verify_dmarc_report_destination( + source_domain: str, + destination_domain: str, + nameservers: list[str] = None, + ignore_unrelated_records: bool = False, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> None: """ - Checks if the report destination accepts reports for the source domain - per RFC 7489, section 7.1. Raises - `checkdmarc.dmarc.UnverifiedDMARCURIDestination` if it doesn't accept. - - Args: - source_domain (str): The source domain - destination_domain (str): The destination domain - nameservers (list): A list of nameservers to query - ignore_unrelated_records (bool): Ignore unrelated TXT records - resolver (dns.resolver.Resolver): A resolver object to use for DNS - requests - timeout (float): number of seconds to wait for an answer from DNS + Checks if the report destination accepts reports for the source domain + per RFC 7489, section 7.1. Raises + `checkdmarc.dmarc.UnverifiedDMARCURIDestination` if it doesn't accept. - Raises: - :exc:`checkdmarc.dmarc.UnverifiedDMARCURIDestination` - :exc:`checkdmarc.dmarc.UnrelatedTXTRecordFound` - """ + Args: + source_domain (str): The source domain + destination_domain (str): The destination domain + nameservers (list): A list of nameservers to query + ignore_unrelated_records (bool): Ignore unrelated TXT records + resolver (dns.resolver.Resolver): A resolver object to use for DNS + requests + timeout (float): number of seconds to wait for an answer from DNS + + Raises: + :exc:`checkdmarc.dmarc.UnverifiedDMARCURIDestination` + :exc:`checkdmarc.dmarc.UnrelatedTXTRecordFound` + """ source_domain = source_domain.lower() destination_domain = destination_domain.lower() if get_base_domain(source_domain) != get_base_domain(destination_domain): if check_wildcard_dmarc_report_authorization( - destination_domain, - nameservers=nameservers, - ignore_unrelated_records=ignore_unrelated_records, - resolver=resolver): + destination_domain, + nameservers=nameservers, + ignore_unrelated_records=ignore_unrelated_records, + resolver=resolver, + ): return target = f"{source_domain}._report._dmarc.{destination_domain}" - message = f"{destination_domain} does not indicate that it accepts " \ - f"DMARC reports about {source_domain} - " \ - "Authorization record not found: " \ - f'{source_domain}._report._dmarc.{destination_domain} " \ + message = ( + f"{destination_domain} does not indicate that it accepts " + f"DMARC reports about {source_domain} - " + "Authorization record not found: " + f'{source_domain}._report._dmarc.{destination_domain} " \ IN TXT "v=DMARC1"' + ) dmarc_record_count = 0 unrelated_records = [] try: - records = query_dns(target, "TXT", - nameservers=nameservers, resolver=resolver, - timeout=timeout) + records = query_dns( + target, + "TXT", + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) for record in records: if record.startswith("v=DMARC1"): @@ -735,7 +787,9 @@ def verify_dmarc_report_destination(source_domain: str, "Unrelated TXT records were discovered. " "These should be removed, as some " "receivers may not expect to find unrelated TXT records " - f"at {target}\n\n{ur_str}", data={"target": target}) + f"at {target}\n\n{ur_str}", + data={"target": target}, + ) if dmarc_record_count < 1: raise UnverifiedDMARCURIDestination(message) @@ -744,13 +798,16 @@ def verify_dmarc_report_destination(source_domain: str, def parse_dmarc_record( - record: str, domain: str, parked: bool = False, - include_tag_descriptions: bool = False, - nameservers: list[str] = None, - ignore_unrelated_records: bool = False, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0, - syntax_error_marker: str = SYNTAX_ERROR_MARKER) -> OrderedDict: + record: str, + domain: str, + parked: bool = False, + include_tag_descriptions: bool = False, + nameservers: list[str] = None, + ignore_unrelated_records: bool = False, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, + syntax_error_marker: str = SYNTAX_ERROR_MARKER, +) -> OrderedDict: """ Parses a DMARC record @@ -792,11 +849,13 @@ def parse_dmarc_record( """ logging.debug(f"Parsing the DMARC record for {domain}") - spf_in_dmarc_error_msg = "Found a SPF record where a DMARC record " \ - "should be; most likely, the _dmarc " \ - "subdomain record does not actually exist, " \ - "and the request for TXT records was " \ - "redirected to the base domain" + spf_in_dmarc_error_msg = ( + "Found a SPF record where a DMARC record " + "should be; most likely, the _dmarc " + "subdomain record does not actually exist, " + "and the request for TXT records was " + "redirected to the base domain" + ) warnings = [] record = record.strip('"') if record.lower().startswith("v=spf1"): @@ -805,14 +864,20 @@ def parse_dmarc_record( parsed_record = dmarc_syntax_checker.parse(record) if not parsed_record.is_valid: expecting = list( - map(lambda x: str(x).strip('"'), list(parsed_record.expecting))) - marked_record = (record[:parsed_record.pos] + syntax_error_marker + - record[parsed_record.pos:]) + map(lambda x: str(x).strip('"'), list(parsed_record.expecting)) + ) + marked_record = ( + record[: parsed_record.pos] + + syntax_error_marker + + record[parsed_record.pos :] + ) expecting = " or ".join(expecting) - raise DMARCSyntaxError(f"Error: Expected {expecting} at position " - f"{parsed_record.pos} " - f"(marked with {syntax_error_marker}) in: " - f"{marked_record}") + raise DMARCSyntaxError( + f"Error: Expected {expecting} at position " + f"{parsed_record.pos} " + f"(marked with {syntax_error_marker}) in: " + f"{marked_record}" + ) pairs = DMARC_TAG_VALUE_REGEX.findall(record) tags = OrderedDict() @@ -820,20 +885,20 @@ def parse_dmarc_record( # Find explicit tags for pair in pairs: tags[pair[0].lower()] = OrderedDict( - [("value", str(pair[1].strip())), ("explicit", True)]) + [("value", str(pair[1].strip())), ("explicit", True)] + ) # Include implicit tags and their defaults for tag in dmarc_tags.keys(): if tag not in tags and "default" in dmarc_tags[tag]: tags[tag] = OrderedDict( - [("value", dmarc_tags[tag]["default"]), ("explicit", False)]) + [("value", dmarc_tags[tag]["default"]), ("explicit", False)] + ) if "p" not in tags: - raise DMARCSyntaxError( - 'The record is missing the required policy ("p") tag') + raise DMARCSyntaxError('The record is missing the required policy ("p") tag') tags["p"]["value"] = tags["p"]["value"].lower() if "sp" not in tags: - tags["sp"] = OrderedDict([("value", tags["p"]["value"]), - ("explicit", False)]) + tags["sp"] = OrderedDict([("value", tags["p"]["value"]), ("explicit", False)]) if list(tags.keys())[1] != "p": raise DMARCSyntaxError("the p tag must immediately follow the v tag") tags["v"]["value"] = tags["v"]["value"].upper() @@ -849,38 +914,37 @@ def parse_dmarc_record( tag_value = tag_value.split(":") if "0" in tag_value and "1" in tag_value: warnings.append( - "When 1 is present in the fo tag, including 0 is " - "redundant" + "When 1 is present in the fo tag, including 0 is " "redundant" ) for value in tag_value: if value not in allowed_values: raise InvalidDMARCTagValue( - f"{value} is not a valid option for the DMARC fo tag") + f"{value} is not a valid option for the DMARC fo tag" + ) elif tag == "rf": tag_value = tag_value.lower().split(":") for value in tag_value: if value not in allowed_values: raise InvalidDMARCTagValue( - f"{value} is not a valid option for the DMARC " - "rf tag") + f"{value} is not a valid option for the DMARC " "rf tag" + ) elif allowed_values and tag_value not in allowed_values: allowed_values_str = ",".join(allowed_values) raise InvalidDMARCTagValue( f"Tag {tag} must have one of the following values: " - f"{allowed_values_str} - not {tags[tag]['value']}") + f"{allowed_values_str} - not {tags[tag]['value']}" + ) try: tags["pct"]["value"] = int(tags["pct"]["value"]) except ValueError: - raise InvalidDMARCTagValue( - "The value of the pct tag must be an integer") + raise InvalidDMARCTagValue("The value of the pct tag must be an integer") try: tags["ri"]["value"] = int(tags["ri"]["value"]) except ValueError: - raise InvalidDMARCTagValue( - "The value of the ri tag must be an integer") + raise InvalidDMARCTagValue("The value of the ri tag must be an integer") if "rua" in tags: parsed_uris = [] @@ -898,12 +962,15 @@ def parse_dmarc_record( nameservers=nameservers, ignore_unrelated_records=ignore_unrelated_records, resolver=resolver, - timeout=timeout) + timeout=timeout, + ) try: - hosts = get_mx_records(email_domain, - nameservers=nameservers, - resolver=resolver, - timeout=timeout) + hosts = get_mx_records( + email_domain, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) if len(hosts) == 0: raise DMARCReportEmailAddressMissingMXRecords( "The domain for rua email address " @@ -913,18 +980,28 @@ def parse_dmarc_record( raise DMARCReportEmailAddressMissingMXRecords( "Failed to retrieve MX records for the domain of " "rua email address " - f"{email_address} - {warning}") + f"{email_address} - {warning}" + ) except _DMARCWarning as warning: warnings.append(str(warning)) tags["rua"]["value"] = parsed_uris if len(parsed_uris) > 2: - warnings.append(str(_DMARCBestPracticeWarning( - "Some DMARC reporters might not send to more than two rua URIs" - ))) + warnings.append( + str( + _DMARCBestPracticeWarning( + "Some DMARC reporters might not send to more than two rua URIs" + ) + ) + ) else: - warnings.append(str(_DMARCBestPracticeWarning( - "rua tag (destination for aggregate reports) not found"))) + warnings.append( + str( + _DMARCBestPracticeWarning( + "rua tag (destination for aggregate reports) not found" + ) + ) + ) if "ruf" in tags.keys(): parsed_uris = [] @@ -942,12 +1019,15 @@ def parse_dmarc_record( nameservers=nameservers, ignore_unrelated_records=ignore_unrelated_records, resolver=resolver, - timeout=timeout) + timeout=timeout, + ) try: - hosts = get_mx_records(email_domain, - nameservers=nameservers, - resolver=resolver, - timeout=timeout) + hosts = get_mx_records( + email_domain, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) if len(hosts) == 0: raise DMARCReportEmailAddressMissingMXRecords( "The domain for ruf email address " @@ -965,25 +1045,31 @@ def parse_dmarc_record( tags["ruf"]["value"] = parsed_uris if len(parsed_uris) > 2: - warnings.append(str(_DMARCBestPracticeWarning( - "Some DMARC reporters might not send to more than two ruf URIs" - ))) + warnings.append( + str( + _DMARCBestPracticeWarning( + "Some DMARC reporters might not send to more than two ruf URIs" + ) + ) + ) if tags["pct"]["value"] < 0 or tags["pct"]["value"] > 100: - warnings.append(str(InvalidDMARCTagValue( - "pct value must be an integer between 0 and 100"))) + warnings.append( + str(InvalidDMARCTagValue("pct value must be an integer between 0 and 100")) + ) elif tags["pct"]["value"] < 100: - warning_msg = "pct value is less than 100. This leads to " \ - "inconsistent and unpredictable policy " \ - "enforcement. Consider using p=none to " \ - "monitor results instead" + warning_msg = ( + "pct value is less than 100. This leads to " + "inconsistent and unpredictable policy " + "enforcement. Consider using p=none to " + "monitor results instead" + ) warnings.append(str(_DMARCBestPracticeWarning(warning_msg))) if parked and tags["p"] != "reject": warning_msg = "Policy (p=) should be reject for parked domains" warnings.append(str(_DMARCBestPracticeWarning(warning_msg))) if parked and tags["sp"] != "reject": - warning_msg = "Subdomain policy (sp=) should be reject for " \ - "parked domains" + warning_msg = "Subdomain policy (sp=) should be reject for " "parked domains" warnings.append(str(_DMARCBestPracticeWarning(warning_msg))) # Add descriptions if requested @@ -999,11 +1085,13 @@ def parse_dmarc_record( return OrderedDict([("tags", tags), ("warnings", warnings)]) -def get_dmarc_record(domain: str, - include_tag_descriptions: bool = False, - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> OrderedDict: +def get_dmarc_record( + domain: str, + include_tag_descriptions: bool = False, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> OrderedDict: """ Retrieves a DMARC record for a domain and parses it @@ -1035,64 +1123,71 @@ def get_dmarc_record(domain: str, :exc:`checkdmarc.dmarc.UnrelatedTXTRecordFound` :exc:`checkdmarc.dmarc.DMARCReportEmailAddressMissingMXRecords` """ - query = query_dmarc_record(domain, nameservers=nameservers, - resolver=resolver, timeout=timeout) + query = query_dmarc_record( + domain, nameservers=nameservers, resolver=resolver, timeout=timeout + ) tag_descriptions = include_tag_descriptions - tags = parse_dmarc_record(query["record"], query["location"], - include_tag_descriptions=tag_descriptions, - nameservers=nameservers, resolver=resolver, - timeout=timeout) - - return OrderedDict([("record", - query["record"]), - ("location", query["location"]), - ("parsed", tags)]) - + tags = parse_dmarc_record( + query["record"], + query["location"], + include_tag_descriptions=tag_descriptions, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) -def check_dmarc(domain: str, parked: bool = False, - include_dmarc_tag_descriptions: bool = False, - ignore_unrelated_records: bool = False, - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> OrderedDict: + return OrderedDict( + [("record", query["record"]), ("location", query["location"]), ("parsed", tags)] + ) + + +def check_dmarc( + domain: str, + parked: bool = False, + include_dmarc_tag_descriptions: bool = False, + ignore_unrelated_records: bool = False, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> OrderedDict: """ - Returns a dictionary with a parsed DMARC record or an error + Returns a dictionary with a parsed DMARC record or an error - Args: - domain (str): A domain name - parked (bool): The domain is parked - include_dmarc_tag_descriptions (bool): Include tag descriptions - ignore_unrelated_records (bool): Ignore unrelated TXT records - nameservers (list): A list of nameservers to query - resolver (dns.resolver.Resolver): A resolver object to use for DNS - requests - timeout (float): number of seconds to wait for a record from DNS + Args: + domain (str): A domain name + parked (bool): The domain is parked + include_dmarc_tag_descriptions (bool): Include tag descriptions + ignore_unrelated_records (bool): Ignore unrelated TXT records + nameservers (list): A list of nameservers to query + resolver (dns.resolver.Resolver): A resolver object to use for DNS + requests + timeout (float): number of seconds to wait for a record from DNS - Returns: - OrderedDict: An ``OrderedDict`` with the following keys: + Returns: + OrderedDict: An ``OrderedDict`` with the following keys: - - ``record`` - the unparsed DMARC record string - - ``location`` - the domain where the record was found - - ``warnings`` - warning conditions found + - ``record`` - the unparsed DMARC record string + - ``location`` - the domain where the record was found + - ``warnings`` - warning conditions found - If a DNS error occurs, the dictionary will have the - following keys: + If a DNS error occurs, the dictionary will have the + following keys: - - ``error`` - An error message - - ``valid`` - False + - ``error`` - An error message + - ``valid`` - False - """ - dmarc_results = OrderedDict([("record", None), ("valid", True), - ("location", None)]) + """ + dmarc_results = OrderedDict([("record", None), ("valid", True), ("location", None)]) try: dmarc_query = query_dmarc_record( domain, ignore_unrelated_records=ignore_unrelated_records, nameservers=nameservers, resolver=resolver, - timeout=timeout) + timeout=timeout, + ) dmarc_results["record"] = dmarc_query["record"] dmarc_results["location"] = dmarc_query["location"] parsed_dmarc_record = parse_dmarc_record( @@ -1101,13 +1196,14 @@ def check_dmarc(domain: str, parked: bool = False, parked=parked, include_tag_descriptions=include_dmarc_tag_descriptions, ignore_unrelated_records=ignore_unrelated_records, - nameservers=nameservers, resolver=resolver, - timeout=timeout) + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) dmarc_results["warnings"] = dmarc_query["warnings"] dmarc_results["tags"] = parsed_dmarc_record["tags"] - dmarc_results["warnings"] += parsed_dmarc_record[ - "warnings"] + dmarc_results["warnings"] += parsed_dmarc_record["warnings"] except DMARCError as error: dmarc_results["error"] = str(error) dmarc_results["valid"] = False diff --git a/checkdmarc/dnssec.py b/checkdmarc/dnssec.py index 2cd7652..af94e81 100644 --- a/checkdmarc/dnssec.py +++ b/checkdmarc/dnssec.py @@ -32,8 +32,12 @@ TLSA_CACHE = ExpiringDict(max_len=200000, max_age_seconds=1800) -def get_dnskey(domain: str, nameservers: list[str] = None, - timeout: float = 2.0, cache: ExpiringDict = None): +def get_dnskey( + domain: str, + nameservers: list[str] = None, + timeout: float = 2.0, + cache: ExpiringDict = None, +): """ Get a DNSKEY RRSet on the given domain @@ -57,9 +61,7 @@ def get_dnskey(domain: str, nameservers: list[str] = None, return cache[domain] logging.debug(f"Checking for DNSKEY records at {domain}") - request = dns.message.make_query(domain, - dns.rdatatype.DNSKEY, - want_dnssec=True) + request = dns.message.make_query(domain, dns.rdatatype.DNSKEY, want_dnssec=True) for nameserver in nameservers: try: response = dns.query.udp(request, nameserver, timeout=timeout) @@ -69,13 +71,13 @@ def get_dnskey(domain: str, nameservers: list[str] = None, logging.debug(f"No DNSKEY records found at {domain}") base_domain = get_base_domain(domain) if domain != base_domain: - return get_dnskey(base_domain, - nameservers=nameservers, - timeout=timeout) + return get_dnskey( + base_domain, nameservers=nameservers, timeout=timeout + ) cache[domain] = None return None rrset = answer[0] - name = dns.name.from_text(f'{domain}.') + name = dns.name.from_text(f"{domain}.") key = {name: rrset} cache[domain] = key return key @@ -84,10 +86,12 @@ def get_dnskey(domain: str, nameservers: list[str] = None, logging.debug(f"DNSKEY query error: {e}") -def test_dnssec(domain: str, - nameservers: list[str] = None, - timeout: float = 2.0, - cache: ExpiringDict = None) -> bool: +def test_dnssec( + domain: str, + nameservers: list[str] = None, + timeout: float = 2.0, + cache: ExpiringDict = None, +) -> bool: """ Check for DNSSEC on the given domain @@ -111,19 +115,18 @@ def test_dnssec(domain: str, key = get_dnskey(domain, nameservers=nameservers, timeout=timeout) if key is None: return False - rdatatypes = [dns.rdatatype.DNSKEY, - dns.rdatatype.MX, - dns.rdatatype.A, - dns.rdatatype.NS, - dns.rdatatype.CNAME] + rdatatypes = [ + dns.rdatatype.DNSKEY, + dns.rdatatype.MX, + dns.rdatatype.A, + dns.rdatatype.NS, + dns.rdatatype.CNAME, + ] for rdatatype in rdatatypes: - request = dns.message.make_query(domain, - rdatatype, - want_dnssec=True) + request = dns.message.make_query(domain, rdatatype, want_dnssec=True) for nameserver in nameservers: try: - response = dns.query.udp(request, nameserver, - timeout=timeout) + response = dns.query.udp(request, nameserver, timeout=timeout) if response is not None: answer = response.answer if len(answer) != 2: @@ -141,10 +144,14 @@ def test_dnssec(domain: str, return False -def get_tlsa_records(hostname: str, nameservers: list[str] = None, - timeout: float = 2.0, port: int = 25, - protocol: str = "tcp", - cache: ExpiringDict = None) -> list[str]: +def get_tlsa_records( + hostname: str, + nameservers: list[str] = None, + timeout: float = 2.0, + port: int = 25, + protocol: str = "tcp", + cache: ExpiringDict = None, +) -> list[str]: """ Checks for TLSA records on the given hostname @@ -170,9 +177,9 @@ def get_tlsa_records(hostname: str, nameservers: list[str] = None, return TLSA_CACHE[query_hostname] tlsa_records = [] logging.debug(f"Checking for TLSA records at {query_hostname}") - request = dns.message.make_query(query_hostname, - dns.rdatatype.TLSA, - want_dnssec=True) + request = dns.message.make_query( + query_hostname, dns.rdatatype.TLSA, want_dnssec=True + ) for nameserver in nameservers: try: response = dns.query.udp(request, nameserver, timeout=timeout) @@ -181,19 +188,18 @@ def get_tlsa_records(hostname: str, nameservers: list[str] = None, if len(answer) != 2: return tlsa_records dnskey = get_dnskey( - domain=hostname, - nameservers=nameservers, - timeout=timeout + domain=hostname, nameservers=nameservers, timeout=timeout ) if dnskey is None: - logging.debug(f"Found TLSA records at {hostname} but not " - f"a DNSKEY record to verify them") + logging.debug( + f"Found TLSA records at {hostname} but not " + f"a DNSKEY record to verify them" + ) return tlsa_records rrset = answer[0] rrsig = answer[1] dns.dnssec.validate(rrset, rrsig, dnskey) - tlsa_records = list(map(lambda x: str(x), - list(rrset.items.keys()))) + tlsa_records = list(map(lambda x: str(x), list(rrset.items.keys()))) cache[query_hostname] = tlsa_records return tlsa_records except Exception as e: diff --git a/checkdmarc/mta_sts.py b/checkdmarc/mta_sts.py index 179e16f..3eaace2 100644 --- a/checkdmarc/mta_sts.py +++ b/checkdmarc/mta_sts.py @@ -9,11 +9,12 @@ import dns import requests -from pyleri import (Grammar, - Regex, - Sequence, - List, - ) +from pyleri import ( + Grammar, + Regex, + Sequence, + List, +) from checkdmarc.utils import query_dns, WSP_REGEX from checkdmarc._constants import SYNTAX_ERROR_MARKER, USER_AGENT @@ -33,8 +34,8 @@ limitations under the License.""" -MTA_STS_VERSION_REGEX_STRING = fr"v{WSP_REGEX}*={WSP_REGEX}*STSv1{WSP_REGEX}*;" -MTA_STS_TAG_VALUE_REGEX_STRING = fr"([a-z]{{1,2}}){WSP_REGEX}*={WSP_REGEX}*([\ +MTA_STS_VERSION_REGEX_STRING = rf"v{WSP_REGEX}*={WSP_REGEX}*STSv1{WSP_REGEX}*;" +MTA_STS_TAG_VALUE_REGEX_STRING = rf"([a-z]{{1,2}}){WSP_REGEX}*={WSP_REGEX}*([\ a-z0-9]+)" MTA_STS_MX_REGEX_STRING = r"[a-z0-9\-*.]+" @@ -43,11 +44,12 @@ class MTASTSError(Exception): """Raised when a fatal MTA-STS error occurs""" + def __init__(self, msg: str, data: dict = None): """ - Args: - msg (str): The error message - data (dict): A dictionary of data to include in the results + Args: + msg (str): The error message + data (dict): A dictionary of data to include in the results """ self.data = data Exception.__init__(self, msg) @@ -55,6 +57,7 @@ def __init__(self, msg: str, data: dict = None): class MTASTSRecordNotFound(MTASTSError): """Raised when an MTA-STS record could not be found""" + def __init__(self, error): if isinstance(error, dns.exception.Timeout): error.kwargs["timeout"] = round(error.kwargs["timeout"], 1) @@ -78,9 +81,9 @@ class UnrelatedTXTRecordFoundAtMTASTS(MTASTSError): class SPFRecordFoundWhereMTASTSRecordShouldBe(UnrelatedTXTRecordFoundAtMTASTS): """Raised when an SPF record is found where an MTA-STS record should be; - most likely, the ``_mta-sts`` subdomain - record does not actually exist, and the request for ``TXT`` records was - redirected to the base domain""" + most likely, the ``_mta-sts`` subdomain + record does not actually exist, and the request for ``TXT`` records was + redirected to the base domain""" class MTASTSRecordInWrongLocation(MTASTSError): @@ -105,38 +108,43 @@ class MTASTSPolicySyntaxError(MTASTSPolicyError): class _STSGrammar(Grammar): """Defines Pyleri grammar for MTA-STS records""" + version_tag = Regex(MTA_STS_VERSION_REGEX_STRING, re.IGNORECASE) tag_value = Regex(MTA_STS_TAG_VALUE_REGEX_STRING, re.IGNORECASE) START = Sequence( version_tag, - List( - tag_value, - delimiter=Regex(f"{WSP_REGEX}*;{WSP_REGEX}*"), - opt=True)) + List(tag_value, delimiter=Regex(f"{WSP_REGEX}*;{WSP_REGEX}*"), opt=True), + ) mta_sts_tags = OrderedDict( - v=OrderedDict(name="Version", - required=True, - description='Currently, only "STSv1" is supported.'), - id=OrderedDict(name="id", - required=True, - description='A short string used to track policy ' - 'updates. This string MUST uniquely identify ' - 'a given instance of a policy, such that ' - 'senders can determine when the policy has ' - 'been updated by comparing to the "id" of a ' - 'previously seen policy. There is no implied ' - 'ordering of "id" fields between revisions.') + v=OrderedDict( + name="Version", + required=True, + description='Currently, only "STSv1" is supported.', + ), + id=OrderedDict( + name="id", + required=True, + description="A short string used to track policy " + "updates. This string MUST uniquely identify " + "a given instance of a policy, such that " + "senders can determine when the policy has " + 'been updated by comparing to the "id" of a ' + "previously seen policy. There is no implied " + 'ordering of "id" fields between revisions.', + ), ) STS_TAG_VALUE_REGEX = re.compile(MTA_STS_TAG_VALUE_REGEX_STRING, re.IGNORECASE) -def query_mta_sts_record(domain: str, - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> OrderedDict: +def query_mta_sts_record( + domain: str, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> OrderedDict: """ Queries DNS for an MTA-STS record @@ -168,8 +176,9 @@ def query_mta_sts_record(domain: str, unrelated_records = [] try: - records = query_dns(target, "TXT", nameservers=nameservers, - resolver=resolver, timeout=timeout) + records = query_dns( + target, "TXT", nameservers=nameservers, resolver=resolver, timeout=timeout + ) for record in records: if record.startswith(txt_prefix): sts_record_count += 1 @@ -177,32 +186,36 @@ def query_mta_sts_record(domain: str, unrelated_records.append(record) if sts_record_count > 1: - raise MultipleMTASTSRecords( - "Multiple MTA-STS records are not permitted") + raise MultipleMTASTSRecords("Multiple MTA-STS records are not permitted") if len(unrelated_records) > 0: ur_str = "\n\n".join(unrelated_records) raise UnrelatedTXTRecordFoundAtMTASTS( "Unrelated TXT records were discovered. These should be " "removed, as some receivers may not expect to find " "unrelated TXT records " - f"at {target}\n\n{ur_str}") + f"at {target}\n\n{ur_str}" + ) sts_record = records[0] except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN): try: - records = query_dns(domain, "TXT", - nameservers=nameservers, resolver=resolver, - timeout=timeout) + records = query_dns( + domain, + "TXT", + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) for record in records: if record.startswith(txt_prefix): raise MTASTSRecordInWrongLocation( "The MTA-STS record must be located at " - f"{target}, not {domain}") + f"{target}, not {domain}" + ) except dns.resolver.NoAnswer: pass except dns.resolver.NXDOMAIN: - raise MTASTSRecordNotFound( - f"The domain {domain} does not exist") + raise MTASTSRecordNotFound(f"The domain {domain} does not exist") except Exception as error: raise MTASTSRecordNotFound(error) except Exception as error: @@ -210,16 +223,17 @@ def query_mta_sts_record(domain: str, if sts_record is None: raise MTASTSRecordNotFound( - "An MTA-STS DNS record does not exist for this domain") + "An MTA-STS DNS record does not exist for this domain" + ) - return OrderedDict([("record", sts_record), - ("warnings", warnings)]) + return OrderedDict([("record", sts_record), ("warnings", warnings)]) def parse_mta_sts_record( - record: str, - include_tag_descriptions: bool = False, - syntax_error_marker: str = SYNTAX_ERROR_MARKER) -> OrderedDict: + record: str, + include_tag_descriptions: bool = False, + syntax_error_marker: str = SYNTAX_ERROR_MARKER, +) -> OrderedDict: """ Parses an MTA-STS record @@ -249,11 +263,13 @@ def parse_mta_sts_record( """ logging.debug("Parsing the MTA-STS record") - spf_in_dmarc_error_msg = "Found a SPF record where a MTA-STS record " \ - "should be; most likely, the _mta-sts " \ - "subdomain record does not actually exist, " \ - "and the request for TXT records was " \ - "redirected to the base domain" + spf_in_dmarc_error_msg = ( + "Found a SPF record where a MTA-STS record " + "should be; most likely, the _mta-sts " + "subdomain record does not actually exist, " + "and the request for TXT records was " + "redirected to the base domain" + ) warnings = [] record = record.strip('"') if record.lower().startswith("v=spf1"): @@ -262,14 +278,20 @@ def parse_mta_sts_record( parsed_record = sts_syntax_checker.parse(record) if not parsed_record.is_valid: expecting = list( - map(lambda x: str(x).strip('"'), list(parsed_record.expecting))) - marked_record = (record[:parsed_record.pos] + syntax_error_marker + - record[parsed_record.pos:]) + map(lambda x: str(x).strip('"'), list(parsed_record.expecting)) + ) + marked_record = ( + record[: parsed_record.pos] + + syntax_error_marker + + record[parsed_record.pos :] + ) expecting = " or ".join(expecting) - raise MTASTSRecordSyntaxError(f"Error: Expected {expecting} " - f"at position {parsed_record.pos} " - f"(marked with {syntax_error_marker}) " - f"in: {marked_record}") + raise MTASTSRecordSyntaxError( + f"Error: Expected {expecting} " + f"at position {parsed_record.pos} " + f"(marked with {syntax_error_marker}) " + f"in: {marked_record}" + ) pairs = STS_TAG_VALUE_REGEX.findall(record) tags = OrderedDict() @@ -315,11 +337,15 @@ def download_mta_sts_policy(domain: str) -> OrderedDict: content_type = response.headers["Content-Type"].split(";")[0] content_type = content_type.strip() if content_type != expected_content_type: - warnings.append(f"Content-Type header should be " - f"{expected_content_type} not {content_type}") + warnings.append( + f"Content-Type header should be " + f"{expected_content_type} not {content_type}" + ) else: - warnings.append("The Content-Type header is missing. It should " - f"be set to {expected_content_type}") + warnings.append( + "The Content-Type header is missing. It should " + f"be set to {expected_content_type}" + ) except Exception as e: raise MTASTSPolicyDownloadError(str(e)) @@ -360,13 +386,11 @@ def parse_mta_sts_policy(policy: str) -> OrderedDict: continue key_value = lines[i].split(":") if len(key_value) != 2: - raise MTASTSPolicySyntaxError( - f"Line {line}: Not a key: value pair") + raise MTASTSPolicySyntaxError(f"Line {line}: Not a key: value pair") key = key_value[0].strip() value = key_value[1].strip() if key not in acceptable_keys: - raise MTASTSPolicySyntaxError( - f"Line {line}: Unexpected key: {key}") + raise MTASTSPolicySyntaxError(f"Line {line}: Unexpected key: {key}") if key in parsed_policy and key != "mx": MTASTSPolicySyntaxError(f"Line {line}: Duplicate key: {key}") elif key == "version" and value not in versions: @@ -374,8 +398,7 @@ def parse_mta_sts_policy(policy: str) -> OrderedDict: elif key == "mode" and value not in modes: MTASTSPolicySyntaxError(f"Line {line}: Invalid mode: {value}") elif key == "max_age": - error_msg = ("max_age must be an integer value between 0 and " - "31557600") + error_msg = "max_age must be an integer value between 0 and " "31557600" if "." in value: raise MTASTSPolicySyntaxError(error_msg) try: @@ -388,26 +411,29 @@ def parse_mta_sts_policy(policy: str) -> OrderedDict: parsed_policy[key] = value else: if len(MTA_STS_MX_REGEX.findall(value)) == 0: - raise MTASTSPolicySyntaxError(f"Line {line}: Invalid mx " - f"value: {value}") + raise MTASTSPolicySyntaxError( + f"Line {line}: Invalid mx " f"value: {value}" + ) mx.append(value) for required_key in required_keys: if required_key not in parsed_policy: - raise MTASTSPolicySyntaxError(f"Missing required key: " - f"{required_key}") + raise MTASTSPolicySyntaxError(f"Missing required key: " f"{required_key}") if parsed_policy["mode"] != "none" and len(mx) == 0: - raise MTASTSPolicySyntaxError(f"{parsed_policy['mode']} mode requires " - f"at least one mx value") + raise MTASTSPolicySyntaxError( + f"{parsed_policy['mode']} mode requires " f"at least one mx value" + ) parsed_policy["mx"] = mx return OrderedDict(policy=parsed_policy, warnings=warnings) -def check_mta_sts(domain: str, - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> OrderedDict: +def check_mta_sts( + domain: str, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> OrderedDict: """ Returns a dictionary with a parsed MTA-STS policy or an error. @@ -436,9 +462,8 @@ def check_mta_sts(domain: str, mta_sts_results = OrderedDict([("valid", True)]) try: mta_sts_record = query_mta_sts_record( - domain, - nameservers=nameservers, resolver=resolver, - timeout=timeout) + domain, nameservers=nameservers, resolver=resolver, timeout=timeout + ) warnings = mta_sts_record["warnings"] mta_sts_record = parse_mta_sts_record(mta_sts_record["record"]) mta_sts_results["id"] = mta_sts_record["tags"]["id"]["value"] @@ -455,8 +480,7 @@ def check_mta_sts(domain: str, return mta_sts_results -def mx_in_mta_sts_patterns(mx_hostname: str, mta_sts_mx_patterns: list[str])\ - -> bool: +def mx_in_mta_sts_patterns(mx_hostname: str, mta_sts_mx_patterns: list[str]) -> bool: """ Tests is a given MX hostname is covered by a given list of MX patterns from an MTA-STS policy: @@ -469,8 +493,7 @@ def mx_in_mta_sts_patterns(mx_hostname: str, mta_sts_mx_patterns: list[str])\ """ for pattern in mta_sts_mx_patterns: regex_pattern = pattern.replace(r".", r"\.") - regex_pattern = regex_pattern.replace(r"*", - r"[a-z0-9\-.]+") + regex_pattern = regex_pattern.replace(r"*", r"[a-z0-9\-.]+") if len(re.findall(regex_pattern, mx_hostname, re.IGNORECASE)) > 0: return True return False diff --git a/checkdmarc/smtp.py b/checkdmarc/smtp.py index e98e95f..fc8ba28 100644 --- a/checkdmarc/smtp.py +++ b/checkdmarc/smtp.py @@ -14,8 +14,12 @@ import timeout_decorator from expiringdict import ExpiringDict -from checkdmarc.utils import (DNSException, - get_a_records, get_reverse_dns, get_mx_records) +from checkdmarc.utils import ( + DNSException, + get_a_records, + get_reverse_dns, + get_mx_records, +) from checkdmarc.mta_sts import mx_in_mta_sts_patterns from checkdmarc.dnssec import test_dnssec, get_tlsa_records @@ -42,10 +46,12 @@ class SMTPError(Exception): """Raised when SMTP error occurs""" -@timeout_decorator.timeout(5, timeout_exception=SMTPError, - exception_message="Connection timed out") -def test_tls(hostname: str, ssl_context: ssl.SSLContext = None, - cache: ExpiringDict = None) -> bool: +@timeout_decorator.timeout( + 5, timeout_exception=SMTPError, exception_message="Connection timed out" +) +def test_tls( + hostname: str, ssl_context: ssl.SSLContext = None, cache: ExpiringDict = None +) -> bool: """ Attempt to connect to an SMTP server port 465 and validate TLS/SSL support @@ -153,11 +159,12 @@ def test_tls(hostname: str, ssl_context: ssl.SSLContext = None, return tls -@timeout_decorator.timeout(5, timeout_exception=SMTPError, - exception_message="Connection timed out") -def test_starttls(hostname: str, - ssl_context: ssl.SSLContext = None, - cache: ExpiringDict = None) -> bool: +@timeout_decorator.timeout( + 5, timeout_exception=SMTPError, exception_message="Connection timed out" +) +def test_starttls( + hostname: str, ssl_context: ssl.SSLContext = None, cache: ExpiringDict = None +) -> bool: """ Attempt to connect to an SMTP server and validate STARTTLS support @@ -266,13 +273,16 @@ def test_starttls(hostname: str, raise SMTPError(error) -def get_mx_hosts(domain: str, skip_tls: bool = False, - approved_hostnames: list[str] = None, - mta_sts_mx_patterns: list[str] = None, - parked: bool = False, - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0): +def get_mx_hosts( + domain: str, + skip_tls: bool = False, + approved_hostnames: list[str] = None, + mta_sts_mx_patterns: list[str] = None, + parked: bool = False, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +): """ Gets MX hostname and their addresses @@ -304,26 +314,31 @@ def get_mx_hosts(domain: str, skip_tls: bool = False, hostnames = set() dupe_hostnames = set() logging.debug(f"Getting MX records for {domain}") - mx_records = get_mx_records(domain, nameservers=nameservers, - resolver=resolver, timeout=timeout) + mx_records = get_mx_records( + domain, nameservers=nameservers, resolver=resolver, timeout=timeout + ) for record in mx_records: - hosts.append(OrderedDict([("preference", record["preference"]), - ("hostname", record["hostname"].lower()), - ("addresses", [])])) + hosts.append( + OrderedDict( + [ + ("preference", record["preference"]), + ("hostname", record["hostname"].lower()), + ("addresses", []), + ] + ) + ) if parked and len(hosts) > 0: warnings.append("MX records found on parked domains") elif not parked and len(hosts) == 0: warnings.append("No MX records found. Is the domain parked?") if approved_hostnames: - approved_hostnames = list(map(lambda h: h.lower(), - approved_hostnames)) + approved_hostnames = list(map(lambda h: h.lower(), approved_hostnames)) for host in hosts: hostname = host["hostname"] if hostname in hostnames: if hostname not in dupe_hostnames: - warnings.append( - f"Hostname {hostname} is listed in multiple MX records") + warnings.append(f"Hostname {hostname} is listed in multiple MX records") dupe_hostnames.add(hostname) continue hostnames.add(hostname) @@ -337,65 +352,62 @@ def get_mx_hosts(domain: str, skip_tls: bool = False, warnings.append(f"Unapproved MX hostname: {hostname}") if mta_sts_mx_patterns: if not mx_in_mta_sts_patterns(hostname, mta_sts_mx_patterns): - warnings.append(f"{hostname} is not included in the MTA-STS " - f"policy") + warnings.append(f"{hostname} is not included in the MTA-STS " f"policy") try: dnssec = False try: - dnssec = test_dnssec(hostname, - nameservers=nameservers, - timeout=timeout) + dnssec = test_dnssec(hostname, nameservers=nameservers, timeout=timeout) except Exception as e: logging.debug(e) host["dnssec"] = dnssec host["addresses"] = [] - host["addresses"] = get_a_records(hostname, - nameservers=nameservers, - resolver=resolver, - timeout=timeout) - tlsa_records = get_tlsa_records(hostname, - nameservers=nameservers, - timeout=timeout) + host["addresses"] = get_a_records( + hostname, nameservers=nameservers, resolver=resolver, timeout=timeout + ) + tlsa_records = get_tlsa_records( + hostname, nameservers=nameservers, timeout=timeout + ) if len(tlsa_records) > 0: host["tlsa"] = tlsa_records if len(host["addresses"]) == 0: - warnings.append( - f"{hostname} does not have any A or AAAA DNS records") + warnings.append(f"{hostname} does not have any A or AAAA DNS records") except Exception as e: if hostname.lower().endswith(".msv1.invalid"): - warnings.append(f"{e}. Consider using a TXT record to " - " validate domain ownership in Office 365 " - "instead.") + warnings.append( + f"{e}. Consider using a TXT record to " + " validate domain ownership in Office 365 " + "instead." + ) else: warnings.append(e.__str__()) for address in host["addresses"]: try: - reverse_hostnames = get_reverse_dns(address, - nameservers=nameservers, - resolver=resolver, - timeout=timeout) + reverse_hostnames = get_reverse_dns( + address, nameservers=nameservers, resolver=resolver, timeout=timeout + ) except DNSException: reverse_hostnames = [] if len(reverse_hostnames) == 0: warnings.append( - f"{address} does not have any reverse DNS (PTR) " - "records") + f"{address} does not have any reverse DNS (PTR) " "records" + ) for reverse_hostname in reverse_hostnames: try: - _addresses = get_a_records(reverse_hostname, - resolver=resolver) + _addresses = get_a_records(reverse_hostname, resolver=resolver) except DNSException as warning: warnings.append(str(warning)) _addresses = [] if address not in _addresses: - warnings.append(f"The reverse DNS of " - f"{address} is {reverse_hostname}, but " - "the A/AAAA DNS records for " - f"{reverse_hostname} do not resolve to " - f"{address}") + warnings.append( + f"The reverse DNS of " + f"{address} is {reverse_hostname}, but " + "the A/AAAA DNS records for " + f"{reverse_hostname} do not resolve to " + f"{address}" + ) if not skip_tls and platform.system() == "Windows": logging.warning("Testing TLS is not supported on Windows") skip_tls = True @@ -403,16 +415,14 @@ def get_mx_hosts(domain: str, skip_tls: bool = False, logging.debug(f"Skipping TLS/SSL tests on {hostname}") else: try: - starttls = test_starttls(hostname, - cache=STARTTLS_CACHE) + starttls = test_starttls(hostname, cache=STARTTLS_CACHE) tls = starttls if not starttls: warnings.append(f"STARTTLS is not supported on {hostname}") tls = test_tls(hostname, cache=TLS_CACHE) if not tls: - warnings.append(f"SSL/TLS is not supported on " - f"{hostname}") + warnings.append(f"SSL/TLS is not supported on " f"{hostname}") host["tls"] = tls host["starttls"] = starttls except DNSException as warning: @@ -432,12 +442,15 @@ def get_mx_hosts(domain: str, skip_tls: bool = False, return OrderedDict([("hosts", hosts), ("warnings", warnings)]) -def check_mx(domain: str, approved_mx_hostnames: list[str] = None, - mta_sts_mx_patterns: list[str] = None, - skip_tls: bool = False, - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> OrderedDict: +def check_mx( + domain: str, + approved_mx_hostnames: list[str] = None, + mta_sts_mx_patterns: list[str] = None, + skip_tls: bool = False, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> OrderedDict: """ Gets MX hostname and their addresses, or an empty list of hosts and an error if a DNS error occurs @@ -474,9 +487,10 @@ def check_mx(domain: str, approved_mx_hostnames: list[str] = None, skip_tls=skip_tls, approved_hostnames=approved_mx_hostnames, mta_sts_mx_patterns=mta_sts_mx_patterns, - nameservers=nameservers, resolver=resolver, - timeout=timeout) + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) except DNSException as error: - mx_results = OrderedDict([("hosts", []), - ("error", str(error))]) + mx_results = OrderedDict([("hosts", []), ("error", str(error))]) return mx_results diff --git a/checkdmarc/smtp_tls_reporting.py b/checkdmarc/smtp_tls_reporting.py index 62b4e3e..c917edb 100644 --- a/checkdmarc/smtp_tls_reporting.py +++ b/checkdmarc/smtp_tls_reporting.py @@ -9,15 +9,10 @@ from collections import OrderedDict import dns -from pyleri import (Grammar, - Regex, - Sequence, - List - ) +from pyleri import Grammar, Regex, Sequence, List from checkdmarc._constants import SYNTAX_ERROR_MARKER -from checkdmarc.utils import (WSP_REGEX, MAILTO_REGEX_STRING, HTTPS_REGEX, - query_dns) +from checkdmarc.utils import WSP_REGEX, MAILTO_REGEX_STRING, HTTPS_REGEX, query_dns """Copyright 2019-2023 Sean Whalen @@ -33,16 +28,17 @@ See the License for the specific language governing permissions and limitations under the License.""" -SMTPTLSREPORTING_VERSION_REGEX_STRING = (fr"v{WSP_REGEX}*=" - fr"{WSP_REGEX}*TLSRPTv1{WSP_REGEX}*;") -SMTPTLSREPORTING_URI_REGEX_STRING = fr"({MAILTO_REGEX_STRING}|{HTTPS_REGEX})" +SMTPTLSREPORTING_VERSION_REGEX_STRING = ( + rf"v{WSP_REGEX}*=" rf"{WSP_REGEX}*TLSRPTv1{WSP_REGEX}*;" +) +SMTPTLSREPORTING_URI_REGEX_STRING = rf"({MAILTO_REGEX_STRING}|{HTTPS_REGEX})" SMTPTLSREPORTING_TAG_VALUE_REGEX_STRING = ( - fr"([a-z]{{1,3}}){WSP_REGEX}*={WSP_REGEX}*" - fr"([^\s;]+)" + rf"([a-z]{{1,3}}){WSP_REGEX}*={WSP_REGEX}*" rf"([^\s;]+)" ) SMTPTLSREPORTING_TAG_VALUE_REGEX = re.compile( - SMTPTLSREPORTING_TAG_VALUE_REGEX_STRING, re.IGNORECASE) + SMTPTLSREPORTING_TAG_VALUE_REGEX_STRING, re.IGNORECASE +) SMTPTLSREPORTING_URI_REGEX = re.compile( SMTPTLSREPORTING_URI_REGEX_STRING, re.IGNORECASE @@ -55,11 +51,12 @@ class _SMTPTLSReportingWarning(Exception): class SMTPTLSReportingError(Exception): """Raised when a fatal SMTP TLS Reporting error occurs""" + def __init__(self, msg: str, data: dict = None): """ - Args: - msg (str): The error message - data (dict): A dictionary of data to include in the results + Args: + msg (str): The error message + data (dict): A dictionary of data to include in the results """ self.data = data Exception.__init__(self, msg) @@ -67,6 +64,7 @@ def __init__(self, msg: str, data: dict = None): class SMTPTLSReportingRecordNotFound(SMTPTLSReportingError): """Raised when an SMTP TLS Reporting record could not be found""" + def __init__(self, error): if isinstance(error, dns.exception.Timeout): error.kwargs["timeout"] = round(error.kwargs["timeout"], 1) @@ -90,15 +88,15 @@ class UnrelatedTXTRecordFoundAtTLSRPT(SMTPTLSReportingError): class SPFRecordFoundWhereTLSRPTShouldBe(UnrelatedTXTRecordFoundAtTLSRPT): """Raised when an SPF record is found where an SMTP TLS Reporting record - should be; - most likely, the ``_smtp._tls.SMTPTLSReporting`` subdomain - record does not actually exist, and the request for ``TXT`` records was - redirected to the base domain""" + should be; + most likely, the ``_smtp._tls.SMTPTLSReporting`` subdomain + record does not actually exist, and the request for ``TXT`` records was + redirected to the base domain""" class SMTPTLSReportingRecordInWrongLocation(SMTPTLSReportingError): """Raised when an SMTP TLS Reporting record is found at the root of a - domain""" + domain""" class MultipleSMTPTLSReportingRecords(SMTPTLSReportingError): @@ -107,36 +105,35 @@ class MultipleSMTPTLSReportingRecords(SMTPTLSReportingError): class _SMTPTLSReportingGrammar(Grammar): """Defines Pyleri grammar for SMTP TLS Reporting records""" + version_tag = Regex(SMTPTLSREPORTING_VERSION_REGEX_STRING) tag_value = Regex(SMTPTLSREPORTING_TAG_VALUE_REGEX_STRING, re.IGNORECASE) START = Sequence( - version_tag, List(tag_value, - delimiter=Regex(f"{WSP_REGEX}*;{WSP_REGEX}*"), - opt=True)) + version_tag, + List(tag_value, delimiter=Regex(f"{WSP_REGEX}*;{WSP_REGEX}*"), opt=True), + ) smtp_rpt_tags = OrderedDict( - v=OrderedDict( - name="Version", - description="Must be TLSRPTv1", - required=True - ), + v=OrderedDict(name="Version", description="Must be TLSRPTv1", required=True), rua=OrderedDict( name="Aggregate Reporting URIs", - description='A URI specifying the endpoint to which aggregate ' - 'information about policy validation results should be ' - 'sent. Two URI schemes are supported: "mailto" and ' - '"https". As with DMARC the Policy Domain can specify a ' - 'comma-separated list of URIs.', - required=False - ) + description="A URI specifying the endpoint to which aggregate " + "information about policy validation results should be " + 'sent. Two URI schemes are supported: "mailto" and ' + '"https". As with DMARC the Policy Domain can specify a ' + "comma-separated list of URIs.", + required=False, + ), ) -def query_smtp_tls_reporting_record(domain: str, - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> OrderedDict: +def query_smtp_tls_reporting_record( + domain: str, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> OrderedDict: """ Queries DNS for an SMTP TLS Reporting record @@ -168,8 +165,9 @@ def query_smtp_tls_reporting_record(domain: str, unrelated_records = [] try: - records = query_dns(target, "TXT", nameservers=nameservers, - resolver=resolver, timeout=timeout) + records = query_dns( + target, "TXT", nameservers=nameservers, resolver=resolver, timeout=timeout + ) for record in records: if record.startswith(txt_prefix): sts_record_count += 1 @@ -178,31 +176,37 @@ def query_smtp_tls_reporting_record(domain: str, if sts_record_count > 1: raise MultipleSMTPTLSReportingRecords( - "Multiple SMTP TLS Reporting records are not permitted") + "Multiple SMTP TLS Reporting records are not permitted" + ) if len(unrelated_records) > 0: ur_str = "\n\n".join(unrelated_records) raise UnrelatedTXTRecordFoundAtTLSRPT( "Unrelated TXT records were discovered. These should be " "removed, as some receivers may not expect to find " "unrelated TXT records " - f"at {target}\n\n{ur_str}") + f"at {target}\n\n{ur_str}" + ) sts_record = records[0] except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN): try: - records = query_dns(domain, "TXT", - nameservers=nameservers, resolver=resolver, - timeout=timeout) + records = query_dns( + domain, + "TXT", + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) for record in records: if record.startswith(txt_prefix): raise SMTPTLSReportingRecordInWrongLocation( "The SMTP TLS Reporting record must be located at " - f"{target}, not {domain}") + f"{target}, not {domain}" + ) except dns.resolver.NoAnswer: pass except dns.resolver.NXDOMAIN: - raise SMTPTLSReportingRecordNotFound( - f"The domain {domain} does not exist") + raise SMTPTLSReportingRecordNotFound(f"The domain {domain} does not exist") except Exception as error: raise SMTPTLSReportingRecordNotFound(error) except Exception as error: @@ -210,16 +214,17 @@ def query_smtp_tls_reporting_record(domain: str, if sts_record is None: raise SMTPTLSReportingRecordNotFound( - "An SMTP TLS Reporting DNS record does not exist for this domain") + "An SMTP TLS Reporting DNS record does not exist for this domain" + ) - return OrderedDict([("record", sts_record), - ("warnings", warnings)]) + return OrderedDict([("record", sts_record), ("warnings", warnings)]) def parse_smtp_tls_reporting_record( - record: str, - include_tag_descriptions: bool = False, - syntax_error_marker: str = SYNTAX_ERROR_MARKER) -> OrderedDict: + record: str, + include_tag_descriptions: bool = False, + syntax_error_marker: str = SYNTAX_ERROR_MARKER, +) -> OrderedDict: """ Parses an SMTP TLS Reporting record @@ -249,11 +254,13 @@ def parse_smtp_tls_reporting_record( """ logging.debug("Parsing the SMTP TLS Reporting record") - spf_in_smtp_error_msg = ("Found a SPF record where a SMTP TLS Reporting " - "record should be; most likely, the _smtp._tls " - "subdomain record does not actually exist, " - "and the request for TXT records was " - "redirected to the base domain") + spf_in_smtp_error_msg = ( + "Found a SPF record where a SMTP TLS Reporting " + "record should be; most likely, the _smtp._tls " + "subdomain record does not actually exist, " + "and the request for TXT records was " + "redirected to the base domain" + ) warnings = [] record = record.strip('"') if record.lower().startswith("v=spf1"): @@ -262,15 +269,21 @@ def parse_smtp_tls_reporting_record( parsed_record = smtp_tls_syntax_checker.parse(record) if not parsed_record.is_valid: expecting = list( - map(lambda x: str(x).strip('"'), list(parsed_record.expecting))) - marked_record = (record[:parsed_record.pos] + syntax_error_marker + - record[parsed_record.pos:]) + map(lambda x: str(x).strip('"'), list(parsed_record.expecting)) + ) + marked_record = ( + record[: parsed_record.pos] + + syntax_error_marker + + record[parsed_record.pos :] + ) expecting = " or ".join(expecting) - raise SMTPTLSReportingSyntaxError(f"Error: Expected {expecting} " - f"at position {parsed_record.pos} " - f"(marked with" - f" {syntax_error_marker}) " - f"in: {marked_record}") + raise SMTPTLSReportingSyntaxError( + f"Error: Expected {expecting} " + f"at position {parsed_record.pos} " + f"(marked with" + f" {syntax_error_marker}) " + f"in: {marked_record}" + ) pairs = SMTPTLSREPORTING_TAG_VALUE_REGEX.findall(record) tags = OrderedDict() @@ -279,27 +292,30 @@ def parse_smtp_tls_reporting_record( tag = pair[0].lower().strip() tag_value = str(pair[1].strip()) if tag not in smtp_rpt_tags: - raise InvalidSMTPTLSReportingTag(f"{tag} is not a valid SMTP TLS " - f"Reporting record tag") + raise InvalidSMTPTLSReportingTag( + f"{tag} is not a valid SMTP TLS " f"Reporting record tag" + ) tags[tag] = OrderedDict(value=tag_value) if include_tag_descriptions: tags[tag]["description"] = smtp_rpt_tags[tag]["description"] if "rua" not in tags: - SMTPTLSReportingSyntaxError("The record is missing the required rua " - "tag") + SMTPTLSReportingSyntaxError("The record is missing the required rua " "tag") tags["rua"]["value"] = tags["rua"]["value"].split(",") for uri in tags["rua"]["value"]: if len(SMTPTLSREPORTING_URI_REGEX.findall(uri)) != 1: - raise SMTPTLSReportingSyntaxError(f"{uri} is not a valid SMTP " - f"TLS reporting URI") + raise SMTPTLSReportingSyntaxError( + f"{uri} is not a valid SMTP " f"TLS reporting URI" + ) return OrderedDict(tags=tags, warnings=warnings) -def check_smtp_tls_reporting(domain: str, - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> OrderedDict: +def check_smtp_tls_reporting( + domain: str, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> OrderedDict: """ Returns a dictionary with a parsed MTA-STS policy or an error. @@ -328,12 +344,12 @@ def check_smtp_tls_reporting(domain: str, smtp_tls_reporting_results = OrderedDict([("valid", True)]) try: smtp_tls_reporting_record = query_smtp_tls_reporting_record( - domain, - nameservers=nameservers, resolver=resolver, - timeout=timeout) + domain, nameservers=nameservers, resolver=resolver, timeout=timeout + ) warnings = smtp_tls_reporting_record["warnings"] smtp_tls_reporting_record = parse_smtp_tls_reporting_record( - smtp_tls_reporting_record["record"]) + smtp_tls_reporting_record["record"] + ) warnings += smtp_tls_reporting_record["warnings"] smtp_tls_reporting_results["tags"] = smtp_tls_reporting_record["tags"] smtp_tls_reporting_results["warnings"] = warnings diff --git a/checkdmarc/spf.py b/checkdmarc/spf.py index 4fe2875..c3fb41f 100644 --- a/checkdmarc/spf.py +++ b/checkdmarc/spf.py @@ -9,16 +9,17 @@ import dns import ipaddress -from pyleri import (Grammar, - Regex, - Sequence, - Repeat - ) +from pyleri import Grammar, Regex, Sequence, Repeat from checkdmarc._constants import SYNTAX_ERROR_MARKER -from checkdmarc.utils import (query_dns, get_a_records, - get_txt_records, get_mx_records, - DNSException, DNSExceptionNXDOMAIN) +from checkdmarc.utils import ( + query_dns, + get_a_records, + get_txt_records, + get_mx_records, + DNSException, + DNSExceptionNXDOMAIN, +) """Copyright 2019-2023 Sean Whalen @@ -65,7 +66,7 @@ class _SPFWarning(Exception): class _SPFMissingRecords(_SPFWarning): """Raised when a mechanism in a ``SPF`` record is missing the requested - A/AAAA or MX records""" + A/AAAA or MX records""" class _SPFDuplicateInclude(_SPFWarning): @@ -116,24 +117,21 @@ class SPFIncludeLoop(SPFError): class _SPFGrammar(Grammar): """Defines Pyleri grammar for SPF records""" + version_tag = Regex(SPF_VERSION_TAG_REGEX_STRING) mechanism = Regex(SPF_MECHANISM_REGEX_STRING, re.IGNORECASE) START = Sequence(version_tag, Repeat(mechanism)) -spf_qualifiers = { - "": "pass", - "?": "neutral", - "+": "pass", - "-": "fail", - "~": "softfail" -} +spf_qualifiers = {"": "pass", "?": "neutral", "+": "pass", "-": "fail", "~": "softfail"} -def query_spf_record(domain: str, - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> OrderedDict: +def query_spf_record( + domain: str, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> OrderedDict: """ Queries DNS for an SPF record @@ -158,43 +156,41 @@ def query_spf_record(domain: str, spf_type_records = [] spf_txt_records = [] try: - spf_type_records += query_dns(domain, "SPF", - nameservers=nameservers, - resolver=resolver, timeout=timeout) + spf_type_records += query_dns( + domain, "SPF", nameservers=nameservers, resolver=resolver, timeout=timeout + ) except (dns.resolver.NoAnswer, Exception): pass if len(spf_type_records) > 0: - message = "SPF type DNS records found. Use of DNS Type SPF has been " \ - "removed in the standards " \ - "track version of SPF, RFC 7208. These records should " \ - "be removed and replaced with TXT records: " \ - f"{','.join(spf_type_records)}" + message = ( + "SPF type DNS records found. Use of DNS Type SPF has been " + "removed in the standards " + "track version of SPF, RFC 7208. These records should " + "be removed and replaced with TXT records: " + f"{','.join(spf_type_records)}" + ) warnings.append(message) try: - answers = query_dns(domain, "TXT", nameservers=nameservers, - resolver=resolver, timeout=timeout) + answers = query_dns( + domain, "TXT", nameservers=nameservers, resolver=resolver, timeout=timeout + ) spf_record = None for record in answers: if record.startswith(txt_prefix): spf_txt_records.append(record) if len(spf_txt_records) > 1: - raise MultipleSPFRTXTRecords( - f"{domain} has multiple SPF TXT records") + raise MultipleSPFRTXTRecords(f"{domain} has multiple SPF TXT records") elif len(spf_txt_records) == 1: spf_record = spf_txt_records[0] if spf_record is None: raise SPFRecordNotFound( - f"{domain} " - f"does not have a SPF TXT record", - domain) + f"{domain} " f"does not have a SPF TXT record", domain + ) except dns.resolver.NoAnswer: - raise SPFRecordNotFound( - f"{domain} does not have a SPF TXT record", - domain) + raise SPFRecordNotFound(f"{domain} does not have a SPF TXT record", domain) except dns.resolver.NXDOMAIN: - raise SPFRecordNotFound(f"The domain {domain} does not exist", - domain) + raise SPFRecordNotFound(f"The domain {domain} does not exist", domain) except Exception as error: raise SPFRecordNotFound(error, domain) @@ -202,13 +198,16 @@ def query_spf_record(domain: str, def parse_spf_record( - record: str, domain: str, - parked: bool = False, seen: bool = None, - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - recursion: OrderedDict = None, - timeout: float = 2.0, - syntax_error_marker: str = SYNTAX_ERROR_MARKER) -> OrderedDict: + record: str, + domain: str, + parked: bool = False, + seen: bool = None, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + recursion: OrderedDict = None, + timeout: float = 2.0, + syntax_error_marker: str = SYNTAX_ERROR_MARKER, +) -> OrderedDict: """ Parses an SPF record, including resolving ``a``, ``mx``, and ``include`` mechanisms @@ -243,14 +242,16 @@ def parse_spf_record( seen = [domain] if recursion is None: recursion = [domain] - record = record.replace('" ', '').replace('"', '') + record = record.replace('" ', "").replace('"', "") warnings = [] spf_syntax_checker = _SPFGrammar() if parked: correct_record = "v=spf1 -all" if record != correct_record: - warnings.append("The SPF record for parked domains should be: " - f"{correct_record} not: {record}") + warnings.append( + "The SPF record for parked domains should be: " + f"{correct_record} not: {record}" + ) if len(AFTER_ALL_REGEX.findall(record)) > 0: warnings.append("Any text after the all mechanism is ignored") record = AFTER_ALL_REGEX.sub(r"\1", record) @@ -258,26 +259,32 @@ def parse_spf_record( if not parsed_record.is_valid: pos = parsed_record.pos expecting = list( - map(lambda x: str(x).strip('"'), list(parsed_record.expecting))) + map(lambda x: str(x).strip('"'), list(parsed_record.expecting)) + ) expecting = " or ".join(expecting) marked_record = record[:pos] + syntax_error_marker + record[pos:] raise SPFSyntaxError( f"{domain}: Expected {expecting} at position {pos} " - f"(marked with {syntax_error_marker}) in: {marked_record}") + f"(marked with {syntax_error_marker}) in: {marked_record}" + ) matches = SPF_MECHANISM_REGEX.findall(record.lower()) - parsed = OrderedDict([("pass", []), - ("neutral", []), - ("softfail", []), - ("fail", []), - ("include", []), - ("redirect", None), - ("exp", None), - ("all", "neutral")]) + parsed = OrderedDict( + [ + ("pass", []), + ("neutral", []), + ("softfail", []), + ("fail", []), + ("include", []), + ("redirect", None), + ("exp", None), + ("all", "neutral"), + ] + ) lookup_mechanism_count = 0 void_lookup_mechanism_count = 0 for match in matches: - mechanism = match[1].lower().strip(':=') + mechanism = match[1].lower().strip(":=") if mechanism in lookup_mechanisms: lookup_mechanism_count += 1 if lookup_mechanism_count > 10: @@ -285,79 +292,92 @@ def parse_spf_record( "Parsing the SPF record requires " f"{lookup_mechanism_count}/10 maximum DNS lookups - " "https://tools.ietf.org/html/rfc7208#section-4.6.4", - dns_lookups=lookup_mechanism_count) + dns_lookups=lookup_mechanism_count, + ) for match in matches: result = spf_qualifiers[match[0]] - mechanism = match[1].strip(':=') + mechanism = match[1].strip(":=") value = match[2] try: if mechanism == "ip4": try: - if not isinstance(ipaddress.ip_network(value, - strict=False), - ipaddress.IPv4Network): - raise SPFSyntaxError(f"{value} is not a valid ipv4 " - "value. Looks like ipv6") + if not isinstance( + ipaddress.ip_network(value, strict=False), ipaddress.IPv4Network + ): + raise SPFSyntaxError( + f"{value} is not a valid ipv4 " "value. Looks like ipv6" + ) except ValueError: raise SPFSyntaxError(f"{value} is not a valid ipv4 value") elif mechanism == "ip6": try: - if not isinstance(ipaddress.ip_network(value, - strict=False), - ipaddress.IPv6Network): - raise SPFSyntaxError(f"{value} is not a valid ipv6 " - "value. Looks like ipv4") + if not isinstance( + ipaddress.ip_network(value, strict=False), ipaddress.IPv6Network + ): + raise SPFSyntaxError( + f"{value} is not a valid ipv6 " "value. Looks like ipv4" + ) except ValueError: raise SPFSyntaxError(f"{value} is not a valid ipv6 value") if mechanism == "a": if value == "": value = domain - a_records = get_a_records(value, nameservers=nameservers, - resolver=resolver, timeout=timeout) + a_records = get_a_records( + value, nameservers=nameservers, resolver=resolver, timeout=timeout + ) if len(a_records) == 0: raise _SPFMissingRecords( - f"{value.lower()} does not have any A/AAAA records") + f"{value.lower()} does not have any A/AAAA records" + ) for record in a_records: - parsed[result].append(OrderedDict( - [("value", record), ("mechanism", mechanism)])) + parsed[result].append( + OrderedDict([("value", record), ("mechanism", mechanism)]) + ) elif mechanism == "mx": if value == "": value = domain - mx_hosts = get_mx_records(value, nameservers=nameservers, - resolver=resolver, timeout=timeout) + mx_hosts = get_mx_records( + value, nameservers=nameservers, resolver=resolver, timeout=timeout + ) if len(mx_hosts) == 0: raise _SPFMissingRecords( - f"{value.lower()} does not have any MX records") + f"{value.lower()} does not have any MX records" + ) if len(mx_hosts) > 10: url = "https://tools.ietf.org/html/rfc7208#section-4.6.4" raise SPFTooManyDNSLookups( f"{value} has more than 10 MX records - {url}", - dns_lookups=len(mx_hosts)) + dns_lookups=len(mx_hosts), + ) for host in mx_hosts: hostname = host["hostname"] - parsed[result].append(OrderedDict( - [("value", hostname), - ("mechanism", mechanism)])) + parsed[result].append( + OrderedDict([("value", hostname), ("mechanism", mechanism)]) + ) elif mechanism == "redirect": if value.lower() in recursion: raise SPFRedirectLoop(f"Redirect loop: {value.lower()}") seen.append(value.lower()) try: - redirect_record = query_spf_record(value, - nameservers=nameservers, - resolver=resolver, - timeout=timeout) + redirect_record = query_spf_record( + value, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) redirect_record = redirect_record["record"] - redirect = parse_spf_record(redirect_record, value, - seen=seen, - recursion=recursion + [ - value.lower()], - nameservers=nameservers, - resolver=resolver, - timeout=timeout) + redirect = parse_spf_record( + redirect_record, + value, + seen=seen, + recursion=recursion + [value.lower()], + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) lookup_mechanism_count += redirect["dns_lookups"] void_lookup_mechanism_count += redirect["dns_void_lookups"] if lookup_mechanism_count > 10: @@ -367,7 +387,8 @@ def parse_spf_record( "DNS lookups - " "https://tools.ietf.org/html/rfc7208" "#section-4.6.4", - dns_lookups=lookup_mechanism_count) + dns_lookups=lookup_mechanism_count, + ) if void_lookup_mechanism_count > 2: u = "https://tools.ietf.org/html/rfc7208#section-4.6.4" raise SPFTooManyVoidDNSLookups( @@ -375,13 +396,18 @@ def parse_spf_record( f"{void_lookup_mechanism_count}/2 maximum void " "DNS lookups - " f"{u}", - dns_void_lookups=void_lookup_mechanism_count) + dns_void_lookups=void_lookup_mechanism_count, + ) parsed["redirect"] = OrderedDict( - [("domain", value), ("record", redirect_record), - ("dns_lookups", redirect["dns_lookups"]), - ("dns_void_lookups", redirect["dns_void_lookups"]), - ("parsed", redirect["parsed"]), - ("warnings", redirect["warnings"])]) + [ + ("domain", value), + ("record", redirect_record), + ("dns_lookups", redirect["dns_lookups"]), + ("dns_void_lookups", redirect["dns_void_lookups"]), + ("parsed", redirect["parsed"]), + ("warnings", redirect["warnings"]), + ] + ) warnings += redirect["warnings"] except DNSException as error: if isinstance(error, DNSExceptionNXDOMAIN): @@ -396,27 +422,29 @@ def parse_spf_record( pointer = " -> ".join(recursion + [value.lower()]) raise SPFIncludeLoop(f"Include loop: {pointer}") if value.lower() in seen: - raise _SPFDuplicateInclude( - f"Duplicate include: {value.lower()}") + raise _SPFDuplicateInclude(f"Duplicate include: {value.lower()}") seen.append(value.lower()) if "%{" in value: - include = OrderedDict( - [("domain", value)]) + include = OrderedDict([("domain", value)]) parsed["include"].append(include) continue try: - include_record = query_spf_record(value, - nameservers=nameservers, - resolver=resolver, - timeout=timeout) + include_record = query_spf_record( + value, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) include_record = include_record["record"] - include = parse_spf_record(include_record, value, - seen=seen, - recursion=recursion + [ - value.lower()], - nameservers=nameservers, - resolver=resolver, - timeout=timeout) + include = parse_spf_record( + include_record, + value, + seen=seen, + recursion=recursion + [value.lower()], + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) lookup_mechanism_count += include["dns_lookups"] void_lookup_mechanism_count += include["dns_void_lookups"] if lookup_mechanism_count > 10: @@ -426,7 +454,8 @@ def parse_spf_record( "DNS lookups - " "https://tools.ietf.org/html/rfc7208" "#section-4.6.4", - dns_lookups=lookup_mechanism_count) + dns_lookups=lookup_mechanism_count, + ) if void_lookup_mechanism_count > 2: u = "https://tools.ietf.org/html/rfc7208#section-4.6.4" raise SPFTooManyVoidDNSLookups( @@ -434,13 +463,18 @@ def parse_spf_record( f"{void_lookup_mechanism_count}/2 maximum void " "DNS lookups - " f"{u}", - dns_void_lookups=void_lookup_mechanism_count) + dns_void_lookups=void_lookup_mechanism_count, + ) include = OrderedDict( - [("domain", value), ("record", include_record), - ("dns_lookups", include["dns_lookups"]), - ("dns_void_lookups", include["dns_void_lookups"]), - ("parsed", include["parsed"]), - ("warnings", include["warnings"])]) + [ + ("domain", value), + ("record", include_record), + ("dns_lookups", include["dns_lookups"]), + ("dns_void_lookups", include["dns_void_lookups"]), + ("parsed", include["parsed"]), + ("warnings", include["warnings"]), + ] + ) parsed["include"].append(include) warnings += include["warnings"] @@ -453,13 +487,17 @@ def parse_spf_record( raise error elif mechanism == "ptr": parsed[result].append( - OrderedDict([("value", value), ("mechanism", mechanism)])) - raise _SPFWarning("The ptr mechanism should not be used - " - "https://tools.ietf.org/html/rfc7208" - "#section-5.5") + OrderedDict([("value", value), ("mechanism", mechanism)]) + ) + raise _SPFWarning( + "The ptr mechanism should not be used - " + "https://tools.ietf.org/html/rfc7208" + "#section-5.5" + ) else: parsed[result].append( - OrderedDict([("value", value), ("mechanism", mechanism)])) + OrderedDict([("value", value), ("mechanism", mechanism)]) + ) except (_SPFWarning, DNSException) as warning: if isinstance(warning, (_SPFMissingRecords, DNSExceptionNXDOMAIN)): @@ -470,17 +508,25 @@ def parse_spf_record( f"{void_lookup_mechanism_count}/2 maximum void DNS " "lookups - " "https://tools.ietf.org/html/rfc7208#section-4.6.4", - dns_void_lookups=void_lookup_mechanism_count) + dns_void_lookups=void_lookup_mechanism_count, + ) warnings.append(str(warning)) return OrderedDict( - [('dns_lookups', lookup_mechanism_count), - ('dns_void_lookups', void_lookup_mechanism_count), - ("parsed", parsed), ("warnings", warnings)]) - - -def get_spf_record(domain: str, nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> OrderedDict: + [ + ("dns_lookups", lookup_mechanism_count), + ("dns_void_lookups", void_lookup_mechanism_count), + ("parsed", parsed), + ("warnings", warnings), + ] + ) + + +def get_spf_record( + domain: str, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> OrderedDict: """ Retrieves and parses an SPF record @@ -502,20 +548,25 @@ def get_spf_record(domain: str, nameservers: list[str] = None, :exc:`checkdmarc.SPFTooManyDNSLookups` """ - record = query_spf_record(domain, nameservers=nameservers, - resolver=resolver, timeout=timeout) + record = query_spf_record( + domain, nameservers=nameservers, resolver=resolver, timeout=timeout + ) record = record["record"] - parsed_record = parse_spf_record(record, domain, nameservers=nameservers, - resolver=resolver, timeout=timeout) + parsed_record = parse_spf_record( + record, domain, nameservers=nameservers, resolver=resolver, timeout=timeout + ) parsed_record["record"] = record return parsed_record -def check_spf(domain: str, parked: bool = False, - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> OrderedDict: +def check_spf( + domain: str, + parked: bool = False, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> OrderedDict: """ Returns a dictionary with a parsed SPF record or an error. @@ -544,26 +595,30 @@ def check_spf(domain: str, parked: bool = False, - ``valid`` - False """ spf_results = OrderedDict( - [("record", None), ("valid", True), ("dns_lookups", None), - ("dns_void_lookups", None)]) + [ + ("record", None), + ("valid", True), + ("dns_lookups", None), + ("dns_void_lookups", None), + ] + ) try: spf_query = query_spf_record( - domain, - nameservers=nameservers, resolver=resolver, - timeout=timeout) + domain, nameservers=nameservers, resolver=resolver, timeout=timeout + ) spf_results["record"] = spf_query["record"] spf_results["warnings"] = spf_query["warnings"] - parsed_spf = parse_spf_record(spf_results["record"], - domain, - parked=parked, - nameservers=nameservers, - resolver=resolver, - timeout=timeout) - - spf_results["dns_lookups"] = parsed_spf[ - "dns_lookups"] - spf_results["dns_void_lookups"] = parsed_spf[ - "dns_void_lookups"] + parsed_spf = parse_spf_record( + spf_results["record"], + domain, + parked=parked, + nameservers=nameservers, + resolver=resolver, + timeout=timeout, + ) + + spf_results["dns_lookups"] = parsed_spf["dns_lookups"] + spf_results["dns_void_lookups"] = parsed_spf["dns_void_lookups"] spf_results["parsed"] = parsed_spf["parsed"] spf_results["warnings"] += parsed_spf["warnings"] except SPFError as error: diff --git a/checkdmarc/utils.py b/checkdmarc/utils.py index 0302ab2..eec03d0 100644 --- a/checkdmarc/utils.py +++ b/checkdmarc/utils.py @@ -29,9 +29,7 @@ DNS_CACHE = ExpiringDict(max_len=200000, max_age_seconds=1800) WSP_REGEX = r"[ \t]" -HTTPS_REGEX = ( - r"(https:\/\/)([\w\-]+\.)+[\w-]+([\w\- ,.\/?%&=]*)" -) +HTTPS_REGEX = r"(https:\/\/)([\w\-]+\.)+[\w-]+([\w\- ,.\/?%&=]*)" MAILTO_REGEX_STRING = ( r"^(mailto):([\w\-!#$%&'*+-/=?^_`{|}~]" r"[\w\-.!#$%&'*+-/=?^_`{|}~]*@[\w\-.]+)(!\w+)?" @@ -73,9 +71,14 @@ def get_base_domain(domain: str) -> str: return psl.privatesuffix(domain) or domain -def query_dns(domain: str, record_type: str, nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0, cache: ExpiringDict = None) -> list[str]: +def query_dns( + domain: str, + record_type: str, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, + cache: ExpiringDict = None, +) -> list[str]: """ Queries DNS @@ -108,12 +111,17 @@ def query_dns(domain: str, record_type: str, nameservers: list[str] = None, resolver.timeout = timeout resolver.lifetime = timeout if record_type == "TXT": - resource_records = list(map( - lambda r: r.strings, - resolver.resolve(domain, record_type, lifetime=timeout))) + resource_records = list( + map( + lambda r: r.strings, + resolver.resolve(domain, record_type, lifetime=timeout), + ) + ) _resource_record = [ resource_record[0][:0].join(resource_record) - for resource_record in resource_records if resource_record] + for resource_record in resource_records + if resource_record + ] records = [] for r in _resource_record: try: @@ -122,18 +130,24 @@ def query_dns(domain: str, record_type: str, nameservers: list[str] = None, pass records.append(r) else: - records = list(map( - lambda r: r.to_text().replace('"', '').rstrip("."), - resolver.resolve(domain, record_type, lifetime=timeout))) + records = list( + map( + lambda r: r.to_text().replace('"', "").rstrip("."), + resolver.resolve(domain, record_type, lifetime=timeout), + ) + ) if type(cache) is ExpiringDict: cache[cache_key] = records return records -def get_a_records(domain: str, nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> list[str]: +def get_a_records( + domain: str, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> list[str]: """ Queries DNS for A and AAAA records @@ -156,8 +170,9 @@ def get_a_records(domain: str, nameservers: list[str] = None, for qt in qtypes: try: logging.debug(f"Getting {qt} records for {domain}") - addresses += query_dns(domain, qt, nameservers=nameservers, - resolver=resolver, timeout=timeout) + addresses += query_dns( + domain, qt, nameservers=nameservers, resolver=resolver, timeout=timeout + ) except dns.resolver.NXDOMAIN: raise DNSExceptionNXDOMAIN(f"The domain {domain} does not exist") except dns.resolver.NoAnswer: @@ -170,9 +185,12 @@ def get_a_records(domain: str, nameservers: list[str] = None, return addresses -def get_reverse_dns(ip_address: str, nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> list[str]: +def get_reverse_dns( + ip_address: str, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> list[str]: """ Queries for an IP addresses reverse DNS hostname(s) @@ -193,8 +211,9 @@ def get_reverse_dns(ip_address: str, nameservers: list[str] = None, try: name = str(dns.reversename.from_address(ip_address)) logging.debug(f"Getting PTR records for {ip_address}") - hostnames = query_dns(name, "PTR", nameservers=nameservers, - resolver=resolver, timeout=timeout) + hostnames = query_dns( + name, "PTR", nameservers=nameservers, resolver=resolver, timeout=timeout + ) except dns.resolver.NXDOMAIN: return [] except Exception as error: @@ -203,9 +222,12 @@ def get_reverse_dns(ip_address: str, nameservers: list[str] = None, return hostnames -def get_txt_records(domain: str, nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> list[str]: +def get_txt_records( + domain: str, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> list[str]: """ Queries DNS for TXT records @@ -224,23 +246,26 @@ def get_txt_records(domain: str, nameservers: list[str] = None, """ try: - records = query_dns(domain, "TXT", nameservers=nameservers, - resolver=resolver, timeout=timeout) + records = query_dns( + domain, "TXT", nameservers=nameservers, resolver=resolver, timeout=timeout + ) except dns.resolver.NXDOMAIN: raise DNSExceptionNXDOMAIN(f"The domain {domain} does not exist") except dns.resolver.NoAnswer: - raise DNSException( - f"The domain {domain} does not have any TXT records") + raise DNSException(f"The domain {domain} does not have any TXT records") except Exception as error: raise DNSException(error) return records -def get_nameservers(domain: str, approved_nameservers: list[str] = None, - nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> dict: +def get_nameservers( + domain: str, + approved_nameservers: list[str] = None, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> dict: """ Gets a list of nameservers for a given domain @@ -263,20 +288,18 @@ def get_nameservers(domain: str, approved_nameservers: list[str] = None, ns_records = [] try: - ns_records = query_dns(domain, "NS", - nameservers=nameservers, - resolver=resolver, timeout=timeout) + ns_records = query_dns( + domain, "NS", nameservers=nameservers, resolver=resolver, timeout=timeout + ) except dns.resolver.NXDOMAIN: - raise DNSExceptionNXDOMAIN( - f"The domain {domain} does not exist") + raise DNSExceptionNXDOMAIN(f"The domain {domain} does not exist") except dns.resolver.NoAnswer: pass except Exception as error: raise DNSException(error) if approved_nameservers: - approved_nameservers = list(map(lambda h: h.lower(), - approved_nameservers)) + approved_nameservers = list(map(lambda h: h.lower(), approved_nameservers)) for nameserver in ns_records: if approved_nameservers: approved = False @@ -290,9 +313,12 @@ def get_nameservers(domain: str, approved_nameservers: list[str] = None, return OrderedDict([("hostnames", ns_records), ("warnings", warnings)]) -def get_mx_records(domain: str, nameservers: list[str] = None, - resolver: dns.resolver.Resolver = None, - timeout: float = 2.0) -> list[OrderedDict]: +def get_mx_records( + domain: str, + nameservers: list[str] = None, + resolver: dns.resolver.Resolver = None, + timeout: float = 2.0, +) -> list[OrderedDict]: """ Queries DNS for a list of Mail Exchange hosts @@ -314,21 +340,22 @@ def get_mx_records(domain: str, nameservers: list[str] = None, hosts = [] try: logging.debug(f"Checking for MX records on {domain}") - answers = query_dns(domain, "MX", nameservers=nameservers, - resolver=resolver, timeout=timeout) - if answers == ['0 ']: - logging.debug("\"No Service\" MX record found") + answers = query_dns( + domain, "MX", nameservers=nameservers, resolver=resolver, timeout=timeout + ) + if answers == ["0 "]: + logging.debug('"No Service" MX record found') return [] for record in answers: record = record.split(" ") preference = int(record[0]) hostname = record[1].rstrip(".").strip().lower() - hosts.append(OrderedDict( - [("preference", preference), ("hostname", hostname)])) + hosts.append( + OrderedDict([("preference", preference), ("hostname", hostname)]) + ) hosts = sorted(hosts, key=lambda h: (h["preference"], h["hostname"])) except dns.resolver.NXDOMAIN: - raise DNSExceptionNXDOMAIN( - f"The domain {domain} does not exist") + raise DNSExceptionNXDOMAIN(f"The domain {domain} does not exist") except dns.resolver.NoAnswer: pass except Exception as error: diff --git a/docs/source/conf.py b/docs/source/conf.py index 04a5ce7..9e7b38f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,15 +12,16 @@ # import os import sys + sys.path.insert(0, os.path.abspath(os.path.join("..", ".."))) from checkdmarc import __version__ # -- Project information ----------------------------------------------------- -project = 'checkdmarc' -copyright = '2017, Sean Whalen' -author = 'Sean Whalen' +project = "checkdmarc" +copyright = "2017, Sean Whalen" +author = "Sean Whalen" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -36,13 +37,15 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.todo', - 'sphinx.ext.viewcode', - 'sphinx.ext.githubpages', - 'sphinx.ext.napoleon', - 'myst_parser'] +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.doctest", + "sphinx.ext.todo", + "sphinx.ext.viewcode", + "sphinx.ext.githubpages", + "sphinx.ext.napoleon", + "myst_parser", +] myst_enable_extensions = [ "amsmath", @@ -64,7 +67,7 @@ autoclass_content = "init" # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffixes of source filenames. @@ -81,9 +84,9 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] diff --git a/requirements.txt b/requirements.txt index 707cbd4..0460601 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,4 @@ autopep8 hatch pytest ruff +black diff --git a/tests.py b/tests.py index 7064657..5008f73 100755 --- a/tests.py +++ b/tests.py @@ -12,10 +12,7 @@ import checkdmarc.dmarc import checkdmarc.dnssec -known_good_domains = [ - "fbi.gov", - "pm.me" -] +known_good_domains = ["fbi.gov", "pm.me"] class Test(unittest.TestCase): @@ -29,20 +26,27 @@ def testKnownGood(self): dmarc_error = None for mx in result["mx"]["hosts"]: self.assertEqual( - mx["starttls"], True, + mx["starttls"], + True, "Host of known good domain {0} failed STARTTLS check: {1}" - "\n\n{0}".format(result["domain"], mx["hostname"]) + "\n\n{0}".format(result["domain"], mx["hostname"]), ) if "error" in result["spf"]: spf_error = result["spf"]["error"] if "error" in result["dmarc"]: dmarc_error = result["dmarc"]["error"] - self.assertEqual(result["spf"]["valid"], True, - "Known good domain {0} failed SPF check:" - "\n\n{1}".format(result["domain"], spf_error)) - self.assertEqual(result["dmarc"]["valid"], True, - "Known good domain {0} failed DMARC check:" - "\n\n{1}".format(result["domain"], dmarc_error)) + self.assertEqual( + result["spf"]["valid"], + True, + "Known good domain {0} failed SPF check:" + "\n\n{1}".format(result["domain"], spf_error), + ) + self.assertEqual( + result["dmarc"]["valid"], + True, + "Known good domain {0} failed DMARC check:" + "\n\n{1}".format(result["domain"], dmarc_error), + ) def testDMARCMixedFormatting(self): """DMARC records with extra spaces and mixed case are still valid""" @@ -51,7 +55,7 @@ def testDMARCMixedFormatting(self): "v = DMARC1;p=reject;", "v = DMARC1\t;\tp=reject\t;", "v = DMARC1\t;\tp\t\t\t=\t\t\treject\t;", - "V=DMARC1;p=reject;" + "V=DMARC1;p=reject;", ] for example in examples: @@ -120,101 +124,154 @@ def testIncludeMissingSPF(self): """SPF records that include domains that are missing SPF records raise SPFRecordNotFound""" - spf_record = '"v=spf1 include:spf.comendosystems.com ' \ - 'include:bounce.peytz.dk include:etrack.indicia.dk ' \ - 'include:etrack1.com include:mail1.dialogportal.com ' \ - 'include:mail2.dialogportal.com a:mailrelay.jppol.dk ' \ - 'a:sendmail.jppol.dk ?all"' + spf_record = ( + '"v=spf1 include:spf.comendosystems.com ' + "include:bounce.peytz.dk include:etrack.indicia.dk " + "include:etrack1.com include:mail1.dialogportal.com " + "include:mail2.dialogportal.com a:mailrelay.jppol.dk " + 'a:sendmail.jppol.dk ?all"' + ) domain = "ekstrabladet.dk" - self.assertRaises(checkdmarc.spf.SPFRecordNotFound, - checkdmarc.spf.parse_spf_record, spf_record, domain) + self.assertRaises( + checkdmarc.spf.SPFRecordNotFound, + checkdmarc.spf.parse_spf_record, + spf_record, + domain, + ) def testTooManySPFDNSLookups(self): """SPF records with > 10 SPF mechanisms that cause DNS lookups raise SPFTooManyDNSLookups""" - spf_record = "v=spf1 a include:_spf.salesforce.com " \ - "include:spf.protection.outlook.com " \ - "include:spf.constantcontact.com " \ - "include:_spf.elasticemail.com " \ - "include:servers.mcsv.net " \ - "include:_spf.google.com " \ - "~all" + spf_record = ( + "v=spf1 a include:_spf.salesforce.com " + "include:spf.protection.outlook.com " + "include:spf.constantcontact.com " + "include:_spf.elasticemail.com " + "include:servers.mcsv.net " + "include:_spf.google.com " + "~all" + ) domain = "example.com" - self.assertRaises(checkdmarc.spf.SPFTooManyDNSLookups, - checkdmarc.spf.parse_spf_record, spf_record, domain) + self.assertRaises( + checkdmarc.spf.SPFTooManyDNSLookups, + checkdmarc.spf.parse_spf_record, + spf_record, + domain, + ) def testTooManySPFVoidDNSLookups(self): """SPF records with > 2 void DNS lookups""" - spf_record = "v=spf1 a:13Mk4olS9VWhQqXRl90fKJrD.example.com " \ - "mx:SfGiqBnQfRbOMapQJhozxo2B.example.com " \ - "a:VAFeyU9N2KJX518aGsN3w6VS.example.com " \ - "~all" + spf_record = ( + "v=spf1 a:13Mk4olS9VWhQqXRl90fKJrD.example.com " + "mx:SfGiqBnQfRbOMapQJhozxo2B.example.com " + "a:VAFeyU9N2KJX518aGsN3w6VS.example.com " + "~all" + ) domain = "example.com" - self.assertRaises(checkdmarc.spf.SPFTooManyVoidDNSLookups, - checkdmarc.spf.parse_spf_record, spf_record, domain) + self.assertRaises( + checkdmarc.spf.SPFTooManyVoidDNSLookups, + checkdmarc.spf.parse_spf_record, + spf_record, + domain, + ) def testSPFSyntaxErrors(self): """SPF record syntax errors raise SPFSyntaxError""" - spf_record = '"v=spf1 mx a:mail.cohaesio.net ' \ - 'include: trustpilotservice.com ~all"' + spf_record = ( + '"v=spf1 mx a:mail.cohaesio.net ' 'include: trustpilotservice.com ~all"' + ) domain = "2021.ai" - self.assertRaises(checkdmarc.spf.SPFSyntaxError, - checkdmarc.spf.parse_spf_record, spf_record, domain) + self.assertRaises( + checkdmarc.spf.SPFSyntaxError, + checkdmarc.spf.parse_spf_record, + spf_record, + domain, + ) def testSPFInvalidIPv4(self): """Invalid ipv4 SPF mechanism values raise SPFSyntaxError""" - spf_record = "v=spf1 ip4:78.46.96.236 +a +mx +ip4:138.201.239.158 " \ - "+ip4:78.46.224.83 " \ - "+ip4:relay.mailchannels.net +ip4:138.201.60.20 ~all" + spf_record = ( + "v=spf1 ip4:78.46.96.236 +a +mx +ip4:138.201.239.158 " + "+ip4:78.46.224.83 " + "+ip4:relay.mailchannels.net +ip4:138.201.60.20 ~all" + ) domain = "surftown.dk" - self.assertRaises(checkdmarc.spf.SPFSyntaxError, - checkdmarc.spf.parse_spf_record, spf_record, domain) + self.assertRaises( + checkdmarc.spf.SPFSyntaxError, + checkdmarc.spf.parse_spf_record, + spf_record, + domain, + ) def testSPFInvalidIPv6inIPv4(self): """Invalid ipv4 SPF mechanism values raise SPFSyntaxError""" spf_record = "v=spf1 ip4:1200:0000:AB00:1234:0000:2552:7777:1313 ~all" domain = "surftown.dk" - self.assertRaises(checkdmarc.spf.SPFSyntaxError, - checkdmarc.spf.parse_spf_record, spf_record, domain) + self.assertRaises( + checkdmarc.spf.SPFSyntaxError, + checkdmarc.spf.parse_spf_record, + spf_record, + domain, + ) def testSPFInvalidIPv4Range(self): """Invalid ipv4 SPF mechanism values raise SPFSyntaxError""" spf_record = "v=spf1 ip4:78.46.96.236/99 ~all" domain = "surftown.dk" - self.assertRaises(checkdmarc.spf.SPFSyntaxError, - checkdmarc.spf.parse_spf_record, spf_record, domain) + self.assertRaises( + checkdmarc.spf.SPFSyntaxError, + checkdmarc.spf.parse_spf_record, + spf_record, + domain, + ) def testSPFInvalidIPv6(self): """Invalid ipv6 SPF mechanism values raise SPFSyntaxError""" spf_record = "v=spf1 ip6:1200:0000:AB00:1234:O000:2552:7777:1313 ~all" domain = "surftown.dk" - self.assertRaises(checkdmarc.spf.SPFSyntaxError, - checkdmarc.spf.parse_spf_record, spf_record, domain) + self.assertRaises( + checkdmarc.spf.SPFSyntaxError, + checkdmarc.spf.parse_spf_record, + spf_record, + domain, + ) def testSPFInvalidIPv4inIPv6(self): """Invalid ipv6 SPF mechanism values raise SPFSyntaxError""" spf_record = "v=spf1 ip6:78.46.96.236 ~all" domain = "surftown.dk" - self.assertRaises(checkdmarc.spf.SPFSyntaxError, - checkdmarc.spf.parse_spf_record, spf_record, domain) + self.assertRaises( + checkdmarc.spf.SPFSyntaxError, + checkdmarc.spf.parse_spf_record, + spf_record, + domain, + ) def testSPFInvalidIPv6Range(self): """Invalid ipv6 SPF mechanism values raise SPFSyntaxError""" record = "v=spf1 ip6:1200:0000:AB00:1234:0000:2552:7777:1313/130 ~all" domain = "surftown.dk" - self.assertRaises(checkdmarc.spf.SPFSyntaxError, - checkdmarc.spf.parse_spf_record, record, domain) + self.assertRaises( + checkdmarc.spf.SPFSyntaxError, + checkdmarc.spf.parse_spf_record, + record, + domain, + ) def testSPFIncludeLoop(self): """SPF record with include loop raises SPFIncludeLoop""" spf_record = '"v=spf1 include:example.com"' domain = "example.com" - self.assertRaises(checkdmarc.spf.SPFIncludeLoop, - checkdmarc.spf.parse_spf_record, spf_record, domain) + self.assertRaises( + checkdmarc.spf.SPFIncludeLoop, + checkdmarc.spf.parse_spf_record, + spf_record, + domain, + ) def testSPFMissingMXRecord(self): """A warning is issued if an SPF record contains a mx mechanism @@ -223,8 +280,9 @@ def testSPFMissingMXRecord(self): spf_record = '"v=spf1 mx ~all"' domain = "seanthegeek.net" results = checkdmarc.spf.parse_spf_record(spf_record, domain) - self.assertIn("{0} does not have any MX records".format(domain), - results["warnings"]) + self.assertIn( + "{0} does not have any MX records".format(domain), results["warnings"] + ) def testSPFMissingARecord(self): """A warning is issued if an SPF record contains a mx mechanism @@ -233,46 +291,60 @@ def testSPFMissingARecord(self): spf_record = '"v=spf1 a ~all"' domain = "cardinalhealth.net" results = checkdmarc.spf.parse_spf_record(spf_record, domain) - self.assertIn("cardinalhealth.net does not have any A/AAAA records", - results["warnings"]) + self.assertIn( + "cardinalhealth.net does not have any A/AAAA records", results["warnings"] + ) def testDMARCPctLessThan100Warning(self): """A warning is issued if the DMARC pvt value is less than 100""" - dmarc_record = "v=DMARC1; p=none; sp=none; fo=1; pct=50; adkim=r; " \ - "aspf=r; rf=afrf; ri=86400; " \ - "rua=mailto:eits.dmarcrua@energy.gov; " \ - "ruf=mailto:eits.dmarcruf@energy.gov" + dmarc_record = ( + "v=DMARC1; p=none; sp=none; fo=1; pct=50; adkim=r; " + "aspf=r; rf=afrf; ri=86400; " + "rua=mailto:eits.dmarcrua@energy.gov; " + "ruf=mailto:eits.dmarcruf@energy.gov" + ) domain = "energy.gov" results = checkdmarc.dmarc.parse_dmarc_record(dmarc_record, domain) - self.assertIn("pct value is less than 100", - results["warnings"][0]) + self.assertIn("pct value is less than 100", results["warnings"][0]) def testInvalidDMARCURI(self): """An invalid DMARC report URI raises InvalidDMARCReportURI""" - dmarc_record = "v=DMARC1; p=none; rua=reports@dmarc.cyber.dhs.gov," \ - "mailto:dmarcreports@usdoj.gov" + dmarc_record = ( + "v=DMARC1; p=none; rua=reports@dmarc.cyber.dhs.gov," + "mailto:dmarcreports@usdoj.gov" + ) domain = "dea.gov" - self.assertRaises(checkdmarc.dmarc.InvalidDMARCReportURI, - checkdmarc.dmarc.parse_dmarc_record, dmarc_record, - domain) - - dmarc_record = "v=DMARC1; p=none; rua=__" \ - "mailto:reports@dmarc.cyber.dhs.gov," \ - "mailto:dmarcreports@usdoj.gov" - self.assertRaises(checkdmarc.dmarc.InvalidDMARCReportURI, - checkdmarc.dmarc.parse_dmarc_record, dmarc_record, - domain) + self.assertRaises( + checkdmarc.dmarc.InvalidDMARCReportURI, + checkdmarc.dmarc.parse_dmarc_record, + dmarc_record, + domain, + ) + + dmarc_record = ( + "v=DMARC1; p=none; rua=__" + "mailto:reports@dmarc.cyber.dhs.gov," + "mailto:dmarcreports@usdoj.gov" + ) + self.assertRaises( + checkdmarc.dmarc.InvalidDMARCReportURI, + checkdmarc.dmarc.parse_dmarc_record, + dmarc_record, + domain, + ) def testInvalidDMARCPolicyValue(self): - """An invalid DMARC policy value raises InvalidDMARCTagValue """ + """An invalid DMARC policy value raises InvalidDMARCTagValue""" dmarc_record = "v=DMARC1; p=foo; rua=mailto:dmarc@example.com" domain = "example.com" - self.assertRaises(checkdmarc.dmarc.InvalidDMARCTagValue, - checkdmarc.dmarc.parse_dmarc_record, - dmarc_record, - domain) + self.assertRaises( + checkdmarc.dmarc.InvalidDMARCTagValue, + checkdmarc.dmarc.parse_dmarc_record, + dmarc_record, + domain, + ) if __name__ == "__main__":