This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4703f9a0e5 BigQueryHook list_rows/get_datasets_list can return
iterator (#30543)
4703f9a0e5 is described below
commit 4703f9a0e589557f5176a6f466ae83fe52644cf6
Author: Victor Chiapaikeo <[email protected]>
AuthorDate: Sat Apr 8 13:01:57 2023 -0400
BigQueryHook list_rows/get_datasets_list can return iterator (#30543)
---
airflow/providers/google/cloud/hooks/bigquery.py | 33 +++++--
.../providers/google/cloud/hooks/test_bigquery.py | 109 ++++++++++++++-------
2 files changed, 99 insertions(+), 43 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py
b/airflow/providers/google/cloud/hooks/bigquery.py
index 3698b863d4..e25918d083 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -34,6 +34,7 @@ from typing import Any, Iterable, Mapping, NoReturn,
Sequence, Union, cast
from aiohttp import ClientSession as ClientSession
from gcloud.aio.bigquery import Job, Table as Table_async
+from google.api_core.page_iterator import HTTPIterator
from google.api_core.retry import Retry
from google.cloud.bigquery import (
DEFAULT_RETRY,
@@ -46,7 +47,7 @@ from google.cloud.bigquery import (
SchemaField,
)
from google.cloud.bigquery.dataset import AccessEntry, Dataset,
DatasetListItem, DatasetReference
-from google.cloud.bigquery.table import EncryptionConfiguration, Row, Table,
TableReference
+from google.cloud.bigquery.table import EncryptionConfiguration, Row,
RowIterator, Table, TableReference
from google.cloud.exceptions import NotFound
from googleapiclient.discovery import Resource, build
from pandas import DataFrame
@@ -1006,7 +1007,8 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
max_results: int | None = None,
page_token: str | None = None,
retry: Retry = DEFAULT_RETRY,
- ) -> list[DatasetListItem]:
+ return_iterator: bool = False,
+ ) -> list[DatasetListItem] | HTTPIterator:
"""
Method returns full list of BigQuery datasets in the current project
@@ -1026,8 +1028,10 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
``next_page_token`` of the
:class:`~google.api_core.page_iterator.HTTPIterator`.
:param page_token: str
:param retry: How to retry the RPC.
+ :param return_iterator: Instead of returning a list[Row], returns a
HTTPIterator
+ which can be used to obtain the next_page_token property.
"""
- datasets = self.get_client(project_id=project_id).list_datasets(
+ iterator = self.get_client(project_id=project_id).list_datasets(
project=project_id,
include_all=include_all,
filter=filter_,
@@ -1035,8 +1039,13 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
page_token=page_token,
retry=retry,
)
- datasets_list = list(datasets)
+ # If iterator is requested, we cannot perform a list() on it to log
the number
+ # of datasets because we will have started iteration
+ if return_iterator:
+ return iterator
+
+ datasets_list = list(iterator)
self.log.info("Datasets List: %s", len(datasets_list))
return datasets_list
@@ -1232,7 +1241,9 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
start_index: int | None = None,
project_id: str | None = None,
location: str | None = None,
- ) -> list[Row]:
+ retry: Retry = DEFAULT_RETRY,
+ return_iterator: bool = False,
+ ) -> list[Row] | RowIterator:
"""
List the rows of the table.
See
https://cloud.google.com/bigquery/docs/reference/rest/v2/tabledata/list
@@ -1247,6 +1258,9 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
:param start_index: zero based index of the starting row to read.
:param project_id: Project ID for the project which the client acts on
behalf of.
:param location: Default location for job.
+ :param retry: How to retry the RPC.
+ :param return_iterator: Instead of returning a list[Row], returns a
RowIterator
+ which can be used to obtain the next_page_token property.
:return: list of rows
"""
location = location or self.location
@@ -1265,14 +1279,17 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
table_id=table_id,
)
- result = self.get_client(project_id=project_id,
location=location).list_rows(
+ iterator = self.get_client(project_id=project_id,
location=location).list_rows(
table=Table.from_api_repr(table),
selected_fields=selected_fields,
max_results=max_results,
page_token=page_token,
start_index=start_index,
+ retry=retry,
)
- return list(result)
+ if return_iterator:
+ return iterator
+ return list(iterator)
@GoogleBaseHook.fallback_to_default_project_id
def get_schema(self, dataset_id: str, table_id: str, project_id: str |
None = None) -> dict:
@@ -2455,7 +2472,7 @@ class BigQueryBaseCursor(LoggingMixin):
)
return self.hook.get_dataset_tables_list(*args, **kwargs)
- def get_datasets_list(self, *args, **kwargs) -> list:
+ def get_datasets_list(self, *args, **kwargs) -> list | HTTPIterator:
"""
This method is deprecated.
Please use
`airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_datasets_list`
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py
b/tests/providers/google/cloud/hooks/test_bigquery.py
index 175b0e7c24..eef00d32c4 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -23,8 +23,10 @@ from unittest import mock
import pytest
from gcloud.aio.bigquery import Job, Table as Table_async
+from google.api_core import page_iterator
from google.cloud.bigquery import DEFAULT_RETRY, DatasetReference, Table,
TableReference
from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem
+from google.cloud.bigquery.table import _EmptyRowIterator
from google.cloud.exceptions import NotFound
from airflow.exceptions import AirflowException
@@ -431,45 +433,63 @@ class TestBigQueryHookMethods(_BigQueryBaseTestClass):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.SchemaField")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client")
def test_list_rows(self, mock_client, mock_schema, mock_table):
- self.hook.list_rows(
- dataset_id=DATASET_ID,
- table_id=TABLE_ID,
- max_results=10,
- selected_fields=["field_1", "field_2"],
- page_token="page123",
- start_index=5,
- location=LOCATION,
- )
- mock_table.from_api_repr.assert_called_once_with({"tableReference":
TABLE_REFERENCE_REPR})
- mock_schema.assert_has_calls([mock.call(x, "") for x in ["field_1",
"field_2"]])
- mock_client.return_value.list_rows.assert_called_once_with(
- table=mock_table.from_api_repr.return_value,
- max_results=10,
- selected_fields=mock.ANY,
- page_token="page123",
- start_index=5,
- )
+ mock_row_iterator = _EmptyRowIterator()
+ mock_client.return_value.list_rows.return_value = mock_row_iterator
+
+ for return_iterator, expected in [(False, []), (True,
mock_row_iterator)]:
+ actual = self.hook.list_rows(
+ dataset_id=DATASET_ID,
+ table_id=TABLE_ID,
+ max_results=10,
+ selected_fields=["field_1", "field_2"],
+ page_token="page123",
+ start_index=5,
+ location=LOCATION,
+ return_iterator=return_iterator,
+ )
+
mock_table.from_api_repr.assert_called_once_with({"tableReference":
TABLE_REFERENCE_REPR})
+ mock_schema.assert_has_calls([mock.call(x, "") for x in
["field_1", "field_2"]])
+ mock_client.return_value.list_rows.assert_called_once_with(
+ table=mock_table.from_api_repr.return_value,
+ max_results=10,
+ selected_fields=mock.ANY,
+ page_token="page123",
+ start_index=5,
+ retry=DEFAULT_RETRY,
+ )
+ assert actual == expected
+ mock_table.from_api_repr.reset_mock()
+ mock_client.return_value.list_rows.reset_mock()
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.Table")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client")
def test_list_rows_with_empty_selected_fields(self, mock_client,
mock_table):
- self.hook.list_rows(
- dataset_id=DATASET_ID,
- table_id=TABLE_ID,
- max_results=10,
- page_token="page123",
- selected_fields=[],
- start_index=5,
- location=LOCATION,
- )
- mock_table.from_api_repr.assert_called_once_with({"tableReference":
TABLE_REFERENCE_REPR})
- mock_client.return_value.list_rows.assert_called_once_with(
- table=mock_table.from_api_repr.return_value,
- max_results=10,
- page_token="page123",
- selected_fields=None,
- start_index=5,
- )
+ mock_row_iterator = _EmptyRowIterator()
+ mock_client.return_value.list_rows.return_value = mock_row_iterator
+
+ for return_iterator, expected in [(False, []), (True,
mock_row_iterator)]:
+ actual = self.hook.list_rows(
+ dataset_id=DATASET_ID,
+ table_id=TABLE_ID,
+ max_results=10,
+ page_token="page123",
+ selected_fields=[],
+ start_index=5,
+ location=LOCATION,
+ return_iterator=return_iterator,
+ )
+
mock_table.from_api_repr.assert_called_once_with({"tableReference":
TABLE_REFERENCE_REPR})
+ mock_client.return_value.list_rows.assert_called_once_with(
+ table=mock_table.from_api_repr.return_value,
+ max_results=10,
+ page_token="page123",
+ selected_fields=None,
+ start_index=5,
+ retry=DEFAULT_RETRY,
+ )
+ assert actual == expected
+ mock_table.from_api_repr.reset_mock()
+ mock_client.return_value.list_rows.reset_mock()
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client")
def test_run_table_delete(self, mock_client):
@@ -1553,6 +1573,25 @@ class TestDatasetsOperations(_BigQueryBaseTestClass):
for exp, res in zip(datasets, result):
assert res.full_dataset_id == exp["id"]
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client")
+ def test_get_datasets_list_returns_iterator(self, mock_client):
+ client = mock.sentinel.client
+ mock_iterator = page_iterator.HTTPIterator(
+ client, mock.sentinel.api_request, "/foo",
mock.sentinel.item_to_value
+ )
+ mock_client.return_value.list_datasets.return_value = mock_iterator
+ actual = self.hook.get_datasets_list(project_id=PROJECT_ID,
return_iterator=True)
+
+ mock_client.return_value.list_datasets.assert_called_once_with(
+ project=PROJECT_ID,
+ include_all=False,
+ filter=None,
+ max_results=None,
+ page_token=None,
+ retry=DEFAULT_RETRY,
+ )
+ assert actual == mock_iterator
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client")
def test_delete_dataset(self, mock_client):
delete_contents = True