Skip to content

Commit

Permalink
fix(bigquery): make get_status work with JWT auth [TCTC-8923] (#1698)
Browse files Browse the repository at this point in the history
* fix(bigquery): make get_status work with JWT auth [TCTC-8923]

Signed-off-by: Luka Peschke <[email protected]>

* chore: mypy

Signed-off-by: Luka Peschke <[email protected]>

* chore: coverage

Signed-off-by: Luka Peschke <[email protected]>

---------

Signed-off-by: Luka Peschke <[email protected]>
  • Loading branch information
lukapeschke authored Jul 8, 2024
1 parent ec00b77 commit d72a392
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 43 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Fix

- BigQuery: the JWT token auth method is now supported in the status check.
- HTTP: allow connector to be instanciated without passing positional arguments to auth.

## [6.3.0] 2024-06-21
Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,12 @@ warn_unused_ignores = true
ignore_missing_imports = true
files = [
"toucan_connectors/auth.py",
"toucan_connectors/toucan_connector.py",
"toucan_connectors/google_big_query/google_big_query_connector.py",
"toucan_connectors/hubspot_private_app/hubspot_connector.py",
"toucan_connectors/snowflake/snowflake_connector.py",
"toucan_connectors/mongo/mongo_connector.py",
"toucan_connectors/peakina/peakina_connector.py",
"toucan_connectors/mongo/mongo_connector.py"
"toucan_connectors/snowflake/snowflake_connector.py",
"toucan_connectors/toucan_connector.py",
]

[tool.ruff]
Expand Down
35 changes: 26 additions & 9 deletions tests/google_big_query/test_google_big_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pydantic import ValidationError
from pytest_mock import MockerFixture, MockFixture

from toucan_connectors.common import ConnectorStatus
from toucan_connectors.google_big_query.google_big_query_connector import (
GoogleBigQueryConnector,
GoogleBigQueryDataSource,
Expand Down Expand Up @@ -137,15 +138,6 @@ def test_prepare_parameters_empty():
assert len(parameters) == 0


@patch("google.cloud.bigquery.Client", autospec=True)
@patch("cryptography.hazmat.primitives.serialization.load_pem_private_key")
def test_connect(load_pem_private_key, client, fixture_credentials, fixture_scope):
credentials = GoogleBigQueryConnector._get_google_credentials(fixture_credentials, fixture_scope)
assert isinstance(credentials, Credentials)
connection = GoogleBigQueryConnector._connect(credentials)
assert isinstance(connection, Client)


def test__http_is_present_as_attr(
mocker: MockFixture,
gbq_connector_with_jwt: GoogleBigQueryConnector,
Expand Down Expand Up @@ -911,6 +903,7 @@ def test_get_status(mocker: MockerFixture, fixture_credentials: GoogleCredential
status = connector.get_status()
assert status.status is False
assert status.details == [
("Credentials provided", True),
("Private key validity", False),
("Sample BigQuery job", False),
]
Expand All @@ -930,6 +923,7 @@ def connect_spy_fail(*args, **kwargs):
status = connector.get_status()
assert status.status is False
assert status.details == [
("Credentials provided", True),
("Private key validity", True),
("Sample BigQuery job", False),
]
Expand All @@ -947,6 +941,7 @@ def query(self, *args, **kwargs):
status = connector.get_status()
assert status.status is False
assert status.details == [
("Credentials provided", True),
("Private key validity", True),
("Sample BigQuery job", False),
]
Expand All @@ -964,7 +959,29 @@ def query(self, *args, **kwargs):
status = connector.get_status()
assert status.status is True
assert status.details == [
("Credentials provided", True),
("Private key validity", True),
("Sample BigQuery job", True),
]
assert status.error is None


def test_get_status_with_jwt(mocker: MockerFixture, gbq_connector_with_jwt: GoogleBigQueryConnector) -> None:
http_connect_mock = mocker.patch.object(gbq_connector_with_jwt, "_http_connect")
status = gbq_connector_with_jwt.get_status()
http_connect_mock.assert_called_once_with(http_session=mocker.ANY, project_id="THE_JWT_project_id")
# no private key validity should appear here, as JWT auth was used
assert status == ConnectorStatus(
status=True, message=None, error=None, details=[("Credentials provided", True), ("Sample BigQuery job", True)]
)


def test_get_status_no_creds() -> None:
conn = GoogleBigQueryConnector(name="woups", scopes=["https://www.googleapis.com/auth/bigquery"])

assert conn.get_status() == ConnectorStatus(
status=False,
message=None,
error="Either google credentials or a JWT token must be provided",
details=[("Credentials provided", False), ("Sample BigQuery job", False)],
)
80 changes: 49 additions & 31 deletions toucan_connectors/google_big_query/google_big_query_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import cached_property
from itertools import groupby
from timeit import default_timer as timer
from typing import Any, Dict, Iterable, Union
from typing import Any, Iterable, Union

import pandas as pd
import requests
Expand Down Expand Up @@ -83,16 +83,16 @@ class GoogleBigQueryDataSource(ToucanDataSource):
...,
description="You can find details on the query syntax "
'<a href="https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax">here</a>',
widget="sql",
widget="sql", # type:ignore[call-arg]
)
query_object: Dict = Field(
query_object: dict | None = Field( # type:ignore[pydantic-field]
None,
description="An object describing a simple select query This field is used internally",
**{"ui.hidden": True},
**{"ui.hidden": True}, # type:ignore[arg-type]
)
language: str = Field("sql", **{"ui.hidden": True})
database: str = Field(None, **{"ui.hidden": True})
db_schema: str = Field(None, description="The name of the db_schema you want to query.")
language: str = Field("sql", **{"ui.hidden": True}) # type:ignore[arg-type,pydantic-field]
database: str | None = Field(None, **{"ui.hidden": True}) # type:ignore[arg-type,pydantic-field]
db_schema: str | None = Field(None, description="The name of the db_schema you want to query.")

@classmethod
def get_form(cls, connector: "GoogleBigQueryConnector", current_config: dict[str, Any]):
Expand Down Expand Up @@ -122,6 +122,7 @@ def _define_query_param(name: str, value: Any) -> BigQueryParam:
return bigquery_helpers.scalar_to_query_parameter(value=value, name=name)


_CREDENTIALS_CHECK_NAME = "Credentials provided"
_KEY_CHECK_NAME = "Private key validity"
_SAMPLE_QUERY = "Sample BigQuery job"

Expand All @@ -143,7 +144,7 @@ class GoogleBigQueryConnector(ToucanConnector, DiscoverableConnector, data_sourc
jwt_credentials: JWTCredentials | None = Field(
None,
title="Google Credentials With JWT",
description="You need to signe a JWT token, that will be use here with the project_id",
description="You need to sign a JWT token, that will be used here with the project_id",
)

dialect: Dialect = Field(
Expand Down Expand Up @@ -204,10 +205,13 @@ def _http_connect(http_session: requests.Session, project_id: str) -> bigquery.C

return client

@staticmethod
def _connect(credentials: Credentials) -> bigquery.Client:
def _connect(self, credentials: Credentials | JWTCredentials) -> bigquery.Client:
start = timer()
client = bigquery.Client(credentials=credentials)
if isinstance(credentials, Credentials):
client = bigquery.Client(credentials=credentials)
else:
session = CustomRequestSession(credentials.jwt_token)
client = self._http_connect(http_session=session, project_id=self._get_project_id())
end = timer()
_LOGGER.info(
f"[benchmark][google_big_query] - connect {end - start} seconds",
Expand Down Expand Up @@ -241,8 +245,8 @@ def _execute_query(client: bigquery.Client, query: str, parameters: list) -> pd.
try:
start = timer()
query = GoogleBigQueryConnector._clean_query(query)
result_iterator = (
client.query(
result_iterator: Iterable[pd.DataFrame] = (
client.query( # type:ignore[assignment]
query,
job_config=bigquery.QueryJobConfig(query_parameters=parameters),
)
Expand Down Expand Up @@ -294,7 +298,7 @@ def _bigquery_client_with_google_creds(self) -> bigquery.Client:
try:
assert self.credentials is not None
credentials = GoogleBigQueryConnector._get_google_credentials(self.credentials, self.scopes)
return GoogleBigQueryConnector._connect(credentials)
return self._connect(credentials)
except AssertionError as excp:
raise GoogleClientCreationError from excp

Expand All @@ -303,13 +307,13 @@ def _bigquery_client(self) -> bigquery.Client:
if self.jwt_credentials and self.jwt_credentials.jwt_token:
try:
# We try to instantiate the bigquery.Client with the given jwt-token
_session = CustomRequestSession(self.jwt_credentials.jwt_token)
client = GoogleBigQueryConnector._http_connect(http_session=_session, project_id=self._get_project_id())
_LOGGER.info("bigqueryClient created with the JWT provided !")
session = CustomRequestSession(self.jwt_credentials.jwt_token)
client = GoogleBigQueryConnector._http_connect(http_session=session, project_id=self._get_project_id())
_LOGGER.debug("BigQuery client created using the provided JWT token")

return client
except InvalidJWTToken:
_LOGGER.info("JWT login failed, falling back to GoogleCredentials if they are presents")
except InvalidJWTToken as exc:
_LOGGER.warning("Login with JWT token failed, falling back to google credentials", exc_info=exc)
# or we fallback on default google-credentials
return self._bigquery_client_with_google_creds()

Expand Down Expand Up @@ -359,7 +363,7 @@ def _format_columns(x: str):

unformatted_db_tree["columns"] = unformatted_db_tree["columns"].apply(_format_columns)
return (
unformatted_db_tree.groupby(["name", "schema", "database", "type"], group_keys=False)["columns"]
unformatted_db_tree.groupby(["name", "schema", "database", "type"], group_keys=False)["columns"] # type:ignore[return-value]
.apply(list)
.reset_index()
.to_dict(orient="records")
Expand Down Expand Up @@ -452,7 +456,7 @@ def _available_schs(self) -> list[str]: # pragma: no cover
datasets = client.list_datasets(timeout=10)
dataset_ids = (ds.dataset_id for ds in datasets)

return pd.Series(dataset_ids).values
return pd.Series(dataset_ids).values # type:ignore[call-overload]

def _get_project_structure(self, db_name: str | None = None, schema_name: str | None = None) -> list[TableInfo]:
client = self._get_bigquery_client()
Expand Down Expand Up @@ -484,12 +488,32 @@ def get_model(self, db_name: str | None = None, schema_name: str | None = None)
return self._get_project_structure(db_name, schema_name)

def get_status(self) -> ConnectorStatus:
checks: list[tuple[str, bool | None]] = []
try:
credentials = get_google_oauth2_credentials(self.credentials)
if self.credentials:
credentials = get_google_oauth2_credentials(self.credentials)
checks += [
(_CREDENTIALS_CHECK_NAME, True),
(_KEY_CHECK_NAME, True),
]
elif self.jwt_credentials:
credentials = self.jwt_credentials
checks.append((_CREDENTIALS_CHECK_NAME, True))
else:
return ConnectorStatus(
status=False,
details=[
(_CREDENTIALS_CHECK_NAME, False),
(_SAMPLE_QUERY, False),
],
error="Either google credentials or a JWT token must be provided",
)

except Exception as exc:
return ConnectorStatus(
status=False,
details=[
(_CREDENTIALS_CHECK_NAME, True),
(_KEY_CHECK_NAME, False),
(_SAMPLE_QUERY, False),
],
Expand All @@ -499,18 +523,12 @@ def get_status(self) -> ConnectorStatus:
try:
client = self._connect(credentials)
client.query("SELECT SESSION_USER() as whoami")
checks.append((_SAMPLE_QUERY, True))
except Exception as exc:
return ConnectorStatus(
status=False,
details=[
(_KEY_CHECK_NAME, True),
(_SAMPLE_QUERY, False),
],
details=[*checks, (_SAMPLE_QUERY, False)],
error=str(exc),
)

return ConnectorStatus(
status=True,
details=[(_KEY_CHECK_NAME, True), (_SAMPLE_QUERY, True)],
error=None,
)
return ConnectorStatus(status=True, details=checks, error=None)

0 comments on commit d72a392

Please sign in to comment.