Copilot commented on code in PR #62560:
URL: https://github.com/apache/airflow/pull/62560#discussion_r2906152464


##########
providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py:
##########
@@ -210,19 +211,88 @@ def upload_etl_script_to_s3(self):
         )
         self.s3_script_location = 
f"s3://{self.s3_bucket}/{self.s3_artifacts_prefix}{script_name}"
 
+    def _get_task_uuid(self, context: Context) -> str:
+        ti = context["ti"]
+        map_index = getattr(ti, "map_index", -1)
+        if map_index is None:
+            map_index = -1
+        return f"{ti.dag_id}:{ti.task_id}:{ti.run_id}:{map_index}"
+
+    def _prepare_script_args_with_task_uuid(self, context: Context) -> 
tuple[dict, str]:
+        script_args = dict(self.script_args or {})
+        if self.TASK_UUID_ARG in script_args:
+            task_uuid = str(script_args[self.TASK_UUID_ARG])
+        else:
+            task_uuid = self._get_task_uuid(context)
+            script_args[self.TASK_UUID_ARG] = task_uuid
+        return script_args, task_uuid
+
+    def _find_job_run_id_by_task_uuid(self, task_uuid: str) -> tuple[str, str] 
| None:
+        next_token: str | None = None
+        while True:
+            request = {"JobName": self.job_name, "MaxResults": 50}
+            if next_token:
+                request["NextToken"] = next_token
+            response = self.hook.conn.get_job_runs(**request)
+            for job_run in response.get("JobRuns", []):
+                args = job_run.get("Arguments", {}) or {}
+                if args.get(self.TASK_UUID_ARG) == task_uuid:
+                    job_run_id = job_run.get("Id")
+                    job_run_state = job_run.get("JobRunState")
+                    if job_run_id and job_run_state:
+                        return job_run_id, job_run_state
+            next_token = response.get("NextToken")
+            if not next_token:
+                return None
+
     def execute(self, context: Context):
         """
         Execute AWS Glue Job from Airflow.
 
         :return: the current Glue job ID.
         """
-        self.log.info(
-            "Initializing AWS Glue Job: %s. Wait for completion: %s",
-            self.job_name,
-            self.wait_for_completion,
-        )
-        glue_job_run = self.hook.initialize_job(self.script_args, 
self.run_job_kwargs)
-        self._job_run_id = glue_job_run["JobRunId"]
+        previous_job_run_id = None
+        script_args = self.script_args
+        task_uuid = None
+        if self.resume_glue_job_on_retry:
+            ti = context["ti"]
+            script_args, task_uuid = 
self._prepare_script_args_with_task_uuid(context)
+            previous_job_run_id = ti.xcom_pull(key="glue_job_run_id", 
task_ids=ti.task_id)
+            if previous_job_run_id:
+                try:
+                    job_run = 
self.hook.conn.get_job_run(JobName=self.job_name, RunId=previous_job_run_id)
+                    state = job_run.get("JobRun", {}).get("JobRunState")
+                    self.log.info("Previous Glue job_run_id: %s, state: %s", 
previous_job_run_id, state)
+                    if state in ("RUNNING", "STARTING"):
+                        self._job_run_id = previous_job_run_id
+                except Exception:
+                    self.log.warning("Failed to get previous Glue job run 
state", exc_info=True)
+            elif task_uuid:
+                try:
+                    existing = self._find_job_run_id_by_task_uuid(task_uuid)
+                    if existing:
+                        existing_job_run_id, existing_job_run_state = existing
+                        self.log.info(
+                            "Found Glue job_run_id by task UUID: %s, state: 
%s",
+                            existing_job_run_id,
+                            existing_job_run_state,
+                        )
+                        if existing_job_run_state in ("RUNNING", "STARTING"):
+                            self._job_run_id = existing_job_run_id
+                            ti.xcom_push(key="glue_job_run_id", 
value=self._job_run_id)
+                except Exception:
+                    self.log.warning("Failed to find previous Glue job run by 
task UUID", exc_info=True)

Review Comment:
   Same as above: the broad `except Exception` around 
`_find_job_run_id_by_task_uuid()` risks masking unexpected errors and falling 
back to starting a new job run. Prefer catching only boto/glue API exceptions 
and allow other exceptions to fail the task so they can be fixed.



##########
providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py:
##########
@@ -210,19 +211,88 @@ def upload_etl_script_to_s3(self):
         )
         self.s3_script_location = 
f"s3://{self.s3_bucket}/{self.s3_artifacts_prefix}{script_name}"
 
+    def _get_task_uuid(self, context: Context) -> str:
+        ti = context["ti"]
+        map_index = getattr(ti, "map_index", -1)
+        if map_index is None:
+            map_index = -1
+        return f"{ti.dag_id}:{ti.task_id}:{ti.run_id}:{map_index}"
+
+    def _prepare_script_args_with_task_uuid(self, context: Context) -> 
tuple[dict, str]:
+        script_args = dict(self.script_args or {})
+        if self.TASK_UUID_ARG in script_args:
+            task_uuid = str(script_args[self.TASK_UUID_ARG])
+        else:
+            task_uuid = self._get_task_uuid(context)
+            script_args[self.TASK_UUID_ARG] = task_uuid
+        return script_args, task_uuid
+
+    def _find_job_run_id_by_task_uuid(self, task_uuid: str) -> tuple[str, str] 
| None:
+        next_token: str | None = None
+        while True:
+            request = {"JobName": self.job_name, "MaxResults": 50}
+            if next_token:
+                request["NextToken"] = next_token
+            response = self.hook.conn.get_job_runs(**request)
+            for job_run in response.get("JobRuns", []):
+                args = job_run.get("Arguments", {}) or {}
+                if args.get(self.TASK_UUID_ARG) == task_uuid:
+                    job_run_id = job_run.get("Id")
+                    job_run_state = job_run.get("JobRunState")
+                    if job_run_id and job_run_state:
+                        return job_run_id, job_run_state
+            next_token = response.get("NextToken")
+            if not next_token:
+                return None
+
     def execute(self, context: Context):
         """
         Execute AWS Glue Job from Airflow.
 
         :return: the current Glue job ID.
         """
-        self.log.info(
-            "Initializing AWS Glue Job: %s. Wait for completion: %s",
-            self.job_name,
-            self.wait_for_completion,
-        )
-        glue_job_run = self.hook.initialize_job(self.script_args, 
self.run_job_kwargs)
-        self._job_run_id = glue_job_run["JobRunId"]
+        previous_job_run_id = None
+        script_args = self.script_args
+        task_uuid = None
+        if self.resume_glue_job_on_retry:
+            ti = context["ti"]
+            script_args, task_uuid = 
self._prepare_script_args_with_task_uuid(context)
+            previous_job_run_id = ti.xcom_pull(key="glue_job_run_id", 
task_ids=ti.task_id)
+            if previous_job_run_id:
+                try:
+                    job_run = 
self.hook.conn.get_job_run(JobName=self.job_name, RunId=previous_job_run_id)
+                    state = job_run.get("JobRun", {}).get("JobRunState")
+                    self.log.info("Previous Glue job_run_id: %s, state: %s", 
previous_job_run_id, state)
+                    if state in ("RUNNING", "STARTING"):
+                        self._job_run_id = previous_job_run_id
+                except Exception:
+                    self.log.warning("Failed to get previous Glue job run 
state", exc_info=True)

Review Comment:
   The `except Exception` around `get_job_run` will also swallow programming 
errors (e.g., `KeyError`, `AttributeError`) and silently start a new run, 
potentially reintroducing the “duplicate job runs” problem. Catch boto 
exceptions more narrowly (e.g., `botocore.exceptions.ClientError` / 
`BotoCoreError`) and let unexpected exceptions propagate.



##########
providers/amazon/tests/unit/amazon/aws/operators/test_glue.py:
##########
@@ -432,6 +432,160 @@ def test_default_conn_passed_to_hook(self):
         )
         assert op.hook.aws_conn_id == DEFAULT_CONN
 
+    @mock.patch.object(GlueJobHook, "get_conn")
+    @mock.patch.object(GlueJobHook, "initialize_job")
+    def test_check_previous_job_id_run_reuse_in_progress(self, 
mock_initialize_job, mock_get_conn):
+        """Test that when resume_glue_job_on_retry=True and previous job is in 
progress, it is reused."""
+        glue = GlueJobOperator(
+            task_id=TASK_ID,
+            job_name=JOB_NAME,
+            script_location="s3://folder/file",
+            aws_conn_id="aws_default",
+            region_name="us-west-2",
+            s3_bucket="some_bucket",
+            iam_role_name="my_test_role",
+            resume_glue_job_on_retry=True,
+            wait_for_completion=False,
+        )

Review Comment:
   These tests construct `GlueJobOperator(..., resume_glue_job_on_retry=True)`, 
but the operator currently doesn’t accept that argument (and `execute()` relies 
on `self.resume_glue_job_on_retry`). The PR needs to add the 
`resume_glue_job_on_retry` init parameter / attribute, otherwise these tests 
(and user code) will fail at operator instantiation.



##########
providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py:
##########
@@ -210,19 +211,88 @@ def upload_etl_script_to_s3(self):
         )
         self.s3_script_location = 
f"s3://{self.s3_bucket}/{self.s3_artifacts_prefix}{script_name}"
 
+    def _get_task_uuid(self, context: Context) -> str:
+        ti = context["ti"]
+        map_index = getattr(ti, "map_index", -1)
+        if map_index is None:
+            map_index = -1
+        return f"{ti.dag_id}:{ti.task_id}:{ti.run_id}:{map_index}"
+
+    def _prepare_script_args_with_task_uuid(self, context: Context) -> 
tuple[dict, str]:
+        script_args = dict(self.script_args or {})
+        if self.TASK_UUID_ARG in script_args:
+            task_uuid = str(script_args[self.TASK_UUID_ARG])
+        else:
+            task_uuid = self._get_task_uuid(context)
+            script_args[self.TASK_UUID_ARG] = task_uuid
+        return script_args, task_uuid
+
+    def _find_job_run_id_by_task_uuid(self, task_uuid: str) -> tuple[str, str] 
| None:
+        next_token: str | None = None
+        while True:
+            request = {"JobName": self.job_name, "MaxResults": 50}
+            if next_token:
+                request["NextToken"] = next_token
+            response = self.hook.conn.get_job_runs(**request)
+            for job_run in response.get("JobRuns", []):
+                args = job_run.get("Arguments", {}) or {}
+                if args.get(self.TASK_UUID_ARG) == task_uuid:
+                    job_run_id = job_run.get("Id")
+                    job_run_state = job_run.get("JobRunState")
+                    if job_run_id and job_run_state:
+                        return job_run_id, job_run_state
+            next_token = response.get("NextToken")
+            if not next_token:
+                return None
+
     def execute(self, context: Context):
         """
         Execute AWS Glue Job from Airflow.
 
         :return: the current Glue job ID.
         """
-        self.log.info(
-            "Initializing AWS Glue Job: %s. Wait for completion: %s",
-            self.job_name,
-            self.wait_for_completion,
-        )
-        glue_job_run = self.hook.initialize_job(self.script_args, 
self.run_job_kwargs)
-        self._job_run_id = glue_job_run["JobRunId"]
+        previous_job_run_id = None
+        script_args = self.script_args
+        task_uuid = None
+        if self.resume_glue_job_on_retry:
+            ti = context["ti"]
+            script_args, task_uuid = 
self._prepare_script_args_with_task_uuid(context)
+            previous_job_run_id = ti.xcom_pull(key="glue_job_run_id", 
task_ids=ti.task_id)

Review Comment:
   `execute()` uses `self.resume_glue_job_on_retry`, but this class doesn’t 
define the attribute anywhere (and `__init__` doesn’t accept a 
`resume_glue_job_on_retry` kwarg). As written, this will raise `AttributeError` 
(or `TypeError` during init if the kwarg is passed). Define a keyword-only 
`resume_glue_job_on_retry: bool = False` parameter in `__init__` and set 
`self.resume_glue_job_on_retry` accordingly (and ensure it’s not forwarded to 
`super().__init__`).



##########
providers/amazon/tests/unit/amazon/aws/operators/test_glue.py:
##########
@@ -432,6 +432,160 @@ def test_default_conn_passed_to_hook(self):
         )
         assert op.hook.aws_conn_id == DEFAULT_CONN
 
+    @mock.patch.object(GlueJobHook, "get_conn")
+    @mock.patch.object(GlueJobHook, "initialize_job")
+    def test_check_previous_job_id_run_reuse_in_progress(self, 
mock_initialize_job, mock_get_conn):
+        """Test that when resume_glue_job_on_retry=True and previous job is in 
progress, it is reused."""
+        glue = GlueJobOperator(
+            task_id=TASK_ID,
+            job_name=JOB_NAME,
+            script_location="s3://folder/file",
+            aws_conn_id="aws_default",
+            region_name="us-west-2",
+            s3_bucket="some_bucket",
+            iam_role_name="my_test_role",
+            resume_glue_job_on_retry=True,
+            wait_for_completion=False,
+        )
+
+        # Mock the context and task instance
+        mock_ti = mock.MagicMock()
+        mock_context = {"ti": mock_ti}
+
+        # Simulate previous job_run_id in XCom
+        previous_job_run_id = "previous_run_12345"
+        mock_ti.xcom_pull.return_value = previous_job_run_id
+
+        # Mock the Glue client to return RUNNING state for the previous job
+        mock_glue_client = mock.MagicMock()
+        glue.hook.conn = mock_glue_client
+        mock_glue_client.get_job_run.return_value = {
+            "JobRun": {

Review Comment:
   New tests use `mock.MagicMock()` for the task instance and Glue client 
without a `spec`/`autospec`. Using a spec (e.g., `create_autospec(TaskInstance, 
instance=True)` and a minimal spec for the Glue client methods you call) helps 
prevent tests from passing when the production code calls a 
misspelled/nonexistent attribute.



##########
providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py:
##########
@@ -210,19 +211,88 @@ def upload_etl_script_to_s3(self):
         )
         self.s3_script_location = 
f"s3://{self.s3_bucket}/{self.s3_artifacts_prefix}{script_name}"
 
+    def _get_task_uuid(self, context: Context) -> str:
+        ti = context["ti"]
+        map_index = getattr(ti, "map_index", -1)
+        if map_index is None:
+            map_index = -1
+        return f"{ti.dag_id}:{ti.task_id}:{ti.run_id}:{map_index}"
+
+    def _prepare_script_args_with_task_uuid(self, context: Context) -> 
tuple[dict, str]:
+        script_args = dict(self.script_args or {})
+        if self.TASK_UUID_ARG in script_args:
+            task_uuid = str(script_args[self.TASK_UUID_ARG])
+        else:
+            task_uuid = self._get_task_uuid(context)
+            script_args[self.TASK_UUID_ARG] = task_uuid
+        return script_args, task_uuid
+
+    def _find_job_run_id_by_task_uuid(self, task_uuid: str) -> tuple[str, str] 
| None:
+        next_token: str | None = None
+        while True:
+            request = {"JobName": self.job_name, "MaxResults": 50}
+            if next_token:
+                request["NextToken"] = next_token
+            response = self.hook.conn.get_job_runs(**request)
+            for job_run in response.get("JobRuns", []):
+                args = job_run.get("Arguments", {}) or {}
+                if args.get(self.TASK_UUID_ARG) == task_uuid:
+                    job_run_id = job_run.get("Id")
+                    job_run_state = job_run.get("JobRunState")
+                    if job_run_id and job_run_state:
+                        return job_run_id, job_run_state
+            next_token = response.get("NextToken")
+            if not next_token:
+                return None

Review Comment:
   `_find_job_run_id_by_task_uuid()` paginates through *all* Glue job runs 
until it finds a match. For long-lived jobs with many historical runs this can 
be very slow and can add noticeable latency/cost to retries. Consider bounding 
the scan (e.g., stop after N pages / N runs, or accept a `max_runs_to_check` 
parameter and document it as a “recent runs” search).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to