This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 08aee444b4e fix: `resume_glue_job_on_retry` with `xcom_push` in
`GlueJobOperator` (#62560)
08aee444b4e is described below
commit 08aee444b4e1cfeb5b76a64eaec973e61b2b31a8
Author: Henry Chen <[email protected]>
AuthorDate: Thu Mar 12 03:27:22 2026 +0800
fix: `resume_glue_job_on_retry` with `xcom_push` in `GlueJobOperator`
(#62560)
* GlueJobOperator: Recover job run via task UUID when XCom is missing
* add resume_glue_job_on_retry back
---
.../airflow/providers/amazon/aws/operators/glue.py | 86 +++++++++++-
.../tests/unit/amazon/aws/operators/test_glue.py | 154 +++++++++++++++++++++
2 files changed, 233 insertions(+), 7 deletions(-)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py
index 8e9a6deab44..ce96a76b3d0 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py
@@ -112,6 +112,7 @@ class GlueJobOperator(AwsBaseOperator[GlueJobHook]):
ui_color = "#ededed"
operator_extra_links = (GlueJobRunDetailsLink(),)
+ TASK_UUID_ARG = "--airflow_task_uuid"
def __init__(
self,
@@ -138,6 +139,7 @@ class GlueJobOperator(AwsBaseOperator[GlueJobHook]):
job_poll_interval: int | float = 6,
waiter_delay: int = 60,
waiter_max_attempts: int = 75,
+ resume_glue_job_on_retry: bool = False,
**kwargs,
):
super().__init__(**kwargs)
@@ -167,6 +169,7 @@ class GlueJobOperator(AwsBaseOperator[GlueJobHook]):
self.s3_script_location: str | None = None
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
+ self.resume_glue_job_on_retry = resume_glue_job_on_retry
@property
def _hook_parameters(self):
@@ -210,19 +213,88 @@ class GlueJobOperator(AwsBaseOperator[GlueJobHook]):
)
self.s3_script_location =
f"s3://{self.s3_bucket}/{self.s3_artifacts_prefix}{script_name}"
+ def _get_task_uuid(self, context: Context) -> str:
+ ti = context["ti"]
+ map_index = getattr(ti, "map_index", -1)
+ if map_index is None:
+ map_index = -1
+ return f"{ti.dag_id}:{ti.task_id}:{ti.run_id}:{map_index}"
+
+ def _prepare_script_args_with_task_uuid(self, context: Context) ->
tuple[dict, str]:
+ script_args = dict(self.script_args or {})
+ if self.TASK_UUID_ARG in script_args:
+ task_uuid = str(script_args[self.TASK_UUID_ARG])
+ else:
+ task_uuid = self._get_task_uuid(context)
+ script_args[self.TASK_UUID_ARG] = task_uuid
+ return script_args, task_uuid
+
+ def _find_job_run_id_by_task_uuid(self, task_uuid: str) -> tuple[str, str]
| None:
+ next_token: str | None = None
+ while True:
+ request = {"JobName": self.job_name, "MaxResults": 50}
+ if next_token:
+ request["NextToken"] = next_token
+ response = self.hook.conn.get_job_runs(**request)
+ for job_run in response.get("JobRuns", []):
+ args = job_run.get("Arguments", {}) or {}
+ if args.get(self.TASK_UUID_ARG) == task_uuid:
+ job_run_id = job_run.get("Id")
+ job_run_state = job_run.get("JobRunState")
+ if job_run_id and job_run_state:
+ return job_run_id, job_run_state
+ next_token = response.get("NextToken")
+ if not next_token:
+ return None
+
def execute(self, context: Context):
"""
Execute AWS Glue Job from Airflow.
:return: the current Glue job ID.
"""
- self.log.info(
- "Initializing AWS Glue Job: %s. Wait for completion: %s",
- self.job_name,
- self.wait_for_completion,
- )
- glue_job_run = self.hook.initialize_job(self.script_args,
self.run_job_kwargs)
- self._job_run_id = glue_job_run["JobRunId"]
+ previous_job_run_id = None
+ script_args = self.script_args
+ task_uuid = None
+ if self.resume_glue_job_on_retry:
+ ti = context["ti"]
+ script_args, task_uuid =
self._prepare_script_args_with_task_uuid(context)
+ previous_job_run_id = ti.xcom_pull(key="glue_job_run_id",
task_ids=ti.task_id)
+ if previous_job_run_id:
+ try:
+ job_run =
self.hook.conn.get_job_run(JobName=self.job_name, RunId=previous_job_run_id)
+ state = job_run.get("JobRun", {}).get("JobRunState")
+ self.log.info("Previous Glue job_run_id: %s, state: %s",
previous_job_run_id, state)
+ if state in ("RUNNING", "STARTING"):
+ self._job_run_id = previous_job_run_id
+ except Exception:
+ self.log.warning("Failed to get previous Glue job run
state", exc_info=True)
+ elif task_uuid:
+ try:
+ existing = self._find_job_run_id_by_task_uuid(task_uuid)
+ if existing:
+ existing_job_run_id, existing_job_run_state = existing
+ self.log.info(
+ "Found Glue job_run_id by task UUID: %s, state:
%s",
+ existing_job_run_id,
+ existing_job_run_state,
+ )
+ if existing_job_run_state in ("RUNNING", "STARTING"):
+ self._job_run_id = existing_job_run_id
+ ti.xcom_push(key="glue_job_run_id",
value=self._job_run_id)
+ except Exception:
+ self.log.warning("Failed to find previous Glue job run by
task UUID", exc_info=True)
+
+ if not self._job_run_id:
+ self.log.info(
+ "Initializing AWS Glue Job: %s. Wait for completion: %s",
+ self.job_name,
+ self.wait_for_completion,
+ )
+ glue_job_run = self.hook.initialize_job(script_args,
self.run_job_kwargs)
+ self._job_run_id = glue_job_run["JobRunId"]
+ context["ti"].xcom_push(key="glue_job_run_id",
value=self._job_run_id)
+
glue_job_run_url = GlueJobRunDetailsLink.format_str.format(
aws_domain=GlueJobRunDetailsLink.get_aws_domain(self.hook.conn_partition),
region_name=self.hook.conn_region_name,
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
index fedf55431a6..e7d15fbbaa8 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
@@ -432,6 +432,160 @@ class TestGlueJobOperator:
)
assert op.hook.aws_conn_id == DEFAULT_CONN
+ @mock.patch.object(GlueJobHook, "get_conn")
+ @mock.patch.object(GlueJobHook, "initialize_job")
+ def test_check_previous_job_id_run_reuse_in_progress(self,
mock_initialize_job, mock_get_conn):
+ """Test that when resume_glue_job_on_retry=True and previous job is in
progress, it is reused."""
+ glue = GlueJobOperator(
+ task_id=TASK_ID,
+ job_name=JOB_NAME,
+ script_location="s3://folder/file",
+ aws_conn_id="aws_default",
+ region_name="us-west-2",
+ s3_bucket="some_bucket",
+ iam_role_name="my_test_role",
+ resume_glue_job_on_retry=True,
+ wait_for_completion=False,
+ )
+
+ # Mock the context and task instance
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+
+ # Simulate previous job_run_id in XCom
+ previous_job_run_id = "previous_run_12345"
+ mock_ti.xcom_pull.return_value = previous_job_run_id
+
+ # Mock the Glue client to return RUNNING state for the previous job
+ mock_glue_client = mock.MagicMock()
+ glue.hook.conn = mock_glue_client
+ mock_glue_client.get_job_run.return_value = {
+ "JobRun": {
+ "JobRunState": "RUNNING",
+ }
+ }
+
+ # Execute the operator
+ glue.execute(mock_context)
+
+ # Verify that the previous job_run_id was reused
+ assert glue._job_run_id == previous_job_run_id
+ # Verify that initialize_job was NOT called
+ mock_initialize_job.assert_not_called()
+ # Verify that XCom push was not called for glue_job_run_id (since we
reused the previous one)
+ # Note: xcom_push may be called for other purposes like
glue_job_run_details
+ xcom_calls = [
+ call for call in mock_ti.xcom_push.call_args_list if
call[1].get("key") == "glue_job_run_id"
+ ]
+ assert len(xcom_calls) == 0, "Should not push new glue_job_run_id when
reusing previous one"
+
+ @mock.patch.object(GlueJobHook, "get_conn")
+ @mock.patch.object(GlueJobHook, "initialize_job")
+ def test_check_previous_job_id_run_new_on_finished(self,
mock_initialize_job, mock_get_conn):
+ """Test that when previous job is finished, a new job is started and
pushed to XCom."""
+ glue = GlueJobOperator(
+ task_id=TASK_ID,
+ job_name=JOB_NAME,
+ script_location="s3://folder/file",
+ aws_conn_id="aws_default",
+ region_name="us-west-2",
+ s3_bucket="some_bucket",
+ iam_role_name="my_test_role",
+ resume_glue_job_on_retry=True,
+ wait_for_completion=False,
+ )
+
+ # Mock the context and task instance
+ mock_ti = mock.MagicMock()
+ mock_context = {"ti": mock_ti}
+
+ # Simulate previous job_run_id in XCom
+ previous_job_run_id = "previous_run_12345"
+ mock_ti.xcom_pull.return_value = previous_job_run_id
+
+ # Mock the Glue client to return SUCCEEDED state for the previous job
+ mock_glue_client = mock.MagicMock()
+ glue.hook.conn = mock_glue_client
+ mock_glue_client.get_job_run.return_value = {
+ "JobRun": {
+ "JobRunState": "SUCCEEDED",
+ }
+ }
+
+ # Mock initialize_job to return a new job run ID
+ new_job_run_id = "new_run_67890"
+ mock_initialize_job.return_value = {
+ "JobRunState": "RUNNING",
+ "JobRunId": new_job_run_id,
+ }
+
+ # Execute the operator
+ glue.execute(mock_context)
+
+ # Verify that a new job_run_id was created
+ assert glue._job_run_id == new_job_run_id
+ # Verify that initialize_job was called
+ mock_initialize_job.assert_called_once()
+ # Verify that the new job_run_id was pushed to XCom
+ xcom_calls = [
+ call for call in mock_ti.xcom_push.call_args_list if
call[1].get("key") == "glue_job_run_id"
+ ]
+ assert len(xcom_calls) == 1, "Should push new glue_job_run_id"
+ assert xcom_calls[0][1]["value"] == new_job_run_id
+
+ @mock.patch.object(GlueJobHook, "get_conn")
+ @mock.patch.object(GlueJobHook, "initialize_job")
+ def test_resume_glue_job_on_retry_find_job_run_by_task_uuid(self,
mock_initialize_job, mock_get_conn):
+ """Test that when XCom is missing, job run is found by task UUID."""
+ glue = GlueJobOperator(
+ task_id=TASK_ID,
+ job_name=JOB_NAME,
+ script_location="s3://folder/file",
+ aws_conn_id="aws_default",
+ region_name="us-west-2",
+ s3_bucket="some_bucket",
+ iam_role_name="my_test_role",
+ resume_glue_job_on_retry=True,
+ wait_for_completion=False,
+ )
+
+ mock_ti = mock.MagicMock()
+ mock_ti.dag_id = "test_dag_id"
+ mock_ti.task_id = TASK_ID
+ mock_ti.run_id = "manual__2024-01-01T00:00:00+00:00"
+ mock_ti.map_index = -1
+ mock_ti.xcom_pull.return_value = None
+ mock_context = {"ti": mock_ti}
+
+ task_uuid =
f"{mock_ti.dag_id}:{mock_ti.task_id}:{mock_ti.run_id}:{mock_ti.map_index}"
+
+ mock_glue_client = mock.MagicMock()
+ glue.hook.conn = mock_glue_client
+ mock_glue_client.get_job_runs.return_value = {
+ "JobRuns": [
+ {
+ "Id": "existing_run_123",
+ "Arguments": {GlueJobOperator.TASK_UUID_ARG: task_uuid},
+ "JobRunState": "STARTING",
+ }
+ ]
+ }
+ mock_glue_client.get_job_run.return_value = {
+ "JobRun": {
+ "JobRunState": "RUNNING",
+ }
+ }
+
+ glue.execute(mock_context)
+
+ assert glue._job_run_id == "existing_run_123"
+ mock_initialize_job.assert_not_called()
+ xcom_calls = [
+ call for call in mock_ti.xcom_push.call_args_list if
call[1].get("key") == "glue_job_run_id"
+ ]
+ assert len(xcom_calls) == 1, "Should push existing glue_job_run_id
when found by task UUID"
+ assert xcom_calls[0][1]["value"] == "existing_run_123"
+
class TestGlueDataQualityOperator:
RULE_SET_NAME = "TestRuleSet"