This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d43c804f2b Add optional 'location' parameter to the 
BigQueryInsertJobTrigger (#37282)
d43c804f2b is described below

commit d43c804f2bda3bc518682c9b2af94ea30475c879
Author: max <[email protected]>
AuthorDate: Mon Feb 12 11:33:36 2024 +0100

    Add optional 'location' parameter to the BigQueryInsertJobTrigger (#37282)
---
 airflow/providers/google/cloud/hooks/bigquery.py   | 63 ++++++++++++++++++----
 .../providers/google/cloud/operators/bigquery.py   |  5 ++
 .../google/cloud/transfers/bigquery_to_gcs.py      |  1 +
 .../google/cloud/transfers/gcs_to_bigquery.py      |  1 +
 .../providers/google/cloud/triggers/bigquery.py    | 18 ++++++-
 .../providers/google/cloud/hooks/test_bigquery.py  | 19 +++----
 .../google/cloud/triggers/test_bigquery.py         | 35 ++++++------
 7 files changed, 103 insertions(+), 39 deletions(-)

diff --git a/airflow/providers/google/cloud/hooks/bigquery.py 
b/airflow/providers/google/cloud/hooks/bigquery.py
index 729d629c31..2b252f3c3f 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -20,6 +20,7 @@
 
 from __future__ import annotations
 
+import asyncio
 import json
 import logging
 import re
@@ -3242,16 +3243,58 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):
             session=cast(Session, session),
         )
 
-    async def get_job_status(self, job_id: str | None, project_id: str | None 
= None) -> dict[str, str]:
-        async with ClientSession() as s:
-            job_client = await self.get_job_instance(project_id, job_id, s)
-            job = await job_client.get_job()
-            status = job.get("status", {})
-            if status["state"] == "DONE":
-                if "errorResult" in status:
-                    return {"status": "error", "message": 
status["errorResult"]["message"]}
-                return {"status": "success", "message": "Job completed"}
-            return {"status": status["state"].lower(), "message": "Job 
running"}
+    async def _get_job(
+        self, job_id: str | None, project_id: str | None = None, location: str 
| None = None
+    ) -> CopyJob | QueryJob | LoadJob | ExtractJob | UnknownJob:
+        """
+        Get BigQuery job by its ID, project ID and location.
+
+        WARNING.
+        This is a temporary workaround  for issues below, and it's not 
intended to be used elsewhere!
+            https://github.com/apache/airflow/issues/35833
+            https://github.com/talkiq/gcloud-aio/issues/584
+
+        This method was developed, because neither the `google-cloud-bigquery` 
nor the `gcloud-aio-bigquery`
+        provides asynchronous access to a BigQuery jobs with location 
parameter. That's why this method wraps
+        synchronous client call with the event loop's run_in_executor() method.
+
+        This workaround must be deleted along with the method _get_job_sync() 
and replaced by more robust and
+        cleaner solution in one of two cases:
+            1. The `google-cloud-bigquery` library provides async client with 
get_job method, that supports
+            optional parameter `location`
+            2. The `gcloud-aio-bigquery` library supports the `location` 
parameter in get_job() method.
+        """
+        loop = asyncio.get_event_loop()
+        job = await loop.run_in_executor(None, self._get_job_sync, job_id, 
project_id, location)
+        return job
+
+    def _get_job_sync(self, job_id, project_id, location):
+        """
+        Get BigQuery job by its ID, project ID and location synchronously.
+
+        WARNING
+        This is a temporary workaround  for issues below, and it's not 
intended to be used elsewhere!
+            https://github.com/apache/airflow/issues/35833
+            https://github.com/talkiq/gcloud-aio/issues/584
+
+        This workaround must be deleted along with the method _get_job() and 
replaced by more robust and
+        cleaner solution in one of two cases:
+            1. The `google-cloud-bigquery` library provides async client with 
get_job method, that supports
+            optional parameter `location`
+            2. The `gcloud-aio-bigquery` library supports the `location` 
parameter in get_job() method.
+        """
+        hook = BigQueryHook(**self._hook_kwargs)
+        return hook.get_job(job_id=job_id, project_id=project_id, 
location=location)
+
+    async def get_job_status(
+        self, job_id: str | None, project_id: str | None = None, location: str 
| None = None
+    ) -> dict[str, str]:
+        job = await self._get_job(job_id=job_id, project_id=project_id, 
location=location)
+        if job.state == "DONE":
+            if job.error_result:
+                return {"status": "error", "message": 
job.error_result["message"]}
+            return {"status": "success", "message": "Job completed"}
+        return {"status": str(job.state).lower(), "message": "Job running"}
 
     async def get_job_output(
         self,
diff --git a/airflow/providers/google/cloud/operators/bigquery.py 
b/airflow/providers/google/cloud/operators/bigquery.py
index fc90bbeed9..b391a1508c 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -313,6 +313,7 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin, 
SQLCheckOperator):
                         conn_id=self.gcp_conn_id,
                         job_id=job.job_id,
                         project_id=hook.project_id,
+                        location=self.location or hook.location,
                         poll_interval=self.poll_interval,
                         impersonation_chain=self.impersonation_chain,
                     ),
@@ -438,6 +439,7 @@ class BigQueryValueCheckOperator(_BigQueryDbHookMixin, 
SQLValueCheckOperator):
                         conn_id=self.gcp_conn_id,
                         job_id=job.job_id,
                         project_id=hook.project_id,
+                        location=self.location or hook.location,
                         sql=self.sql,
                         pass_value=self.pass_value,
                         tolerance=self.tol,
@@ -594,6 +596,7 @@ class BigQueryIntervalCheckOperator(_BigQueryDbHookMixin, 
SQLIntervalCheckOperat
                     second_job_id=job_2.job_id,
                     project_id=hook.project_id,
                     table=self.table,
+                    location=self.location or hook.location,
                     metrics_thresholds=self.metrics_thresholds,
                     date_filter_column=self.date_filter_column,
                     days_back=self.days_back,
@@ -1068,6 +1071,7 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
                 dataset_id=self.dataset_id,
                 table_id=self.table_id,
                 project_id=self.job_project_id or hook.project_id,
+                location=self.location or hook.location,
                 poll_interval=self.poll_interval,
                 as_dict=self.as_dict,
                 impersonation_chain=self.impersonation_chain,
@@ -2876,6 +2880,7 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator, 
_BigQueryOpenLineageMix
                         conn_id=self.gcp_conn_id,
                         job_id=self.job_id,
                         project_id=self.project_id,
+                        location=self.location or hook.location,
                         poll_interval=self.poll_interval,
                         impersonation_chain=self.impersonation_chain,
                     ),
diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py 
b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
index 3ede4db32f..aeb7f46f6e 100644
--- a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
@@ -261,6 +261,7 @@ class BigQueryToGCSOperator(BaseOperator):
                     conn_id=self.gcp_conn_id,
                     job_id=self._job_id,
                     project_id=self.project_id or self.hook.project_id,
+                    location=self.location or self.hook.location,
                     impersonation_chain=self.impersonation_chain,
                 ),
                 method_name="execute_complete",
diff --git a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py 
b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
index 9d8ce53f4c..4e1d6e0919 100644
--- a/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
+++ b/airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
@@ -435,6 +435,7 @@ class GCSToBigQueryOperator(BaseOperator):
                         conn_id=self.gcp_conn_id,
                         job_id=self.job_id,
                         project_id=self.project_id or self.hook.project_id,
+                        location=self.location or self.hook.location,
                         impersonation_chain=self.impersonation_chain,
                     ),
                     method_name="execute_complete",
diff --git a/airflow/providers/google/cloud/triggers/bigquery.py 
b/airflow/providers/google/cloud/triggers/bigquery.py
index 3d73e91e98..bc9e812d1b 100644
--- a/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/airflow/providers/google/cloud/triggers/bigquery.py
@@ -33,6 +33,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
     :param conn_id: Reference to google cloud connection id
     :param job_id:  The ID of the job. It will be suffixed with hash of job 
configuration
     :param project_id: Google Cloud Project where the job is running
+    :param location: The dataset location.
     :param dataset_id: The dataset ID of the requested table. (templated)
     :param table_id: The table ID of the requested table. (templated)
     :param poll_interval: polling period in seconds to check for the status. 
(templated)
@@ -51,6 +52,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
         conn_id: str,
         job_id: str | None,
         project_id: str | None,
+        location: str | None,
         dataset_id: str | None = None,
         table_id: str | None = None,
         poll_interval: float = 4.0,
@@ -63,6 +65,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
         self._job_conn = None
         self.dataset_id = dataset_id
         self.project_id = project_id
+        self.location = location
         self.table_id = table_id
         self.poll_interval = poll_interval
         self.impersonation_chain = impersonation_chain
@@ -76,6 +79,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
                 "job_id": self.job_id,
                 "dataset_id": self.dataset_id,
                 "project_id": self.project_id,
+                "location": self.location,
                 "table_id": self.table_id,
                 "poll_interval": self.poll_interval,
                 "impersonation_chain": self.impersonation_chain,
@@ -87,7 +91,9 @@ class BigQueryInsertJobTrigger(BaseTrigger):
         hook = self._get_async_hook()
         try:
             while True:
-                job_status = await hook.get_job_status(job_id=self.job_id, 
project_id=self.project_id)
+                job_status = await hook.get_job_status(
+                    job_id=self.job_id, project_id=self.project_id, 
location=self.location
+                )
                 if job_status["status"] == "success":
                     yield TriggerEvent(
                         {
@@ -127,6 +133,7 @@ class BigQueryCheckTrigger(BigQueryInsertJobTrigger):
                 "job_id": self.job_id,
                 "dataset_id": self.dataset_id,
                 "project_id": self.project_id,
+                "location": self.location,
                 "table_id": self.table_id,
                 "poll_interval": self.poll_interval,
                 "impersonation_chain": self.impersonation_chain,
@@ -201,6 +208,7 @@ class BigQueryGetDataTrigger(BigQueryInsertJobTrigger):
                 "job_id": self.job_id,
                 "dataset_id": self.dataset_id,
                 "project_id": self.project_id,
+                "location": self.location,
                 "table_id": self.table_id,
                 "poll_interval": self.poll_interval,
                 "impersonation_chain": self.impersonation_chain,
@@ -253,6 +261,7 @@ class 
BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
     :param dataset_id: The dataset ID of the requested table. (templated)
     :param table: table name
     :param metrics_thresholds: dictionary of ratios indexed by metrics
+    :param location: The dataset location.
     :param date_filter_column: column name. (templated)
     :param days_back: number of days between ds and the ds we want to check 
against. (templated)
     :param ratio_formula: ration formula. (templated)
@@ -277,6 +286,7 @@ class 
BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
         project_id: str | None,
         table: str,
         metrics_thresholds: dict[str, int],
+        location: str | None = None,
         date_filter_column: str | None = "ds",
         days_back: SupportsAbs[int] = -7,
         ratio_formula: str = "max_over_min",
@@ -290,6 +300,7 @@ class 
BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
             conn_id=conn_id,
             job_id=first_job_id,
             project_id=project_id,
+            location=location,
             dataset_id=dataset_id,
             table_id=table_id,
             poll_interval=poll_interval,
@@ -317,6 +328,7 @@ class 
BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger):
                 "project_id": self.project_id,
                 "table": self.table,
                 "metrics_thresholds": self.metrics_thresholds,
+                "location": self.location,
                 "date_filter_column": self.date_filter_column,
                 "days_back": self.days_back,
                 "ratio_formula": self.ratio_formula,
@@ -414,6 +426,7 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
     :param tolerance: certain metrics for tolerance. (templated)
     :param dataset_id: The dataset ID of the requested table. (templated)
     :param table_id: The table ID of the requested table. (templated)
+    :param location: The dataset location
     :param poll_interval: polling period in seconds to check for the status. 
(templated)
     :param impersonation_chain: Optional service account to impersonate using 
short-term
         credentials, or chained list of accounts required to get the 
access_token
@@ -435,6 +448,7 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
         tolerance: Any = None,
         dataset_id: str | None = None,
         table_id: str | None = None,
+        location: str | None = None,
         poll_interval: float = 4.0,
         impersonation_chain: str | Sequence[str] | None = None,
     ):
@@ -444,6 +458,7 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
             project_id=project_id,
             dataset_id=dataset_id,
             table_id=table_id,
+            location=location,
             poll_interval=poll_interval,
             impersonation_chain=impersonation_chain,
         )
@@ -464,6 +479,7 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger):
                 "sql": self.sql,
                 "table_id": self.table_id,
                 "tolerance": self.tolerance,
+                "location": self.location,
                 "poll_interval": self.poll_interval,
                 "impersonation_chain": self.impersonation_chain,
             },
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py 
b/tests/providers/google/cloud/hooks/test_bigquery.py
index 27db487b61..5ca34b276f 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -2155,23 +2155,18 @@ class 
TestBigQueryAsyncHookMethods(_BigQueryBaseAsyncTestClass):
         assert isinstance(result, Job)
 
     @pytest.mark.parametrize(
-        "job_status, expected",
+        "job_state, error_result, expected",
         [
-            ({"status": {"state": "DONE"}}, {"status": "success", "message": 
"Job completed"}),
-            (
-                {"status": {"state": "DONE", "errorResult": {"message": 
"Timeout"}}},
-                {"status": "error", "message": "Timeout"},
-            ),
-            ({"status": {"state": "running"}}, {"status": "running", 
"message": "Job running"}),
+            ("DONE", None, {"status": "success", "message": "Job completed"}),
+            ("DONE", {"message": "Timeout"}, {"status": "error", "message": 
"Timeout"}),
+            ("RUNNING", None, {"status": "running", "message": "Job running"}),
         ],
     )
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
-    async def test_get_job_status(self, mock_job_instance, job_status, 
expected):
+    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook._get_job")
+    async def test_get_job_status(self, mock_get_job, job_state, error_result, 
expected):
         hook = BigQueryAsyncHook()
-        mock_job_client = AsyncMock(Job)
-        mock_job_instance.return_value = mock_job_client
-        mock_job_instance.return_value.get_job.return_value = job_status
+        mock_get_job.return_value = mock.MagicMock(state=job_state, 
error_result=error_result)
         resp = await hook.get_job_status(job_id=JOB_ID, project_id=PROJECT_ID)
         assert resp == expected
 
diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py 
b/tests/providers/google/cloud/triggers/test_bigquery.py
index ea2e478d04..ed4861ca76 100644
--- a/tests/providers/google/cloud/triggers/test_bigquery.py
+++ b/tests/providers/google/cloud/triggers/test_bigquery.py
@@ -24,7 +24,7 @@ from unittest.mock import AsyncMock
 
 import pytest
 from aiohttp import ClientResponseError, RequestInfo
-from gcloud.aio.bigquery import Job, Table
+from gcloud.aio.bigquery import Table
 from multidict import CIMultiDict
 from yarl import URL
 
@@ -48,6 +48,7 @@ TEST_JOB_ID = "1234"
 TEST_GCP_PROJECT_ID = "test-project"
 TEST_DATASET_ID = "bq_dataset"
 TEST_TABLE_ID = "bq_table"
+TEST_LOCATION = "US"
 POLLING_PERIOD_SECONDS = 4.0
 TEST_SQL_QUERY = "SELECT count(*) from Any"
 TEST_PASS_VALUE = 2
@@ -73,6 +74,7 @@ def insert_job_trigger():
         project_id=TEST_GCP_PROJECT_ID,
         dataset_id=TEST_DATASET_ID,
         table_id=TEST_TABLE_ID,
+        location=TEST_LOCATION,
         poll_interval=POLLING_PERIOD_SECONDS,
         impersonation_chain=TEST_IMPERSONATION_CHAIN,
     )
@@ -86,6 +88,7 @@ def get_data_trigger():
         project_id=TEST_GCP_PROJECT_ID,
         dataset_id=TEST_DATASET_ID,
         table_id=TEST_TABLE_ID,
+        location=None,
         poll_interval=POLLING_PERIOD_SECONDS,
         impersonation_chain=TEST_IMPERSONATION_CHAIN,
     )
@@ -132,6 +135,7 @@ def check_trigger():
         project_id=TEST_GCP_PROJECT_ID,
         dataset_id=TEST_DATASET_ID,
         table_id=TEST_TABLE_ID,
+        location=None,
         poll_interval=POLLING_PERIOD_SECONDS,
         impersonation_chain=TEST_IMPERSONATION_CHAIN,
     )
@@ -166,6 +170,7 @@ class TestBigQueryInsertJobTrigger:
             "project_id": TEST_GCP_PROJECT_ID,
             "dataset_id": TEST_DATASET_ID,
             "table_id": TEST_TABLE_ID,
+            "location": TEST_LOCATION,
             "poll_interval": POLLING_PERIOD_SECONDS,
             "impersonation_chain": TEST_IMPERSONATION_CHAIN,
         }
@@ -185,13 +190,11 @@ class TestBigQueryInsertJobTrigger:
         )
 
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
-    async def test_bigquery_insert_job_trigger_running(self, 
mock_job_instance, caplog, insert_job_trigger):
+    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook._get_job")
+    async def test_bigquery_insert_job_trigger_running(self, mock_get_job, 
caplog, insert_job_trigger):
         """Test that BigQuery Triggers do not fire while a query is still 
running."""
 
-        mock_job_client = AsyncMock(Job)
-        mock_job_instance.return_value = mock_job_client
-        mock_job_instance.return_value.get_job.return_value = {"status": 
{"state": "running"}}
+        mock_get_job.return_value = mock.MagicMock(state="RUNNING")
         caplog.set_level(logging.INFO)
 
         task = asyncio.create_task(insert_job_trigger.run().__anext__())
@@ -245,17 +248,16 @@ class TestBigQueryGetDataTrigger:
             "dataset_id": TEST_DATASET_ID,
             "project_id": TEST_GCP_PROJECT_ID,
             "table_id": TEST_TABLE_ID,
+            "location": None,
             "poll_interval": POLLING_PERIOD_SECONDS,
         }
 
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
-    async def test_bigquery_get_data_trigger_running(self, mock_job_instance, 
caplog, get_data_trigger):
+    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook._get_job")
+    async def test_bigquery_get_data_trigger_running(self, mock_get_job, 
caplog, get_data_trigger):
         """Test that BigQuery Triggers do not fire while a query is still 
running."""
 
-        mock_job_client = AsyncMock(Job)
-        mock_job_instance.return_value = mock_job_client
-        mock_job_instance.return_value.get_job.return_value = {"status": 
{"state": "running"}}
+        mock_get_job.return_value = mock.MagicMock(state="running")
         caplog.set_level(logging.INFO)
 
         task = asyncio.create_task(get_data_trigger.run().__anext__())
@@ -348,13 +350,11 @@ class TestBigQueryGetDataTrigger:
 
 class TestBigQueryCheckTrigger:
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
-    async def test_bigquery_check_trigger_running(self, mock_job_instance, 
caplog, check_trigger):
+    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook._get_job")
+    async def test_bigquery_check_trigger_running(self, mock_get_job, caplog, 
check_trigger):
         """Test that BigQuery Triggers do not fire while a query is still 
running."""
 
-        mock_job_client = AsyncMock(Job)
-        mock_job_instance.return_value = mock_job_client
-        mock_job_instance.return_value.get_job.return_value = {"status": 
{"state": "running"}}
+        mock_get_job.return_value = mock.MagicMock(state="running")
 
         task = asyncio.create_task(check_trigger.run().__anext__())
         await asyncio.sleep(0.5)
@@ -406,6 +406,7 @@ class TestBigQueryCheckTrigger:
             "dataset_id": TEST_DATASET_ID,
             "project_id": TEST_GCP_PROJECT_ID,
             "table_id": TEST_TABLE_ID,
+            "location": None,
             "poll_interval": POLLING_PERIOD_SECONDS,
         }
 
@@ -487,6 +488,7 @@ class TestBigQueryIntervalCheckTrigger:
             "second_job_id": TEST_SECOND_JOB_ID,
             "project_id": TEST_GCP_PROJECT_ID,
             "table": TEST_TABLE_ID,
+            "location": None,
             "metrics_thresholds": TEST_METRIC_THRESHOLDS,
             "date_filter_column": TEST_DATE_FILTER_COLUMN,
             "days_back": TEST_DAYS_BACK,
@@ -578,6 +580,7 @@ class TestBigQueryValueCheckTrigger:
             "job_id": TEST_JOB_ID,
             "dataset_id": TEST_DATASET_ID,
             "project_id": TEST_GCP_PROJECT_ID,
+            "location": None,
             "sql": TEST_SQL_QUERY,
             "table_id": TEST_TABLE_ID,
             "tolerance": TEST_TOLERANCE,

Reply via email to