From d72a392337a8f4fea69949da5f44a4696e4a70e3 Mon Sep 17 00:00:00 2001 From: Luka Peschke Date: Mon, 8 Jul 2024 17:09:16 +0200 Subject: [PATCH] fix(bigquery): make get_status work with JWT auth [TCTC-8923] (#1698) * fix(bigquery): make get_status work with JWT auth [TCTC-8923] Signed-off-by: Luka Peschke * chore: mypy Signed-off-by: Luka Peschke * chore: coverage Signed-off-by: Luka Peschke --------- Signed-off-by: Luka Peschke --- CHANGELOG.md | 1 + pyproject.toml | 7 +- .../google_big_query/test_google_big_query.py | 35 +++++--- .../google_big_query_connector.py | 80 ++++++++++++------- 4 files changed, 80 insertions(+), 43 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a2412646..a70e093f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### Fix +- BigQuery: the JWT token auth method is now supported in the status check. - HTTP: allow connector to be instanciated without passing positional arguments to auth. ## [6.3.0] 2024-06-21 diff --git a/pyproject.toml b/pyproject.toml index a9f6c4d5f..2cf4c18cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -160,11 +160,12 @@ warn_unused_ignores = true ignore_missing_imports = true files = [ "toucan_connectors/auth.py", - "toucan_connectors/toucan_connector.py", + "toucan_connectors/google_big_query/google_big_query_connector.py", "toucan_connectors/hubspot_private_app/hubspot_connector.py", - "toucan_connectors/snowflake/snowflake_connector.py", + "toucan_connectors/mongo/mongo_connector.py", "toucan_connectors/peakina/peakina_connector.py", - "toucan_connectors/mongo/mongo_connector.py" + "toucan_connectors/snowflake/snowflake_connector.py", + "toucan_connectors/toucan_connector.py", ] [tool.ruff] diff --git a/tests/google_big_query/test_google_big_query.py b/tests/google_big_query/test_google_big_query.py index 7971fc9e2..cf10b8a5d 100644 --- a/tests/google_big_query/test_google_big_query.py +++ b/tests/google_big_query/test_google_big_query.py @@ -16,6 +16,7 @@ from pydantic import ValidationError from pytest_mock import MockerFixture, MockFixture +from toucan_connectors.common import ConnectorStatus from toucan_connectors.google_big_query.google_big_query_connector import ( GoogleBigQueryConnector, GoogleBigQueryDataSource, @@ -137,15 +138,6 @@ def test_prepare_parameters_empty(): assert len(parameters) == 0 -@patch("google.cloud.bigquery.Client", autospec=True) -@patch("cryptography.hazmat.primitives.serialization.load_pem_private_key") -def test_connect(load_pem_private_key, client, fixture_credentials, fixture_scope): - credentials = GoogleBigQueryConnector._get_google_credentials(fixture_credentials, fixture_scope) - assert isinstance(credentials, Credentials) - connection = GoogleBigQueryConnector._connect(credentials) - assert isinstance(connection, Client) - - def test__http_is_present_as_attr( mocker: MockFixture, gbq_connector_with_jwt: GoogleBigQueryConnector, @@ -911,6 +903,7 @@ def test_get_status(mocker: MockerFixture, fixture_credentials: GoogleCredential status = connector.get_status() assert status.status is False assert status.details == [ + ("Credentials provided", True), ("Private key validity", False), ("Sample BigQuery job", False), ] @@ -930,6 +923,7 @@ def connect_spy_fail(*args, **kwargs): status = connector.get_status() assert status.status is False assert status.details == [ + ("Credentials provided", True), ("Private key validity", True), ("Sample BigQuery job", False), ] @@ -947,6 +941,7 @@ def query(self, *args, **kwargs): status = connector.get_status() assert status.status is False assert status.details == [ + ("Credentials provided", True), ("Private key validity", True), ("Sample BigQuery job", False), ] @@ -964,7 +959,29 @@ def query(self, *args, **kwargs): status = connector.get_status() assert status.status is True assert status.details == [ + ("Credentials provided", True), ("Private key validity", True), ("Sample BigQuery job", True), ] assert status.error is None + + +def test_get_status_with_jwt(mocker: MockerFixture, gbq_connector_with_jwt: GoogleBigQueryConnector) -> None: + http_connect_mock = mocker.patch.object(gbq_connector_with_jwt, "_http_connect") + status = gbq_connector_with_jwt.get_status() + http_connect_mock.assert_called_once_with(http_session=mocker.ANY, project_id="THE_JWT_project_id") + # no private key validity should appear here, as JWT auth was used + assert status == ConnectorStatus( + status=True, message=None, error=None, details=[("Credentials provided", True), ("Sample BigQuery job", True)] + ) + + +def test_get_status_no_creds() -> None: + conn = GoogleBigQueryConnector(name="woups", scopes=["https://www.googleapis.com/auth/bigquery"]) + + assert conn.get_status() == ConnectorStatus( + status=False, + message=None, + error="Either google credentials or a JWT token must be provided", + details=[("Credentials provided", False), ("Sample BigQuery job", False)], + ) diff --git a/toucan_connectors/google_big_query/google_big_query_connector.py b/toucan_connectors/google_big_query/google_big_query_connector.py index be01b6a95..3f8b37242 100644 --- a/toucan_connectors/google_big_query/google_big_query_connector.py +++ b/toucan_connectors/google_big_query/google_big_query_connector.py @@ -5,7 +5,7 @@ from functools import cached_property from itertools import groupby from timeit import default_timer as timer -from typing import Any, Dict, Iterable, Union +from typing import Any, Iterable, Union import pandas as pd import requests @@ -83,16 +83,16 @@ class GoogleBigQueryDataSource(ToucanDataSource): ..., description="You can find details on the query syntax " 'here', - widget="sql", + widget="sql", # type:ignore[call-arg] ) - query_object: Dict = Field( + query_object: dict | None = Field( # type:ignore[pydantic-field] None, description="An object describing a simple select query This field is used internally", - **{"ui.hidden": True}, + **{"ui.hidden": True}, # type:ignore[arg-type] ) - language: str = Field("sql", **{"ui.hidden": True}) - database: str = Field(None, **{"ui.hidden": True}) - db_schema: str = Field(None, description="The name of the db_schema you want to query.") + language: str = Field("sql", **{"ui.hidden": True}) # type:ignore[arg-type,pydantic-field] + database: str | None = Field(None, **{"ui.hidden": True}) # type:ignore[arg-type,pydantic-field] + db_schema: str | None = Field(None, description="The name of the db_schema you want to query.") @classmethod def get_form(cls, connector: "GoogleBigQueryConnector", current_config: dict[str, Any]): @@ -122,6 +122,7 @@ def _define_query_param(name: str, value: Any) -> BigQueryParam: return bigquery_helpers.scalar_to_query_parameter(value=value, name=name) +_CREDENTIALS_CHECK_NAME = "Credentials provided" _KEY_CHECK_NAME = "Private key validity" _SAMPLE_QUERY = "Sample BigQuery job" @@ -143,7 +144,7 @@ class GoogleBigQueryConnector(ToucanConnector, DiscoverableConnector, data_sourc jwt_credentials: JWTCredentials | None = Field( None, title="Google Credentials With JWT", - description="You need to signe a JWT token, that will be use here with the project_id", + description="You need to sign a JWT token, that will be used here with the project_id", ) dialect: Dialect = Field( @@ -204,10 +205,13 @@ def _http_connect(http_session: requests.Session, project_id: str) -> bigquery.C return client - @staticmethod - def _connect(credentials: Credentials) -> bigquery.Client: + def _connect(self, credentials: Credentials | JWTCredentials) -> bigquery.Client: start = timer() - client = bigquery.Client(credentials=credentials) + if isinstance(credentials, Credentials): + client = bigquery.Client(credentials=credentials) + else: + session = CustomRequestSession(credentials.jwt_token) + client = self._http_connect(http_session=session, project_id=self._get_project_id()) end = timer() _LOGGER.info( f"[benchmark][google_big_query] - connect {end - start} seconds", @@ -241,8 +245,8 @@ def _execute_query(client: bigquery.Client, query: str, parameters: list) -> pd. try: start = timer() query = GoogleBigQueryConnector._clean_query(query) - result_iterator = ( - client.query( + result_iterator: Iterable[pd.DataFrame] = ( + client.query( # type:ignore[assignment] query, job_config=bigquery.QueryJobConfig(query_parameters=parameters), ) @@ -294,7 +298,7 @@ def _bigquery_client_with_google_creds(self) -> bigquery.Client: try: assert self.credentials is not None credentials = GoogleBigQueryConnector._get_google_credentials(self.credentials, self.scopes) - return GoogleBigQueryConnector._connect(credentials) + return self._connect(credentials) except AssertionError as excp: raise GoogleClientCreationError from excp @@ -303,13 +307,13 @@ def _bigquery_client(self) -> bigquery.Client: if self.jwt_credentials and self.jwt_credentials.jwt_token: try: # We try to instantiate the bigquery.Client with the given jwt-token - _session = CustomRequestSession(self.jwt_credentials.jwt_token) - client = GoogleBigQueryConnector._http_connect(http_session=_session, project_id=self._get_project_id()) - _LOGGER.info("bigqueryClient created with the JWT provided !") + session = CustomRequestSession(self.jwt_credentials.jwt_token) + client = GoogleBigQueryConnector._http_connect(http_session=session, project_id=self._get_project_id()) + _LOGGER.debug("BigQuery client created using the provided JWT token") return client - except InvalidJWTToken: - _LOGGER.info("JWT login failed, falling back to GoogleCredentials if they are presents") + except InvalidJWTToken as exc: + _LOGGER.warning("Login with JWT token failed, falling back to google credentials", exc_info=exc) # or we fallback on default google-credentials return self._bigquery_client_with_google_creds() @@ -359,7 +363,7 @@ def _format_columns(x: str): unformatted_db_tree["columns"] = unformatted_db_tree["columns"].apply(_format_columns) return ( - unformatted_db_tree.groupby(["name", "schema", "database", "type"], group_keys=False)["columns"] + unformatted_db_tree.groupby(["name", "schema", "database", "type"], group_keys=False)["columns"] # type:ignore[return-value] .apply(list) .reset_index() .to_dict(orient="records") @@ -452,7 +456,7 @@ def _available_schs(self) -> list[str]: # pragma: no cover datasets = client.list_datasets(timeout=10) dataset_ids = (ds.dataset_id for ds in datasets) - return pd.Series(dataset_ids).values + return pd.Series(dataset_ids).values # type:ignore[call-overload] def _get_project_structure(self, db_name: str | None = None, schema_name: str | None = None) -> list[TableInfo]: client = self._get_bigquery_client() @@ -484,12 +488,32 @@ def get_model(self, db_name: str | None = None, schema_name: str | None = None) return self._get_project_structure(db_name, schema_name) def get_status(self) -> ConnectorStatus: + checks: list[tuple[str, bool | None]] = [] try: - credentials = get_google_oauth2_credentials(self.credentials) + if self.credentials: + credentials = get_google_oauth2_credentials(self.credentials) + checks += [ + (_CREDENTIALS_CHECK_NAME, True), + (_KEY_CHECK_NAME, True), + ] + elif self.jwt_credentials: + credentials = self.jwt_credentials + checks.append((_CREDENTIALS_CHECK_NAME, True)) + else: + return ConnectorStatus( + status=False, + details=[ + (_CREDENTIALS_CHECK_NAME, False), + (_SAMPLE_QUERY, False), + ], + error="Either google credentials or a JWT token must be provided", + ) + except Exception as exc: return ConnectorStatus( status=False, details=[ + (_CREDENTIALS_CHECK_NAME, True), (_KEY_CHECK_NAME, False), (_SAMPLE_QUERY, False), ], @@ -499,18 +523,12 @@ def get_status(self) -> ConnectorStatus: try: client = self._connect(credentials) client.query("SELECT SESSION_USER() as whoami") + checks.append((_SAMPLE_QUERY, True)) except Exception as exc: return ConnectorStatus( status=False, - details=[ - (_KEY_CHECK_NAME, True), - (_SAMPLE_QUERY, False), - ], + details=[*checks, (_SAMPLE_QUERY, False)], error=str(exc), ) - return ConnectorStatus( - status=True, - details=[(_KEY_CHECK_NAME, True), (_SAMPLE_QUERY, True)], - error=None, - ) + return ConnectorStatus(status=True, details=checks, error=None)