This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3c14753b03 Fix BigQueryGetDataOperator where project_id is not being 
respected in deferrable mode (#32488)
3c14753b03 is described below

commit 3c14753b03872b259ce2248eda92f7fb6f4d751b
Author: Avinash Holla Pandeshwar <[email protected]>
AuthorDate: Thu Jul 20 23:52:45 2023 +0530

    Fix BigQueryGetDataOperator where project_id is not being respected in 
deferrable mode (#32488)
    
    * fixing BigQueryGetDataOperator to respect project_id as compute project. 
A new parameter table_project_id will be used for specifying table storage 
project.
---
 .../providers/google/cloud/operators/bigquery.py   | 34 +++++++++++++++++-----
 .../google/cloud/operators/test_bigquery.py        | 13 ++++++---
 2 files changed, 36 insertions(+), 11 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/bigquery.py 
b/airflow/providers/google/cloud/operators/bigquery.py
index 70ab30d61a..f5e5a9634f 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -810,7 +810,11 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
 
     :param dataset_id: The dataset ID of the requested table. (templated)
     :param table_id: The table ID of the requested table. (templated)
-    :param project_id: (Optional) The name of the project where the data
+    :param table_project_id: (Optional) The project ID of the requested table.
+        If None, it will be derived from the hook's project ID. (templated)
+    :param job_project_id: (Optional) Google Cloud Project where the job is 
running.
+        If None, it will be derived from the hook's project ID. (templated)
+    :param project_id: (Deprecated) (Optional) The name of the project where 
the data
         will be returned from. If None, it will be derived from the hook's 
project ID. (templated)
     :param max_results: The maximum number of records (rows) to be fetched
         from the table. (templated)
@@ -837,6 +841,8 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
     template_fields: Sequence[str] = (
         "dataset_id",
         "table_id",
+        "table_project_id",
+        "job_project_id",
         "project_id",
         "max_results",
         "selected_fields",
@@ -849,6 +855,8 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
         *,
         dataset_id: str,
         table_id: str,
+        table_project_id: str | None = None,
+        job_project_id: str | None = None,
         project_id: str | None = None,
         max_results: int = 100,
         selected_fields: str | None = None,
@@ -863,8 +871,10 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
     ) -> None:
         super().__init__(**kwargs)
 
+        self.table_project_id = table_project_id
         self.dataset_id = dataset_id
         self.table_id = table_id
+        self.job_project_id = job_project_id
         self.max_results = int(max_results)
         self.selected_fields = selected_fields
         self.gcp_conn_id = gcp_conn_id
@@ -887,7 +897,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
         return hook.insert_job(
             configuration=configuration,
             location=self.location,
-            project_id=hook.project_id,
+            project_id=self.job_project_id or hook.project_id,
             job_id=job_id,
             nowait=True,
         )
@@ -900,12 +910,22 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
         else:
             query += "*"
         query += (
-            f" from `{self.project_id or hook.project_id}.{self.dataset_id}"
+            f" from `{self.table_project_id or 
hook.project_id}.{self.dataset_id}"
             f".{self.table_id}` limit {self.max_results}"
         )
         return query
 
     def execute(self, context: Context):
+        if self.project_id:
+            self.log.warning(
+                "The project_id parameter is deprecated, and will be removed 
in a future release."
+                " Please use table_project_id instead.",
+            )
+            if not self.table_project_id:
+                self.table_project_id = self.project_id
+            else:
+                self.log.info("Ignoring project_id parameter, as 
table_project_id is found.")
+
         hook = BigQueryHook(
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
@@ -915,7 +935,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
         if not self.deferrable:
             self.log.info(
                 "Fetching Data from %s.%s.%s max results: %s",
-                self.project_id or hook.project_id,
+                self.table_project_id or hook.project_id,
                 self.dataset_id,
                 self.table_id,
                 self.max_results,
@@ -924,7 +944,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
                 schema: dict[str, list] = hook.get_schema(
                     dataset_id=self.dataset_id,
                     table_id=self.table_id,
-                    project_id=self.project_id,
+                    project_id=self.table_project_id or hook.project_id,
                 )
                 if "fields" in schema:
                     self.selected_fields = ",".join([field["name"] for field 
in schema["fields"]])
@@ -935,7 +955,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
                 max_results=self.max_results,
                 selected_fields=self.selected_fields,
                 location=self.location,
-                project_id=self.project_id,
+                project_id=self.table_project_id or hook.project_id,
             )
 
             if isinstance(rows, RowIterator):
@@ -961,7 +981,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
                 job_id=job.job_id,
                 dataset_id=self.dataset_id,
                 table_id=self.table_id,
-                project_id=hook.project_id,
+                project_id=self.job_project_id or hook.project_id,
                 poll_interval=self.poll_interval,
                 as_dict=self.as_dict,
             ),
diff --git a/tests/providers/google/cloud/operators/test_bigquery.py 
b/tests/providers/google/cloud/operators/test_bigquery.py
index 226d1e7095..4026b4ba45 100644
--- a/tests/providers/google/cloud/operators/test_bigquery.py
+++ b/tests/providers/google/cloud/operators/test_bigquery.py
@@ -64,6 +64,7 @@ TASK_ID = "test-bq-generic-operator"
 TEST_DATASET = "test-dataset"
 TEST_DATASET_LOCATION = "EU"
 TEST_GCP_PROJECT_ID = "test-project"
+TEST_JOB_PROJECT_ID = "test-job-project"
 TEST_DELETE_CONTENTS = True
 TEST_TABLE_ID = "test-table-id"
 TEST_GCS_BUCKET = "test-bucket"
@@ -804,7 +805,7 @@ class TestBigQueryGetDataOperator:
             task_id=TASK_ID,
             dataset_id=TEST_DATASET,
             table_id=TEST_TABLE_ID,
-            project_id=TEST_GCP_PROJECT_ID,
+            table_project_id=TEST_GCP_PROJECT_ID,
             max_results=max_results,
             selected_fields=selected_fields,
             location=TEST_DATASET_LOCATION,
@@ -823,13 +824,13 @@ class TestBigQueryGetDataOperator:
         )
 
     
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
-    def test_generate_query__with_project_id(self, mock_hook):
+    def test_generate_query__with_table_project_id(self, mock_hook):
         operator = BigQueryGetDataOperator(
             gcp_conn_id=GCP_CONN_ID,
             task_id=TASK_ID,
             dataset_id=TEST_DATASET,
             table_id=TEST_TABLE_ID,
-            project_id=TEST_GCP_PROJECT_ID,
+            table_project_id=TEST_GCP_PROJECT_ID,
             max_results=100,
             use_legacy_sql=False,
         )
@@ -839,7 +840,7 @@ class TestBigQueryGetDataOperator:
         )
 
     
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
-    def test_generate_query__without_project_id(self, mock_hook):
+    def test_generate_query__without_table_project_id(self, mock_hook):
         hook_project_id = mock_hook.project_id
         operator = BigQueryGetDataOperator(
             gcp_conn_id=GCP_CONN_ID,
@@ -868,6 +869,7 @@ class TestBigQueryGetDataOperator:
             task_id="get_data_from_bq",
             dataset_id=TEST_DATASET,
             table_id=TEST_TABLE_ID,
+            job_project_id=TEST_JOB_PROJECT_ID,
             max_results=100,
             selected_fields="value,name",
             deferrable=True,
@@ -896,6 +898,7 @@ class TestBigQueryGetDataOperator:
             task_id="get_data_from_bq",
             dataset_id=TEST_DATASET,
             table_id=TEST_TABLE_ID,
+            job_project_id=TEST_JOB_PROJECT_ID,
             max_results=100,
             deferrable=True,
             as_dict=as_dict,
@@ -917,6 +920,7 @@ class TestBigQueryGetDataOperator:
             task_id="get_data_from_bq",
             dataset_id=TEST_DATASET,
             table_id="any",
+            job_project_id=TEST_JOB_PROJECT_ID,
             max_results=100,
             deferrable=True,
             as_dict=as_dict,
@@ -936,6 +940,7 @@ class TestBigQueryGetDataOperator:
             task_id="get_data_from_bq",
             dataset_id=TEST_DATASET,
             table_id="any",
+            job_project_id=TEST_JOB_PROJECT_ID,
             max_results=100,
             deferrable=True,
             as_dict=as_dict,

Reply via email to