This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0d6e626b05 Fix `BigQueryGetDataOperator`'s query job bugs in 
deferrable mode (#31433)
0d6e626b05 is described below

commit 0d6e626b050a860462224ad64dc5e9831fe8624d
Author: Shahar Epstein <[email protected]>
AuthorDate: Mon May 22 21:20:28 2023 +0300

    Fix `BigQueryGetDataOperator`'s query job bugs in deferrable mode (#31433)
---
 .../providers/google/cloud/operators/bigquery.py   | 14 ++++++----
 .../providers/google/cloud/triggers/bigquery.py    |  1 +
 .../google/cloud/operators/test_bigquery.py        | 32 ++++++++++++++++++++++
 .../google/cloud/triggers/test_bigquery.py         |  1 +
 4 files changed, 43 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/bigquery.py 
b/airflow/providers/google/cloud/operators/bigquery.py
index 8cf3489ccf..10aafd5d66 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -802,7 +802,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
     :param dataset_id: The dataset ID of the requested table. (templated)
     :param table_id: The table ID of the requested table. (templated)
     :param project_id: (Optional) The name of the project where the data
-        will be returned from. (templated)
+        will be returned from. If None, it will be derived from the hook's 
project ID. (templated)
     :param max_results: The maximum number of records (rows) to be fetched
         from the table. (templated)
     :param selected_fields: List of fields to return (comma-separated). If
@@ -872,7 +872,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
         hook: BigQueryHook,
         job_id: str,
     ) -> BigQueryJob:
-        get_query = self.generate_query()
+        get_query = self.generate_query(hook=hook)
         configuration = {"query": {"query": get_query, "useLegacySql": 
self.use_legacy_sql}}
         """Submit a new job and get the job id for polling the status using 
Triggerer."""
         return hook.insert_job(
@@ -883,17 +883,21 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
             nowait=True,
         )
 
-    def generate_query(self) -> str:
+    def generate_query(self, hook: BigQueryHook) -> str:
         """
         Generate a select query if selected fields are given or with *
         for the given dataset and table id
+        :param hook BigQuery Hook
         """
         query = "select "
         if self.selected_fields:
             query += self.selected_fields
         else:
             query += "*"
-        query += f" from `{self.project_id}.{self.dataset_id}.{self.table_id}` 
limit {self.max_results}"
+        query += (
+            f" from `{self.project_id or hook.project_id}.{self.dataset_id}"
+            f".{self.table_id}` limit {self.max_results}"
+        )
         return query
 
     def execute(self, context: Context):
@@ -906,7 +910,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
         if not self.deferrable:
             self.log.info(
                 "Fetching Data from %s.%s.%s max results: %s",
-                self.project_id,
+                self.project_id or hook.project_id,
                 self.dataset_id,
                 self.table_id,
                 self.max_results,
diff --git a/airflow/providers/google/cloud/triggers/bigquery.py 
b/airflow/providers/google/cloud/triggers/bigquery.py
index 1da7f87f90..c7b17af2ed 100644
--- a/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/airflow/providers/google/cloud/triggers/bigquery.py
@@ -187,6 +187,7 @@ class BigQueryGetDataTrigger(BigQueryInsertJobTrigger):
                 "project_id": self.project_id,
                 "table_id": self.table_id,
                 "poll_interval": self.poll_interval,
+                "as_dict": self.as_dict,
             },
         )
 
diff --git a/tests/providers/google/cloud/operators/test_bigquery.py 
b/tests/providers/google/cloud/operators/test_bigquery.py
index 1e871678a9..b0547cec3f 100644
--- a/tests/providers/google/cloud/operators/test_bigquery.py
+++ b/tests/providers/google/cloud/operators/test_bigquery.py
@@ -814,6 +814,38 @@ class TestBigQueryGetDataOperator:
             location=TEST_DATASET_LOCATION,
         )
 
+    
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+    def test_generate_query__with_project_id(self, mock_hook):
+        operator = BigQueryGetDataOperator(
+            gcp_conn_id=GCP_CONN_ID,
+            task_id=TASK_ID,
+            dataset_id=TEST_DATASET,
+            table_id=TEST_TABLE_ID,
+            project_id=TEST_GCP_PROJECT_ID,
+            max_results=100,
+            use_legacy_sql=False,
+        )
+        assert (
+            operator.generate_query(hook=mock_hook) == f"select * from 
`{TEST_GCP_PROJECT_ID}."
+            f"{TEST_DATASET}.{TEST_TABLE_ID}` limit 100"
+        )
+
+    
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
+    def test_generate_query__without_project_id(self, mock_hook):
+        hook_project_id = mock_hook.project_id
+        operator = BigQueryGetDataOperator(
+            gcp_conn_id=GCP_CONN_ID,
+            task_id=TASK_ID,
+            dataset_id=TEST_DATASET,
+            table_id=TEST_TABLE_ID,
+            max_results=100,
+            use_legacy_sql=False,
+        )
+        assert (
+            operator.generate_query(hook=mock_hook) == f"select * from 
`{hook_project_id}."
+            f"{TEST_DATASET}.{TEST_TABLE_ID}` limit 100"
+        )
+
     
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
     def test_bigquery_get_data_operator_async_with_selected_fields(
         self, mock_hook, create_task_instance_of_operator
diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py 
b/tests/providers/google/cloud/triggers/test_bigquery.py
index 410aa14b1a..cb997259bd 100644
--- a/tests/providers/google/cloud/triggers/test_bigquery.py
+++ b/tests/providers/google/cloud/triggers/test_bigquery.py
@@ -224,6 +224,7 @@ class TestBigQueryGetDataTrigger:
         classpath, kwargs = get_data_trigger.serialize()
         assert classpath == 
"airflow.providers.google.cloud.triggers.bigquery.BigQueryGetDataTrigger"
         assert kwargs == {
+            "as_dict": False,
             "conn_id": TEST_CONN_ID,
             "job_id": TEST_JOB_ID,
             "dataset_id": TEST_DATASET_ID,

Reply via email to