This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e237041142 Add logic to handle on_kill for BigQueryInsertJobOperator
when deferrable=True (#38912)
e237041142 is described below
commit e237041142e36349cc62e743105c91b04ddf4253
Author: Ankit Chaurasia <[email protected]>
AuthorDate: Mon Apr 15 16:51:04 2024 +0545
Add logic to handle on_kill for BigQueryInsertJobOperator when
deferrable=True (#38912)
---
airflow/providers/google/cloud/hooks/bigquery.py | 25 +++++++++++
.../providers/google/cloud/operators/bigquery.py | 1 +
.../providers/google/cloud/triggers/bigquery.py | 11 +++++
.../providers/google/cloud/hooks/test_bigquery.py | 49 ++++++++++++++++++++++
.../google/cloud/triggers/test_bigquery.py | 36 ++++++++++++++++
5 files changed, 122 insertions(+)
diff --git a/airflow/providers/google/cloud/hooks/bigquery.py
b/airflow/providers/google/cloud/hooks/bigquery.py
index 0594ce4351..f270e256fb 100644
--- a/airflow/providers/google/cloud/hooks/bigquery.py
+++ b/airflow/providers/google/cloud/hooks/bigquery.py
@@ -3388,6 +3388,31 @@ class BigQueryAsyncHook(GoogleBaseAsyncHook):
job_query_resp = await job_client.query(query_request,
cast(Session, session))
return job_query_resp["jobReference"]["jobId"]
+ async def cancel_job(self, job_id: str, project_id: str | None, location:
str | None) -> None:
+ """
+ Cancel a BigQuery job.
+
+ :param job_id: ID of the job to cancel.
+ :param project_id: Google Cloud Project where the job was running.
+ :param location: Location where the job was running.
+ """
+ async with ClientSession() as session:
+ token = await self.get_token(session=session)
+ job = Job(job_id=job_id, project=project_id, location=location,
token=token, session=session) # type: ignore[arg-type]
+
+ self.log.info(
+ "Attempting to cancel BigQuery job: %s in project: %s,
location: %s",
+ job_id,
+ project_id,
+ location,
+ )
+ try:
+ await job.cancel()
+ self.log.info("Job %s cancellation requested.", job_id)
+ except Exception as e:
+ self.log.error("Failed to cancel BigQuery job %s: %s", job_id,
str(e))
+ raise
+
def get_records(self, query_results: dict[str, Any], as_dict: bool =
False) -> list[Any]:
"""Convert a response from BigQuery to records.
diff --git a/airflow/providers/google/cloud/operators/bigquery.py
b/airflow/providers/google/cloud/operators/bigquery.py
index 68b423fb46..9da97afc2a 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -2903,6 +2903,7 @@ class BigQueryInsertJobOperator(GoogleCloudBaseOperator,
_BigQueryOpenLineageMix
location=self.location or hook.location,
poll_interval=self.poll_interval,
impersonation_chain=self.impersonation_chain,
+ cancel_on_kill=self.cancel_on_kill,
),
method_name="execute_complete",
)
diff --git a/airflow/providers/google/cloud/triggers/bigquery.py
b/airflow/providers/google/cloud/triggers/bigquery.py
index eafa4825be..fd01705261 100644
--- a/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/airflow/providers/google/cloud/triggers/bigquery.py
@@ -57,6 +57,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
table_id: str | None = None,
poll_interval: float = 4.0,
impersonation_chain: str | Sequence[str] | None = None,
+ cancel_on_kill: bool = True,
):
super().__init__()
self.log.info("Using the connection %s .", conn_id)
@@ -69,6 +70,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
self.table_id = table_id
self.poll_interval = poll_interval
self.impersonation_chain = impersonation_chain
+ self.cancel_on_kill = cancel_on_kill
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize BigQueryInsertJobTrigger arguments and classpath."""
@@ -83,6 +85,7 @@ class BigQueryInsertJobTrigger(BaseTrigger):
"table_id": self.table_id,
"poll_interval": self.poll_interval,
"impersonation_chain": self.impersonation_chain,
+ "cancel_on_kill": self.cancel_on_kill,
},
)
@@ -113,6 +116,14 @@ class BigQueryInsertJobTrigger(BaseTrigger):
self.poll_interval,
)
await asyncio.sleep(self.poll_interval)
+ except asyncio.CancelledError:
+ self.log.info("Task was killed.")
+ if self.job_id and self.cancel_on_kill:
+ await hook.cancel_job( # type: ignore[union-attr]
+ job_id=self.job_id, project_id=self.project_id,
location=self.location
+ )
+ else:
+ self.log.info("Skipping to cancel job: %s:%s.%s",
self.project_id, self.location, self.job_id)
except Exception as e:
self.log.exception("Exception occurred while checking for query
completion")
yield TriggerEvent({"status": "error", "message": str(e)})
diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py
b/tests/providers/google/cloud/hooks/test_bigquery.py
index 9118116287..37096b0ff3 100644
--- a/tests/providers/google/cloud/hooks/test_bigquery.py
+++ b/tests/providers/google/cloud/hooks/test_bigquery.py
@@ -2190,6 +2190,55 @@ class
TestBigQueryAsyncHookMethods(_BigQueryBaseAsyncTestClass):
resp = await hook.get_job_output(job_id=JOB_ID, project_id=PROJECT_ID)
assert resp == response
+ @pytest.mark.asyncio
+ @pytest.mark.db_test
+ @mock.patch("google.auth.default")
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Job")
+ async def test_cancel_job_success(self, mock_job, mock_auth_default):
+ mock_credentials =
mock.MagicMock(spec=google.auth.compute_engine.Credentials)
+ mock_credentials.token = "ACCESS_TOKEN"
+ mock_auth_default.return_value = (mock_credentials, PROJECT_ID)
+ job_id = "test_job_id"
+ project_id = "test_project"
+ location = "US"
+
+ mock_job_instance = AsyncMock()
+ mock_job_instance.cancel.return_value = None
+ mock_job.return_value = mock_job_instance
+
+ await self.hook.cancel_job(job_id=job_id, project_id=project_id,
location=location)
+
+ mock_job_instance.cancel.assert_called_once()
+
+ @pytest.mark.asyncio
+ @pytest.mark.db_test
+ @mock.patch("google.auth.default")
+ @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Job")
+ async def test_cancel_job_failure(self, mock_job, mock_auth_default):
+ """
+ Test that BigQueryAsyncHook handles exceptions during job cancellation
correctly.
+ """
+ mock_credentials =
mock.MagicMock(spec=google.auth.compute_engine.Credentials)
+ mock_credentials.token = "ACCESS_TOKEN"
+ mock_auth_default.return_value = (mock_credentials, PROJECT_ID)
+
+ mock_job_instance = AsyncMock()
+ mock_job_instance.cancel.side_effect = Exception("Cancellation failed")
+ mock_job.return_value = mock_job_instance
+
+ hook = BigQueryAsyncHook()
+
+ job_id = "test_job_id"
+ project_id = "test_project"
+ location = "US"
+
+ with pytest.raises(Exception) as excinfo:
+ await hook.cancel_job(job_id=job_id, project_id=project_id,
location=location)
+
+ assert "Cancellation failed" in str(excinfo.value), "Exception message
not passed correctly"
+
+ mock_job_instance.cancel.assert_called_once()
+
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.ClientSession")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance")
diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py
b/tests/providers/google/cloud/triggers/test_bigquery.py
index 9eec245f83..367b4850de 100644
--- a/tests/providers/google/cloud/triggers/test_bigquery.py
+++ b/tests/providers/google/cloud/triggers/test_bigquery.py
@@ -165,6 +165,7 @@ class TestBigQueryInsertJobTrigger:
classpath, kwargs = insert_job_trigger.serialize()
assert classpath ==
"airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger"
assert kwargs == {
+ "cancel_on_kill": True,
"conn_id": TEST_CONN_ID,
"job_id": TEST_JOB_ID,
"project_id": TEST_GCP_PROJECT_ID,
@@ -233,6 +234,41 @@ class TestBigQueryInsertJobTrigger:
actual = await generator.asend(None)
assert TriggerEvent({"status": "error", "message": "Test exception"})
== actual
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job")
+
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
+ async def test_bigquery_insert_job_trigger_cancellation(
+ self, mock_get_job_status, mock_cancel_job, caplog, insert_job_trigger
+ ):
+ """
+ Test that BigQueryInsertJobTrigger handles cancellation correctly,
logs the appropriate message,
+ and conditionally cancels the job based on the `cancel_on_kill`
attribute.
+ """
+ insert_job_trigger.cancel_on_kill = True
+ insert_job_trigger.job_id = "1234"
+
+ mock_get_job_status.side_effect = [
+ {"status": "running", "message": "Job is still running"},
+ asyncio.CancelledError(),
+ ]
+
+ mock_cancel_job.return_value = asyncio.Future()
+ mock_cancel_job.return_value.set_result(None)
+
+ caplog.set_level(logging.INFO)
+
+ try:
+ async for _ in insert_job_trigger.run():
+ pass
+ except asyncio.CancelledError:
+ pass
+
+ assert (
+ "Task was killed" in caplog.text
+ or "Bigquery job status is running. Sleeping for 4.0 seconds." in
caplog.text
+ ), "Expected messages about task status or cancellation not found in
log."
+ mock_cancel_job.assert_awaited_once()
+
class TestBigQueryGetDataTrigger:
def test_bigquery_get_data_trigger_serialization(self, get_data_trigger):