diff --git a/dask_bigquery/core.py b/dask_bigquery/core.py index 7ea9c9d..df3b3fa 100644 --- a/dask_bigquery/core.py +++ b/dask_bigquery/core.py @@ -10,6 +10,8 @@ import pandas as pd import pyarrow from dask.base import tokenize +from dask.dataframe._compat import PANDAS_GE_220 +from dask.dataframe.utils import pyarrow_strings_enabled from google.api_core import client_info as rest_client_info from google.api_core import exceptions from google.api_core.gapic_v1 import client_info as grpc_client_info @@ -95,6 +97,7 @@ def bigquery_read( read_kwargs: dict, arrow_options: dict, credentials: dict = None, + convert_string: bool = False, ) -> pd.DataFrame: """Read a single batch of rows via BQ Storage API, in Arrow binary format. @@ -114,7 +117,15 @@ def bigquery_read( BigQuery Storage API Stream "name" NOTE: Please set if reading from Storage API without any `row_restriction`. https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream + convert_string: bool + Whether to convert strings directly to arrow strings in the output DataFrame """ + arrow_options = arrow_options.copy() + if convert_string: + types_mapper = _get_types_mapper(arrow_options.get("types_mapper", {}.get)) + if types_mapper is not None: + arrow_options["types_mapper"] = types_mapper + with bigquery_clients(project_id, credentials=credentials) as (_, bqs_client): session = bqs_client.create_read_session(make_create_read_session_request()) schema = pyarrow.ipc.read_schema( @@ -130,6 +141,37 @@ def bigquery_read( return pd.concat(shards) +def _get_types_mapper(user_mapper): + type_mappers = [] + + # always use the user-defined mapper first, if available + if user_mapper is not None: + type_mappers.append(user_mapper) + + type_mappers.append({pyarrow.string(): pd.StringDtype("pyarrow")}.get) + if PANDAS_GE_220: + type_mappers.append({pyarrow.large_string(): pd.StringDtype("pyarrow")}.get) + type_mappers.append({pyarrow.date32(): pd.ArrowDtype(pyarrow.date32())}.get) + type_mappers.append({pyarrow.date64(): pd.ArrowDtype(pyarrow.date64())}.get) + + def _convert_decimal_type(type): + if pyarrow.types.is_decimal(type): + return pd.ArrowDtype(type) + return None + + type_mappers.append(_convert_decimal_type) + + def default_types_mapper(pyarrow_dtype): + """Try all type mappers in order, starting from the user type mapper.""" + for type_converter in type_mappers: + converted_type = type_converter(pyarrow_dtype) + if converted_type is not None: + return converted_type + + if len(type_mappers) > 0: + return default_types_mapper + + def read_gbq( project_id: str, dataset_id: str, @@ -196,13 +238,19 @@ def make_create_read_session_request(): ), ) + arrow_options_meta = arrow_options.copy() + if pyarrow_strings_enabled(): + types_mapper = _get_types_mapper(arrow_options.get("types_mapper", {}.get)) + if types_mapper is not None: + arrow_options_meta["types_mapper"] = types_mapper + # Create a read session in order to detect the schema. # Read sessions are light weight and will be auto-deleted after 24 hours. session = bqs_client.create_read_session(make_create_read_session_request()) schema = pyarrow.ipc.read_schema( pyarrow.py_buffer(session.arrow_schema.serialized_schema) ) - meta = schema.empty_table().to_pandas(**arrow_options) + meta = schema.empty_table().to_pandas(**arrow_options_meta) return dd.from_map( partial( @@ -212,6 +260,7 @@ def make_create_read_session_request(): read_kwargs=read_kwargs, arrow_options=arrow_options, credentials=credentials, + convert_string=pyarrow_strings_enabled(), ), [stream.name for stream in session.streams], meta=meta, diff --git a/dask_bigquery/tests/__init__.py b/dask_bigquery/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dask_bigquery/tests/test_core.py b/dask_bigquery/tests/test_core.py index bc4c48c..cbfeeb4 100644 --- a/dask_bigquery/tests/test_core.py +++ b/dask_bigquery/tests/test_core.py @@ -5,13 +5,14 @@ import uuid from datetime import datetime, timedelta, timezone +import dask import dask.dataframe as dd import gcsfs import google.auth import pandas as pd import pyarrow as pa import pytest -from dask.dataframe.utils import assert_eq +from dask.dataframe.utils import assert_eq, pyarrow_strings_enabled from distributed.utils_test import cleanup # noqa: F401 from distributed.utils_test import client # noqa: F401 from distributed.utils_test import cluster_fixture # noqa: F401 @@ -380,11 +381,32 @@ def test_arrow_options(table): project_id=project_id, dataset_id=dataset_id, table_id=table_id, - arrow_options={ - "types_mapper": {pa.string(): pd.StringDtype(storage="pyarrow")}.get - }, + arrow_options={"types_mapper": {pa.int64(): pd.Float32Dtype()}.get}, ) - assert ddf.dtypes["name"] == pd.StringDtype(storage="pyarrow") + assert ddf.dtypes["number"] == pd.Float32Dtype() + + +@pytest.mark.parametrize("convert_string", [True, False, None]) +def test_convert_string(table, convert_string, df): + project_id, dataset_id, table_id = table + config = {} + if convert_string is not None: + config = {"dataframe.convert-string": convert_string} + with dask.config.set(config): + ddf = read_gbq( + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, + ) + # Roundtrip through `dd.from_pandas` to check consistent + # behavior with Dask DataFrame + result = dd.from_pandas(df, npartitions=1) + if convert_string is True or (convert_string is None and pyarrow_strings_enabled()): + assert ddf.dtypes["name"] == pd.StringDtype(storage="pyarrow") + else: + assert ddf.dtypes["name"] == object + + assert assert_eq(ddf.set_index("idx"), result.set_index("idx")) @pytest.mark.skipif(sys.platform == "darwin", reason="Segfaults on macOS")