Skip to content

Commit

Permalink
feat: add bigquery_client as a parameter for read_gbq and to_gbq (#878)
Browse files Browse the repository at this point in the history
Co-authored-by: Tim Sweña (Swast) <swast@google.com>
  • Loading branch information
sycai and tswast authored Feb 20, 2025
1 parent efdbc13 commit d42a562
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 2 deletions.
43 changes: 41 additions & 2 deletions pandas_gbq/gbq.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def __init__(
client_secret=None,
user_agent=None,
rfc9110_delimiter=False,
bigquery_client=None,
):
global context
from google.api_core.exceptions import ClientError, GoogleAPIError
Expand All @@ -288,6 +289,14 @@ def __init__(
self.client_secret = client_secret
self.user_agent = user_agent
self.rfc9110_delimiter = rfc9110_delimiter
self.use_bqstorage_api = use_bqstorage_api

if bigquery_client is not None:
# If a bq client is already provided, use it to populate auth fields.
self.project_id = bigquery_client.project
self.credentials = bigquery_client._credentials
self.client = bigquery_client
return

default_project = None

Expand Down Expand Up @@ -325,8 +334,9 @@ def __init__(
if context.project is None:
context.project = self.project_id

self.client = self.get_client()
self.use_bqstorage_api = use_bqstorage_api
self.client = _get_client(
self.user_agent, self.rfc9110_delimiter, self.project_id, self.credentials
)

def _start_timer(self):
self.start = time.time()
Expand Down Expand Up @@ -702,6 +712,7 @@ def read_gbq(
client_secret=None,
*,
col_order=None,
bigquery_client=None,
):
r"""Read data from Google BigQuery to a pandas DataFrame.
Expand Down Expand Up @@ -849,6 +860,9 @@ def read_gbq(
the user is attempting to connect to.
col_order : list(str), optional
Alias for columns, retained for backwards compatibility.
bigquery_client : google.cloud.bigquery.Client, optional
A Google Cloud BigQuery Python Client instance. If provided, it will be used for reading
data, while the project and credentials parameters will be ignored.
Returns
-------
Expand Down Expand Up @@ -900,6 +914,7 @@ def read_gbq(
auth_redirect_uri=auth_redirect_uri,
client_id=client_id,
client_secret=client_secret,
bigquery_client=bigquery_client,
)

if _is_query(query_or_table):
Expand Down Expand Up @@ -971,6 +986,7 @@ def to_gbq(
client_secret=None,
user_agent=None,
rfc9110_delimiter=False,
bigquery_client=None,
):
"""Write a DataFrame to a Google BigQuery table.
Expand Down Expand Up @@ -1087,6 +1103,9 @@ def to_gbq(
rfc9110_delimiter : bool
Sets user agent delimiter to a hyphen or a slash.
Default is False, meaning a hyphen will be used.
bigquery_client : google.cloud.bigquery.Client, optional
A Google Cloud BigQuery Python Client instance. If provided, it will be used for reading
data, while the project, user_agent, and credentials parameters will be ignored.
.. versionadded:: 0.23.3
"""
Expand Down Expand Up @@ -1157,6 +1176,7 @@ def to_gbq(
client_secret=client_secret,
user_agent=user_agent,
rfc9110_delimiter=rfc9110_delimiter,
bigquery_client=bigquery_client,
)
bqclient = connector.client

Expand Down Expand Up @@ -1492,3 +1512,22 @@ def create_user_agent(
user_agent = f"{user_agent} {identity}"

return user_agent


def _get_client(user_agent, rfc9110_delimiter, project_id, credentials):
import google.api_core.client_info

bigquery = FEATURES.bigquery_try_import()

user_agent = create_user_agent(
user_agent=user_agent, rfc9110_delimiter=rfc9110_delimiter
)

client_info = google.api_core.client_info.ClientInfo(
user_agent=user_agent,
)
return bigquery.Client(
project=project_id,
credentials=credentials,
client_info=client_info,
)
14 changes: 14 additions & 0 deletions tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def to_gbq(credentials, project_id):
)


@pytest.fixture
def to_gbq_with_bq_client(bigquery_client):
import pandas_gbq

return functools.partial(pandas_gbq.to_gbq, bigquery_client=bigquery_client)


@pytest.fixture
def read_gbq(credentials, project_id):
import pandas_gbq
Expand All @@ -63,6 +70,13 @@ def read_gbq(credentials, project_id):
)


@pytest.fixture
def read_gbq_with_bq_client(bigquery_client):
import pandas_gbq

return functools.partial(pandas_gbq.read_gbq, bigquery_client=bigquery_client)


@pytest.fixture()
def random_dataset_id(bigquery_client: bigquery.Client, project_id: str):
dataset_id = prefixer.create_prefix()
Expand Down
10 changes: 10 additions & 0 deletions tests/system/test_gbq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,3 +1398,13 @@ def test_to_gbq_does_not_override_mode(gbq_table, gbq_connector):
)

assert verify_schema(gbq_connector, gbq_table.dataset_id, table_id, table_schema)


def test_gbqconnector_init_with_bq_client(bigquery_client):
gbq_connector = gbq.GbqConnector(
project_id="project_id", credentials=None, bigquery_client=bigquery_client
)

assert gbq_connector.project_id == bigquery_client.project
assert gbq_connector.credentials is bigquery_client._credentials
assert gbq_connector.client is bigquery_client
11 changes: 11 additions & 0 deletions tests/system/test_read_gbq.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,3 +659,14 @@ def test_dml_query(read_gbq, writable_table: str):
"""
result = read_gbq(query)
assert result is not None


def test_read_gbq_with_bq_client(read_gbq_with_bq_client):
query = "SELECT * FROM UNNEST([1, 2, 3]) AS numbers"

actual_result = read_gbq_with_bq_client(query)

expected_result = pandas.DataFrame(
{"numbers": pandas.Series([1, 2, 3], dtype="Int64")}
)
pandas.testing.assert_frame_equal(actual_result, expected_result)
14 changes: 14 additions & 0 deletions tests/system/test_to_gbq.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,17 @@ def test_dataframe_round_trip_with_table_schema(
pandas.testing.assert_frame_equal(
expected_df.set_index("row_num").sort_index(), round_trip
)


def test_dataframe_round_trip_with_bq_client(
to_gbq_with_bq_client, read_gbq_with_bq_client, random_dataset_id
):
table_id = (
f"{random_dataset_id}.round_trip_w_bq_client_{random.randrange(1_000_000)}"
)
df = pandas.DataFrame({"numbers": pandas.Series([1, 2, 3], dtype="Int64")})

to_gbq_with_bq_client(df, table_id)
result = read_gbq_with_bq_client(table_id)

pandas.testing.assert_frame_equal(result, df)

0 comments on commit d42a562

Please sign in to comment.