Skip to content

Commit

Permalink
fix: make EPSS behave like other data sources
Browse files Browse the repository at this point in the history
This will make it so that `-d EPSS` will actually disable the EPSS data
source, and should make it fail more gracefully when the source is not
working for any reason.

Note that the EPSS source may not be working correctly even when not
disabled; I'll file a separate issue.

* fixes intel#4083

Signed-off-by: Terri Oda <[email protected]>
  • Loading branch information
terriko committed May 21, 2024
1 parent d24deaa commit 931f09b
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 34 deletions.
5 changes: 5 additions & 0 deletions cve_bin_tool/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from cve_bin_tool.data_sources import (
DataSourceSupport,
curl_source,
epss_source,
gad_source,
nvd_source,
osv_source,
Expand Down Expand Up @@ -717,6 +718,10 @@ def main(argv=None):
source_curl = curl_source.Curl_Source()
enabled_sources.append(source_curl)

if "EPSS" not in disabled_sources:
source_epss = epss_source.Epss_Source()
enabled_sources.append(source_epss)

if "NVD" not in disabled_sources:
source_nvd = nvd_source.NVD_Source(
nvd_type=nvd_type,
Expand Down
51 changes: 24 additions & 27 deletions cve_bin_tool/cvedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class CVEDB:
LOGGER = LOGGER.getChild("CVEDB")
SOURCES = [
curl_source.Curl_Source,
epss_source.Epss_Source,
osv_source.OSV_Source,
gad_source.GAD_Source,
nvd_source.NVD_Source, # last to avoid data overwrites
Expand Down Expand Up @@ -477,11 +478,8 @@ def populate_db(self) -> None:
"""

self.populate_metrics()

# EPSS uses metrics table to get the EPSS metric id.
# It can't be run before creation of metrics table.
self.populate_epss()
self.store_epss_data()

for idx, data in enumerate(self.data):
_, source_name = data
Expand All @@ -494,22 +492,29 @@ def populate_db(self) -> None:
# if source_name != "NVD" and cve_data[0] is not None:
# cve_data = self.update_vendors(cve_data)

severity_data, affected_data = cve_data
if source_name == "EPSS":
if cve_data is not None:
self.store_epss_data(cve_data)

cursor = self.db_open_and_get_cursor()
else:
severity_data, affected_data = cve_data

if severity_data is not None and len(severity_data) > 0:
self.populate_severity(severity_data, cursor, data_source=source_name)
self.populate_cve_metrics(severity_data, cursor)
if affected_data is not None:
self.populate_affected(
affected_data,
cursor,
data_source=source_name,
)
if self.connection is not None:
self.connection.commit()
self.db_close()
cursor = self.db_open_and_get_cursor()

if severity_data is not None and len(severity_data) > 0:
self.populate_severity(
severity_data, cursor, data_source=source_name
)
self.populate_cve_metrics(severity_data, cursor)
if affected_data is not None:
self.populate_affected(
affected_data,
cursor,
data_source=source_name,
)
if self.connection is not None:
self.connection.commit()
self.db_close()

def populate_severity(self, severity_data, cursor, data_source):
"""Populate the database with CVE severities."""
Expand Down Expand Up @@ -627,14 +632,6 @@ def populate_metrics(self):
self.connection.commit()
self.db_close()

def populate_epss(self):
"""Exploit Prediction Scoring System (EPSS) data to help users evaluate risks
Add EPSS data into the database"""
epss = epss_source.Epss_Source()
cursor = self.db_open_and_get_cursor()
self.epss_data = run_coroutine(epss.update_epss(cursor))
self.db_close()

def metric_finder(self, cursor, cve):
"""
SQL query to retrieve the metrics_name based on the metrics_id
Expand Down Expand Up @@ -863,11 +860,11 @@ def populate_exploit_db(self, exploits):
self.connection.commit()
self.db_close()

def store_epss_data(self):
def store_epss_data(self, epss_data):
"""Insert Exploit Prediction Scoring System (EPSS) data into database."""
insert_cve_metrics = self.INSERT_QUERIES["insert_cve_metrics"]
cursor = self.db_open_and_get_cursor()
cursor.executemany(insert_cve_metrics, self.epss_data)
cursor.executemany(insert_cve_metrics, epss_data)
self.connection.commit()
self.db_close()

Expand Down
26 changes: 20 additions & 6 deletions cve_bin_tool/data_sources/epss_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class Epss_Source:
SOURCE = "Epss"
SOURCE = "EPSS"
CACHEDIR = DISK_LOCATION_DEFAULT
BACKUPCACHEDIR = DISK_LOCATION_BACKUP
LOGGER = logging.getLogger().getChild("CVEDB")
Expand All @@ -43,14 +43,12 @@ async def update_epss(self, cursor):
- EPSS score
- EPSS percentile
"""
self.EPSS_id_finder(cursor)
await self.download_and_parse_epss()
return self.epss_data
self.LOGGER.debug("Fetching EPSS data...")

async def download_and_parse_epss(self):
"""Downloads and parses the EPSS data from the CSV file."""
self.EPSS_id_finder(cursor)
await self.download_epss_data()
self.epss_data = self.parse_epss_data()
return self.epss_data

async def download_epss_data(self):
"""Downloads the EPSS CSV file and saves it to the local filesystem.
Expand Down Expand Up @@ -134,3 +132,19 @@ def parse_epss_data(self, file_path=None):
(cve_id, self.epss_metric_id, epss_score, epss_percentile)
)
return parsed_data

async def get_cve_data(self):
"""Gets EPSS data.
This function is so that the epss source matches the others api-wise to make for
easier disabling/enabling.
returns (data, "EPSS") so that the source can be identified for storing data
"""

try:
await self.update_epss()
except Exception as e:
self.LOGGER.debug(f"Error while fetching EPSS data: {e}")
self.LOGGER.error("Unable to fetch EPSS, skipping EPSS.")

return self.epss_data, self.SOURCE
2 changes: 1 addition & 1 deletion cve_bin_tool/data_sources/nvd_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def format_data_api2(self, all_cve_entries):
cvss_data = cve_cvss["cvssMetricV2"][0]["cvssData"]
cve["CVSS_version"] = 2
else:
LOGGER.info(f"Unknown CVSS metrics field {cve_item['id']}")
LOGGER.debug(f"Unknown CVSS metrics field {cve_item['id']}")
cvss_available = False
if cvss_available:
cve["severity"] = cvss_data.get("baseSeverity", "UNKNOWN")
Expand Down

0 comments on commit 931f09b

Please sign in to comment.