This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f3ca1c27abd Crash recovery for YARN cluster mode in 
SparkSubmitOperator built on AIP-103 (#67473)
f3ca1c27abd is described below

commit f3ca1c27abd3344750b3e3baa21bd4bc07e60b79
Author: Amogh Desai <[email protected]>
AuthorDate: Wed Jun 10 10:51:21 2026 +0530

    Crash recovery for YARN cluster mode in SparkSubmitOperator built on 
AIP-103 (#67473)
---
 docs/spelling_wordlist.txt                         |   1 +
 .../providers/apache/spark/hooks/spark_submit.py   |  65 ++++++--
 .../apache/spark/operators/spark_submit.py         |  64 ++++++--
 .../unit/apache/spark/hooks/test_spark_submit.py   | 132 ++++++++---------
 .../apache/spark/operators/test_spark_submit.py    | 163 ++++++++++++++++++++-
 5 files changed, 321 insertions(+), 104 deletions(-)

diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index dac67989526..b0c7bea4c81 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -646,6 +646,7 @@ Filesystem
 filesystem
 filesystems
 filetype
+finalStatus
 fips
 firebase
 Firehose
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
index 7cf1f3248ad..3a19950696a 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -306,6 +306,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         self._deploy_mode = deploy_mode
         self._connection = self._resolve_connection()
         self._is_yarn = "yarn" in self._connection["master"]
+        self._is_yarn_cluster_mode = self._is_yarn and 
self._connection["deploy_mode"] == "cluster"
         self._is_kubernetes = "k8s" in self._connection["master"]
         if self._is_kubernetes and kube_client is None:
             raise RuntimeError(
@@ -786,12 +787,8 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                 )
 
             if self._should_track_yarn_application_via_rm_api():
-                # Once spark-submit exits successfully, rely on RM REST API 
polling instead
-                # of requiring a particular Spark log line such as "Submitted 
application ...".
-                # The RM REST API is the authoritative source for the 
application's lifecycle.
                 if not self._yarn_application_id:
                     raise RuntimeError("No YARN application id found after 
spark-submit completed.")
-                self._track_yarn_application(self._yarn_application_id)
                 return self._driver_id
 
             if self._should_track_driver_status and self._driver_id is None:
@@ -871,18 +868,42 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
 
             self.log.info(line)
 
-    def _track_yarn_application(self, application_id: str) -> None:
-        """Poll the YARN RM REST API until the application reaches a terminal 
state."""
+    def _start_yarn_application_status_tracking(self, application_id: str) -> 
None:
+        """
+        Poll the YARN ResourceManager REST API until the application reaches a 
terminal state.
+
+        Raises ``RuntimeError`` if the application fails or if too many 
consecutive RM
+        request failures occur.
+
+        Possible statuses (from YARN state + finalStatus):
+
+        NEW
+            Application has been created but not yet submitted to the scheduler
+        NEW_SAVING
+            Application metadata is being persisted before scheduling
+        SUBMITTED
+            Application has been submitted and is waiting to be scheduled
+        ACCEPTED
+            Application has been accepted by the scheduler and is queued
+        RUNNING
+            Application is actively executing on the cluster
+        SUCCEEDED
+            Application completed successfully (state=FINISHED, 
finalStatus=SUCCEEDED)
+        FAILED
+            Application terminated unsuccessfully — covers YARN state FAILED, 
KILLED,
+            or FINISHED with a non-SUCCEEDED finalStatus
+        """
         self.log.info(
             "Tracking YARN application %s via ResourceManager REST API 
polling",
             application_id,
         )
         poll_interval = max(self._status_poll_interval, 10)
-        # Tolerate transient RM REST API failures (RM hiccup, network blip, 
request
-        # timeout) the same way `_start_driver_status_tracking` does for spark
-        # standalone — only give up after this many consecutive failures.
         consecutive_failures = 0
         max_consecutive_failures = 10
+        heartbeat_interval = 10
+        poll_count = 0
+        last_state: str | None = None
+
         while True:
             self.log.debug("Polling YARN RM REST API for application %s", 
application_id)
             try:
@@ -904,13 +925,20 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                 time.sleep(poll_interval)
                 continue
             consecutive_failures = 0
+            poll_count += 1
+
+            if state != last_state:
+                self.log.info("YARN application %s status: %s", 
application_id, state)
+                last_state = state
+            elif poll_count % heartbeat_interval == 0:
+                self.log.info("YARN application %s is still %s", 
application_id, state)
+
             if state in self._YARN_FINAL_FAILURES:
                 raise RuntimeError(
                     f"YARN application {application_id} ended with state: 
{state}, "
                     f"final status: {final_status}"
                 )
             if final_status == self._YARN_FINAL_SUCCESS:
-                self.log.info("YARN application %s finished with SUCCEEDED", 
application_id)
                 return
             if final_status in self._YARN_FINAL_FAILURES:
                 raise RuntimeError(
@@ -1298,3 +1326,20 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
             self._kill_yarn_application(self._yarn_application_id)
 
         self._run_post_submit_commands()
+
+    def query_yarn_application_status(self, application_id: str) -> str:
+        """
+        Return a normalized single string status for the ResumableJobMixin 
interface.
+
+        - Active states (NEW, NEW_SAVING, SUBMITTED, ACCEPTED, RUNNING) are 
returned as-is.
+        - Terminal states are collapsed to "SUCCEEDED" or "FAILED" with the 
following rules:
+            - FINISHED + finalStatus SUCCEEDED -> "SUCCEEDED"
+            - FINISHED + any other finalStatus -> "FAILED"
+            - FAILED or KILLED -> "FAILED"
+        """
+        state, final_status = 
self._query_yarn_application_status(application_id)
+        if state in {"NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED", "RUNNING"}:
+            return state
+        if state == "FINISHED" and final_status == self._YARN_FINAL_SUCCESS:
+            return "SUCCEEDED"
+        return "FAILED"
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index 902d96a225b..ac9b550409f 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -274,11 +274,39 @@ class SparkSubmitOperator(ResumableJobMixin, 
BaseOperator):
             hook.submit(self.application)
             hook._poll_k8s_driver_via_api()
             return
+        if hook._is_yarn_cluster_mode:
+            if self.reconnect_on_retry and not hook._yarn_track_via_rm_api:
+                raise ValueError(
+                    "YARN cluster mode with reconnect_on_retry=True requires 
yarn_track_via_rm_api=True. "
+                    "The RM REST API is needed to check application status on 
retry."
+                )
+            if hook._yarn_track_via_rm_api:
+                hook._validate_yarn_track_via_rm_api_config()
+                if self.reconnect_on_retry:
+                    return self.execute_resumable(context)
+                # reconnect_on_retry=False: still submit-and-poll, just skip 
task_state persistence.
+                driver_id = self.submit_job(context)
+                self.poll_until_complete(driver_id, context)
+                return self.get_job_result(driver_id, context)
         hook.submit(self.application)
 
     def submit_job(self, context: Context) -> str:
         if self._hook is None:
             self._hook = self._get_hook()
+        if self._hook._is_yarn_cluster_mode:
+            if self._hook._conf.get("spark.yarn.submit.waitAppCompletion", 
"").strip().lower() == "true":
+                raise ValueError(
+                    "spark.yarn.submit.waitAppCompletion=true cannot be set 
for cluster mode as it conflicts"
+                    "with the need to exit spark-submit immediately to persist 
the application ID for tracking. "
+                    "Either remove the explicit conf or set 
reconnect_on_retry=False."
+                )
+            self._hook._conf["spark.yarn.submit.waitAppCompletion"] = "false"
+            self._hook.submit(self.application)
+            app_id = self._hook._yarn_application_id
+            if not app_id:
+                raise RuntimeError("spark-submit did not produce a YARN 
application ID")
+            self.log.info("YARN application submitted: %s", app_id)
+            return app_id
         driver_id = self._hook.submit(self.application)
         if not driver_id:
             raise RuntimeError("spark-submit did not return a driver ID")
@@ -290,15 +318,13 @@ class SparkSubmitOperator(ResumableJobMixin, 
BaseOperator):
         external_id = cast("str", external_id)
         if self._hook is None:
             self._hook = self._get_hook()
-        # The YARN and K8s branches below (and in is_job_active, 
is_job_succeeded, poll_until_complete)
-        # are currently unreachable: execute_resumable is only called when 
_should_track_driver_status
-        # is True, which requires spark:// + cluster mode. They are 
scaffolding for a follow-up PR
-        # that extends ResumableJobMixin support to YARN and Kubernetes.
-        if self._hook._is_yarn:
-            # TODO: call YARN ResourceManager REST API
-            # GET http://rm:8088/ws/v1/cluster/apps/{external_id}
-            raise NotImplementedError("YARN job status not yet implemented")
+        if self._hook._is_yarn_cluster_mode:
+            return self._hook.query_yarn_application_status(external_id)
         if self._hook._is_kubernetes:
+            # The K8s branches below (and in is_job_active, is_job_succeeded, 
poll_until_complete)
+            # are currently unreachable: execute_resumable is only called when 
_should_track_driver_status
+            # is True, which requires spark:// + cluster mode. They are 
scaffolding for a follow-up PR
+            # that extends ResumableJobMixin support to Kubernetes.
             # TODO: call K8s pod status API
             raise NotImplementedError("K8s job status not yet implemented")
         scheme = self._hook._connection.get("rest_scheme", "http")
@@ -338,9 +364,9 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
         if self._hook is None:
             self._hook = self._get_hook()
         status = status.upper()
-        if self._hook._is_yarn:
+        if self._hook._is_yarn_cluster_mode:
             # 
https://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/ResourceManagerRest.html
-            return status in ("NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED", 
"RUNNING")
+            return status in {"NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED", 
"RUNNING"}
         if self._hook._is_kubernetes:
             return status in ("PENDING", "RUNNING")
         # RELAUNCHING: driver is being restarted after a failure, still alive.
@@ -352,6 +378,8 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
         if self._hook is None:
             self._hook = self._get_hook()
         status = status.upper()
+        if self._hook._is_yarn_cluster_mode:
+            return status == "SUCCEEDED"
         if self._hook._is_kubernetes:
             return status == "SUCCEEDED"
         # standalone and YARN both use FINISHED
@@ -362,9 +390,12 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
         external_id = cast("str", external_id)
         if self._hook is None:
             self._hook = self._get_hook()
-        if self._hook._is_yarn:
-            # TODO: poll YARN ResourceManager until app reaches terminal state
-            raise NotImplementedError("YARN poll not yet implemented")
+        if self._hook._is_yarn_cluster_mode:
+            try:
+                self._hook._start_yarn_application_status_tracking(external_id)
+            finally:
+                self._hook._run_post_submit_commands()
+            return
         if self._hook._is_kubernetes:
             # TODO: poll K8s pod phase until terminal
             raise NotImplementedError("K8s poll not yet implemented")
@@ -384,7 +415,12 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
     def on_kill(self) -> None:
         if self._hook is None:
             self._hook = self._get_hook()
-        self._hook.on_kill()
+        if self._hook._is_yarn_cluster_mode and 
self._hook._yarn_application_id:
+            # spark-submit has already exited (waitAppCompletion=false), so 
the hook's
+            # CLI-based kill has nothing to terminate. Kill the YARN app via 
REST API instead.
+            self._hook._kill_yarn_application(self._hook._yarn_application_id)
+        else:
+            self._hook.on_kill()
 
     def _get_hook(self) -> SparkSubmitHook:
         return SparkSubmitHook(
diff --git 
a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py 
b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
index f4a610a9408..90f923905fa 100644
--- a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
+++ b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
@@ -1456,6 +1456,35 @@ class TestSparkSubmitHook:
         hook = SparkSubmitHook(conn_id="")
         assert hook._post_submit_commands == []
 
+    @pytest.mark.parametrize(
+        ("state", "final_status", "expected"),
+        [
+            ("NEW", "UNDEFINED", "NEW"),
+            ("NEW_SAVING", "UNDEFINED", "NEW_SAVING"),
+            ("SUBMITTED", "UNDEFINED", "SUBMITTED"),
+            ("ACCEPTED", "UNDEFINED", "ACCEPTED"),
+            ("RUNNING", "UNDEFINED", "RUNNING"),
+            ("FINISHED", "SUCCEEDED", "SUCCEEDED"),
+            ("FINISHED", "FAILED", "FAILED"),
+            ("FINISHED", "KILLED", "FAILED"),
+            ("FINISHED", "UNDEFINED", "FAILED"),
+            ("FAILED", "FAILED", "FAILED"),
+            ("FAILED", "KILLED", "FAILED"),
+            ("KILLED", "KILLED", "FAILED"),
+        ],
+    )
+    def test_query_yarn_application_status_state_mapping(self, state, 
final_status, expected):
+        hook = SparkSubmitHook(conn_id="")
+        mock_response = MagicMock(spec=requests.Response)
+        mock_response.status_code = 200
+        mock_response.json.return_value = {"app": {"state": state, 
"finalStatus": final_status}}
+
+        with patch(
+            "airflow.providers.apache.spark.hooks.spark_submit.requests.get", 
return_value=mock_response
+        ):
+            with patch.object(hook, "_get_yarn_rm_base_url", 
return_value="http://rm.example.com:8088";):
+                assert 
hook.query_yarn_application_status("application_1234_0001") == expected
+
     @pytest.mark.parametrize(
         ("conn_id", "flag", "expected"),
         [
@@ -1686,27 +1715,35 @@ class TestSparkSubmitHook:
         with pytest.raises(ValueError, match="requires 
`deploy_mode='cluster'`"):
             hook.submit()
 
-    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
     @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
     
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
-    def test_yarn_status_tracking_succeeds(self, mock_popen, mock_get, 
mock_sleep):
-        """RM returns UNDEFINED then SUCCEEDED -> hook returns normally."""
+    def test_yarn_submit_does_not_poll_rm_api(self, mock_popen, mock_get):
+        """hook.submit() with yarn_track_via_rm_api=True must NOT poll RM REST 
API."""
         proc = MagicMock(spec=["stdout", "terminate", "wait"])
         proc.stdout = iter(self._YARN_LOG_LINES)
         proc.wait.return_value = 0
         mock_popen.return_value = proc
 
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        hook.submit()
+
+        spark_submit_cmd = mock_popen.call_args.args[0]
+        assert "spark.yarn.submit.waitAppCompletion=false" in spark_submit_cmd
+        mock_get.assert_not_called()
+        assert hook._yarn_application_id == self._RM_APP_ID
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    def test_yarn_status_tracking_succeeds(self, mock_get, mock_sleep):
+        """RM returns UNDEFINED then SUCCEEDED -> 
_start_yarn_application_status_tracking returns normally."""
         mock_get.side_effect = [
             self._rm_status_resp("UNDEFINED", state="RUNNING"),
             self._rm_status_resp("SUCCEEDED"),
         ]
 
         hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
-        hook.submit()
+        hook._start_yarn_application_status_tracking(self._RM_APP_ID)
 
-        spark_submit_cmd = mock_popen.call_args.args[0]
-        assert "spark.yarn.submit.waitAppCompletion=false" in spark_submit_cmd
-        proc.terminate.assert_not_called()
         assert mock_get.call_count == 2
         mock_sleep.assert_called_once_with(10)
         for call_obj in mock_get.call_args_list:
@@ -1714,65 +1751,43 @@ class TestSparkSubmitHook:
 
     @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
     @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
-    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
-    def test_yarn_status_tracking_fails_on_killed(self, mock_popen, mock_get, 
mock_sleep):
+    def test_yarn_status_tracking_fails_on_killed(self, mock_get, mock_sleep):
         """RM returns KILLED -> raise with message containing app id and 
KILLED."""
-        proc = MagicMock(spec=["stdout", "terminate", "wait"])
-        proc.stdout = iter(self._YARN_LOG_LINES)
-        proc.wait.return_value = 0
-        mock_popen.return_value = proc
-
         mock_get.return_value = self._rm_status_resp("KILLED")
 
         hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
         with pytest.raises(RuntimeError, match=f"{self._RM_APP_ID}.*KILLED"):
-            hook.submit()
-        proc.terminate.assert_not_called()
+            hook._start_yarn_application_status_tracking(self._RM_APP_ID)
 
     @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
     @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
-    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
     def 
test_yarn_status_tracking_fails_on_failed_state_with_undefined_final_status(
-        self, mock_popen, mock_get, mock_sleep
+        self, mock_get, mock_sleep
     ):
         """RM state FAILED with finalStatus UNDEFINED should not poll 
forever."""
-        proc = MagicMock(spec=["stdout", "terminate", "wait"])
-        proc.stdout = iter(self._YARN_LOG_LINES)
-        proc.wait.return_value = 0
-        mock_popen.return_value = proc
-
         mock_get.return_value = self._rm_status_resp("UNDEFINED", 
state="FAILED")
 
         hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
         with pytest.raises(RuntimeError, match=f"{self._RM_APP_ID}.*state: 
FAILED"):
-            hook.submit()
+            hook._start_yarn_application_status_tracking(self._RM_APP_ID)
 
-        proc.terminate.assert_not_called()
         mock_sleep.assert_not_called()
 
     @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
     @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
-    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
-    def test_yarn_status_tracking_fails_on_unexpected_final_status(self, 
mock_popen, mock_get, mock_sleep):
+    def test_yarn_status_tracking_fails_on_unexpected_final_status(self, 
mock_get, mock_sleep):
         """RM returns a non-standard finalStatus ('BOGUS') -> raise without 
sleeping."""
-        proc = MagicMock(spec=["stdout", "terminate", "wait"])
-        proc.stdout = iter(self._YARN_LOG_LINES)
-        proc.wait.return_value = 0
-        mock_popen.return_value = proc
-
         mock_get.return_value = self._rm_status_resp("BOGUS")
 
         hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
         with pytest.raises(RuntimeError, match="unexpected final status: 
BOGUS"):
-            hook.submit()
+            hook._start_yarn_application_status_tracking(self._RM_APP_ID)
 
-        proc.terminate.assert_not_called()
         mock_sleep.assert_not_called()
 
-    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
     
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
-    def 
test_yarn_status_tracking_polls_without_application_submission_log(self, 
mock_popen, mock_get):
-        """Missing 'Submitted application' log line should not block RM REST 
polling."""
+    def 
test_yarn_submit_captures_app_id_without_submitted_application_log(self, 
mock_popen):
+        """App ID parsed from log lines other than 'Submitted application ...' 
is captured by hook.submit()."""
         yarn_log_lines = [
             "INFO Client: Uploading resource file:/tmp/lib.zip -> "
             
"hdfs://namenode:8020/user/root/.sparkStaging/application_1700000000000_0001/lib.zip",
@@ -1782,14 +1797,11 @@ class TestSparkSubmitHook:
         proc.stdout = iter(yarn_log_lines)
         proc.wait.return_value = 0
         mock_popen.return_value = proc
-        mock_get.return_value = self._rm_status_resp("SUCCEEDED")
 
         hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
         hook.submit()
 
         assert hook._yarn_application_id == self._RM_APP_ID
-        assert mock_get.call_args.args[0] == self._rm_status_url()
-        proc.terminate.assert_not_called()
 
     @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
     
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
@@ -1820,14 +1832,8 @@ class TestSparkSubmitHook:
 
     @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
     @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
-    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
-    def test_yarn_status_tracking_tolerates_transient_failures(self, 
mock_popen, mock_get, mock_sleep):
+    def test_yarn_status_tracking_tolerates_transient_failures(self, mock_get, 
mock_sleep):
         """3 consecutive 5xx responses then SUCCEEDED -> normal completion."""
-        proc = MagicMock(spec=["stdout", "terminate", "wait"])
-        proc.stdout = iter(self._YARN_LOG_LINES)
-        proc.wait.return_value = 0
-        mock_popen.return_value = proc
-
         # 3 transient failures (within the 10-failure budget), then SUCCEEDED.
         mock_get.side_effect = [
             self._rm_failure_resp(503, "Service Unavailable"),
@@ -1837,27 +1843,21 @@ class TestSparkSubmitHook:
         ]
 
         hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
-        hook.submit()
+        hook._start_yarn_application_status_tracking(self._RM_APP_ID)
 
         assert mock_get.call_count == 4
 
     @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
     @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
-    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
-    def test_yarn_status_tracking_tolerates_status_timeouts(self, mock_popen, 
mock_get, mock_sleep):
+    def test_yarn_status_tracking_tolerates_status_timeouts(self, mock_get, 
mock_sleep):
         """First requests.exceptions.Timeout, second call succeeds -> normal 
completion."""
-        proc = MagicMock(spec=["stdout", "terminate", "wait"])
-        proc.stdout = iter(self._YARN_LOG_LINES)
-        proc.wait.return_value = 0
-        mock_popen.return_value = proc
-
         mock_get.side_effect = [
             requests.exceptions.Timeout("read timed out"),
             self._rm_status_resp("SUCCEEDED"),
         ]
 
         hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
-        hook.submit()
+        hook._start_yarn_application_status_tracking(self._RM_APP_ID)
 
         assert mock_get.call_count == 2
         # All calls must include the (connect, read) timeout tuple.
@@ -1866,20 +1866,14 @@ class TestSparkSubmitHook:
 
     @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
     @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
-    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
-    def test_yarn_status_tracking_raises_after_too_many_failures(self, 
mock_popen, mock_get, mock_sleep):
+    def test_yarn_status_tracking_raises_after_too_many_failures(self, 
mock_get, mock_sleep):
         """11 consecutive 5xx responses -> raise 'Giving up tracking YARN 
application'."""
-        proc = MagicMock(spec=["stdout", "terminate", "wait"])
-        proc.stdout = iter(self._YARN_LOG_LINES)
-        proc.wait.return_value = 0
-        mock_popen.return_value = proc
-
         # 11 failures: 10 tolerated; the 11th trips the budget.
         mock_get.side_effect = [self._rm_failure_resp() for _ in range(11)]
 
         hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
         with pytest.raises(RuntimeError, match="Giving up tracking YARN 
application"):
-            hook.submit()
+            hook._start_yarn_application_status_tracking(self._RM_APP_ID)
 
         assert mock_get.call_count == 11
 
@@ -1987,18 +1981,12 @@ class TestSparkSubmitHook:
 
     @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
     @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
-    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
-    def test_yarn_rm_base_url_is_resolved_once_across_polling_loop(self, 
mock_popen, mock_get, mock_sleep):
+    def test_yarn_rm_base_url_is_resolved_once_across_polling_loop(self, 
mock_get, mock_sleep):
         """Connection lookup must run once even if the polling loop runs many 
iterations.
 
         Regression guard: a job polling every few seconds for hours must not 
re-fetch
         the Spark connection (and potentially re-hit a Secrets Backend) on 
every iteration.
         """
-        proc = MagicMock(spec=["stdout", "terminate", "wait"])
-        proc.stdout = iter(self._YARN_LOG_LINES)
-        proc.wait.return_value = 0
-        mock_popen.return_value = proc
-
         # 4 UNDEFINED iterations then SUCCEEDED -> 5 polling iterations total.
         mock_get.side_effect = [
             self._rm_status_resp("UNDEFINED", state="RUNNING"),
@@ -2010,7 +1998,7 @@ class TestSparkSubmitHook:
 
         hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
         with patch.object(hook, "get_connection", wraps=hook.get_connection) 
as spy_get_conn:
-            hook.submit()
+            hook._start_yarn_application_status_tracking(self._RM_APP_ID)
 
         assert mock_get.call_count == 5
         assert spy_get_conn.call_count == 1
diff --git 
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py 
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
index 95ad9f5142a..56b2ad3a409 100644
--- 
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
+++ 
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
@@ -495,7 +495,7 @@ class FakeTaskState:
 
 @pytest.mark.skipif(
     not AIRFLOW_V_3_3_PLUS,
-    reason="ResumableJobMixin reconnect requires task_state, available in 
Airflow 3.3+",
+    reason="ResumableJobMixin reconnect requires task_store, available in 
Airflow 3.3+",
 )
 class TestSparkSubmitOperatorResumable:
     def setup_method(self):
@@ -505,10 +505,12 @@ class TestSparkSubmitOperatorResumable:
     def _make_operator(self, **kwargs):
         return SparkSubmitOperator(task_id="test", dag=self.dag, 
application="test.jar", **kwargs)
 
-    def _make_hook(self, should_track=False, is_yarn=False, 
is_kubernetes=False):
+    def _make_hook(self, should_track=False, is_yarn=False, 
is_yarn_cluster=False, is_kubernetes=False):
         hook = MagicMock()
         hook._should_track_driver_status = should_track
+        hook._should_track_driver_via_k8s_api.return_value = False
         hook._is_yarn = is_yarn
+        hook._is_yarn_cluster_mode = is_yarn_cluster
         hook._is_kubernetes = is_kubernetes
         hook._connection = {"master": "spark://localhost:7077"}
         return hook
@@ -598,26 +600,35 @@ class TestSparkSubmitOperatorResumable:
         assert polled == ["driver-new"]
 
     @pytest.mark.parametrize(
-        ("is_yarn", "is_kubernetes", "status", "expected_active", 
"expected_succeeded"),
+        ("is_yarn_cluster", "is_kubernetes", "status", "expected_active", 
"expected_succeeded"),
         [
+            # Spark standalone cluster mode
             (False, False, "RUNNING", True, False),
             (False, False, "SUBMITTED", True, False),
+            (False, False, "RELAUNCHING", True, False),
+            (False, False, "UNKNOWN", True, False),
             (False, False, "FINISHED", False, True),
             (False, False, "FAILED", False, False),
-            (True, False, "RUNNING", True, False),
-            (True, False, "ACCEPTED", True, False),
+            # YARN cluster mode — synthesized statuses from 
query_yarn_application_status
             (True, False, "NEW", True, False),
-            (True, False, "FINISHED", False, True),
+            (True, False, "NEW_SAVING", True, False),
+            (True, False, "SUBMITTED", True, False),
+            (True, False, "ACCEPTED", True, False),
+            (True, False, "RUNNING", True, False),
+            (True, False, "SUCCEEDED", False, True),
             (True, False, "FAILED", False, False),
+            # Kubernetes
             (False, True, "Running", True, False),
             (False, True, "Pending", True, False),
             (False, True, "Succeeded", False, True),
             (False, True, "Failed", False, False),
         ],
     )
-    def test_job_status_mappings(self, is_yarn, is_kubernetes, status, 
expected_active, expected_succeeded):
+    def test_job_status_mappings(
+        self, is_yarn_cluster, is_kubernetes, status, expected_active, 
expected_succeeded
+    ):
         operator = self._make_operator()
-        operator._hook = self._make_hook(is_yarn=is_yarn, 
is_kubernetes=is_kubernetes)
+        operator._hook = self._make_hook(is_yarn_cluster=is_yarn_cluster, 
is_kubernetes=is_kubernetes)
 
         assert operator.is_job_active(status) == expected_active
         assert operator.is_job_succeeded(status) == expected_succeeded
@@ -715,6 +726,108 @@ class TestSparkSubmitOperatorResumable:
         assert len(captured_urls) == 1
         assert captured_urls[0].startswith("https://";)
 
+    def test_yarn_first_run_persists_app_id_before_polling(self):
+        operator = self._make_operator()
+        operator._hook = self._make_hook(is_yarn_cluster=True)
+        operator._hook._conf = {}
+        operator._hook._yarn_application_id = "application_1234_0001"
+        operator._hook.submit.return_value = None
+
+        task_store = FakeTaskState()
+        persisted_before_poll = []
+
+        def track_poll(external_id, context):
+            persisted_before_poll.append(task_store.get("spark_job_id"))
+
+        operator.poll_until_complete = track_poll
+        operator.execute(context={"task_store": task_store})
+
+        assert persisted_before_poll == ["application_1234_0001"]
+
+    def test_yarn_retry_reconnects_to_running_app(self):
+        operator = self._make_operator()
+        operator._hook = self._make_hook(is_yarn_cluster=True)
+        task_store = FakeTaskState({"spark_job_id": "application_1234_0001"})
+
+        operator.get_job_status = lambda external_id, context: "RUNNING"
+        polled = []
+        operator.poll_until_complete = lambda external_id, context: 
polled.append(external_id)
+
+        operator.execute(context={"task_store": task_store})
+
+        operator._hook.submit.assert_not_called()
+        assert polled == ["application_1234_0001"]
+
+    def test_yarn_retry_skips_already_succeeded_app(self):
+        operator = self._make_operator()
+        operator._hook = self._make_hook(is_yarn_cluster=True)
+        task_store = FakeTaskState({"spark_job_id": "application_1234_0001"})
+
+        operator.get_job_status = lambda external_id, context: "SUCCEEDED"
+
+        operator.execute(context={"task_store": task_store})
+
+        operator._hook.submit.assert_not_called()
+
+    def test_yarn_retry_resubmits_after_failed_app(self):
+        operator = self._make_operator()
+        operator._hook = self._make_hook(is_yarn_cluster=True)
+        operator._hook._conf = {}
+        operator._hook._yarn_application_id = "application_1234_0002"
+        operator._hook.submit.return_value = None
+        task_store = FakeTaskState({"spark_job_id": "application_1234_0001"})
+
+        operator.get_job_status = lambda external_id, context: "FAILED"
+        polled = []
+        operator.poll_until_complete = lambda external_id, context: 
polled.append(external_id)
+
+        operator.execute(context={"task_store": task_store})
+
+        operator._hook.submit.assert_called_once_with("test.jar")
+        assert polled == ["application_1234_0002"]
+
+    def test_yarn_injects_wait_app_completion_false(self):
+        operator = self._make_operator()
+        hook = self._make_hook(is_yarn_cluster=True)
+        hook._conf = {}
+        hook._yarn_application_id = "application_1234_0001"
+        hook.submit.return_value = None
+        operator._hook = hook
+
+        operator.submit_job(context={})
+
+        assert hook._conf.get("spark.yarn.submit.waitAppCompletion") == "false"
+
+    def test_yarn_raises_if_wait_app_completion_true(self):
+        operator = self._make_operator()
+        hook = self._make_hook(is_yarn_cluster=True)
+        hook._conf = {"spark.yarn.submit.waitAppCompletion": "true"}
+        operator._hook = hook
+
+        with pytest.raises(ValueError, match="waitAppCompletion=true"):
+            operator.submit_job(context={})
+
+    def test_yarn_poll_tolerates_transient_resourcemanager_failures(self):
+        operator = self._make_operator()
+        operator._hook = self._make_hook(is_yarn_cluster=True)
+        operator._hook._status_poll_interval = 0
+
+        call_count = 0
+
+        def flaky_status(external_id):
+            nonlocal call_count
+            call_count += 1
+            if call_count < 5:
+                raise RuntimeError("RM temporarily unavailable")
+            return "SUCCEEDED"
+
+        operator.get_job_status = flaky_status
+
+        with mock.patch("time.sleep"):
+            operator.poll_until_complete("application_1234_0001", context={})
+
+        operator._hook._run_post_submit_commands.assert_called_once()
+
     def test_poll_until_complete_runs_post_submit_on_failure(self):
         """post_submit_commands must run even when the driver exits with a 
failure status."""
         operator = self._make_operator()
@@ -733,6 +846,39 @@ class TestSparkSubmitOperatorResumable:
         with pytest.raises(RuntimeError, match="FAILED"):
             operator.poll_until_complete("driver-001", {})
 
+    def test_on_kill_sends_authenticated_kill_to_yarn_rm(self):
+        """operator.on_kill() must call _kill_yarn_application so Kerberos 
auth is applied."""
+        operator = self._make_operator()
+        hook = self._make_hook(is_yarn_cluster=True)
+        hook._is_yarn_cluster_mode = True
+        hook._yarn_application_id = "application_1234_0001"
+        operator._hook = hook
+
+        operator.on_kill()
+
+        
hook._kill_yarn_application.assert_called_once_with("application_1234_0001")
+
+    def test_yarn_cluster_reconnect_without_rm_api_raises(self):
+        """reconnect_on_retry=True + yarn_track_via_rm_api=False must raise - 
RM API is required for resume."""
+        operator = self._make_operator(reconnect_on_retry=True)
+        hook = self._make_hook(is_yarn_cluster=True)
+        hook._yarn_track_via_rm_api = False
+        operator._hook = hook
+
+        with pytest.raises(ValueError, match="yarn_track_via_rm_api=True"):
+            operator.execute(context={})
+
+    def 
test_yarn_cluster_without_rm_api_reconnect_false_falls_through_to_hook_submit(self):
+        """reconnect_on_retry=False + yarn_track_via_rm_api=False falls 
through to hook.submit() - no RM polling."""
+        operator = self._make_operator(reconnect_on_retry=False)
+        hook = self._make_hook(is_yarn_cluster=True)
+        hook._yarn_track_via_rm_api = False
+        operator._hook = hook
+
+        operator.execute(context={})
+
+        hook.submit.assert_called_once_with("test.jar")
+
 
 class TestSparkSubmitOperatorK8sTracking:
     def setup_method(self):
@@ -746,6 +892,7 @@ class TestSparkSubmitOperatorK8sTracking:
         hook = MagicMock()
         hook._should_track_driver_status = False
         hook._should_track_driver_via_k8s_api.return_value = True
+        hook._is_yarn_cluster_mode = False
         return hook
 
     def test_execute_calls_submit_then_poll_when_flag_set(self):


Reply via email to