Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add possibility to provide cred #11

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ conda install -c conda-forge dask-bigquery

## Example

`dask-bigquery` assumes that you are already authenticated.
`dask-bigquery` assumes that you are already authenticated.

```python
import dask_bigquery
Expand Down
41 changes: 37 additions & 4 deletions dask_bigquery/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from contextlib import contextmanager
from functools import partial

import google.auth.transport.requests
import google.oauth2.credentials
import pandas as pd
import pyarrow
from dask.base import tokenize
Expand All @@ -17,7 +19,7 @@


@contextmanager
def bigquery_clients(project_id):
def bigquery_clients(project_id, credentials=None):
"""This context manager is a temporary solution until there is an
upstream solution to handle this.
See googleapis/google-cloud-python#9457
Expand All @@ -30,7 +32,9 @@ def bigquery_clients(project_id):
user_agent=f"dask-bigquery/{dask_bigquery.__version__}"
)

with bigquery.Client(project_id, client_info=bq_client_info) as bq_client:
with bigquery.Client(
project_id, credentials=credentials, client_info=bq_client_info
) as bq_client:
bq_storage_client = bigquery_storage.BigQueryReadClient(
credentials=bq_client._credentials,
client_info=bqstorage_client_info,
Expand All @@ -54,6 +58,7 @@ def bigquery_read(
make_create_read_session_request: callable,
project_id: str,
read_kwargs: dict,
cred_token: str,
stream_name: str,
) -> pd.DataFrame:
"""Read a single batch of rows via BQ Storage API, in Arrow binary format.
Expand All @@ -70,8 +75,16 @@ def bigquery_read(
BigQuery Storage API Stream "name"
NOTE: Please set if reading from Storage API without any `row_restriction`.
https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream
cred_token: str
google_auth bearer token
"""
with bigquery_clients(project_id) as (_, bqs_client):

if cred_token:
credentials = google.oauth2.credentials.Credentials(cred_token)
else:
credentials = None

with bigquery_clients(project_id, credentials=credentials) as (_, bqs_client):
session = bqs_client.create_read_session(make_create_read_session_request())
schema = pyarrow.ipc.read_schema(
pyarrow.py_buffer(session.arrow_schema.serialized_schema)
Expand All @@ -91,6 +104,7 @@ def read_gbq(
row_filter: str = "",
columns: list[str] = None,
read_kwargs: dict = None,
fwd_creds: bool = False,
):
"""Read table as dask dataframe using BigQuery Storage API via Arrow format.
Partitions will be approximately balanced according to BigQuery stream allocation logic.
Expand All @@ -109,13 +123,31 @@ def read_gbq(
list of columns to load from the table
read_kwargs: dict
kwargs to pass to read_rows()
fwd_creds: bool
Set to True if user desires to forward credentials to the workers. Default to False.

Returns
-------
Dask DataFrame
"""
read_kwargs = read_kwargs or {}
with bigquery_clients(project_id) as (bq_client, bqs_client):

if fwd_creds:
credentials, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/bigquery.readonly"]
)

auth_req = google.auth.transport.requests.Request()
credentials.refresh(auth_req)
cred_token = credentials.token
else:
credentials = None
cred_token = None

with bigquery_clients(project_id, credentials=credentials) as (
bq_client,
bqs_client,
):
table_ref = bq_client.get_table(f"{dataset_id}.{table_id}")
if table_ref.table_type == "VIEW":
raise TypeError("Table type VIEW not supported")
Expand Down Expand Up @@ -161,6 +193,7 @@ def make_create_read_session_request(row_filter=""):
make_create_read_session_request,
project_id,
read_kwargs,
cred_token,
),
label=label,
)
Expand Down
43 changes: 38 additions & 5 deletions dask_bigquery/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,43 +51,54 @@ def dataset(df):
)


def test_read_gbq(df, dataset, client):
@pytest.mark.parametrize("fwd_creds", [False, True])
def test_read_gbq(df, dataset, fwd_creds, client):
project_id, dataset_id, table_id = dataset
ddf = read_gbq(project_id=project_id, dataset_id=dataset_id, table_id=table_id)
ddf = read_gbq(
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id,
fwd_creds=fwd_creds,
)

assert list(ddf.columns) == ["name", "number", "idx"]
assert ddf.npartitions == 2
assert assert_eq(ddf.set_index("idx"), df.set_index("idx"))


def test_read_row_filter(df, dataset, client):
@pytest.mark.parametrize("fwd_creds", [False, True])
def test_read_row_filter(df, dataset, fwd_creds, client):
project_id, dataset_id, table_id = dataset
ddf = read_gbq(
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id,
row_filter="idx < 5",
fwd_creds=fwd_creds,
)

assert list(ddf.columns) == ["name", "number", "idx"]
assert ddf.npartitions == 2
assert assert_eq(ddf.set_index("idx").loc[:4], df.set_index("idx").loc[:4])


def test_read_kwargs(dataset, client):
@pytest.mark.parametrize("fwd_creds", [False, True])
def test_read_kwargs(dataset, fwd_creds, client):
project_id, dataset_id, table_id = dataset
ddf = read_gbq(
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id,
read_kwargs={"timeout": 1e-12},
fwd_creds=fwd_creds,
)

with pytest.raises(Exception, match="Deadline Exceeded"):
ddf.compute()


def test_read_columns(df, dataset, client):
@pytest.mark.parametrize("fwd_creds", [False, True])
def test_read_columns(df, dataset, fwd_creds, client):
project_id, dataset_id, table_id = dataset
assert df.shape[1] > 1, "Test data should have multiple columns"

Expand All @@ -97,5 +108,27 @@ def test_read_columns(df, dataset, client):
dataset_id=dataset_id,
table_id=table_id,
columns=columns,
fwd_creds=fwd_creds,
)
assert list(ddf.columns) == columns


@pytest.mark.parametrize("fwd_creds", [False, True])
def test_read_gbq_no_creds_fail(dataset, fwd_creds, monkeypatch, client):
"""This test is to check that if we do not have credentials
we can not authenticate.
"""
project_id, dataset_id, table_id = dataset

def mock_auth(scopes=["https://www.googleapis.com/auth/bigquery.readonly"]):
raise google.auth.exceptions.DefaultCredentialsError()

monkeypatch.setattr(google.auth, "default", mock_auth)

with pytest.raises(google.auth.exceptions.DefaultCredentialsError):
read_gbq(
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id,
fwd_creds=fwd_creds,
)