amoghrajesh commented on code in PR #65991:
URL: https://github.com/apache/airflow/pull/65991#discussion_r3322423890
##########
providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py:
##########
@@ -268,6 +324,22 @@ def _resolve_should_track_driver_status(self) -> bool:
"""
return "spark://" in self._connection["master"] and
self._connection["deploy_mode"] == "cluster"
+ def _should_track_yarn_application_via_rm_api(self) -> bool:
+ """Return whether this submit should switch to YARN RM REST API
polling."""
+ return self._yarn_track_via_rm_api and self._is_yarn and
self._connection["deploy_mode"] == "cluster"
+
+ def _validate_yarn_track_via_rm_api_config(self) -> None:
+ """Validate that YARN RM REST API tracking can run for this submit."""
+ if not self._yarn_track_via_rm_api:
+ return
+ if not self._is_yarn:
+ raise ValueError("`yarn_track_via_rm_api=True` requires Spark
master to be YARN.")
+ if self._connection["deploy_mode"] != "cluster":
+ raise ValueError(
+ "`yarn_track_via_rm_api=True` requires
`deploy_mode='cluster'`; "
+ f"got {self._connection['deploy_mode']!r}."
+ )
Review Comment:
Just realised one thing: `_validate_yarn_track_via_rm_api_config` only
checks master and deploy mode, but not `yarn_resourcemanager_webapp_address`.
That check happens lazily on the first pol, ie: after spark-submit has already
submitted the job. If the address is missing, the task fails with a
`ValueError` but the YARN app is still running on the cluster as an orphaned
application, and `on_kill` can't kill it because it also needs the URL. Should
call `_get_yarn_rm_base_url()` here as an eager validation so the whole config
is checked before spark-submit fires.
##########
providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py:
##########
@@ -1366,3 +1397,378 @@ def
test_post_submit_commands_none_gives_empty_list(self):
"""Test that None post_submit_commands results in an empty list."""
hook = SparkSubmitHook(conn_id="")
assert hook._post_submit_commands == []
+
+ _YARN_LOG_LINES = [
+ "INFO Client: Requesting a new application from cluster with 1
NodeManagers",
+ "INFO Client: Uploading resource file:/tmp/lib.zip -> "
+
"hdfs://namenode:8020/user/root/.sparkStaging/application_1700000000000_0001/lib.zip",
+ "INFO Client: Submitting application application_1700000000000_0001 to
ResourceManager",
+ "INFO YarnClientImpl: Submitted application
application_1700000000000_0001",
+ "INFO Client: Application report for application_1700000000000_0001
(state: ACCEPTED)",
+ "INFO Client: Application report for application_1700000000000_0001
(state: RUNNING)",
+ "INFO Client: Application report for application_1700000000000_0001
(state: FINISHED)",
+ "INFO Client: final status: SUCCEEDED",
+ ]
+
+ _RM_BASE_URL = "http://rm.test:8088"
+ _RM_APP_ID = "application_1700000000000_0001"
+
+ @classmethod
+ def _rm_status_url(cls, app_id: str | None = None) -> str:
+ return f"{cls._RM_BASE_URL}/ws/v1/cluster/apps/{app_id or
cls._RM_APP_ID}"
+
+ @classmethod
+ def _rm_kill_url(cls, app_id: str | None = None) -> str:
+ return f"{cls._RM_BASE_URL}/ws/v1/cluster/apps/{app_id or
cls._RM_APP_ID}/state"
+
+ @classmethod
+ def _rm_status_resp(cls, final_status: str, state: str = "FINISHED") ->
MagicMock:
+ resp = MagicMock(spec=requests.Response)
+ resp.status_code = 200
+ resp.json.return_value = {"app": {"id": cls._RM_APP_ID, "state":
state, "finalStatus": final_status}}
+ return resp
+
+ @staticmethod
+ def _rm_failure_resp(status_code: int = 500, text: str = "Internal Server
Error") -> MagicMock:
+ resp = MagicMock(spec=requests.Response)
+ resp.status_code = status_code
+ resp.text = text
+ return resp
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.put")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def test_default_keeps_existing_behavior_in_yarn_cluster(self, mock_popen,
mock_get, mock_put):
+ """Flag default False -> no HTTP calls; behavior identical to today."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_cluster")
+ hook.submit()
+
+ proc.terminate.assert_not_called()
+ mock_get.assert_not_called()
+ mock_put.assert_not_called()
+ assert hook._yarn_application_id == "application_1700000000000_0001"
+
+ def test_yarn_status_tracking_requires_yarn_master(self):
+ """yarn_track_via_rm_api=True should fail fast outside YARN."""
+ hook = SparkSubmitHook(conn_id="spark_k8s_cluster",
yarn_track_via_rm_api=True)
+
+ with pytest.raises(ValueError, match="requires Spark master to be
YARN"):
+ hook.submit()
+
+ def test_yarn_status_tracking_requires_cluster_deploy_mode(self):
+ """yarn_track_via_rm_api=True should fail fast outside cluster deploy
mode."""
+ hook = SparkSubmitHook(
+ conn_id="spark_yarn_rm",
+ deploy_mode="client",
+ yarn_track_via_rm_api=True,
+ )
+
+ with pytest.raises(ValueError, match="requires
`deploy_mode='cluster'`"):
+ hook.submit()
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def test_yarn_status_tracking_succeeds(self, mock_popen, mock_get,
mock_sleep):
+ """RM returns UNDEFINED then SUCCEEDED -> hook returns normally."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ mock_get.side_effect = [
+ self._rm_status_resp("UNDEFINED", state="RUNNING"),
+ self._rm_status_resp("SUCCEEDED"),
+ ]
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ hook.submit()
+
+ spark_submit_cmd = mock_popen.call_args.args[0]
+ assert "spark.yarn.submit.waitAppCompletion=false" in spark_submit_cmd
+ proc.terminate.assert_not_called()
+ assert mock_get.call_count == 2
+ mock_sleep.assert_called_once_with(10)
+ for call_obj in mock_get.call_args_list:
+ assert call_obj.args[0] == self._rm_status_url()
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def test_yarn_status_tracking_fails_on_killed(self, mock_popen, mock_get,
mock_sleep):
+ """RM returns KILLED -> raise with message containing app id and
KILLED."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ mock_get.return_value = self._rm_status_resp("KILLED")
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ with pytest.raises(RuntimeError, match=f"{self._RM_APP_ID}.*KILLED"):
+ hook.submit()
+ proc.terminate.assert_not_called()
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def
test_yarn_status_tracking_fails_on_failed_state_with_undefined_final_status(
+ self, mock_popen, mock_get, mock_sleep
+ ):
+ """RM state FAILED with finalStatus UNDEFINED should not poll
forever."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ mock_get.return_value = self._rm_status_resp("UNDEFINED",
state="FAILED")
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ with pytest.raises(RuntimeError, match=f"{self._RM_APP_ID}.*state:
FAILED"):
+ hook.submit()
+
+ proc.terminate.assert_not_called()
+ mock_sleep.assert_not_called()
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def test_yarn_status_tracking_fails_on_unexpected_final_status(self,
mock_popen, mock_get, mock_sleep):
+ """RM returns a non-standard finalStatus ('BOGUS') -> raise without
sleeping."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ mock_get.return_value = self._rm_status_resp("BOGUS")
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ with pytest.raises(RuntimeError, match="unexpected final status:
BOGUS"):
+ hook.submit()
+
+ proc.terminate.assert_not_called()
+ mock_sleep.assert_not_called()
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def
test_yarn_status_tracking_polls_without_application_submission_log(self,
mock_popen, mock_get):
+ """Missing 'Submitted application' log line should not block RM REST
polling."""
+ yarn_log_lines = [
+ "INFO Client: Uploading resource file:/tmp/lib.zip -> "
+
"hdfs://namenode:8020/user/root/.sparkStaging/application_1700000000000_0001/lib.zip",
+ "INFO Client: Submitting application
application_1700000000000_0001 to ResourceManager",
+ ]
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(yarn_log_lines)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+ mock_get.return_value = self._rm_status_resp("SUCCEEDED")
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ hook.submit()
+
+ assert hook._yarn_application_id == self._RM_APP_ID
+ assert mock_get.call_args.args[0] == self._rm_status_url()
+ proc.terminate.assert_not_called()
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def
test_yarn_status_tracking_checks_spark_submit_exit_code_before_polling(self,
mock_popen, mock_get):
+ """spark-submit exits non-zero -> raise BEFORE issuing any HTTP
request."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 1
+ mock_popen.return_value = proc
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ with pytest.raises(AirflowException, match="Error code is: 1"):
+ hook.submit()
+
+ proc.terminate.assert_not_called()
+ mock_get.assert_not_called()
+
+ def
test_yarn_status_tracking_rejects_conflicting_wait_app_completion_conf(self):
+ """User-set spark.yarn.submit.waitAppCompletion=true conflicts with
flag -> ValueError."""
+ hook = SparkSubmitHook(
+ conn_id="spark_yarn_rm",
+ conf={"spark.yarn.submit.waitAppCompletion": "true"},
+ yarn_track_via_rm_api=True,
+ )
+
+ with pytest.raises(ValueError,
match="spark.yarn.submit.waitAppCompletion=false"):
+ hook._build_spark_submit_command("")
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def test_yarn_status_tracking_tolerates_transient_failures(self,
mock_popen, mock_get, mock_sleep):
+ """3 consecutive 5xx responses then SUCCEEDED -> normal completion."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ # 3 transient failures (within the 10-failure budget), then SUCCEEDED.
+ mock_get.side_effect = [
+ self._rm_failure_resp(503, "Service Unavailable"),
+ self._rm_failure_resp(502, "Bad Gateway"),
+ self._rm_failure_resp(500, "Internal Server Error"),
+ self._rm_status_resp("SUCCEEDED"),
+ ]
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ hook.submit()
+
+ assert mock_get.call_count == 4
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def test_yarn_status_tracking_tolerates_status_timeouts(self, mock_popen,
mock_get, mock_sleep):
+ """First requests.exceptions.Timeout, second call succeeds -> normal
completion."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ mock_get.side_effect = [
+ requests.exceptions.Timeout("read timed out"),
+ self._rm_status_resp("SUCCEEDED"),
+ ]
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ hook.submit()
+
+ assert mock_get.call_count == 2
+ # All calls must include the (connect, read) timeout tuple.
+ for call_obj in mock_get.call_args_list:
+ assert call_obj.kwargs["timeout"] == (5, 30)
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def test_yarn_status_tracking_raises_after_too_many_failures(self,
mock_popen, mock_get, mock_sleep):
+ """11 consecutive 5xx responses -> raise 'Giving up tracking YARN
application'."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ # 11 failures: 10 tolerated; the 11th trips the budget.
+ mock_get.side_effect = [self._rm_failure_resp() for _ in range(11)]
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ with pytest.raises(RuntimeError, match="Giving up tracking YARN
application"):
+ hook.submit()
+
+ assert mock_get.call_count == 11
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+ def test_yarn_status_query_passes_provided_auth_to_requests(self,
mock_get):
+ """yarn_rm_auth=<sentinel AuthBase> -> requests.get called with
auth=<sentinel>."""
+
+ class _SentinelAuth(requests.auth.AuthBase):
+ def __call__(self, r):
+ return r
+
+ sentinel = _SentinelAuth()
+ mock_get.return_value = self._rm_status_resp("SUCCEEDED")
+
+ hook = SparkSubmitHook(
+ conn_id="spark_yarn_rm",
+ yarn_track_via_rm_api=True,
+ yarn_rm_auth=sentinel,
+ )
+ hook._query_yarn_application_status(self._RM_APP_ID)
+
+ assert mock_get.call_args.kwargs["auth"] is sentinel
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+ def test_yarn_status_query_sends_no_auth_by_default(self, mock_get):
+ """Without yarn_rm_auth -> requests.get called with auth=None."""
+ mock_get.return_value = self._rm_status_resp("SUCCEEDED")
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ hook._query_yarn_application_status(self._RM_APP_ID)
+
+ assert mock_get.call_args.kwargs["auth"] is None
Review Comment:
This test just verifies `auth=None` when you never passed auth. The
forwarding mechanism is already proven by
`test_yarn_status_query_passes_provided_auth_to_requests`. These two could be
one test with pytest.parameterize instead
##########
providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py:
##########
@@ -268,6 +324,22 @@ def _resolve_should_track_driver_status(self) -> bool:
"""
return "spark://" in self._connection["master"] and
self._connection["deploy_mode"] == "cluster"
+ def _should_track_yarn_application_via_rm_api(self) -> bool:
+ """Return whether this submit should switch to YARN RM REST API
polling."""
+ return self._yarn_track_via_rm_api and self._is_yarn and
self._connection["deploy_mode"] == "cluster"
+
+ def _validate_yarn_track_via_rm_api_config(self) -> None:
+ """Validate that YARN RM REST API tracking can run for this submit."""
+ if not self._yarn_track_via_rm_api:
+ return
+ if not self._is_yarn:
+ raise ValueError("`yarn_track_via_rm_api=True` requires Spark
master to be YARN.")
+ if self._connection["deploy_mode"] != "cluster":
+ raise ValueError(
+ "`yarn_track_via_rm_api=True` requires
`deploy_mode='cluster'`; "
+ f"got {self._connection['deploy_mode']!r}."
+ )
Review Comment:
And add a test too pls.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]