This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4703f9a0e5 BigQueryHook list_rows/get_datasets_list can return 
iterator (#30543)
4703f9a0e5 is described below

commit 4703f9a0e589557f5176a6f466ae83fe52644cf6
Author: Victor Chiapaikeo <[email protected]>
AuthorDate: Sat Apr 8 13:01:57 2023 -0400

    BigQueryHook list_rows/get_datasets_list can return iterator (#30543)
---
 airflow/providers/google/cloud/hooks/bigquery.py   |  33 +++++--
 .../providers/google/cloud/hooks/test_bigquery.py  | 109 ++++++++++++++-------
 2 files changed, 99 insertions(+), 43 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/bigquery.py 
b/airflow/providers/google/cloud/hooks/bigquery.py
index 3698b863d4..e25918d083 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -34,6 +34,7 @@ from typing import Any, Iterable, Mapping, NoReturn, 
Sequence, Union, cast
 
 from aiohttp import ClientSession as ClientSession
 from gcloud.aio.bigquery import Job, Table as Table_async
+from google.api_core.page_iterator import HTTPIterator
 from google.api_core.retry import Retry
 from google.cloud.bigquery import (
     DEFAULT_RETRY,
@@ -46,7 +47,7 @@ from google.cloud.bigquery import (
     SchemaField,
 )
 from google.cloud.bigquery.dataset import AccessEntry, Dataset, 
DatasetListItem, DatasetReference
-from google.cloud.bigquery.table import EncryptionConfiguration, Row, Table, 
TableReference
+from google.cloud.bigquery.table import EncryptionConfiguration, Row, 
RowIterator, Table, TableReference
 from google.cloud.exceptions import NotFound
 from googleapiclient.discovery import Resource, build
 from pandas import DataFrame
@@ -1006,7 +1007,8 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
         max_results: int | None = None,
         page_token: str | None = None,
         retry: Retry = DEFAULT_RETRY,
-    ) -> list[DatasetListItem]:
+        return_iterator: bool = False,
+    ) -> list[DatasetListItem] | HTTPIterator:
         """
         Method returns full list of BigQuery datasets in the current project
 
@@ -1026,8 +1028,10 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
             ``next_page_token`` of the 
:class:`~google.api_core.page_iterator.HTTPIterator`.
         :param page_token: str
         :param retry: How to retry the RPC.
+        :param return_iterator: Instead of returning a list[Row], returns a 
HTTPIterator
+            which can be used to obtain the next_page_token property.
         """
-        datasets = self.get_client(project_id=project_id).list_datasets(
+        iterator = self.get_client(project_id=project_id).list_datasets(
             project=project_id,
             include_all=include_all,
             filter=filter_,
@@ -1035,8 +1039,13 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
             page_token=page_token,
             retry=retry,
         )
-        datasets_list = list(datasets)
 
+        # If iterator is requested, we cannot perform a list() on it to log 
the number
+        # of datasets because we will have started iteration
+        if return_iterator:
+            return iterator
+
+        datasets_list = list(iterator)
         self.log.info("Datasets List: %s", len(datasets_list))
         return datasets_list
 
@@ -1232,7 +1241,9 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
         start_index: int | None = None,
         project_id: str | None = None,
         location: str | None = None,
-    ) -> list[Row]:
+        retry: Retry = DEFAULT_RETRY,
+        return_iterator: bool = False,
+    ) -> list[Row] | RowIterator:
         """
         List the rows of the table.
         See 
https://cloud.google.com/bigquery/docs/reference/rest/v2/tabledata/list
@@ -1247,6 +1258,9 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
         :param start_index: zero based index of the starting row to read.
         :param project_id: Project ID for the project which the client acts on 
behalf of.
         :param location: Default location for job.
+        :param retry: How to retry the RPC.
+        :param return_iterator: Instead of returning a list[Row], returns a 
RowIterator
+            which can be used to obtain the next_page_token property.
         :return: list of rows
         """
         location = location or self.location
@@ -1265,14 +1279,17 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
             table_id=table_id,
         )
 
-        result = self.get_client(project_id=project_id, 
location=location).list_rows(
+        iterator = self.get_client(project_id=project_id, 
location=location).list_rows(
             table=Table.from_api_repr(table),
             selected_fields=selected_fields,
             max_results=max_results,
             page_token=page_token,
             start_index=start_index,
+            retry=retry,
         )
-        return list(result)
+        if return_iterator:
+            return iterator
+        return list(iterator)
 
     @GoogleBaseHook.fallback_to_default_project_id
     def get_schema(self, dataset_id: str, table_id: str, project_id: str | 
None = None) -> dict:
@@ -2455,7 +2472,7 @@ class BigQueryBaseCursor(LoggingMixin):
         )
         return self.hook.get_dataset_tables_list(*args, **kwargs)
 
-    def get_datasets_list(self, *args, **kwargs) -> list:
+    def get_datasets_list(self, *args, **kwargs) -> list | HTTPIterator:
         """
         This method is deprecated.
         Please use 
`airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_datasets_list`
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py 
b/tests/providers/google/cloud/hooks/test_bigquery.py
index 175b0e7c24..eef00d32c4 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -23,8 +23,10 @@ from unittest import mock
 
 import pytest
 from gcloud.aio.bigquery import Job, Table as Table_async
+from google.api_core import page_iterator
 from google.cloud.bigquery import DEFAULT_RETRY, DatasetReference, Table, 
TableReference
 from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem
+from google.cloud.bigquery.table import _EmptyRowIterator
 from google.cloud.exceptions import NotFound
 
 from airflow.exceptions import AirflowException
@@ -431,45 +433,63 @@ class TestBigQueryHookMethods(_BigQueryBaseTestClass):
     @mock.patch("airflow.providers.google.cloud.hooks.bigquery.SchemaField")
     @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client")
     def test_list_rows(self, mock_client, mock_schema, mock_table):
-        self.hook.list_rows(
-            dataset_id=DATASET_ID,
-            table_id=TABLE_ID,
-            max_results=10,
-            selected_fields=["field_1", "field_2"],
-            page_token="page123",
-            start_index=5,
-            location=LOCATION,
-        )
-        mock_table.from_api_repr.assert_called_once_with({"tableReference": 
TABLE_REFERENCE_REPR})
-        mock_schema.assert_has_calls([mock.call(x, "") for x in ["field_1", 
"field_2"]])
-        mock_client.return_value.list_rows.assert_called_once_with(
-            table=mock_table.from_api_repr.return_value,
-            max_results=10,
-            selected_fields=mock.ANY,
-            page_token="page123",
-            start_index=5,
-        )
+        mock_row_iterator = _EmptyRowIterator()
+        mock_client.return_value.list_rows.return_value = mock_row_iterator
+
+        for return_iterator, expected in [(False, []), (True, 
mock_row_iterator)]:
+            actual = self.hook.list_rows(
+                dataset_id=DATASET_ID,
+                table_id=TABLE_ID,
+                max_results=10,
+                selected_fields=["field_1", "field_2"],
+                page_token="page123",
+                start_index=5,
+                location=LOCATION,
+                return_iterator=return_iterator,
+            )
+            
mock_table.from_api_repr.assert_called_once_with({"tableReference": 
TABLE_REFERENCE_REPR})
+            mock_schema.assert_has_calls([mock.call(x, "") for x in 
["field_1", "field_2"]])
+            mock_client.return_value.list_rows.assert_called_once_with(
+                table=mock_table.from_api_repr.return_value,
+                max_results=10,
+                selected_fields=mock.ANY,
+                page_token="page123",
+                start_index=5,
+                retry=DEFAULT_RETRY,
+            )
+            assert actual == expected
+            mock_table.from_api_repr.reset_mock()
+            mock_client.return_value.list_rows.reset_mock()
 
     @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Table")
     @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client")
     def test_list_rows_with_empty_selected_fields(self, mock_client, 
mock_table):
-        self.hook.list_rows(
-            dataset_id=DATASET_ID,
-            table_id=TABLE_ID,
-            max_results=10,
-            page_token="page123",
-            selected_fields=[],
-            start_index=5,
-            location=LOCATION,
-        )
-        mock_table.from_api_repr.assert_called_once_with({"tableReference": 
TABLE_REFERENCE_REPR})
-        mock_client.return_value.list_rows.assert_called_once_with(
-            table=mock_table.from_api_repr.return_value,
-            max_results=10,
-            page_token="page123",
-            selected_fields=None,
-            start_index=5,
-        )
+        mock_row_iterator = _EmptyRowIterator()
+        mock_client.return_value.list_rows.return_value = mock_row_iterator
+
+        for return_iterator, expected in [(False, []), (True, 
mock_row_iterator)]:
+            actual = self.hook.list_rows(
+                dataset_id=DATASET_ID,
+                table_id=TABLE_ID,
+                max_results=10,
+                page_token="page123",
+                selected_fields=[],
+                start_index=5,
+                location=LOCATION,
+                return_iterator=return_iterator,
+            )
+            
mock_table.from_api_repr.assert_called_once_with({"tableReference": 
TABLE_REFERENCE_REPR})
+            mock_client.return_value.list_rows.assert_called_once_with(
+                table=mock_table.from_api_repr.return_value,
+                max_results=10,
+                page_token="page123",
+                selected_fields=None,
+                start_index=5,
+                retry=DEFAULT_RETRY,
+            )
+            assert actual == expected
+            mock_table.from_api_repr.reset_mock()
+            mock_client.return_value.list_rows.reset_mock()
 
     @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client")
     def test_run_table_delete(self, mock_client):
@@ -1553,6 +1573,25 @@ class TestDatasetsOperations(_BigQueryBaseTestClass):
         for exp, res in zip(datasets, result):
             assert res.full_dataset_id == exp["id"]
 
+    @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client")
+    def test_get_datasets_list_returns_iterator(self, mock_client):
+        client = mock.sentinel.client
+        mock_iterator = page_iterator.HTTPIterator(
+            client, mock.sentinel.api_request, "/foo", 
mock.sentinel.item_to_value
+        )
+        mock_client.return_value.list_datasets.return_value = mock_iterator
+        actual = self.hook.get_datasets_list(project_id=PROJECT_ID, 
return_iterator=True)
+
+        mock_client.return_value.list_datasets.assert_called_once_with(
+            project=PROJECT_ID,
+            include_all=False,
+            filter=None,
+            max_results=None,
+            page_token=None,
+            retry=DEFAULT_RETRY,
+        )
+        assert actual == mock_iterator
+
     @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client")
     def test_delete_dataset(self, mock_client):
         delete_contents = True

Reply via email to