Skip to content

Commit

Permalink
Add oauth_console authentication type
Browse files Browse the repository at this point in the history
  • Loading branch information
hovaesco committed Dec 5, 2023
1 parent 22720e3 commit 0c031da
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20231201-103207.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Add oauth_console authentication type
time: 2023-12-01T10:32:07.070436+01:00
custom:
Author: hovaesco
Issue: ""
PR: "379"
31 changes: 31 additions & 0 deletions dbt/adapters/trino/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def _create_trino_profile(cls, profile):
return TrinoJwtCredentials
elif method == "oauth":
return TrinoOauthCredentials
elif method == "oauth_console":
return TrinoOauthConsoleCredentials
return TrinoNoneCredentials

@classmethod
Expand Down Expand Up @@ -273,6 +275,35 @@ def trino_auth(self):
return self.OAUTH


@dataclass
class TrinoOauthConsoleCredentials(TrinoCredentials):
host: str
port: Port
user: Optional[str] = None
client_tags: Optional[List[str]] = None
roles: Optional[Dict[str, str]] = None
cert: Optional[str] = None
http_headers: Optional[Dict[str, str]] = None
session_properties: Dict[str, Any] = field(default_factory=dict)
prepared_statements_enabled: bool = PREPARED_STATEMENTS_ENABLED_DEFAULT
retries: Optional[int] = trino.constants.DEFAULT_MAX_ATTEMPTS
timezone: Optional[str] = None
OAUTH = trino.auth.OAuth2Authentication(
redirect_auth_url_handler=trino.auth.ConsoleRedirectHandler()
)

@property
def http_scheme(self):
return HttpScheme.HTTPS

@property
def method(self):
return "oauth_console"

def trino_auth(self):
return self.OAUTH


class ConnectionWrapper(object):
"""Wrap a Trino connection in a way that accomplishes two tasks:
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TrinoKerberosCredentials,
TrinoLdapCredentials,
TrinoNoneCredentials,
TrinoOauthConsoleCredentials,
TrinoOauthCredentials,
)

Expand Down Expand Up @@ -397,6 +398,34 @@ def test_oauth_authentication(self):
self.assertEqual(credentials.client_tags, ["dev", "oauth"])
self.assertEqual(credentials.timezone, "UTC")

def test_oauth_console_authentication(self):
connection = self.acquire_connection_with_profile(
{
"type": "trino",
"catalog": "trinodb",
"host": "database",
"port": 5439,
"method": "oauth_console",
"schema": "dbt_test_schema",
"cert": "/path/to/cert",
"client_tags": ["dev", "oauth_console"],
"http_headers": {"X-Trino-Client-Info": "dbt-trino"},
"session_properties": {
"query_max_run_time": "4h",
"exchange_compression": True,
},
"timezone": "UTC",
}
)
credentials = connection.credentials
self.assertIsInstance(credentials, TrinoOauthConsoleCredentials)
self.assert_default_connection_credentials(credentials)
self.assertEqual(credentials.http_scheme, HttpScheme.HTTPS)
self.assertEqual(credentials.cert, "/path/to/cert")
self.assertEqual(connection.credentials.prepared_statements_enabled, True)
self.assertEqual(credentials.client_tags, ["dev", "oauth_console"])
self.assertEqual(credentials.timezone, "UTC")


class TestPreparedStatementsEnabled(TestCase):
def setup_profile(self, credentials):
Expand Down

0 comments on commit 0c031da

Please sign in to comment.