This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f3ca1c27abd Crash recovery for YARN cluster mode in
SparkSubmitOperator built on AIP-103 (#67473)
f3ca1c27abd is described below
commit f3ca1c27abd3344750b3e3baa21bd4bc07e60b79
Author: Amogh Desai <[email protected]>
AuthorDate: Wed Jun 10 10:51:21 2026 +0530
Crash recovery for YARN cluster mode in SparkSubmitOperator built on
AIP-103 (#67473)
---
docs/spelling_wordlist.txt | 1 +
.../providers/apache/spark/hooks/spark_submit.py | 65 ++++++--
.../apache/spark/operators/spark_submit.py | 64 ++++++--
.../unit/apache/spark/hooks/test_spark_submit.py | 132 ++++++++---------
.../apache/spark/operators/test_spark_submit.py | 163 ++++++++++++++++++++-
5 files changed, 321 insertions(+), 104 deletions(-)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index dac67989526..b0c7bea4c81 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -646,6 +646,7 @@ Filesystem
filesystem
filesystems
filetype
+finalStatus
fips
firebase
Firehose
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
index 7cf1f3248ad..3a19950696a 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -306,6 +306,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
self._deploy_mode = deploy_mode
self._connection = self._resolve_connection()
self._is_yarn = "yarn" in self._connection["master"]
+ self._is_yarn_cluster_mode = self._is_yarn and
self._connection["deploy_mode"] == "cluster"
self._is_kubernetes = "k8s" in self._connection["master"]
if self._is_kubernetes and kube_client is None:
raise RuntimeError(
@@ -786,12 +787,8 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
)
if self._should_track_yarn_application_via_rm_api():
- # Once spark-submit exits successfully, rely on RM REST API
polling instead
- # of requiring a particular Spark log line such as "Submitted
application ...".
- # The RM REST API is the authoritative source for the
application's lifecycle.
if not self._yarn_application_id:
raise RuntimeError("No YARN application id found after
spark-submit completed.")
- self._track_yarn_application(self._yarn_application_id)
return self._driver_id
if self._should_track_driver_status and self._driver_id is None:
@@ -871,18 +868,42 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
self.log.info(line)
- def _track_yarn_application(self, application_id: str) -> None:
- """Poll the YARN RM REST API until the application reaches a terminal
state."""
+ def _start_yarn_application_status_tracking(self, application_id: str) ->
None:
+ """
+ Poll the YARN ResourceManager REST API until the application reaches a
terminal state.
+
+ Raises ``RuntimeError`` if the application fails or if too many
consecutive RM
+ request failures occur.
+
+ Possible statuses (from YARN state + finalStatus):
+
+ NEW
+ Application has been created but not yet submitted to the scheduler
+ NEW_SAVING
+ Application metadata is being persisted before scheduling
+ SUBMITTED
+ Application has been submitted and is waiting to be scheduled
+ ACCEPTED
+ Application has been accepted by the scheduler and is queued
+ RUNNING
+ Application is actively executing on the cluster
+ SUCCEEDED
+ Application completed successfully (state=FINISHED,
finalStatus=SUCCEEDED)
+ FAILED
+ Application terminated unsuccessfully — covers YARN state FAILED,
KILLED,
+ or FINISHED with a non-SUCCEEDED finalStatus
+ """
self.log.info(
"Tracking YARN application %s via ResourceManager REST API
polling",
application_id,
)
poll_interval = max(self._status_poll_interval, 10)
- # Tolerate transient RM REST API failures (RM hiccup, network blip,
request
- # timeout) the same way `_start_driver_status_tracking` does for spark
- # standalone — only give up after this many consecutive failures.
consecutive_failures = 0
max_consecutive_failures = 10
+ heartbeat_interval = 10
+ poll_count = 0
+ last_state: str | None = None
+
while True:
self.log.debug("Polling YARN RM REST API for application %s",
application_id)
try:
@@ -904,13 +925,20 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
time.sleep(poll_interval)
continue
consecutive_failures = 0
+ poll_count += 1
+
+ if state != last_state:
+ self.log.info("YARN application %s status: %s",
application_id, state)
+ last_state = state
+ elif poll_count % heartbeat_interval == 0:
+ self.log.info("YARN application %s is still %s",
application_id, state)
+
if state in self._YARN_FINAL_FAILURES:
raise RuntimeError(
f"YARN application {application_id} ended with state:
{state}, "
f"final status: {final_status}"
)
if final_status == self._YARN_FINAL_SUCCESS:
- self.log.info("YARN application %s finished with SUCCEEDED",
application_id)
return
if final_status in self._YARN_FINAL_FAILURES:
raise RuntimeError(
@@ -1298,3 +1326,20 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
self._kill_yarn_application(self._yarn_application_id)
self._run_post_submit_commands()
+
+ def query_yarn_application_status(self, application_id: str) -> str:
+ """
+ Return a normalized single string status for the ResumableJobMixin
interface.
+
+ - Active states (NEW, NEW_SAVING, SUBMITTED, ACCEPTED, RUNNING) are
returned as-is.
+ - Terminal states are collapsed to "SUCCEEDED" or "FAILED" with the
following rules:
+ - FINISHED + finalStatus SUCCEEDED -> "SUCCEEDED"
+ - FINISHED + any other finalStatus -> "FAILED"
+ - FAILED or KILLED -> "FAILED"
+ """
+ state, final_status =
self._query_yarn_application_status(application_id)
+ if state in {"NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED", "RUNNING"}:
+ return state
+ if state == "FINISHED" and final_status == self._YARN_FINAL_SUCCESS:
+ return "SUCCEEDED"
+ return "FAILED"
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index 902d96a225b..ac9b550409f 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -274,11 +274,39 @@ class SparkSubmitOperator(ResumableJobMixin,
BaseOperator):
hook.submit(self.application)
hook._poll_k8s_driver_via_api()
return
+ if hook._is_yarn_cluster_mode:
+ if self.reconnect_on_retry and not hook._yarn_track_via_rm_api:
+ raise ValueError(
+ "YARN cluster mode with reconnect_on_retry=True requires
yarn_track_via_rm_api=True. "
+ "The RM REST API is needed to check application status on
retry."
+ )
+ if hook._yarn_track_via_rm_api:
+ hook._validate_yarn_track_via_rm_api_config()
+ if self.reconnect_on_retry:
+ return self.execute_resumable(context)
+ # reconnect_on_retry=False: still submit-and-poll, just skip
task_state persistence.
+ driver_id = self.submit_job(context)
+ self.poll_until_complete(driver_id, context)
+ return self.get_job_result(driver_id, context)
hook.submit(self.application)
def submit_job(self, context: Context) -> str:
if self._hook is None:
self._hook = self._get_hook()
+ if self._hook._is_yarn_cluster_mode:
+ if self._hook._conf.get("spark.yarn.submit.waitAppCompletion",
"").strip().lower() == "true":
+ raise ValueError(
+ "spark.yarn.submit.waitAppCompletion=true cannot be set
for cluster mode as it conflicts"
+ "with the need to exit spark-submit immediately to persist
the application ID for tracking. "
+ "Either remove the explicit conf or set
reconnect_on_retry=False."
+ )
+ self._hook._conf["spark.yarn.submit.waitAppCompletion"] = "false"
+ self._hook.submit(self.application)
+ app_id = self._hook._yarn_application_id
+ if not app_id:
+ raise RuntimeError("spark-submit did not produce a YARN
application ID")
+ self.log.info("YARN application submitted: %s", app_id)
+ return app_id
driver_id = self._hook.submit(self.application)
if not driver_id:
raise RuntimeError("spark-submit did not return a driver ID")
@@ -290,15 +318,13 @@ class SparkSubmitOperator(ResumableJobMixin,
BaseOperator):
external_id = cast("str", external_id)
if self._hook is None:
self._hook = self._get_hook()
- # The YARN and K8s branches below (and in is_job_active,
is_job_succeeded, poll_until_complete)
- # are currently unreachable: execute_resumable is only called when
_should_track_driver_status
- # is True, which requires spark:// + cluster mode. They are
scaffolding for a follow-up PR
- # that extends ResumableJobMixin support to YARN and Kubernetes.
- if self._hook._is_yarn:
- # TODO: call YARN ResourceManager REST API
- # GET http://rm:8088/ws/v1/cluster/apps/{external_id}
- raise NotImplementedError("YARN job status not yet implemented")
+ if self._hook._is_yarn_cluster_mode:
+ return self._hook.query_yarn_application_status(external_id)
if self._hook._is_kubernetes:
+ # The K8s branches below (and in is_job_active, is_job_succeeded,
poll_until_complete)
+ # are currently unreachable: execute_resumable is only called when
_should_track_driver_status
+ # is True, which requires spark:// + cluster mode. They are
scaffolding for a follow-up PR
+ # that extends ResumableJobMixin support to Kubernetes.
# TODO: call K8s pod status API
raise NotImplementedError("K8s job status not yet implemented")
scheme = self._hook._connection.get("rest_scheme", "http")
@@ -338,9 +364,9 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
if self._hook is None:
self._hook = self._get_hook()
status = status.upper()
- if self._hook._is_yarn:
+ if self._hook._is_yarn_cluster_mode:
#
https://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/ResourceManagerRest.html
- return status in ("NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED",
"RUNNING")
+ return status in {"NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED",
"RUNNING"}
if self._hook._is_kubernetes:
return status in ("PENDING", "RUNNING")
# RELAUNCHING: driver is being restarted after a failure, still alive.
@@ -352,6 +378,8 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
if self._hook is None:
self._hook = self._get_hook()
status = status.upper()
+ if self._hook._is_yarn_cluster_mode:
+ return status == "SUCCEEDED"
if self._hook._is_kubernetes:
return status == "SUCCEEDED"
# standalone and YARN both use FINISHED
@@ -362,9 +390,12 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
external_id = cast("str", external_id)
if self._hook is None:
self._hook = self._get_hook()
- if self._hook._is_yarn:
- # TODO: poll YARN ResourceManager until app reaches terminal state
- raise NotImplementedError("YARN poll not yet implemented")
+ if self._hook._is_yarn_cluster_mode:
+ try:
+ self._hook._start_yarn_application_status_tracking(external_id)
+ finally:
+ self._hook._run_post_submit_commands()
+ return
if self._hook._is_kubernetes:
# TODO: poll K8s pod phase until terminal
raise NotImplementedError("K8s poll not yet implemented")
@@ -384,7 +415,12 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
def on_kill(self) -> None:
if self._hook is None:
self._hook = self._get_hook()
- self._hook.on_kill()
+ if self._hook._is_yarn_cluster_mode and
self._hook._yarn_application_id:
+ # spark-submit has already exited (waitAppCompletion=false), so
the hook's
+ # CLI-based kill has nothing to terminate. Kill the YARN app via
REST API instead.
+ self._hook._kill_yarn_application(self._hook._yarn_application_id)
+ else:
+ self._hook.on_kill()
def _get_hook(self) -> SparkSubmitHook:
return SparkSubmitHook(
diff --git
a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
index f4a610a9408..90f923905fa 100644
--- a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
+++ b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
@@ -1456,6 +1456,35 @@ class TestSparkSubmitHook:
hook = SparkSubmitHook(conn_id="")
assert hook._post_submit_commands == []
+ @pytest.mark.parametrize(
+ ("state", "final_status", "expected"),
+ [
+ ("NEW", "UNDEFINED", "NEW"),
+ ("NEW_SAVING", "UNDEFINED", "NEW_SAVING"),
+ ("SUBMITTED", "UNDEFINED", "SUBMITTED"),
+ ("ACCEPTED", "UNDEFINED", "ACCEPTED"),
+ ("RUNNING", "UNDEFINED", "RUNNING"),
+ ("FINISHED", "SUCCEEDED", "SUCCEEDED"),
+ ("FINISHED", "FAILED", "FAILED"),
+ ("FINISHED", "KILLED", "FAILED"),
+ ("FINISHED", "UNDEFINED", "FAILED"),
+ ("FAILED", "FAILED", "FAILED"),
+ ("FAILED", "KILLED", "FAILED"),
+ ("KILLED", "KILLED", "FAILED"),
+ ],
+ )
+ def test_query_yarn_application_status_state_mapping(self, state,
final_status, expected):
+ hook = SparkSubmitHook(conn_id="")
+ mock_response = MagicMock(spec=requests.Response)
+ mock_response.status_code = 200
+ mock_response.json.return_value = {"app": {"state": state,
"finalStatus": final_status}}
+
+ with patch(
+ "airflow.providers.apache.spark.hooks.spark_submit.requests.get",
return_value=mock_response
+ ):
+ with patch.object(hook, "_get_yarn_rm_base_url",
return_value="http://rm.example.com:8088"):
+ assert
hook.query_yarn_application_status("application_1234_0001") == expected
+
@pytest.mark.parametrize(
("conn_id", "flag", "expected"),
[
@@ -1686,27 +1715,35 @@ class TestSparkSubmitHook:
with pytest.raises(ValueError, match="requires
`deploy_mode='cluster'`"):
hook.submit()
- @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
@patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
- def test_yarn_status_tracking_succeeds(self, mock_popen, mock_get,
mock_sleep):
- """RM returns UNDEFINED then SUCCEEDED -> hook returns normally."""
+ def test_yarn_submit_does_not_poll_rm_api(self, mock_popen, mock_get):
+ """hook.submit() with yarn_track_via_rm_api=True must NOT poll RM REST
API."""
proc = MagicMock(spec=["stdout", "terminate", "wait"])
proc.stdout = iter(self._YARN_LOG_LINES)
proc.wait.return_value = 0
mock_popen.return_value = proc
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ hook.submit()
+
+ spark_submit_cmd = mock_popen.call_args.args[0]
+ assert "spark.yarn.submit.waitAppCompletion=false" in spark_submit_cmd
+ mock_get.assert_not_called()
+ assert hook._yarn_application_id == self._RM_APP_ID
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+ def test_yarn_status_tracking_succeeds(self, mock_get, mock_sleep):
+ """RM returns UNDEFINED then SUCCEEDED ->
_start_yarn_application_status_tracking returns normally."""
mock_get.side_effect = [
self._rm_status_resp("UNDEFINED", state="RUNNING"),
self._rm_status_resp("SUCCEEDED"),
]
hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
- hook.submit()
+ hook._start_yarn_application_status_tracking(self._RM_APP_ID)
- spark_submit_cmd = mock_popen.call_args.args[0]
- assert "spark.yarn.submit.waitAppCompletion=false" in spark_submit_cmd
- proc.terminate.assert_not_called()
assert mock_get.call_count == 2
mock_sleep.assert_called_once_with(10)
for call_obj in mock_get.call_args_list:
@@ -1714,65 +1751,43 @@ class TestSparkSubmitHook:
@patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
@patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
-
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
- def test_yarn_status_tracking_fails_on_killed(self, mock_popen, mock_get,
mock_sleep):
+ def test_yarn_status_tracking_fails_on_killed(self, mock_get, mock_sleep):
"""RM returns KILLED -> raise with message containing app id and
KILLED."""
- proc = MagicMock(spec=["stdout", "terminate", "wait"])
- proc.stdout = iter(self._YARN_LOG_LINES)
- proc.wait.return_value = 0
- mock_popen.return_value = proc
-
mock_get.return_value = self._rm_status_resp("KILLED")
hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
with pytest.raises(RuntimeError, match=f"{self._RM_APP_ID}.*KILLED"):
- hook.submit()
- proc.terminate.assert_not_called()
+ hook._start_yarn_application_status_tracking(self._RM_APP_ID)
@patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
@patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
-
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
def
test_yarn_status_tracking_fails_on_failed_state_with_undefined_final_status(
- self, mock_popen, mock_get, mock_sleep
+ self, mock_get, mock_sleep
):
"""RM state FAILED with finalStatus UNDEFINED should not poll
forever."""
- proc = MagicMock(spec=["stdout", "terminate", "wait"])
- proc.stdout = iter(self._YARN_LOG_LINES)
- proc.wait.return_value = 0
- mock_popen.return_value = proc
-
mock_get.return_value = self._rm_status_resp("UNDEFINED",
state="FAILED")
hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
with pytest.raises(RuntimeError, match=f"{self._RM_APP_ID}.*state:
FAILED"):
- hook.submit()
+ hook._start_yarn_application_status_tracking(self._RM_APP_ID)
- proc.terminate.assert_not_called()
mock_sleep.assert_not_called()
@patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
@patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
-
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
- def test_yarn_status_tracking_fails_on_unexpected_final_status(self,
mock_popen, mock_get, mock_sleep):
+ def test_yarn_status_tracking_fails_on_unexpected_final_status(self,
mock_get, mock_sleep):
"""RM returns a non-standard finalStatus ('BOGUS') -> raise without
sleeping."""
- proc = MagicMock(spec=["stdout", "terminate", "wait"])
- proc.stdout = iter(self._YARN_LOG_LINES)
- proc.wait.return_value = 0
- mock_popen.return_value = proc
-
mock_get.return_value = self._rm_status_resp("BOGUS")
hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
with pytest.raises(RuntimeError, match="unexpected final status:
BOGUS"):
- hook.submit()
+ hook._start_yarn_application_status_tracking(self._RM_APP_ID)
- proc.terminate.assert_not_called()
mock_sleep.assert_not_called()
- @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
- def
test_yarn_status_tracking_polls_without_application_submission_log(self,
mock_popen, mock_get):
- """Missing 'Submitted application' log line should not block RM REST
polling."""
+ def
test_yarn_submit_captures_app_id_without_submitted_application_log(self,
mock_popen):
+ """App ID parsed from log lines other than 'Submitted application ...'
is captured by hook.submit()."""
yarn_log_lines = [
"INFO Client: Uploading resource file:/tmp/lib.zip -> "
"hdfs://namenode:8020/user/root/.sparkStaging/application_1700000000000_0001/lib.zip",
@@ -1782,14 +1797,11 @@ class TestSparkSubmitHook:
proc.stdout = iter(yarn_log_lines)
proc.wait.return_value = 0
mock_popen.return_value = proc
- mock_get.return_value = self._rm_status_resp("SUCCEEDED")
hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
hook.submit()
assert hook._yarn_application_id == self._RM_APP_ID
- assert mock_get.call_args.args[0] == self._rm_status_url()
- proc.terminate.assert_not_called()
@patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
@@ -1820,14 +1832,8 @@ class TestSparkSubmitHook:
@patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
@patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
-
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
- def test_yarn_status_tracking_tolerates_transient_failures(self,
mock_popen, mock_get, mock_sleep):
+ def test_yarn_status_tracking_tolerates_transient_failures(self, mock_get,
mock_sleep):
"""3 consecutive 5xx responses then SUCCEEDED -> normal completion."""
- proc = MagicMock(spec=["stdout", "terminate", "wait"])
- proc.stdout = iter(self._YARN_LOG_LINES)
- proc.wait.return_value = 0
- mock_popen.return_value = proc
-
# 3 transient failures (within the 10-failure budget), then SUCCEEDED.
mock_get.side_effect = [
self._rm_failure_resp(503, "Service Unavailable"),
@@ -1837,27 +1843,21 @@ class TestSparkSubmitHook:
]
hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
- hook.submit()
+ hook._start_yarn_application_status_tracking(self._RM_APP_ID)
assert mock_get.call_count == 4
@patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
@patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
-
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
- def test_yarn_status_tracking_tolerates_status_timeouts(self, mock_popen,
mock_get, mock_sleep):
+ def test_yarn_status_tracking_tolerates_status_timeouts(self, mock_get,
mock_sleep):
"""First requests.exceptions.Timeout, second call succeeds -> normal
completion."""
- proc = MagicMock(spec=["stdout", "terminate", "wait"])
- proc.stdout = iter(self._YARN_LOG_LINES)
- proc.wait.return_value = 0
- mock_popen.return_value = proc
-
mock_get.side_effect = [
requests.exceptions.Timeout("read timed out"),
self._rm_status_resp("SUCCEEDED"),
]
hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
- hook.submit()
+ hook._start_yarn_application_status_tracking(self._RM_APP_ID)
assert mock_get.call_count == 2
# All calls must include the (connect, read) timeout tuple.
@@ -1866,20 +1866,14 @@ class TestSparkSubmitHook:
@patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
@patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
-
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
- def test_yarn_status_tracking_raises_after_too_many_failures(self,
mock_popen, mock_get, mock_sleep):
+ def test_yarn_status_tracking_raises_after_too_many_failures(self,
mock_get, mock_sleep):
"""11 consecutive 5xx responses -> raise 'Giving up tracking YARN
application'."""
- proc = MagicMock(spec=["stdout", "terminate", "wait"])
- proc.stdout = iter(self._YARN_LOG_LINES)
- proc.wait.return_value = 0
- mock_popen.return_value = proc
-
# 11 failures: 10 tolerated; the 11th trips the budget.
mock_get.side_effect = [self._rm_failure_resp() for _ in range(11)]
hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
with pytest.raises(RuntimeError, match="Giving up tracking YARN
application"):
- hook.submit()
+ hook._start_yarn_application_status_tracking(self._RM_APP_ID)
assert mock_get.call_count == 11
@@ -1987,18 +1981,12 @@ class TestSparkSubmitHook:
@patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
@patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
-
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
- def test_yarn_rm_base_url_is_resolved_once_across_polling_loop(self,
mock_popen, mock_get, mock_sleep):
+ def test_yarn_rm_base_url_is_resolved_once_across_polling_loop(self,
mock_get, mock_sleep):
"""Connection lookup must run once even if the polling loop runs many
iterations.
Regression guard: a job polling every few seconds for hours must not
re-fetch
the Spark connection (and potentially re-hit a Secrets Backend) on
every iteration.
"""
- proc = MagicMock(spec=["stdout", "terminate", "wait"])
- proc.stdout = iter(self._YARN_LOG_LINES)
- proc.wait.return_value = 0
- mock_popen.return_value = proc
-
# 4 UNDEFINED iterations then SUCCEEDED -> 5 polling iterations total.
mock_get.side_effect = [
self._rm_status_resp("UNDEFINED", state="RUNNING"),
@@ -2010,7 +1998,7 @@ class TestSparkSubmitHook:
hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
with patch.object(hook, "get_connection", wraps=hook.get_connection)
as spy_get_conn:
- hook.submit()
+ hook._start_yarn_application_status_tracking(self._RM_APP_ID)
assert mock_get.call_count == 5
assert spy_get_conn.call_count == 1
diff --git
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
index 95ad9f5142a..56b2ad3a409 100644
---
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
+++
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
@@ -495,7 +495,7 @@ class FakeTaskState:
@pytest.mark.skipif(
not AIRFLOW_V_3_3_PLUS,
- reason="ResumableJobMixin reconnect requires task_state, available in
Airflow 3.3+",
+ reason="ResumableJobMixin reconnect requires task_store, available in
Airflow 3.3+",
)
class TestSparkSubmitOperatorResumable:
def setup_method(self):
@@ -505,10 +505,12 @@ class TestSparkSubmitOperatorResumable:
def _make_operator(self, **kwargs):
return SparkSubmitOperator(task_id="test", dag=self.dag,
application="test.jar", **kwargs)
- def _make_hook(self, should_track=False, is_yarn=False,
is_kubernetes=False):
+ def _make_hook(self, should_track=False, is_yarn=False,
is_yarn_cluster=False, is_kubernetes=False):
hook = MagicMock()
hook._should_track_driver_status = should_track
+ hook._should_track_driver_via_k8s_api.return_value = False
hook._is_yarn = is_yarn
+ hook._is_yarn_cluster_mode = is_yarn_cluster
hook._is_kubernetes = is_kubernetes
hook._connection = {"master": "spark://localhost:7077"}
return hook
@@ -598,26 +600,35 @@ class TestSparkSubmitOperatorResumable:
assert polled == ["driver-new"]
@pytest.mark.parametrize(
- ("is_yarn", "is_kubernetes", "status", "expected_active",
"expected_succeeded"),
+ ("is_yarn_cluster", "is_kubernetes", "status", "expected_active",
"expected_succeeded"),
[
+ # Spark standalone cluster mode
(False, False, "RUNNING", True, False),
(False, False, "SUBMITTED", True, False),
+ (False, False, "RELAUNCHING", True, False),
+ (False, False, "UNKNOWN", True, False),
(False, False, "FINISHED", False, True),
(False, False, "FAILED", False, False),
- (True, False, "RUNNING", True, False),
- (True, False, "ACCEPTED", True, False),
+ # YARN cluster mode — synthesized statuses from
query_yarn_application_status
(True, False, "NEW", True, False),
- (True, False, "FINISHED", False, True),
+ (True, False, "NEW_SAVING", True, False),
+ (True, False, "SUBMITTED", True, False),
+ (True, False, "ACCEPTED", True, False),
+ (True, False, "RUNNING", True, False),
+ (True, False, "SUCCEEDED", False, True),
(True, False, "FAILED", False, False),
+ # Kubernetes
(False, True, "Running", True, False),
(False, True, "Pending", True, False),
(False, True, "Succeeded", False, True),
(False, True, "Failed", False, False),
],
)
- def test_job_status_mappings(self, is_yarn, is_kubernetes, status,
expected_active, expected_succeeded):
+ def test_job_status_mappings(
+ self, is_yarn_cluster, is_kubernetes, status, expected_active,
expected_succeeded
+ ):
operator = self._make_operator()
- operator._hook = self._make_hook(is_yarn=is_yarn,
is_kubernetes=is_kubernetes)
+ operator._hook = self._make_hook(is_yarn_cluster=is_yarn_cluster,
is_kubernetes=is_kubernetes)
assert operator.is_job_active(status) == expected_active
assert operator.is_job_succeeded(status) == expected_succeeded
@@ -715,6 +726,108 @@ class TestSparkSubmitOperatorResumable:
assert len(captured_urls) == 1
assert captured_urls[0].startswith("https://")
+ def test_yarn_first_run_persists_app_id_before_polling(self):
+ operator = self._make_operator()
+ operator._hook = self._make_hook(is_yarn_cluster=True)
+ operator._hook._conf = {}
+ operator._hook._yarn_application_id = "application_1234_0001"
+ operator._hook.submit.return_value = None
+
+ task_store = FakeTaskState()
+ persisted_before_poll = []
+
+ def track_poll(external_id, context):
+ persisted_before_poll.append(task_store.get("spark_job_id"))
+
+ operator.poll_until_complete = track_poll
+ operator.execute(context={"task_store": task_store})
+
+ assert persisted_before_poll == ["application_1234_0001"]
+
+ def test_yarn_retry_reconnects_to_running_app(self):
+ operator = self._make_operator()
+ operator._hook = self._make_hook(is_yarn_cluster=True)
+ task_store = FakeTaskState({"spark_job_id": "application_1234_0001"})
+
+ operator.get_job_status = lambda external_id, context: "RUNNING"
+ polled = []
+ operator.poll_until_complete = lambda external_id, context:
polled.append(external_id)
+
+ operator.execute(context={"task_store": task_store})
+
+ operator._hook.submit.assert_not_called()
+ assert polled == ["application_1234_0001"]
+
+ def test_yarn_retry_skips_already_succeeded_app(self):
+ operator = self._make_operator()
+ operator._hook = self._make_hook(is_yarn_cluster=True)
+ task_store = FakeTaskState({"spark_job_id": "application_1234_0001"})
+
+ operator.get_job_status = lambda external_id, context: "SUCCEEDED"
+
+ operator.execute(context={"task_store": task_store})
+
+ operator._hook.submit.assert_not_called()
+
+ def test_yarn_retry_resubmits_after_failed_app(self):
+ operator = self._make_operator()
+ operator._hook = self._make_hook(is_yarn_cluster=True)
+ operator._hook._conf = {}
+ operator._hook._yarn_application_id = "application_1234_0002"
+ operator._hook.submit.return_value = None
+ task_store = FakeTaskState({"spark_job_id": "application_1234_0001"})
+
+ operator.get_job_status = lambda external_id, context: "FAILED"
+ polled = []
+ operator.poll_until_complete = lambda external_id, context:
polled.append(external_id)
+
+ operator.execute(context={"task_store": task_store})
+
+ operator._hook.submit.assert_called_once_with("test.jar")
+ assert polled == ["application_1234_0002"]
+
+ def test_yarn_injects_wait_app_completion_false(self):
+ operator = self._make_operator()
+ hook = self._make_hook(is_yarn_cluster=True)
+ hook._conf = {}
+ hook._yarn_application_id = "application_1234_0001"
+ hook.submit.return_value = None
+ operator._hook = hook
+
+ operator.submit_job(context={})
+
+ assert hook._conf.get("spark.yarn.submit.waitAppCompletion") == "false"
+
+ def test_yarn_raises_if_wait_app_completion_true(self):
+ operator = self._make_operator()
+ hook = self._make_hook(is_yarn_cluster=True)
+ hook._conf = {"spark.yarn.submit.waitAppCompletion": "true"}
+ operator._hook = hook
+
+ with pytest.raises(ValueError, match="waitAppCompletion=true"):
+ operator.submit_job(context={})
+
+ def test_yarn_poll_tolerates_transient_resourcemanager_failures(self):
+ operator = self._make_operator()
+ operator._hook = self._make_hook(is_yarn_cluster=True)
+ operator._hook._status_poll_interval = 0
+
+ call_count = 0
+
+ def flaky_status(external_id):
+ nonlocal call_count
+ call_count += 1
+ if call_count < 5:
+ raise RuntimeError("RM temporarily unavailable")
+ return "SUCCEEDED"
+
+ operator.get_job_status = flaky_status
+
+ with mock.patch("time.sleep"):
+ operator.poll_until_complete("application_1234_0001", context={})
+
+ operator._hook._run_post_submit_commands.assert_called_once()
+
def test_poll_until_complete_runs_post_submit_on_failure(self):
"""post_submit_commands must run even when the driver exits with a
failure status."""
operator = self._make_operator()
@@ -733,6 +846,39 @@ class TestSparkSubmitOperatorResumable:
with pytest.raises(RuntimeError, match="FAILED"):
operator.poll_until_complete("driver-001", {})
+ def test_on_kill_sends_authenticated_kill_to_yarn_rm(self):
+ """operator.on_kill() must call _kill_yarn_application so Kerberos
auth is applied."""
+ operator = self._make_operator()
+ hook = self._make_hook(is_yarn_cluster=True)
+ hook._is_yarn_cluster_mode = True
+ hook._yarn_application_id = "application_1234_0001"
+ operator._hook = hook
+
+ operator.on_kill()
+
+
hook._kill_yarn_application.assert_called_once_with("application_1234_0001")
+
+ def test_yarn_cluster_reconnect_without_rm_api_raises(self):
+ """reconnect_on_retry=True + yarn_track_via_rm_api=False must raise -
RM API is required for resume."""
+ operator = self._make_operator(reconnect_on_retry=True)
+ hook = self._make_hook(is_yarn_cluster=True)
+ hook._yarn_track_via_rm_api = False
+ operator._hook = hook
+
+ with pytest.raises(ValueError, match="yarn_track_via_rm_api=True"):
+ operator.execute(context={})
+
+ def
test_yarn_cluster_without_rm_api_reconnect_false_falls_through_to_hook_submit(self):
+ """reconnect_on_retry=False + yarn_track_via_rm_api=False falls
through to hook.submit() - no RM polling."""
+ operator = self._make_operator(reconnect_on_retry=False)
+ hook = self._make_hook(is_yarn_cluster=True)
+ hook._yarn_track_via_rm_api = False
+ operator._hook = hook
+
+ operator.execute(context={})
+
+ hook.submit.assert_called_once_with("test.jar")
+
class TestSparkSubmitOperatorK8sTracking:
def setup_method(self):
@@ -746,6 +892,7 @@ class TestSparkSubmitOperatorK8sTracking:
hook = MagicMock()
hook._should_track_driver_status = False
hook._should_track_driver_via_k8s_api.return_value = True
+ hook._is_yarn_cluster_mode = False
return hook
def test_execute_calls_submit_then_poll_when_flag_set(self):