diff --git a/.changes/unreleased/Features-20231201-103207.yaml b/.changes/unreleased/Features-20231201-103207.yaml new file mode 100644 index 00000000..ca50a873 --- /dev/null +++ b/.changes/unreleased/Features-20231201-103207.yaml @@ -0,0 +1,7 @@ +kind: Features +body: Add oauth_console authentication type +time: 2023-12-01T10:32:07.070436+01:00 +custom: + Author: hovaesco + Issue: "" + PR: "379" diff --git a/dbt/adapters/trino/connections.py b/dbt/adapters/trino/connections.py index dbe8cf17..f40b29b4 100644 --- a/dbt/adapters/trino/connections.py +++ b/dbt/adapters/trino/connections.py @@ -44,6 +44,8 @@ def _create_trino_profile(cls, profile): return TrinoJwtCredentials elif method == "oauth": return TrinoOauthCredentials + elif method == "oauth_console": + return TrinoOauthConsoleCredentials return TrinoNoneCredentials @classmethod @@ -273,6 +275,35 @@ def trino_auth(self): return self.OAUTH +@dataclass +class TrinoOauthConsoleCredentials(TrinoCredentials): + host: str + port: Port + user: Optional[str] = None + client_tags: Optional[List[str]] = None + roles: Optional[Dict[str, str]] = None + cert: Optional[str] = None + http_headers: Optional[Dict[str, str]] = None + session_properties: Dict[str, Any] = field(default_factory=dict) + prepared_statements_enabled: bool = PREPARED_STATEMENTS_ENABLED_DEFAULT + retries: Optional[int] = trino.constants.DEFAULT_MAX_ATTEMPTS + timezone: Optional[str] = None + OAUTH = trino.auth.OAuth2Authentication( + redirect_auth_url_handler=trino.auth.ConsoleRedirectHandler() + ) + + @property + def http_scheme(self): + return HttpScheme.HTTPS + + @property + def method(self): + return "oauth_console" + + def trino_auth(self): + return self.OAUTH + + class ConnectionWrapper(object): """Wrap a Trino connection in a way that accomplishes two tasks: diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index edb11d0a..1121065b 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -18,6 +18,7 @@ TrinoKerberosCredentials, TrinoLdapCredentials, TrinoNoneCredentials, + TrinoOauthConsoleCredentials, TrinoOauthCredentials, ) @@ -397,6 +398,34 @@ def test_oauth_authentication(self): self.assertEqual(credentials.client_tags, ["dev", "oauth"]) self.assertEqual(credentials.timezone, "UTC") + def test_oauth_console_authentication(self): + connection = self.acquire_connection_with_profile( + { + "type": "trino", + "catalog": "trinodb", + "host": "database", + "port": 5439, + "method": "oauth_console", + "schema": "dbt_test_schema", + "cert": "/path/to/cert", + "client_tags": ["dev", "oauth_console"], + "http_headers": {"X-Trino-Client-Info": "dbt-trino"}, + "session_properties": { + "query_max_run_time": "4h", + "exchange_compression": True, + }, + "timezone": "UTC", + } + ) + credentials = connection.credentials + self.assertIsInstance(credentials, TrinoOauthConsoleCredentials) + self.assert_default_connection_credentials(credentials) + self.assertEqual(credentials.http_scheme, HttpScheme.HTTPS) + self.assertEqual(credentials.cert, "/path/to/cert") + self.assertEqual(connection.credentials.prepared_statements_enabled, True) + self.assertEqual(credentials.client_tags, ["dev", "oauth_console"]) + self.assertEqual(credentials.timezone, "UTC") + class TestPreparedStatementsEnabled(TestCase): def setup_profile(self, credentials):