This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 08aee444b4e fix: `resume_glue_job_on_retry` with `xcom_push` in 
`GlueJobOperator` (#62560)
08aee444b4e is described below

commit 08aee444b4e1cfeb5b76a64eaec973e61b2b31a8
Author: Henry Chen <[email protected]>
AuthorDate: Thu Mar 12 03:27:22 2026 +0800

    fix: `resume_glue_job_on_retry` with `xcom_push` in `GlueJobOperator` 
(#62560)
    
    * GlueJobOperator: Recover job run via task UUID when XCom is missing
    
    * add resume_glue_job_on_retry back
---
 .../airflow/providers/amazon/aws/operators/glue.py |  86 +++++++++++-
 .../tests/unit/amazon/aws/operators/test_glue.py   | 154 +++++++++++++++++++++
 2 files changed, 233 insertions(+), 7 deletions(-)

diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py 
b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py
index 8e9a6deab44..ce96a76b3d0 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py
@@ -112,6 +112,7 @@ class GlueJobOperator(AwsBaseOperator[GlueJobHook]):
     ui_color = "#ededed"
 
     operator_extra_links = (GlueJobRunDetailsLink(),)
+    TASK_UUID_ARG = "--airflow_task_uuid"
 
     def __init__(
         self,
@@ -138,6 +139,7 @@ class GlueJobOperator(AwsBaseOperator[GlueJobHook]):
         job_poll_interval: int | float = 6,
         waiter_delay: int = 60,
         waiter_max_attempts: int = 75,
+        resume_glue_job_on_retry: bool = False,
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -167,6 +169,7 @@ class GlueJobOperator(AwsBaseOperator[GlueJobHook]):
         self.s3_script_location: str | None = None
         self.waiter_delay = waiter_delay
         self.waiter_max_attempts = waiter_max_attempts
+        self.resume_glue_job_on_retry = resume_glue_job_on_retry
 
     @property
     def _hook_parameters(self):
@@ -210,19 +213,88 @@ class GlueJobOperator(AwsBaseOperator[GlueJobHook]):
         )
         self.s3_script_location = 
f"s3://{self.s3_bucket}/{self.s3_artifacts_prefix}{script_name}"
 
+    def _get_task_uuid(self, context: Context) -> str:
+        ti = context["ti"]
+        map_index = getattr(ti, "map_index", -1)
+        if map_index is None:
+            map_index = -1
+        return f"{ti.dag_id}:{ti.task_id}:{ti.run_id}:{map_index}"
+
+    def _prepare_script_args_with_task_uuid(self, context: Context) -> 
tuple[dict, str]:
+        script_args = dict(self.script_args or {})
+        if self.TASK_UUID_ARG in script_args:
+            task_uuid = str(script_args[self.TASK_UUID_ARG])
+        else:
+            task_uuid = self._get_task_uuid(context)
+            script_args[self.TASK_UUID_ARG] = task_uuid
+        return script_args, task_uuid
+
+    def _find_job_run_id_by_task_uuid(self, task_uuid: str) -> tuple[str, str] 
| None:
+        next_token: str | None = None
+        while True:
+            request = {"JobName": self.job_name, "MaxResults": 50}
+            if next_token:
+                request["NextToken"] = next_token
+            response = self.hook.conn.get_job_runs(**request)
+            for job_run in response.get("JobRuns", []):
+                args = job_run.get("Arguments", {}) or {}
+                if args.get(self.TASK_UUID_ARG) == task_uuid:
+                    job_run_id = job_run.get("Id")
+                    job_run_state = job_run.get("JobRunState")
+                    if job_run_id and job_run_state:
+                        return job_run_id, job_run_state
+            next_token = response.get("NextToken")
+            if not next_token:
+                return None
+
     def execute(self, context: Context):
         """
         Execute AWS Glue Job from Airflow.
 
         :return: the current Glue job ID.
         """
-        self.log.info(
-            "Initializing AWS Glue Job: %s. Wait for completion: %s",
-            self.job_name,
-            self.wait_for_completion,
-        )
-        glue_job_run = self.hook.initialize_job(self.script_args, 
self.run_job_kwargs)
-        self._job_run_id = glue_job_run["JobRunId"]
+        previous_job_run_id = None
+        script_args = self.script_args
+        task_uuid = None
+        if self.resume_glue_job_on_retry:
+            ti = context["ti"]
+            script_args, task_uuid = 
self._prepare_script_args_with_task_uuid(context)
+            previous_job_run_id = ti.xcom_pull(key="glue_job_run_id", 
task_ids=ti.task_id)
+            if previous_job_run_id:
+                try:
+                    job_run = 
self.hook.conn.get_job_run(JobName=self.job_name, RunId=previous_job_run_id)
+                    state = job_run.get("JobRun", {}).get("JobRunState")
+                    self.log.info("Previous Glue job_run_id: %s, state: %s", 
previous_job_run_id, state)
+                    if state in ("RUNNING", "STARTING"):
+                        self._job_run_id = previous_job_run_id
+                except Exception:
+                    self.log.warning("Failed to get previous Glue job run 
state", exc_info=True)
+            elif task_uuid:
+                try:
+                    existing = self._find_job_run_id_by_task_uuid(task_uuid)
+                    if existing:
+                        existing_job_run_id, existing_job_run_state = existing
+                        self.log.info(
+                            "Found Glue job_run_id by task UUID: %s, state: 
%s",
+                            existing_job_run_id,
+                            existing_job_run_state,
+                        )
+                        if existing_job_run_state in ("RUNNING", "STARTING"):
+                            self._job_run_id = existing_job_run_id
+                            ti.xcom_push(key="glue_job_run_id", 
value=self._job_run_id)
+                except Exception:
+                    self.log.warning("Failed to find previous Glue job run by 
task UUID", exc_info=True)
+
+        if not self._job_run_id:
+            self.log.info(
+                "Initializing AWS Glue Job: %s. Wait for completion: %s",
+                self.job_name,
+                self.wait_for_completion,
+            )
+            glue_job_run = self.hook.initialize_job(script_args, 
self.run_job_kwargs)
+            self._job_run_id = glue_job_run["JobRunId"]
+            context["ti"].xcom_push(key="glue_job_run_id", 
value=self._job_run_id)
+
         glue_job_run_url = GlueJobRunDetailsLink.format_str.format(
             
aws_domain=GlueJobRunDetailsLink.get_aws_domain(self.hook.conn_partition),
             region_name=self.hook.conn_region_name,
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
index fedf55431a6..e7d15fbbaa8 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
@@ -432,6 +432,160 @@ class TestGlueJobOperator:
         )
         assert op.hook.aws_conn_id == DEFAULT_CONN
 
+    @mock.patch.object(GlueJobHook, "get_conn")
+    @mock.patch.object(GlueJobHook, "initialize_job")
+    def test_check_previous_job_id_run_reuse_in_progress(self, 
mock_initialize_job, mock_get_conn):
+        """Test that when resume_glue_job_on_retry=True and previous job is in 
progress, it is reused."""
+        glue = GlueJobOperator(
+            task_id=TASK_ID,
+            job_name=JOB_NAME,
+            script_location="s3://folder/file",
+            aws_conn_id="aws_default",
+            region_name="us-west-2",
+            s3_bucket="some_bucket",
+            iam_role_name="my_test_role",
+            resume_glue_job_on_retry=True,
+            wait_for_completion=False,
+        )
+
+        # Mock the context and task instance
+        mock_ti = mock.MagicMock()
+        mock_context = {"ti": mock_ti}
+
+        # Simulate previous job_run_id in XCom
+        previous_job_run_id = "previous_run_12345"
+        mock_ti.xcom_pull.return_value = previous_job_run_id
+
+        # Mock the Glue client to return RUNNING state for the previous job
+        mock_glue_client = mock.MagicMock()
+        glue.hook.conn = mock_glue_client
+        mock_glue_client.get_job_run.return_value = {
+            "JobRun": {
+                "JobRunState": "RUNNING",
+            }
+        }
+
+        # Execute the operator
+        glue.execute(mock_context)
+
+        # Verify that the previous job_run_id was reused
+        assert glue._job_run_id == previous_job_run_id
+        # Verify that initialize_job was NOT called
+        mock_initialize_job.assert_not_called()
+        # Verify that XCom push was not called for glue_job_run_id (since we 
reused the previous one)
+        # Note: xcom_push may be called for other purposes like 
glue_job_run_details
+        xcom_calls = [
+            call for call in mock_ti.xcom_push.call_args_list if 
call[1].get("key") == "glue_job_run_id"
+        ]
+        assert len(xcom_calls) == 0, "Should not push new glue_job_run_id when 
reusing previous one"
+
+    @mock.patch.object(GlueJobHook, "get_conn")
+    @mock.patch.object(GlueJobHook, "initialize_job")
+    def test_check_previous_job_id_run_new_on_finished(self, 
mock_initialize_job, mock_get_conn):
+        """Test that when previous job is finished, a new job is started and 
pushed to XCom."""
+        glue = GlueJobOperator(
+            task_id=TASK_ID,
+            job_name=JOB_NAME,
+            script_location="s3://folder/file",
+            aws_conn_id="aws_default",
+            region_name="us-west-2",
+            s3_bucket="some_bucket",
+            iam_role_name="my_test_role",
+            resume_glue_job_on_retry=True,
+            wait_for_completion=False,
+        )
+
+        # Mock the context and task instance
+        mock_ti = mock.MagicMock()
+        mock_context = {"ti": mock_ti}
+
+        # Simulate previous job_run_id in XCom
+        previous_job_run_id = "previous_run_12345"
+        mock_ti.xcom_pull.return_value = previous_job_run_id
+
+        # Mock the Glue client to return SUCCEEDED state for the previous job
+        mock_glue_client = mock.MagicMock()
+        glue.hook.conn = mock_glue_client
+        mock_glue_client.get_job_run.return_value = {
+            "JobRun": {
+                "JobRunState": "SUCCEEDED",
+            }
+        }
+
+        # Mock initialize_job to return a new job run ID
+        new_job_run_id = "new_run_67890"
+        mock_initialize_job.return_value = {
+            "JobRunState": "RUNNING",
+            "JobRunId": new_job_run_id,
+        }
+
+        # Execute the operator
+        glue.execute(mock_context)
+
+        # Verify that a new job_run_id was created
+        assert glue._job_run_id == new_job_run_id
+        # Verify that initialize_job was called
+        mock_initialize_job.assert_called_once()
+        # Verify that the new job_run_id was pushed to XCom
+        xcom_calls = [
+            call for call in mock_ti.xcom_push.call_args_list if 
call[1].get("key") == "glue_job_run_id"
+        ]
+        assert len(xcom_calls) == 1, "Should push new glue_job_run_id"
+        assert xcom_calls[0][1]["value"] == new_job_run_id
+
+    @mock.patch.object(GlueJobHook, "get_conn")
+    @mock.patch.object(GlueJobHook, "initialize_job")
+    def test_resume_glue_job_on_retry_find_job_run_by_task_uuid(self, 
mock_initialize_job, mock_get_conn):
+        """Test that when XCom is missing, job run is found by task UUID."""
+        glue = GlueJobOperator(
+            task_id=TASK_ID,
+            job_name=JOB_NAME,
+            script_location="s3://folder/file",
+            aws_conn_id="aws_default",
+            region_name="us-west-2",
+            s3_bucket="some_bucket",
+            iam_role_name="my_test_role",
+            resume_glue_job_on_retry=True,
+            wait_for_completion=False,
+        )
+
+        mock_ti = mock.MagicMock()
+        mock_ti.dag_id = "test_dag_id"
+        mock_ti.task_id = TASK_ID
+        mock_ti.run_id = "manual__2024-01-01T00:00:00+00:00"
+        mock_ti.map_index = -1
+        mock_ti.xcom_pull.return_value = None
+        mock_context = {"ti": mock_ti}
+
+        task_uuid = 
f"{mock_ti.dag_id}:{mock_ti.task_id}:{mock_ti.run_id}:{mock_ti.map_index}"
+
+        mock_glue_client = mock.MagicMock()
+        glue.hook.conn = mock_glue_client
+        mock_glue_client.get_job_runs.return_value = {
+            "JobRuns": [
+                {
+                    "Id": "existing_run_123",
+                    "Arguments": {GlueJobOperator.TASK_UUID_ARG: task_uuid},
+                    "JobRunState": "STARTING",
+                }
+            ]
+        }
+        mock_glue_client.get_job_run.return_value = {
+            "JobRun": {
+                "JobRunState": "RUNNING",
+            }
+        }
+
+        glue.execute(mock_context)
+
+        assert glue._job_run_id == "existing_run_123"
+        mock_initialize_job.assert_not_called()
+        xcom_calls = [
+            call for call in mock_ti.xcom_push.call_args_list if 
call[1].get("key") == "glue_job_run_id"
+        ]
+        assert len(xcom_calls) == 1, "Should push existing glue_job_run_id 
when found by task UUID"
+        assert xcom_calls[0][1]["value"] == "existing_run_123"
+
 
 class TestGlueDataQualityOperator:
     RULE_SET_NAME = "TestRuleSet"

Reply via email to