This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b8f73768ec Add `as_dict` param to `BigQueryGetDataOperator` (#30887)
b8f73768ec is described below
commit b8f73768ec13f8d4cc1605cca3fa93be6caac473
Author: Shahar Epstein <[email protected]>
AuthorDate: Tue May 9 09:05:24 2023 +0300
Add `as_dict` param to `BigQueryGetDataOperator` (#30887)
* Add "as_dict" param to BigQueryGetDataOperator
---
airflow/providers/google/cloud/hooks/bigquery.py | 12 ++++++--
.../providers/google/cloud/operators/bigquery.py | 35 +++++++++++++++-------
.../providers/google/cloud/triggers/bigquery.py | 13 ++++++--
.../operators/cloud/bigquery.rst | 7 +++--
.../providers/google/cloud/hooks/test_bigquery.py | 26 ++++++++++++++++
.../google/cloud/operators/test_bigquery.py | 16 +++++++---
6 files changed, 87 insertions(+), 22 deletions(-)
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py
b/airflow/providers/google/cloud/hooks/bigquery.py
index f24562da7b..a091dd73fe 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -3125,20 +3125,26 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):
job_query_resp = await job_client.query(query_request,
cast(Session, session))
return job_query_resp["jobReference"]["jobId"]
- def get_records(self, query_results: dict[str, Any]) -> list[Any]:
+ def get_records(self, query_results: dict[str, Any], as_dict: bool =
False) -> list[Any]:
"""
Given the output query response from gcloud-aio bigquery, convert the
response to records.
:param query_results: the results from a SQL query
+ :param as_dict: if True returns the result as a list of dictionaries,
otherwise as list of lists.
"""
- buffer = []
+ buffer: list[Any] = []
if "rows" in query_results and query_results["rows"]:
rows = query_results["rows"]
fields = query_results["schema"]["fields"]
col_types = [field["type"] for field in fields]
for dict_row in rows:
typed_row = [bq_cast(vs["v"], col_types[idx]) for idx, vs in
enumerate(dict_row["f"])]
- buffer.append(typed_row)
+ if not as_dict:
+ buffer.append(typed_row)
+ else:
+ fields_names = [field["name"] for field in fields]
+ typed_row_dict = {k: v for k, v in zip(fields_names,
typed_row)}
+ buffer.append(typed_row_dict)
return buffer
def value_check(
diff --git a/airflow/providers/google/cloud/operators/bigquery.py
b/airflow/providers/google/cloud/operators/bigquery.py
index 22150a6221..3d3d9719cc 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -758,12 +758,19 @@ class BigQueryTableCheckOperator(_BigQueryDbHookMixin,
SQLTableCheckOperator):
class BigQueryGetDataOperator(GoogleCloudBaseOperator):
"""
- Fetches the data from a BigQuery table (alternatively fetch data for
selected columns)
- and returns data in a python list. The number of elements in the returned
list will
- be equal to the number of rows fetched. Each element in the list will
again be a list
- where element would represent the columns values for that row.
+ Fetches the data from a BigQuery table (alternatively fetch data for
selected columns) and returns data
+ in either of the following two formats, based on "as_dict" value:
+ 1. False (Default) - A Python list of lists, with the number of nested
lists equal to the number of rows
+ fetched. Each nested list represents a row, where the elements within it
correspond to the column values
+ for that particular row.
- **Example Result**: ``[['Tony', '10'], ['Mike', '20'], ['Steve', '15']]``
+ **Example Result**: ``[['Tony', 10], ['Mike', 20]``
+
+
+ 2. True - A Python list of dictionaries, where each dictionary represents
a row. In each dictionary,
+ the keys are the column names and the values are the corresponding values
for those columns.
+
+ **Example Result**: ``[{'name': 'Tony', 'age': 10}, {'name': 'Mike',
'age': 20}]``
.. seealso::
For more information on how to use this operator, take a look at the
guide:
@@ -810,6 +817,8 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
:param deferrable: Run operator in the deferrable mode
:param poll_interval: (Deferrable mode only) polling period in seconds to
check for the status of job.
Defaults to 4 seconds.
+ :param as_dict: if True returns the result as a list of dictionaries,
otherwise as list of lists
+ (default: False).
"""
template_fields: Sequence[str] = (
@@ -835,6 +844,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
poll_interval: float = 4.0,
+ as_dict: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -849,6 +859,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
self.project_id = project_id
self.deferrable = deferrable
self.poll_interval = poll_interval
+ self.as_dict = as_dict
def _submit_job(
self,
@@ -884,7 +895,6 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
- self.hook = hook
if not self.deferrable:
self.log.info(
@@ -910,21 +920,26 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
self.log.info("Total extracted rows: %s", len(rows))
- table_data = [row.values() for row in rows]
+ if self.as_dict:
+ table_data = [{k: v for k, v in row.items()} for row in rows]
+ else:
+ table_data = [row.values() for row in rows]
+
return table_data
job = self._submit_job(hook, job_id="")
- self.job_id = job.job_id
- context["ti"].xcom_push(key="job_id", value=self.job_id)
+
+ context["ti"].xcom_push(key="job_id", value=job.job_id)
self.defer(
timeout=self.execution_timeout,
trigger=BigQueryGetDataTrigger(
conn_id=self.gcp_conn_id,
- job_id=self.job_id,
+ job_id=job.job_id,
dataset_id=self.dataset_id,
table_id=self.table_id,
project_id=hook.project_id,
poll_interval=self.poll_interval,
+ as_dict=self.as_dict,
),
method_name="execute_complete",
)
diff --git a/airflow/providers/google/cloud/triggers/bigquery.py
b/airflow/providers/google/cloud/triggers/bigquery.py
index ba4ce8c19b..1da7f87f90 100644
--- a/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/airflow/providers/google/cloud/triggers/bigquery.py
@@ -165,7 +165,16 @@ class BigQueryCheckTrigger(BigQueryInsertJobTrigger):
class BigQueryGetDataTrigger(BigQueryInsertJobTrigger):
- """BigQueryGetDataTrigger run on the trigger worker, inherits from
BigQueryInsertJobTrigger class"""
+ """
+ BigQueryGetDataTrigger run on the trigger worker, inherits from
BigQueryInsertJobTrigger class
+
+ :param as_dict: if True returns the result as a list of dictionaries,
otherwise as list of lists
+ (default: False).
+ """
+
+ def __init__(self, as_dict: bool = False, **kwargs):
+ super().__init__(**kwargs)
+ self.as_dict = as_dict
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes BigQueryInsertJobTrigger arguments and classpath."""
@@ -190,7 +199,7 @@ class BigQueryGetDataTrigger(BigQueryInsertJobTrigger):
response_from_hook = await
hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
if response_from_hook == "success":
query_results = await
hook.get_job_output(job_id=self.job_id, project_id=self.project_id)
- records = hook.get_records(query_results)
+ records = hook.get_records(query_results=query_results,
as_dict=self.as_dict)
self.log.debug("Response from hook: %s",
response_from_hook)
yield TriggerEvent(
{
diff --git a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst
b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst
index 61f4439fe1..6529ee4522 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/bigquery.rst
@@ -208,10 +208,11 @@ To fetch data from a BigQuery table you can use
Alternatively you can fetch data for selected columns if you pass fields to
``selected_fields``.
-This operator returns data in a Python list where the number of elements in the
-returned list will be equal to the number of rows fetched. Each element in the
-list will again be a list where elements would represent the column values for
+The result of this operator can be retrieved in two different formats based on
the value of the ``as_dict`` parameter:
+``False`` (default) - A Python list of lists, where the number of elements in
the nesting list will be equal to the number of rows fetched. Each element in
the
+nesting will a nested list where elements would represent the column values for
that row.
+``True`` - A Python list of dictionaries, where each dictionary represents a
row. In each dictionary, the keys are the column names and the values are the
corresponding values for those columns.
.. exampleinclude::
/../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
:language: python
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py
b/tests/providers/google/cloud/hooks/test_bigquery.py
index 0e09080fa7..a508f14249 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -2348,3 +2348,29 @@ class
TestBigQueryAsyncHookMethods(_BigQueryBaseAsyncTestClass):
assert isinstance(result[0][0], int)
assert isinstance(result[0][1], float)
assert isinstance(result[0][2], str)
+
+ def test_get_records_as_dict(self):
+ query_result = {
+ "kind": "bigquery#getQueryResultsResponse",
+ "etag": "test_etag",
+ "schema": {
+ "fields": [
+ {"name": "f0_", "type": "INTEGER", "mode": "NULLABLE"},
+ {"name": "f1_", "type": "FLOAT", "mode": "NULLABLE"},
+ {"name": "f2_", "type": "STRING", "mode": "NULLABLE"},
+ ]
+ },
+ "jobReference": {
+ "projectId": "test_airflow-providers",
+ "jobId": "test_jobid",
+ "location": "US",
+ },
+ "totalRows": "1",
+ "rows": [{"f": [{"v": "22"}, {"v": "3.14"}, {"v": "PI"}]}],
+ "totalBytesProcessed": "0",
+ "jobComplete": True,
+ "cacheHit": False,
+ }
+ hook = BigQueryAsyncHook()
+ result = hook.get_records(query_result, as_dict=True)
+ assert result == [{"f0_": 22, "f1_": 3.14, "f2_": "PI"}]
diff --git a/tests/providers/google/cloud/operators/test_bigquery.py
b/tests/providers/google/cloud/operators/test_bigquery.py
index d5fafb6995..c5051d92e6 100644
--- a/tests/providers/google/cloud/operators/test_bigquery.py
+++ b/tests/providers/google/cloud/operators/test_bigquery.py
@@ -785,8 +785,9 @@ class TestBigQueryOperator:
class TestBigQueryGetDataOperator:
+ @pytest.mark.parametrize("as_dict", [True, False])
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
- def test_execute(self, mock_hook):
+ def test_execute(self, mock_hook, as_dict):
max_results = 100
selected_fields = "DATE"
operator = BigQueryGetDataOperator(
@@ -797,6 +798,7 @@ class TestBigQueryGetDataOperator:
max_results=max_results,
selected_fields=selected_fields,
location=TEST_DATASET_LOCATION,
+ as_dict=as_dict,
)
operator.execute(None)
mock_hook.return_value.list_rows.assert_called_once_with(
@@ -840,9 +842,10 @@ class TestBigQueryGetDataOperator:
exc.value.trigger, BigQueryGetDataTrigger
), "Trigger is not a BigQueryGetDataTrigger"
+ @pytest.mark.parametrize("as_dict", [True, False])
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_get_data_operator_async_without_selected_fields(
- self, mock_hook, create_task_instance_of_operator
+ self, mock_hook, create_task_instance_of_operator, as_dict
):
"""
Asserts that a task is deferred and a BigQueryGetDataTrigger will be
fired
@@ -862,6 +865,7 @@ class TestBigQueryGetDataOperator:
table_id=TEST_TABLE_ID,
max_results=100,
deferrable=True,
+ as_dict=as_dict,
)
with pytest.raises(TaskDeferred) as exc:
@@ -871,7 +875,8 @@ class TestBigQueryGetDataOperator:
exc.value.trigger, BigQueryGetDataTrigger
), "Trigger is not a BigQueryGetDataTrigger"
- def test_bigquery_get_data_operator_execute_failure(self):
+ @pytest.mark.parametrize("as_dict", [True, False])
+ def test_bigquery_get_data_operator_execute_failure(self, as_dict):
"""Tests that an AirflowException is raised in case of error event"""
operator = BigQueryGetDataOperator(
@@ -880,6 +885,7 @@ class TestBigQueryGetDataOperator:
table_id="any",
max_results=100,
deferrable=True,
+ as_dict=as_dict,
)
with pytest.raises(AirflowException):
@@ -887,7 +893,8 @@ class TestBigQueryGetDataOperator:
context=None, event={"status": "error", "message": "test
failure message"}
)
- def test_bigquery_get_data_op_execute_complete_with_records(self):
+ @pytest.mark.parametrize("as_dict", [True, False])
+ def test_bigquery_get_data_op_execute_complete_with_records(self, as_dict):
"""Asserts that exception is raised with correct expected exception
message"""
operator = BigQueryGetDataOperator(
@@ -896,6 +903,7 @@ class TestBigQueryGetDataOperator:
table_id="any",
max_results=100,
deferrable=True,
+ as_dict=as_dict,
)
with mock.patch.object(operator.log, "info") as mock_log_info: