Skip to content

Commit

Permalink
Reformat code using black
Browse files Browse the repository at this point in the history
  • Loading branch information
seanthegeek committed Oct 26, 2024
1 parent 53f2fcd commit f163a9f
Show file tree
Hide file tree
Showing 14 changed files with 1,719 additions and 1,296 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/python-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ jobs:
make html
- name: Check code style
run: |
flake8 checkdmarc
flake8 tests.py
black --check .
- name: Run unit tests
run: |
coverage run tests.py
Expand Down
188 changes: 114 additions & 74 deletions checkdmarc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,19 @@
__version__ = checkdmarc._constants.__version__


def check_domains(domains: list[str], parked: bool = False,
approved_nameservers: list[str] = None,
approved_mx_hostnames: bool = None,
skip_tls: bool = False,
bimi_selector: str = None,
include_tag_descriptions: bool = False,
nameservers: list[str] = None,
resolver: dns.resolver.Resolver = None,
timeout: float = 2.0,
wait: float = 0.0) -> Union[OrderedDict, list[OrderedDict]]:
def check_domains(
domains: list[str],
parked: bool = False,
approved_nameservers: list[str] = None,
approved_mx_hostnames: bool = None,
skip_tls: bool = False,
bimi_selector: str = None,
include_tag_descriptions: bool = False,
nameservers: list[str] = None,
resolver: dns.resolver.Resolver = None,
timeout: float = 2.0,
wait: float = 0.0,
) -> Union[OrderedDict, list[OrderedDict]]:
"""
Check the given domains for SPF and DMARC records, parse them, and return
them
Expand Down Expand Up @@ -82,9 +85,11 @@ def check_domains(domains: list[str], parked: bool = False,
- ``dmarc`` - A ``valid`` flag, plus the output of
:func:`checkdmarc.dmarc.parse_dmarc_record` or an ``error``
"""
domains = sorted(list(set(
map(lambda d: d.rstrip(".\r\n").strip().lower().split(",")[0],
domains))))
domains = sorted(
list(
set(map(lambda d: d.rstrip(".\r\n").strip().lower().split(",")[0], domains))
)
)
not_domains = []
for domain in domains:
if "." not in domain:
Expand All @@ -99,27 +104,31 @@ def check_domains(domains: list[str], parked: bool = False,
logging.debug(f"Checking: {domain}")

domain_results = OrderedDict(
[("domain", domain), ("base_domain", get_base_domain(domain)),
("dnssec", None), ("ns", []), ("mx", [])])
[
("domain", domain),
("base_domain", get_base_domain(domain)),
("dnssec", None),
("ns", []),
("mx", []),
]
)

domain_results["dnssec"] = test_dnssec(
domain,
nameservers=nameservers,
timeout=timeout
)
domain, nameservers=nameservers, timeout=timeout
)

domain_results["ns"] = check_ns(
domain,
approved_nameservers=approved_nameservers,
nameservers=nameservers,
resolver=resolver, timeout=timeout
)
resolver=resolver,
timeout=timeout,
)

mta_sts_mx_patterns = None
domain_results["mta_sts"] = check_mta_sts(domain,
nameservers=nameservers,
resolver=resolver,
timeout=timeout)
domain_results["mta_sts"] = check_mta_sts(
domain, nameservers=nameservers, resolver=resolver, timeout=timeout
)
if domain_results["mta_sts"]["valid"]:
mta_sts_mx_patterns = domain_results["mta_sts"]["policy"]["mx"]
domain_results["mx"] = check_mx(
Expand All @@ -129,31 +138,28 @@ def check_domains(domains: list[str], parked: bool = False,
skip_tls=skip_tls,
nameservers=nameservers,
resolver=resolver,
timeout=timeout
)
timeout=timeout,
)

domain_results["spf"] = check_spf(
domain,
parked=parked,
nameservers=nameservers,
resolver=resolver,
timeout=timeout
)
timeout=timeout,
)

domain_results["dmarc"] = check_dmarc(
domain,
parked=parked,
include_dmarc_tag_descriptions=include_tag_descriptions,
nameservers=nameservers,
resolver=resolver,
timeout=timeout
)
timeout=timeout,
)

domain_results["smtp_tls_reporting"] = check_smtp_tls_reporting(
domain,
nameservers=nameservers,
resolver=resolver,
timeout=timeout
domain, nameservers=nameservers, resolver=resolver, timeout=timeout
)

if bimi_selector is not None:
Expand All @@ -163,7 +169,8 @@ def check_domains(domains: list[str], parked: bool = False,
include_tag_descriptions=include_tag_descriptions,
nameservers=nameservers,
resolver=resolver,
timeout=timeout)
timeout=timeout,
)

results.append(domain_results)
if wait > 0.0:
Expand All @@ -175,11 +182,13 @@ def check_domains(domains: list[str], parked: bool = False,
return results


def check_ns(domain: str,
approved_nameservers: list[str] = None,
nameservers: list[str] = None,
resolver: dns.resolver.Resolver = None,
timeout: float = 2.0) -> OrderedDict:
def check_ns(
domain: str,
approved_nameservers: list[str] = None,
nameservers: list[str] = None,
resolver: dns.resolver.Resolver = None,
timeout: float = 2.0,
) -> OrderedDict:
"""
Returns a dictionary of nameservers and warnings or a dictionary with an
empty list and an error.
Expand Down Expand Up @@ -207,11 +216,12 @@ def check_ns(domain: str,
ns_results = get_nameservers(
domain,
approved_nameservers=approved_nameservers,
nameservers=nameservers, resolver=resolver,
timeout=timeout)
nameservers=nameservers,
resolver=resolver,
timeout=timeout,
)
except DNSException as error:
ns_results = OrderedDict([("hostnames", []),
("error", error.__str__())])
ns_results = OrderedDict([("hostnames", []), ("error", error.__str__())])
return ns_results


Expand Down Expand Up @@ -278,8 +288,9 @@ def results_to_csv_rows(results: Union[dict, list[dict]]) -> list[dict]:
row["bimi_l"] = _bimi["tags"]["l"]["value"]
if "a" in _bimi["tags"]:
row["bimi_a"] = _bimi["tags"]["a"]["value"]
row["mx"] = "|".join(list(
map(lambda r: f"{r['preference']}, {r['hostname']}", mx["hosts"])))
row["mx"] = "|".join(
list(map(lambda r: f"{r['preference']}, {r['hostname']}", mx["hosts"]))
)
tls = None
try:
tls_results = list(map(lambda r: f"{r['starttls']}", mx["hosts"]))
Expand All @@ -296,8 +307,7 @@ def results_to_csv_rows(results: Union[dict, list[dict]]) -> list[dict]:

starttls = None
try:
starttls_results = list(
map(lambda r: f"{r['starttls']}", mx["hosts"]))
starttls_results = list(map(lambda r: f"{r['starttls']}", mx["hosts"]))
for starttls_result in starttls_results:
starttls = starttls_result
if starttls_result is False:
Expand Down Expand Up @@ -335,27 +345,26 @@ def results_to_csv_rows(results: Union[dict, list[dict]]) -> list[dict]:
row["dmarc_sp"] = _dmarc["tags"]["sp"]["value"]
if "rua" in _dmarc["tags"]:
addresses = _dmarc["tags"]["rua"]["value"]
addresses = list(map(lambda u: "{}:{}".format(
u["scheme"],
u["address"]), addresses))
addresses = list(
map(lambda u: "{}:{}".format(u["scheme"], u["address"]), addresses)
)
row["dmarc_rua"] = "|".join(addresses)
if "ruf" in _dmarc["tags"]:
addresses = _dmarc["tags"]["ruf"]["value"]
addresses = list(map(lambda u: "{}:{}".format(
u["scheme"],
u["address"]), addresses))
addresses = list(
map(lambda u: "{}:{}".format(u["scheme"], u["address"]), addresses)
)
row["dmarc_ruf"] = "|".join(addresses)
row["dmarc_warnings"] = "|".join(_dmarc["warnings"])
if "error" in _smtp_tls_reporting:
row["smtp_tls_reporting_valid"] = False
row["smtp_tls_reporting_error"] = _smtp_tls_reporting["error"]
else:
row["smtp_tls_reporting_valid"] = True
row["smtp_tls_reporting_rua"] = "|".join(_smtp_tls_reporting[
"tags"]["rua"][
"value"])
row["smtp_tls_reporting_warnings"] = _smtp_tls_reporting[
"warnings"]
row["smtp_tls_reporting_rua"] = "|".join(
_smtp_tls_reporting["tags"]["rua"]["value"]
)
row["smtp_tls_reporting_warnings"] = _smtp_tls_reporting["warnings"]
rows.append(row)
return rows

Expand All @@ -370,18 +379,48 @@ def results_to_csv(results: dict) -> str:
Returns:
str: A CSV of results
"""
fields = ["domain", "base_domain", "dnssec", "spf_valid", "dmarc_valid",
"dmarc_adkim", "dmarc_aspf",
"dmarc_fo", "dmarc_p", "dmarc_pct", "dmarc_rf", "dmarc_ri",
"dmarc_rua", "dmarc_ruf", "dmarc_sp",
"tls", "starttls", "spf_record", "dmarc_record",
"dmarc_record_location", "mx", "mx_error", "mx_warnings",
"mta_sts_id", "mta_sts_mode", "mta_sts_max_age",
"smtp_tls_reporting_valid", "smtp_tls_reporting_rua",
"mta_sts_mx", "mta_sts_error", "mta_sts_warnings", "spf_error",
"spf_warnings", "dmarc_error", "dmarc_warnings",
"ns", "ns_error", "ns_warnings",
"smtp_tls_reporting_error", "smtp_tls_reporting_warnings"]
fields = [
"domain",
"base_domain",
"dnssec",
"spf_valid",
"dmarc_valid",
"dmarc_adkim",
"dmarc_aspf",
"dmarc_fo",
"dmarc_p",
"dmarc_pct",
"dmarc_rf",
"dmarc_ri",
"dmarc_rua",
"dmarc_ruf",
"dmarc_sp",
"tls",
"starttls",
"spf_record",
"dmarc_record",
"dmarc_record_location",
"mx",
"mx_error",
"mx_warnings",
"mta_sts_id",
"mta_sts_mode",
"mta_sts_max_age",
"smtp_tls_reporting_valid",
"smtp_tls_reporting_rua",
"mta_sts_mx",
"mta_sts_error",
"mta_sts_warnings",
"spf_error",
"spf_warnings",
"dmarc_error",
"dmarc_warnings",
"ns",
"ns_error",
"ns_warnings",
"smtp_tls_reporting_error",
"smtp_tls_reporting_warnings",
]
output = StringIO(newline="\n")
writer = DictWriter(output, fieldnames=fields)
writer.writeheader()
Expand All @@ -400,6 +439,7 @@ def output_to_file(path: str, content: str):
path (str): A file path
content (str): JSON or CSV text
"""
with open(path, "w", newline="\n", encoding="utf-8",
errors="ignore") as output_file:
with open(
path, "w", newline="\n", encoding="utf-8", errors="ignore"
) as output_file:
output_file.write(content)
Loading

0 comments on commit f163a9f

Please sign in to comment.