nailo2c commented on code in PR #65991:
URL: https://github.com/apache/airflow/pull/65991#discussion_r3326213993


##########
providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py:
##########
@@ -268,6 +324,22 @@ def _resolve_should_track_driver_status(self) -> bool:
         """
         return "spark://" in self._connection["master"] and 
self._connection["deploy_mode"] == "cluster"
 
+    def _should_track_yarn_application_via_rm_api(self) -> bool:
+        """Return whether this submit should switch to YARN RM REST API 
polling."""
+        return self._yarn_track_via_rm_api and self._is_yarn and 
self._connection["deploy_mode"] == "cluster"
+
+    def _validate_yarn_track_via_rm_api_config(self) -> None:
+        """Validate that YARN RM REST API tracking can run for this submit."""
+        if not self._yarn_track_via_rm_api:
+            return
+        if not self._is_yarn:
+            raise ValueError("`yarn_track_via_rm_api=True` requires Spark 
master to be YARN.")
+        if self._connection["deploy_mode"] != "cluster":
+            raise ValueError(
+                "`yarn_track_via_rm_api=True` requires 
`deploy_mode='cluster'`; "
+                f"got {self._connection['deploy_mode']!r}."
+            )

Review Comment:
   fixed



##########
providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py:
##########
@@ -1366,3 +1397,378 @@ def 
test_post_submit_commands_none_gives_empty_list(self):
         """Test that None post_submit_commands results in an empty list."""
         hook = SparkSubmitHook(conn_id="")
         assert hook._post_submit_commands == []
+
+    _YARN_LOG_LINES = [
+        "INFO Client: Requesting a new application from cluster with 1 
NodeManagers",
+        "INFO Client: Uploading resource file:/tmp/lib.zip -> "
+        
"hdfs://namenode:8020/user/root/.sparkStaging/application_1700000000000_0001/lib.zip",
+        "INFO Client: Submitting application application_1700000000000_0001 to 
ResourceManager",
+        "INFO YarnClientImpl: Submitted application 
application_1700000000000_0001",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: ACCEPTED)",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: RUNNING)",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: FINISHED)",
+        "INFO Client: final status: SUCCEEDED",
+    ]
+
+    _RM_BASE_URL = "http://rm.test:8088";
+    _RM_APP_ID = "application_1700000000000_0001"
+
+    @classmethod
+    def _rm_status_url(cls, app_id: str | None = None) -> str:
+        return f"{cls._RM_BASE_URL}/ws/v1/cluster/apps/{app_id or 
cls._RM_APP_ID}"
+
+    @classmethod
+    def _rm_kill_url(cls, app_id: str | None = None) -> str:
+        return f"{cls._RM_BASE_URL}/ws/v1/cluster/apps/{app_id or 
cls._RM_APP_ID}/state"
+
+    @classmethod
+    def _rm_status_resp(cls, final_status: str, state: str = "FINISHED") -> 
MagicMock:
+        resp = MagicMock(spec=requests.Response)
+        resp.status_code = 200
+        resp.json.return_value = {"app": {"id": cls._RM_APP_ID, "state": 
state, "finalStatus": final_status}}
+        return resp
+
+    @staticmethod
+    def _rm_failure_resp(status_code: int = 500, text: str = "Internal Server 
Error") -> MagicMock:
+        resp = MagicMock(spec=requests.Response)
+        resp.status_code = status_code
+        resp.text = text
+        return resp
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.put")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_default_keeps_existing_behavior_in_yarn_cluster(self, mock_popen, 
mock_get, mock_put):
+        """Flag default False -> no HTTP calls; behavior identical to today."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_cluster")
+        hook.submit()
+
+        proc.terminate.assert_not_called()
+        mock_get.assert_not_called()
+        mock_put.assert_not_called()
+        assert hook._yarn_application_id == "application_1700000000000_0001"
+
+    def test_yarn_status_tracking_requires_yarn_master(self):
+        """yarn_track_via_rm_api=True should fail fast outside YARN."""
+        hook = SparkSubmitHook(conn_id="spark_k8s_cluster", 
yarn_track_via_rm_api=True)
+
+        with pytest.raises(ValueError, match="requires Spark master to be 
YARN"):
+            hook.submit()
+
+    def test_yarn_status_tracking_requires_cluster_deploy_mode(self):
+        """yarn_track_via_rm_api=True should fail fast outside cluster deploy 
mode."""
+        hook = SparkSubmitHook(
+            conn_id="spark_yarn_rm",
+            deploy_mode="client",
+            yarn_track_via_rm_api=True,
+        )
+
+        with pytest.raises(ValueError, match="requires 
`deploy_mode='cluster'`"):
+            hook.submit()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_succeeds(self, mock_popen, mock_get, 
mock_sleep):
+        """RM returns UNDEFINED then SUCCEEDED -> hook returns normally."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.side_effect = [
+            self._rm_status_resp("UNDEFINED", state="RUNNING"),
+            self._rm_status_resp("SUCCEEDED"),
+        ]
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        hook.submit()
+
+        spark_submit_cmd = mock_popen.call_args.args[0]
+        assert "spark.yarn.submit.waitAppCompletion=false" in spark_submit_cmd
+        proc.terminate.assert_not_called()
+        assert mock_get.call_count == 2
+        mock_sleep.assert_called_once_with(10)
+        for call_obj in mock_get.call_args_list:
+            assert call_obj.args[0] == self._rm_status_url()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_fails_on_killed(self, mock_popen, mock_get, 
mock_sleep):
+        """RM returns KILLED -> raise with message containing app id and 
KILLED."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.return_value = self._rm_status_resp("KILLED")
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        with pytest.raises(RuntimeError, match=f"{self._RM_APP_ID}.*KILLED"):
+            hook.submit()
+        proc.terminate.assert_not_called()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def 
test_yarn_status_tracking_fails_on_failed_state_with_undefined_final_status(
+        self, mock_popen, mock_get, mock_sleep
+    ):
+        """RM state FAILED with finalStatus UNDEFINED should not poll 
forever."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.return_value = self._rm_status_resp("UNDEFINED", 
state="FAILED")
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        with pytest.raises(RuntimeError, match=f"{self._RM_APP_ID}.*state: 
FAILED"):
+            hook.submit()
+
+        proc.terminate.assert_not_called()
+        mock_sleep.assert_not_called()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_fails_on_unexpected_final_status(self, 
mock_popen, mock_get, mock_sleep):
+        """RM returns a non-standard finalStatus ('BOGUS') -> raise without 
sleeping."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.return_value = self._rm_status_resp("BOGUS")
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        with pytest.raises(RuntimeError, match="unexpected final status: 
BOGUS"):
+            hook.submit()
+
+        proc.terminate.assert_not_called()
+        mock_sleep.assert_not_called()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def 
test_yarn_status_tracking_polls_without_application_submission_log(self, 
mock_popen, mock_get):
+        """Missing 'Submitted application' log line should not block RM REST 
polling."""
+        yarn_log_lines = [
+            "INFO Client: Uploading resource file:/tmp/lib.zip -> "
+            
"hdfs://namenode:8020/user/root/.sparkStaging/application_1700000000000_0001/lib.zip",
+            "INFO Client: Submitting application 
application_1700000000000_0001 to ResourceManager",
+        ]
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(yarn_log_lines)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+        mock_get.return_value = self._rm_status_resp("SUCCEEDED")
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        hook.submit()
+
+        assert hook._yarn_application_id == self._RM_APP_ID
+        assert mock_get.call_args.args[0] == self._rm_status_url()
+        proc.terminate.assert_not_called()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def 
test_yarn_status_tracking_checks_spark_submit_exit_code_before_polling(self, 
mock_popen, mock_get):
+        """spark-submit exits non-zero -> raise BEFORE issuing any HTTP 
request."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 1
+        mock_popen.return_value = proc
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        with pytest.raises(AirflowException, match="Error code is: 1"):
+            hook.submit()
+
+        proc.terminate.assert_not_called()
+        mock_get.assert_not_called()
+
+    def 
test_yarn_status_tracking_rejects_conflicting_wait_app_completion_conf(self):
+        """User-set spark.yarn.submit.waitAppCompletion=true conflicts with 
flag -> ValueError."""
+        hook = SparkSubmitHook(
+            conn_id="spark_yarn_rm",
+            conf={"spark.yarn.submit.waitAppCompletion": "true"},
+            yarn_track_via_rm_api=True,
+        )
+
+        with pytest.raises(ValueError, 
match="spark.yarn.submit.waitAppCompletion=false"):
+            hook._build_spark_submit_command("")
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_tolerates_transient_failures(self, 
mock_popen, mock_get, mock_sleep):
+        """3 consecutive 5xx responses then SUCCEEDED -> normal completion."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        # 3 transient failures (within the 10-failure budget), then SUCCEEDED.
+        mock_get.side_effect = [
+            self._rm_failure_resp(503, "Service Unavailable"),
+            self._rm_failure_resp(502, "Bad Gateway"),
+            self._rm_failure_resp(500, "Internal Server Error"),
+            self._rm_status_resp("SUCCEEDED"),
+        ]
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        hook.submit()
+
+        assert mock_get.call_count == 4
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_tolerates_status_timeouts(self, mock_popen, 
mock_get, mock_sleep):
+        """First requests.exceptions.Timeout, second call succeeds -> normal 
completion."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.side_effect = [
+            requests.exceptions.Timeout("read timed out"),
+            self._rm_status_resp("SUCCEEDED"),
+        ]
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        hook.submit()
+
+        assert mock_get.call_count == 2
+        # All calls must include the (connect, read) timeout tuple.
+        for call_obj in mock_get.call_args_list:
+            assert call_obj.kwargs["timeout"] == (5, 30)
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_raises_after_too_many_failures(self, 
mock_popen, mock_get, mock_sleep):
+        """11 consecutive 5xx responses -> raise 'Giving up tracking YARN 
application'."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        # 11 failures: 10 tolerated; the 11th trips the budget.
+        mock_get.side_effect = [self._rm_failure_resp() for _ in range(11)]
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        with pytest.raises(RuntimeError, match="Giving up tracking YARN 
application"):
+            hook.submit()
+
+        assert mock_get.call_count == 11
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    def test_yarn_status_query_passes_provided_auth_to_requests(self, 
mock_get):
+        """yarn_rm_auth=<sentinel AuthBase> -> requests.get called with 
auth=<sentinel>."""
+
+        class _SentinelAuth(requests.auth.AuthBase):
+            def __call__(self, r):
+                return r
+
+        sentinel = _SentinelAuth()
+        mock_get.return_value = self._rm_status_resp("SUCCEEDED")
+
+        hook = SparkSubmitHook(
+            conn_id="spark_yarn_rm",
+            yarn_track_via_rm_api=True,
+            yarn_rm_auth=sentinel,
+        )
+        hook._query_yarn_application_status(self._RM_APP_ID)
+
+        assert mock_get.call_args.kwargs["auth"] is sentinel
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    def test_yarn_status_query_sends_no_auth_by_default(self, mock_get):
+        """Without yarn_rm_auth -> requests.get called with auth=None."""
+        mock_get.return_value = self._rm_status_resp("SUCCEEDED")
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        hook._query_yarn_application_status(self._RM_APP_ID)
+
+        assert mock_get.call_args.kwargs["auth"] is None

Review Comment:
   fixed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to