Skip to content

Commit

Permalink
Add require_certificate_validation Behavior Flag
Browse files Browse the repository at this point in the history
  • Loading branch information
damian3031 committed Nov 21, 2024
1 parent de67814 commit 6f12e46
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Under the Hood-20241120-191809.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Under the Hood
body: Revert cert default to False. Add require_certificate_validation Behavior Flag
time: 2024-11-20T19:18:09.725288+01:00
custom:
Author: damian3031
Issue: ""
PR: "447"
13 changes: 13 additions & 0 deletions dbt/adapters/trino/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,12 @@ class TrinoAdapterResponse(AdapterResponse):

class TrinoConnectionManager(SQLConnectionManager):
TYPE = "trino"
behavior_flags = None

def __init__(self, profile, mp_context, behavior_flags=None) -> None:
super().__init__(profile, mp_context)

TrinoConnectionManager.behavior_flags = behavior_flags

@contextmanager
def exception_handler(self, sql):
Expand Down Expand Up @@ -465,6 +471,13 @@ def open(cls, connection):

credentials = connection.credentials

# set default `cert` value, according to
# require_certificate_validation behavior flag
if credentials.cert is None:
req_cert_val_flag = cls.behavior_flags.require_certificate_validation.setting
if req_cert_val_flag:
credentials.cert = True

# it's impossible for trino to fail here as 'connections' are actually
# just cursor factories.
trino_conn = trino.dbapi.connect(
Expand Down
21 changes: 21 additions & 0 deletions dbt/adapters/trino/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Support,
)
from dbt.adapters.sql import SQLAdapter
from dbt_common.behavior_flags import BehaviorFlag
from dbt_common.contracts.constraints import ConstraintType
from dbt_common.exceptions import DbtDatabaseError

Expand Down Expand Up @@ -47,6 +48,26 @@ class TrinoAdapter(SQLAdapter):
}
)

def __init__(self, config, mp_context) -> None:
super().__init__(config, mp_context)
self.connections = self.ConnectionManager(config, mp_context, self.behavior)

@property
def _behavior_flags(self) -> list[BehaviorFlag]:
return [
{ # type: ignore
"name": "require_certificate_validation",
"default": False,
"description": (
"SSL certificate validation is disabled by default. "
"It is legacy behavior which will be changed in future releases. "
"It is strongly advised to enable `require_certificate_validation` flag "
"or explicitly set `cert` configuration to `True` for security reasons. "
"You may receive an error after that if your SSL setup is incorrect."
),
}
]

@classmethod
def date_function(cls):
return "datenow()"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import warnings

import pytest
from dbt.tests.util import run_dbt, run_dbt_and_capture
from urllib3.exceptions import InsecureRequestWarning


class TestRequireCertificateValidationDefault:
@pytest.fixture(scope="class")
def project_config_update(self):
return {"flags": {}}

def test_cert_default_value(self, project):
assert project.adapter.connections.profile.credentials.cert is None

def test_require_certificate_validation_logs(self, project):
dbt_args = ["show", "--inline", "select 1"]
_, logs = run_dbt_and_capture(dbt_args)
assert "It is strongly advised to enable `require_certificate_validation` flag" in logs

@pytest.mark.skip_profile("trino_starburst")
def test_require_certificate_validation_insecure_request_warning(self, project):
with warnings.catch_warnings(record=True) as w:
dbt_args = ["show", "--inline", "select 1"]
run_dbt(dbt_args)

# Check if any InsecureRequestWarning was raised
assert any(
issubclass(warning.category, InsecureRequestWarning) for warning in w
), "InsecureRequestWarning was not raised"


class TestRequireCertificateValidationFalse:
@pytest.fixture(scope="class")
def project_config_update(self):
return {"flags": {"require_certificate_validation": False}}

def test_cert_default_value(self, project):
assert project.adapter.connections.profile.credentials.cert is None

def test_require_certificate_validation_logs(self, project):
dbt_args = ["show", "--inline", "select 1"]
_, logs = run_dbt_and_capture(dbt_args)
assert "It is strongly advised to enable `require_certificate_validation` flag" in logs

@pytest.mark.skip_profile("trino_starburst")
def test_require_certificate_validation_insecure_request_warning(self, project):
with warnings.catch_warnings(record=True) as w:
dbt_args = ["show", "--inline", "select 1"]
run_dbt(dbt_args)

# Check if any InsecureRequestWarning was raised
assert any(
issubclass(warning.category, InsecureRequestWarning) for warning in w
), "InsecureRequestWarning was not raised"


class TestRequireCertificateValidationTrue:
@pytest.fixture(scope="class")
def project_config_update(self):
return {"flags": {"require_certificate_validation": True}}

def test_cert_default_value(self, project):
assert project.adapter.connections.profile.credentials.cert is True

def test_require_certificate_validation_logs(self, project):
dbt_args = ["show", "--inline", "select 1"]
_, logs = run_dbt_and_capture(dbt_args)
assert "It is strongly advised to enable `require_certificate_validation` flag" not in logs

@pytest.mark.skip_profile("trino_starburst")
def test_require_certificate_validation_insecure_request_warning(self, project):
with warnings.catch_warnings(record=True) as w:
dbt_args = ["show", "--inline", "select 1"]
run_dbt(dbt_args)

# Check if not any InsecureRequestWarning was raised
assert not any(
issubclass(warning.category, InsecureRequestWarning) for warning in w
), "InsecureRequestWarning was not raised"

0 comments on commit 6f12e46

Please sign in to comment.