Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possibility to provide cred #11

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pip install dask-bigquery

## Example

`dask-bigquery` assumes that you are already authenticated.
`dask-bigquery` assumes that you are already authenticated.

```python
import dask_bigquery
Expand Down
47 changes: 43 additions & 4 deletions dask_bigquery/core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import os
from contextlib import contextmanager
from functools import partial

import google.auth.transport.requests
import google.oauth2.credentials
import pandas as pd
import pyarrow
from dask.base import tokenize
Expand All @@ -12,12 +15,13 @@
from google.api_core import client_info as rest_client_info
from google.api_core.gapic_v1 import client_info as grpc_client_info
from google.cloud import bigquery, bigquery_storage
from google.oauth2 import service_account

import dask_bigquery


@contextmanager
def bigquery_clients(project_id):
def bigquery_clients(project_id, credentials=None):
"""This context manager is a temporary solution until there is an
upstream solution to handle this.
See googleapis/google-cloud-python#9457
Expand All @@ -30,7 +34,9 @@ def bigquery_clients(project_id):
user_agent=f"dask-bigquery/{dask_bigquery.__version__}"
)

with bigquery.Client(project_id, client_info=bq_client_info) as bq_client:
with bigquery.Client(
project_id, credentials=credentials, client_info=bq_client_info
) as bq_client:
bq_storage_client = bigquery_storage.BigQueryReadClient(
credentials=bq_client._credentials,
client_info=bqstorage_client_info,
Expand All @@ -54,6 +60,7 @@ def bigquery_read(
make_create_read_session_request: callable,
project_id: str,
read_kwargs: dict,
cred_token: str,
stream_name: str,
) -> pd.DataFrame:
"""Read a single batch of rows via BQ Storage API, in Arrow binary format.
Expand All @@ -70,8 +77,16 @@ def bigquery_read(
BigQuery Storage API Stream "name"
NOTE: Please set if reading from Storage API without any `row_restriction`.
https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream
cred_token: str
google_auth bearer token
"""
with bigquery_clients(project_id) as (_, bqs_client):

if cred_token:
credentials = google.oauth2.credentials.Credentials(cred_token)
else:
credentials = None

with bigquery_clients(project_id, credentials=credentials) as (_, bqs_client):
session = bqs_client.create_read_session(make_create_read_session_request())
schema = pyarrow.ipc.read_schema(
pyarrow.py_buffer(session.arrow_schema.serialized_schema)
Expand All @@ -91,6 +106,7 @@ def read_gbq(
row_filter: str = "",
columns: list[str] = None,
read_kwargs: dict = None,
fwd_creds: bool = False,
):
"""Read table as dask dataframe using BigQuery Storage API via Arrow format.
Partitions will be approximately balanced according to BigQuery stream allocation logic.
Expand All @@ -109,13 +125,35 @@ def read_gbq(
list of columns to load from the table
read_kwargs: dict
kwargs to pass to read_rows()
fwd_creds: bool
Set to True if user desires to forward credentials to the workers. Default to False.

Returns
-------
Dask DataFrame
"""
read_kwargs = read_kwargs or {}
with bigquery_clients(project_id) as (bq_client, bqs_client):

if fwd_creds:
creds_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
if creds_path is None:
raise ValueError("No credentials found")

credentials = service_account.Credentials.from_service_account_file(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want to call

https://googleapis.dev/python/google-auth/latest/reference/google.auth.html#google.auth.default

instead. That'll find the GOOGLE_APPLICATIONS_CREDENTIALS environment variable if available, but also allow for other authentication methods such as gcloud auth application-default login.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe any kind of credentials it returns should have a refresh function, just like the service account credentials.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Tim, I will take a look at this.

creds_path, scopes=["https://www.googleapis.com/auth/bigquery.readonly"]
)

auth_req = google.auth.transport.requests.Request()
credentials.refresh(auth_req)
cred_token = credentials.token
else:
credentials = None
cred_token = None

with bigquery_clients(project_id, credentials=credentials) as (
bq_client,
bqs_client,
):
table_ref = bq_client.get_table(f"{dataset_id}.{table_id}")
if table_ref.table_type == "VIEW":
raise TypeError("Table type VIEW not supported")
Expand Down Expand Up @@ -161,6 +199,7 @@ def make_create_read_session_request(row_filter=""):
make_create_read_session_request,
project_id,
read_kwargs,
cred_token,
),
label=label,
)
Expand Down
22 changes: 17 additions & 5 deletions dask_bigquery/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,43 +51,54 @@ def dataset(df):
)


def test_read_gbq(df, dataset, client):
@pytest.mark.parametrize("fwd_creds", [False, True])
def test_read_gbq(df, dataset, fwd_creds, client):
project_id, dataset_id, table_id = dataset
ddf = read_gbq(project_id=project_id, dataset_id=dataset_id, table_id=table_id)
ddf = read_gbq(
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id,
fwd_creds=fwd_creds,
)

assert list(ddf.columns) == ["name", "number", "idx"]
assert ddf.npartitions == 2
assert assert_eq(ddf.set_index("idx"), df.set_index("idx"))


def test_read_row_filter(df, dataset, client):
@pytest.mark.parametrize("fwd_creds", [False, True])
def test_read_row_filter(df, dataset, fwd_creds, client):
project_id, dataset_id, table_id = dataset
ddf = read_gbq(
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id,
row_filter="idx < 5",
fwd_creds=fwd_creds,
)

assert list(ddf.columns) == ["name", "number", "idx"]
assert ddf.npartitions == 2
assert assert_eq(ddf.set_index("idx").loc[:4], df.set_index("idx").loc[:4])


def test_read_kwargs(dataset, client):
@pytest.mark.parametrize("fwd_creds", [False, True])
def test_read_kwargs(dataset, fwd_creds, client):
project_id, dataset_id, table_id = dataset
ddf = read_gbq(
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id,
read_kwargs={"timeout": 1e-12},
fwd_creds=fwd_creds,
)

with pytest.raises(Exception, match="Deadline Exceeded"):
ddf.compute()


def test_read_columns(df, dataset, client):
@pytest.mark.parametrize("fwd_creds", [False, True])
def test_read_columns(df, dataset, fwd_creds, client):
project_id, dataset_id, table_id = dataset
assert df.shape[1] > 1, "Test data should have multiple columns"

Expand All @@ -97,5 +108,6 @@ def test_read_columns(df, dataset, client):
dataset_id=dataset_id,
table_id=table_id,
columns=columns,
fwd_creds=fwd_creds,
)
assert list(ddf.columns) == columns