Skip to content

Commit

Permalink
Fix CI + use from_map
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau committed Jul 17, 2024
1 parent 18fd1f3 commit e37232a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 50 deletions.
25 changes: 4 additions & 21 deletions dask_bigquery/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@
from contextlib import contextmanager
from functools import partial

import dask.dataframe as dd
import gcsfs
import pandas as pd
import pyarrow
from dask.base import tokenize
from dask.dataframe.core import new_dd_object
from dask.highlevelgraph import HighLevelGraph
from dask.layers import DataFrameIOLayer
from google.api_core import client_info as rest_client_info
from google.api_core import exceptions
from google.api_core.gapic_v1 import client_info as grpc_client_info
Expand Down Expand Up @@ -206,19 +204,7 @@ def make_create_read_session_request():
)
meta = schema.empty_table().to_pandas(**arrow_options)

label = "read-gbq-"
output_name = label + tokenize(
project_id,
dataset_id,
table_id,
row_filter,
read_kwargs,
)

layer = DataFrameIOLayer(
output_name,
meta.columns,
[stream.name for stream in session.streams],
return dd.from_map(
partial(
bigquery_read,
make_create_read_session_request=make_create_read_session_request,
Expand All @@ -227,12 +213,9 @@ def make_create_read_session_request():
arrow_options=arrow_options,
credentials=credentials,
),
label=label,
[stream.name for stream in session.streams],
meta=meta,
)
divisions = tuple([None] * (len(session.streams) + 1))

graph = HighLevelGraph({output_name: layer}, {output_name: set()})
return new_dd_object(graph, output_name, meta, divisions)


def to_gbq(
Expand Down
56 changes: 27 additions & 29 deletions dask_bigquery/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,13 @@ def df():
for i in range(10)
]

yield pd.DataFrame(records)
df = pd.DataFrame(records)
df["timestamp"] = df["timestamp"].astype("datetime64[us, UTC]")
yield df


@pytest.fixture(scope="module")
def dataset():
project_id = os.environ.get("DASK_BIGQUERY_PROJECT_ID")
if not project_id:
credentials, project_id = google.auth.default()
def dataset(project_id):
dataset_id = f"{sys.platform}_{uuid.uuid4().hex}"

with bigquery.Client() as bq_client:
Expand Down Expand Up @@ -110,25 +109,30 @@ def required_partition_filter_table(dataset, df):
yield project_id, dataset_id, table_id


@pytest.fixture(scope="module")
def project_id():
project_id = os.environ.get("DASK_BIGQUERY_PROJECT_ID")
if not project_id:
_, project_id = google.auth.default()

yield project_id


@pytest.fixture
def google_creds():
env_creds_file = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
if env_creds_file:
credentials = json.load(open(env_creds_file))
else:
if os.environ.get("GOOGLE_APPLICATION_CREDENTIALS"):
credentials = json.load(open(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")))
elif os.environ.get("DASK_BIGQUERY_GCP_CREDENTIALS"):
credentials = json.loads(os.environ.get("DASK_BIGQUERY_GCP_CREDENTIALS"))
else:
credentials, _ = google.auth.default()

yield credentials


@pytest.fixture
def bucket(google_creds):
project_id = google_creds["project_id"]
env_project_id = os.environ.get("DASK_BIGQUERY_PROJECT_ID")
if env_project_id:
project_id = env_project_id

def bucket(google_creds, project_id):
bucket = f"dask-bigquery-tmp-{uuid.uuid4().hex}"

fs = gcsfs.GCSFileSystem(
project=project_id, access="read_write", token=google_creds
)
Expand All @@ -140,12 +144,7 @@ def bucket(google_creds):


@pytest.fixture
def write_dataset(google_creds):
project_id = google_creds["project_id"]
env_project_id = os.environ.get("DASK_BIGQUERY_PROJECT_ID")
if env_project_id:
project_id = env_project_id

def write_dataset(google_creds, project_id):
dataset_id = f"{sys.platform}_{uuid.uuid4().hex}"

yield google_creds, project_id, dataset_id, None
Expand All @@ -158,8 +157,7 @@ def write_dataset(google_creds):


@pytest.fixture
def write_existing_dataset(google_creds):
project_id = os.environ.get("DASK_BIGQUERY_PROJECT_ID", google_creds["project_id"])
def write_existing_dataset(google_creds, project_id):
dataset_id = "persistent_dataset"
table_id = f"table_to_write_{sys.platform}_{uuid.uuid4().hex}"

Expand All @@ -181,7 +179,7 @@ def write_existing_dataset(google_creds):
[
("name", pa.string()),
("number", pa.uint8()),
("timestamp", pa.timestamp("ns")),
("timestamp", pa.timestamp("us")),
("idx", pa.uint8()),
]
),
Expand Down Expand Up @@ -285,14 +283,14 @@ def test_roundtrip(df, dataset_fixture, request):

ddf_out = read_gbq(project_id=project_id, dataset_id=dataset_id, table_id=table_id)
# bigquery does not guarantee ordering, so let's reindex
assert_eq(ddf.set_index("idx"), ddf_out.set_index("idx"))
assert_eq(ddf.set_index("idx"), ddf_out.set_index("idx"), check_divisions=False)


def test_read_gbq(df, table, client):
project_id, dataset_id, table_id = table
ddf = read_gbq(project_id=project_id, dataset_id=dataset_id, table_id=table_id)

assert list(ddf.columns) == ["name", "number", "timestamp", "idx"]
assert list(df.columns) == list(ddf.columns)
assert assert_eq(ddf.set_index("idx"), df.set_index("idx"))


Expand All @@ -305,7 +303,7 @@ def test_read_row_filter(df, table, client):
row_filter="idx < 5",
)

assert list(ddf.columns) == ["name", "number", "timestamp", "idx"]
assert list(df.columns) == list(ddf.columns)
assert assert_eq(ddf.set_index("idx").loc[:4], df.set_index("idx").loc[:4])


Expand Down Expand Up @@ -361,7 +359,7 @@ def test_read_gbq_credentials(df, dataset_fixture, request, monkeypatch):
credentials=credentials,
)

assert list(ddf.columns) == ["name", "number", "timestamp", "idx"]
assert list(df.columns) == list(ddf.columns)
assert assert_eq(ddf.set_index("idx"), df.set_index("idx"))


Expand Down

0 comments on commit e37232a

Please sign in to comment.