This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b8f73768ec Add `as_dict` param to `BigQueryGetDataOperator` (#30887)
b8f73768ec is described below

commit b8f73768ec13f8d4cc1605cca3fa93be6caac473
Author: Shahar Epstein <[email protected]>
AuthorDate: Tue May 9 09:05:24 2023 +0300

    Add `as_dict` param to `BigQueryGetDataOperator` (#30887)
    
    * Add "as_dict" param to BigQueryGetDataOperator
---
 airflow/providers/google/cloud/hooks/bigquery.py   | 12 ++++++--
 .../providers/google/cloud/operators/bigquery.py   | 35 +++++++++++++++-------
 .../providers/google/cloud/triggers/bigquery.py    | 13 ++++++--
 .../operators/cloud/bigquery.rst                   |  7 +++--
 .../providers/google/cloud/hooks/test_bigquery.py  | 26 ++++++++++++++++
 .../google/cloud/operators/test_bigquery.py        | 16 +++++++---
 6 files changed, 87 insertions(+), 22 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/bigquery.py 
b/airflow/providers/google/cloud/hooks/bigquery.py
index f24562da7b..a091dd73fe 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -3125,20 +3125,26 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):
             job_query_resp = await job_client.query(query_request, 
cast(Session, session))
             return job_query_resp["jobReference"]["jobId"]
 
-    def get_records(self, query_results: dict[str, Any]) -> list[Any]:
+    def get_records(self, query_results: dict[str, Any], as_dict: bool = 
False) -> list[Any]:
         """
         Given the output query response from gcloud-aio bigquery, convert the 
response to records.
 
         :param query_results: the results from a SQL query
+        :param as_dict: if True returns the result as a list of dictionaries, 
otherwise as list of lists.
         """
-        buffer = []
+        buffer: list[Any] = []
         if "rows" in query_results and query_results["rows"]:
             rows = query_results["rows"]
             fields = query_results["schema"]["fields"]
             col_types = [field["type"] for field in fields]
             for dict_row in rows:
                 typed_row = [bq_cast(vs["v"], col_types[idx]) for idx, vs in 
enumerate(dict_row["f"])]
-                buffer.append(typed_row)
+                if not as_dict:
+                    buffer.append(typed_row)
+                else:
+                    fields_names = [field["name"] for field in fields]
+                    typed_row_dict = {k: v for k, v in zip(fields_names, 
typed_row)}
+                    buffer.append(typed_row_dict)
         return buffer
 
     def value_check(
diff --git a/airflow/providers/google/cloud/operators/bigquery.py 
b/airflow/providers/google/cloud/operators/bigquery.py
index 22150a6221..3d3d9719cc 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -758,12 +758,19 @@ class BigQueryTableCheckOperator(_BigQueryDbHookMixin, 
SQLTableCheckOperator):
 
 class BigQueryGetDataOperator(GoogleCloudBaseOperator):
     """
-    Fetches the data from a BigQuery table (alternatively fetch data for 
selected columns)
-    and returns data in a python list. The number of elements in the returned 
list will
-    be equal to the number of rows fetched. Each element in the list will 
again be a list
-    where element would represent the columns values for that row.
+    Fetches the data from a BigQuery table (alternatively fetch data for 
selected columns) and returns data
+    in either of the following two formats, based on "as_dict" value:
+    1. False (Default) - A Python list of lists, with the number of nested 
lists equal to the number of rows
+    fetched. Each nested list represents a row, where the elements within it 
correspond to the column values
+    for that particular row.
 
-    **Example Result**: ``[['Tony', '10'], ['Mike', '20'], ['Steve', '15']]``
+    **Example Result**: ``[['Tony', 10], ['Mike', 20]``
+
+
+    2. True - A Python list of dictionaries, where each dictionary represents 
a row. In each dictionary,
+    the keys are the column names and the values are the corresponding values 
for those columns.
+
+    **Example Result**: ``[{'name': 'Tony', 'age': 10}, {'name': 'Mike', 
'age': 20}]``
 
     .. seealso::
         For more information on how to use this operator, take a look at the 
guide:
@@ -810,6 +817,8 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
     :param deferrable: Run operator in the deferrable mode
     :param poll_interval: (Deferrable mode only) polling period in seconds to 
check for the status of job.
         Defaults to 4 seconds.
+    :param as_dict: if True returns the result as a list of dictionaries, 
otherwise as list of lists
+        (default: False).
     """
 
     template_fields: Sequence[str] = (
@@ -835,6 +844,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
         impersonation_chain: str | Sequence[str] | None = None,
         deferrable: bool = False,
         poll_interval: float = 4.0,
+        as_dict: bool = False,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -849,6 +859,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
         self.project_id = project_id
         self.deferrable = deferrable
         self.poll_interval = poll_interval
+        self.as_dict = as_dict
 
     def _submit_job(
         self,
@@ -884,7 +895,6 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
         )
-        self.hook = hook
 
         if not self.deferrable:
             self.log.info(
@@ -910,21 +920,26 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
 
             self.log.info("Total extracted rows: %s", len(rows))
 
-            table_data = [row.values() for row in rows]
+            if self.as_dict:
+                table_data = [{k: v for k, v in row.items()} for row in rows]
+            else:
+                table_data = [row.values() for row in rows]
+
             return table_data
 
         job = self._submit_job(hook, job_id="")
-        self.job_id = job.job_id
-        context["ti"].xcom_push(key="job_id", value=self.job_id)
+
+        context["ti"].xcom_push(key="job_id", value=job.job_id)
         self.defer(
             timeout=self.execution_timeout,
             trigger=BigQueryGetDataTrigger(
                 conn_id=self.gcp_conn_id,
-                job_id=self.job_id,
+                job_id=job.job_id,
                 dataset_id=self.dataset_id,
                 table_id=self.table_id,
                 project_id=hook.project_id,
                 poll_interval=self.poll_interval,
+                as_dict=self.as_dict,
             ),
             method_name="execute_complete",
         )
diff --git a/airflow/providers/google/cloud/triggers/bigquery.py 
b/airflow/providers/google/cloud/triggers/bigquery.py
index ba4ce8c19b..1da7f87f90 100644
--- a/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/airflow/providers/google/cloud/triggers/bigquery.py
@@ -165,7 +165,16 @@ class BigQueryCheckTrigger(BigQueryInsertJobTrigger):
 
 
 class BigQueryGetDataTrigger(BigQueryInsertJobTrigger):
-    """BigQueryGetDataTrigger run on the trigger worker, inherits from 
BigQueryInsertJobTrigger class"""
+    """
+    BigQueryGetDataTrigger run on the trigger worker, inherits from 
BigQueryInsertJobTrigger class
+
+    :param as_dict: if True returns the result as a list of dictionaries, 
otherwise as list of lists
+        (default: False).
+    """
+
+    def __init__(self, as_dict: bool = False, **kwargs):
+        super().__init__(**kwargs)
+        self.as_dict = as_dict
 
     def serialize(self) -> tuple[str, dict[str, Any]]:
         """Serializes BigQueryInsertJobTrigger arguments and classpath."""
@@ -190,7 +199,7 @@ class BigQueryGetDataTrigger(BigQueryInsertJobTrigger):
                 response_from_hook = await 
hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
                 if response_from_hook == "success":
                     query_results = await 
hook.get_job_output(job_id=self.job_id, project_id=self.project_id)
-                    records = hook.get_records(query_results)
+                    records = hook.get_records(query_results=query_results, 
as_dict=self.as_dict)
                     self.log.debug("Response from hook: %s", 
response_from_hook)
                     yield TriggerEvent(
                         {
diff --git a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst 
b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst
index 61f4439fe1..6529ee4522 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst
@@ -208,10 +208,11 @@ To fetch data from a BigQuery table you can use
 Alternatively you can fetch data for selected columns if you pass fields to
 ``selected_fields``.
 
-This operator returns data in a Python list where the number of elements in the
-returned list will be equal to the number of rows fetched. Each element in the
-list will again be a list where elements would represent the column values for
+The result of this operator can be retrieved in two different formats based on 
the value of the ``as_dict`` parameter:
+``False`` (default) - A Python list of lists, where the number of elements in 
the nesting list will be equal to the number of rows fetched. Each element in 
the
+nesting will a nested list where elements would represent the column values for
 that row.
+``True`` - A Python list of dictionaries, where each dictionary represents a 
row. In each dictionary, the keys are the column names and the values are the 
corresponding values for those columns.
 
 .. exampleinclude:: 
/../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
     :language: python
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py 
b/tests/providers/google/cloud/hooks/test_bigquery.py
index 0e09080fa7..a508f14249 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -2348,3 +2348,29 @@ class 
TestBigQueryAsyncHookMethods(_BigQueryBaseAsyncTestClass):
         assert isinstance(result[0][0], int)
         assert isinstance(result[0][1], float)
         assert isinstance(result[0][2], str)
+
+    def test_get_records_as_dict(self):
+        query_result = {
+            "kind": "bigquery#getQueryResultsResponse",
+            "etag": "test_etag",
+            "schema": {
+                "fields": [
+                    {"name": "f0_", "type": "INTEGER", "mode": "NULLABLE"},
+                    {"name": "f1_", "type": "FLOAT", "mode": "NULLABLE"},
+                    {"name": "f2_", "type": "STRING", "mode": "NULLABLE"},
+                ]
+            },
+            "jobReference": {
+                "projectId": "test_airflow-providers",
+                "jobId": "test_jobid",
+                "location": "US",
+            },
+            "totalRows": "1",
+            "rows": [{"f": [{"v": "22"}, {"v": "3.14"}, {"v": "PI"}]}],
+            "totalBytesProcessed": "0",
+            "jobComplete": True,
+            "cacheHit": False,
+        }
+        hook = BigQueryAsyncHook()
+        result = hook.get_records(query_result, as_dict=True)
+        assert result == [{"f0_": 22, "f1_": 3.14, "f2_": "PI"}]
diff --git a/tests/providers/google/cloud/operators/test_bigquery.py 
b/tests/providers/google/cloud/operators/test_bigquery.py
index d5fafb6995..c5051d92e6 100644
--- a/tests/providers/google/cloud/operators/test_bigquery.py
+++ b/tests/providers/google/cloud/operators/test_bigquery.py
@@ -785,8 +785,9 @@ class TestBigQueryOperator:
 
 
 class TestBigQueryGetDataOperator:
+    @pytest.mark.parametrize("as_dict", [True, False])
     
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
-    def test_execute(self, mock_hook):
+    def test_execute(self, mock_hook, as_dict):
         max_results = 100
         selected_fields = "DATE"
         operator = BigQueryGetDataOperator(
@@ -797,6 +798,7 @@ class TestBigQueryGetDataOperator:
             max_results=max_results,
             selected_fields=selected_fields,
             location=TEST_DATASET_LOCATION,
+            as_dict=as_dict,
         )
         operator.execute(None)
         mock_hook.return_value.list_rows.assert_called_once_with(
@@ -840,9 +842,10 @@ class TestBigQueryGetDataOperator:
             exc.value.trigger, BigQueryGetDataTrigger
         ), "Trigger is not a BigQueryGetDataTrigger"
 
+    @pytest.mark.parametrize("as_dict", [True, False])
     
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
     def test_bigquery_get_data_operator_async_without_selected_fields(
-        self, mock_hook, create_task_instance_of_operator
+        self, mock_hook, create_task_instance_of_operator, as_dict
     ):
         """
         Asserts that a task is deferred and a BigQueryGetDataTrigger will be 
fired
@@ -862,6 +865,7 @@ class TestBigQueryGetDataOperator:
             table_id=TEST_TABLE_ID,
             max_results=100,
             deferrable=True,
+            as_dict=as_dict,
         )
 
         with pytest.raises(TaskDeferred) as exc:
@@ -871,7 +875,8 @@ class TestBigQueryGetDataOperator:
             exc.value.trigger, BigQueryGetDataTrigger
         ), "Trigger is not a BigQueryGetDataTrigger"
 
-    def test_bigquery_get_data_operator_execute_failure(self):
+    @pytest.mark.parametrize("as_dict", [True, False])
+    def test_bigquery_get_data_operator_execute_failure(self, as_dict):
         """Tests that an AirflowException is raised in case of error event"""
 
         operator = BigQueryGetDataOperator(
@@ -880,6 +885,7 @@ class TestBigQueryGetDataOperator:
             table_id="any",
             max_results=100,
             deferrable=True,
+            as_dict=as_dict,
         )
 
         with pytest.raises(AirflowException):
@@ -887,7 +893,8 @@ class TestBigQueryGetDataOperator:
                 context=None, event={"status": "error", "message": "test 
failure message"}
             )
 
-    def test_bigquery_get_data_op_execute_complete_with_records(self):
+    @pytest.mark.parametrize("as_dict", [True, False])
+    def test_bigquery_get_data_op_execute_complete_with_records(self, as_dict):
         """Asserts that exception is raised with correct expected exception 
message"""
 
         operator = BigQueryGetDataOperator(
@@ -896,6 +903,7 @@ class TestBigQueryGetDataOperator:
             table_id="any",
             max_results=100,
             deferrable=True,
+            as_dict=as_dict,
         )
 
         with mock.patch.object(operator.log, "info") as mock_log_info:

Reply via email to