This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e237041142 Add logic to handle on_kill for BigQueryInsertJobOperator 
when deferrable=True (#38912)
e237041142 is described below

commit e237041142e36349cc62e743105c91b04ddf4253
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Mon Apr 15 16:51:04 2024 +0545

    Add logic to handle on_kill for BigQueryInsertJobOperator when 
deferrable=True (#38912)
---
 airflow/providers/google/cloud/hooks/bigquery.py   | 25 +++++++++++
 .../providers/google/cloud/operators/bigquery.py   |  1 +
 .../providers/google/cloud/triggers/bigquery.py    | 11 +++++
 .../providers/google/cloud/hooks/test_bigquery.py  | 49 ++++++++++++++++++++++
 .../google/cloud/triggers/test_bigquery.py         | 36 ++++++++++++++++
 5 files changed, 122 insertions(+)

diff --git a/airflow/providers/google/cloud/hooks/bigquery.py 
b/airflow/providers/google/cloud/hooks/bigquery.py
index 0594ce4351..f270e256fb 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -3388,6 +3388,31 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):
             job_query_resp = await job_client.query(query_request, 
cast(Session, session))
             return job_query_resp["jobReference"]["jobId"]
 
+    async def cancel_job(self, job_id: str, project_id: str | None, location: 
str | None) -> None:
+        """
+        Cancel a BigQuery job.
+
+        :param job_id: ID of the job to cancel.
+        :param project_id: Google Cloud Project where the job was running.
+        :param location: Location where the job was running.
+        """
+        async with ClientSession() as session:
+            token = await self.get_token(session=session)
+            job = Job(job_id=job_id, project=project_id, location=location, 
token=token, session=session)  # type: ignore[arg-type]
+
+            self.log.info(
+                "Attempting to cancel BigQuery job: %s in project: %s, 
location: %s",
+                job_id,
+                project_id,
+                location,
+            )
+            try:
+                await job.cancel()
+                self.log.info("Job %s cancellation requested.", job_id)
+            except Exception as e:
+                self.log.error("Failed to cancel BigQuery job %s: %s", job_id, 
str(e))
+                raise
+
     def get_records(self, query_results: dict[str, Any], as_dict: bool = 
False) -> list[Any]:
         """Convert a response from BigQuery to records.
 
diff --git a/airflow/providers/google/cloud/operators/bigquery.py 
b/airflow/providers/google/cloud/operators/bigquery.py
index 68b423fb46..9da97afc2a 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -2903,6 +2903,7 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator, 
_BigQueryOpenLineageMix
                         location=self.location or hook.location,
                         poll_interval=self.poll_interval,
                         impersonation_chain=self.impersonation_chain,
+                        cancel_on_kill=self.cancel_on_kill,
                     ),
                     method_name="execute_complete",
                 )
diff --git a/airflow/providers/google/cloud/triggers/bigquery.py 
b/airflow/providers/google/cloud/triggers/bigquery.py
index eafa4825be..fd01705261 100644
--- a/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/airflow/providers/google/cloud/triggers/bigquery.py
@@ -57,6 +57,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
         table_id: str | None = None,
         poll_interval: float = 4.0,
         impersonation_chain: str | Sequence[str] | None = None,
+        cancel_on_kill: bool = True,
     ):
         super().__init__()
         self.log.info("Using the connection  %s .", conn_id)
@@ -69,6 +70,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
         self.table_id = table_id
         self.poll_interval = poll_interval
         self.impersonation_chain = impersonation_chain
+        self.cancel_on_kill = cancel_on_kill
 
     def serialize(self) -> tuple[str, dict[str, Any]]:
         """Serialize BigQueryInsertJobTrigger arguments and classpath."""
@@ -83,6 +85,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
                 "table_id": self.table_id,
                 "poll_interval": self.poll_interval,
                 "impersonation_chain": self.impersonation_chain,
+                "cancel_on_kill": self.cancel_on_kill,
             },
         )
 
@@ -113,6 +116,14 @@ class BigQueryInsertJobTrigger(BaseTrigger):
                         self.poll_interval,
                     )
                     await asyncio.sleep(self.poll_interval)
+        except asyncio.CancelledError:
+            self.log.info("Task was killed.")
+            if self.job_id and self.cancel_on_kill:
+                await hook.cancel_job(  # type: ignore[union-attr]
+                    job_id=self.job_id, project_id=self.project_id, 
location=self.location
+                )
+            else:
+                self.log.info("Skipping to cancel job: %s:%s.%s", 
self.project_id, self.location, self.job_id)
         except Exception as e:
             self.log.exception("Exception occurred while checking for query 
completion")
             yield TriggerEvent({"status": "error", "message": str(e)})
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py 
b/tests/providers/google/cloud/hooks/test_bigquery.py
index 9118116287..37096b0ff3 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -2190,6 +2190,55 @@ class 
TestBigQueryAsyncHookMethods(_BigQueryBaseAsyncTestClass):
         resp = await hook.get_job_output(job_id=JOB_ID, project_id=PROJECT_ID)
         assert resp == response
 
+    @pytest.mark.asyncio
+    @pytest.mark.db_test
+    @mock.patch("google.auth.default")
+    @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Job")
+    async def test_cancel_job_success(self, mock_job, mock_auth_default):
+        mock_credentials = 
mock.MagicMock(spec=google.auth.compute_engine.Credentials)
+        mock_credentials.token = "ACCESS_TOKEN"
+        mock_auth_default.return_value = (mock_credentials, PROJECT_ID)
+        job_id = "test_job_id"
+        project_id = "test_project"
+        location = "US"
+
+        mock_job_instance = AsyncMock()
+        mock_job_instance.cancel.return_value = None
+        mock_job.return_value = mock_job_instance
+
+        await self.hook.cancel_job(job_id=job_id, project_id=project_id, 
location=location)
+
+        mock_job_instance.cancel.assert_called_once()
+
+    @pytest.mark.asyncio
+    @pytest.mark.db_test
+    @mock.patch("google.auth.default")
+    @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Job")
+    async def test_cancel_job_failure(self, mock_job, mock_auth_default):
+        """
+        Test that BigQueryAsyncHook handles exceptions during job cancellation 
correctly.
+        """
+        mock_credentials = 
mock.MagicMock(spec=google.auth.compute_engine.Credentials)
+        mock_credentials.token = "ACCESS_TOKEN"
+        mock_auth_default.return_value = (mock_credentials, PROJECT_ID)
+
+        mock_job_instance = AsyncMock()
+        mock_job_instance.cancel.side_effect = Exception("Cancellation failed")
+        mock_job.return_value = mock_job_instance
+
+        hook = BigQueryAsyncHook()
+
+        job_id = "test_job_id"
+        project_id = "test_project"
+        location = "US"
+
+        with pytest.raises(Exception) as excinfo:
+            await hook.cancel_job(job_id=job_id, project_id=project_id, 
location=location)
+
+        assert "Cancellation failed" in str(excinfo.value), "Exception message 
not passed correctly"
+
+        mock_job_instance.cancel.assert_called_once()
+
     @pytest.mark.asyncio
     @mock.patch("airflow.providers.google.cloud.hooks.bigquery.ClientSession")
     
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py 
b/tests/providers/google/cloud/triggers/test_bigquery.py
index 9eec245f83..367b4850de 100644
--- a/tests/providers/google/cloud/triggers/test_bigquery.py
+++ b/tests/providers/google/cloud/triggers/test_bigquery.py
@@ -165,6 +165,7 @@ class TestBigQueryInsertJobTrigger:
         classpath, kwargs = insert_job_trigger.serialize()
         assert classpath == 
"airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger"
         assert kwargs == {
+            "cancel_on_kill": True,
             "conn_id": TEST_CONN_ID,
             "job_id": TEST_JOB_ID,
             "project_id": TEST_GCP_PROJECT_ID,
@@ -233,6 +234,41 @@ class TestBigQueryInsertJobTrigger:
         actual = await generator.asend(None)
         assert TriggerEvent({"status": "error", "message": "Test exception"}) 
== actual
 
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job")
+    
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+    async def test_bigquery_insert_job_trigger_cancellation(
+        self, mock_get_job_status, mock_cancel_job, caplog, insert_job_trigger
+    ):
+        """
+        Test that BigQueryInsertJobTrigger handles cancellation correctly, 
logs the appropriate message,
+        and conditionally cancels the job based on the `cancel_on_kill` 
attribute.
+        """
+        insert_job_trigger.cancel_on_kill = True
+        insert_job_trigger.job_id = "1234"
+
+        mock_get_job_status.side_effect = [
+            {"status": "running", "message": "Job is still running"},
+            asyncio.CancelledError(),
+        ]
+
+        mock_cancel_job.return_value = asyncio.Future()
+        mock_cancel_job.return_value.set_result(None)
+
+        caplog.set_level(logging.INFO)
+
+        try:
+            async for _ in insert_job_trigger.run():
+                pass
+        except asyncio.CancelledError:
+            pass
+
+        assert (
+            "Task was killed" in caplog.text
+            or "Bigquery job status is running. Sleeping for 4.0 seconds." in 
caplog.text
+        ), "Expected messages about task status or cancellation not found in 
log."
+        mock_cancel_job.assert_awaited_once()
+
 
 class TestBigQueryGetDataTrigger:
     def test_bigquery_get_data_trigger_serialization(self, get_data_trigger):

Reply via email to