This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3c14753b03 Fix BigQueryGetDataOperator where project_id is not being
respected in deferrable mode (#32488)
3c14753b03 is described below
commit 3c14753b03872b259ce2248eda92f7fb6f4d751b
Author: Avinash Holla Pandeshwar <[email protected]>
AuthorDate: Thu Jul 20 23:52:45 2023 +0530
Fix BigQueryGetDataOperator where project_id is not being respected in
deferrable mode (#32488)
* fixing BigQueryGetDataOperator to respect project_id as compute project.
A new parameter table_project_id will be used for specifying table storage
project.
---
.../providers/google/cloud/operators/bigquery.py | 34 +++++++++++++++++-----
.../google/cloud/operators/test_bigquery.py | 13 ++++++---
2 files changed, 36 insertions(+), 11 deletions(-)
diff --git a/airflow/providers/google/cloud/operators/bigquery.py
b/airflow/providers/google/cloud/operators/bigquery.py
index 70ab30d61a..f5e5a9634f 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -810,7 +810,11 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
:param dataset_id: The dataset ID of the requested table. (templated)
:param table_id: The table ID of the requested table. (templated)
- :param project_id: (Optional) The name of the project where the data
+ :param table_project_id: (Optional) The project ID of the requested table.
+ If None, it will be derived from the hook's project ID. (templated)
+ :param job_project_id: (Optional) Google Cloud Project where the job is
running.
+ If None, it will be derived from the hook's project ID. (templated)
+ :param project_id: (Deprecated) (Optional) The name of the project where
the data
will be returned from. If None, it will be derived from the hook's
project ID. (templated)
:param max_results: The maximum number of records (rows) to be fetched
from the table. (templated)
@@ -837,6 +841,8 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
template_fields: Sequence[str] = (
"dataset_id",
"table_id",
+ "table_project_id",
+ "job_project_id",
"project_id",
"max_results",
"selected_fields",
@@ -849,6 +855,8 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
*,
dataset_id: str,
table_id: str,
+ table_project_id: str | None = None,
+ job_project_id: str | None = None,
project_id: str | None = None,
max_results: int = 100,
selected_fields: str | None = None,
@@ -863,8 +871,10 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
) -> None:
super().__init__(**kwargs)
+ self.table_project_id = table_project_id
self.dataset_id = dataset_id
self.table_id = table_id
+ self.job_project_id = job_project_id
self.max_results = int(max_results)
self.selected_fields = selected_fields
self.gcp_conn_id = gcp_conn_id
@@ -887,7 +897,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
return hook.insert_job(
configuration=configuration,
location=self.location,
- project_id=hook.project_id,
+ project_id=self.job_project_id or hook.project_id,
job_id=job_id,
nowait=True,
)
@@ -900,12 +910,22 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
else:
query += "*"
query += (
- f" from `{self.project_id or hook.project_id}.{self.dataset_id}"
+ f" from `{self.table_project_id or
hook.project_id}.{self.dataset_id}"
f".{self.table_id}` limit {self.max_results}"
)
return query
def execute(self, context: Context):
+ if self.project_id:
+ self.log.warning(
+ "The project_id parameter is deprecated, and will be removed
in a future release."
+ " Please use table_project_id instead.",
+ )
+ if not self.table_project_id:
+ self.table_project_id = self.project_id
+ else:
+ self.log.info("Ignoring project_id parameter, as
table_project_id is found.")
+
hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
@@ -915,7 +935,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
if not self.deferrable:
self.log.info(
"Fetching Data from %s.%s.%s max results: %s",
- self.project_id or hook.project_id,
+ self.table_project_id or hook.project_id,
self.dataset_id,
self.table_id,
self.max_results,
@@ -924,7 +944,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
schema: dict[str, list] = hook.get_schema(
dataset_id=self.dataset_id,
table_id=self.table_id,
- project_id=self.project_id,
+ project_id=self.table_project_id or hook.project_id,
)
if "fields" in schema:
self.selected_fields = ",".join([field["name"] for field
in schema["fields"]])
@@ -935,7 +955,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
max_results=self.max_results,
selected_fields=self.selected_fields,
location=self.location,
- project_id=self.project_id,
+ project_id=self.table_project_id or hook.project_id,
)
if isinstance(rows, RowIterator):
@@ -961,7 +981,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
job_id=job.job_id,
dataset_id=self.dataset_id,
table_id=self.table_id,
- project_id=hook.project_id,
+ project_id=self.job_project_id or hook.project_id,
poll_interval=self.poll_interval,
as_dict=self.as_dict,
),
diff --git a/tests/providers/google/cloud/operators/test_bigquery.py
b/tests/providers/google/cloud/operators/test_bigquery.py
index 226d1e7095..4026b4ba45 100644
--- a/tests/providers/google/cloud/operators/test_bigquery.py
+++ b/tests/providers/google/cloud/operators/test_bigquery.py
@@ -64,6 +64,7 @@ TASK_ID = "test-bq-generic-operator"
TEST_DATASET = "test-dataset"
TEST_DATASET_LOCATION = "EU"
TEST_GCP_PROJECT_ID = "test-project"
+TEST_JOB_PROJECT_ID = "test-job-project"
TEST_DELETE_CONTENTS = True
TEST_TABLE_ID = "test-table-id"
TEST_GCS_BUCKET = "test-bucket"
@@ -804,7 +805,7 @@ class TestBigQueryGetDataOperator:
task_id=TASK_ID,
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
- project_id=TEST_GCP_PROJECT_ID,
+ table_project_id=TEST_GCP_PROJECT_ID,
max_results=max_results,
selected_fields=selected_fields,
location=TEST_DATASET_LOCATION,
@@ -823,13 +824,13 @@ class TestBigQueryGetDataOperator:
)
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
- def test_generate_query__with_project_id(self, mock_hook):
+ def test_generate_query__with_table_project_id(self, mock_hook):
operator = BigQueryGetDataOperator(
gcp_conn_id=GCP_CONN_ID,
task_id=TASK_ID,
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
- project_id=TEST_GCP_PROJECT_ID,
+ table_project_id=TEST_GCP_PROJECT_ID,
max_results=100,
use_legacy_sql=False,
)
@@ -839,7 +840,7 @@ class TestBigQueryGetDataOperator:
)
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
- def test_generate_query__without_project_id(self, mock_hook):
+ def test_generate_query__without_table_project_id(self, mock_hook):
hook_project_id = mock_hook.project_id
operator = BigQueryGetDataOperator(
gcp_conn_id=GCP_CONN_ID,
@@ -868,6 +869,7 @@ class TestBigQueryGetDataOperator:
task_id="get_data_from_bq",
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
+ job_project_id=TEST_JOB_PROJECT_ID,
max_results=100,
selected_fields="value,name",
deferrable=True,
@@ -896,6 +898,7 @@ class TestBigQueryGetDataOperator:
task_id="get_data_from_bq",
dataset_id=TEST_DATASET,
table_id=TEST_TABLE_ID,
+ job_project_id=TEST_JOB_PROJECT_ID,
max_results=100,
deferrable=True,
as_dict=as_dict,
@@ -917,6 +920,7 @@ class TestBigQueryGetDataOperator:
task_id="get_data_from_bq",
dataset_id=TEST_DATASET,
table_id="any",
+ job_project_id=TEST_JOB_PROJECT_ID,
max_results=100,
deferrable=True,
as_dict=as_dict,
@@ -936,6 +940,7 @@ class TestBigQueryGetDataOperator:
task_id="get_data_from_bq",
dataset_id=TEST_DATASET,
table_id="any",
+ job_project_id=TEST_JOB_PROJECT_ID,
max_results=100,
deferrable=True,
as_dict=as_dict,