SameerMesiah97 commented on code in PR #68277: URL: https://github.com/apache/airflow/pull/68277#discussion_r3391338967
########## providers/apache/spark/src/airflow/providers/apache/spark/triggers/spark_submit.py: ########## @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio + +from collections.abc import AsyncIterator +from typing import Any + +import aiohttp + +from airflow.triggers.base import BaseTrigger, TriggerEvent + +_SPARK_TERMINAL_STATES = frozenset({"FINISHED", "FAILED", "KILLED", "ERROR"}) +_SPARK_ACTIVE_STATES = frozenset({"SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN"}) + + +class SparkDriverTrigger(BaseTrigger): + """ + Async trigger that polls the Spark standalone REST API until the driver + reaches a terminal state. Used when SparkSubmitOperator runs with deferrable=True. + :param driver_id: Spark driver submission ID returned by spark-submit --rest. + :param master_urls: List of Spark master REST base URLs e.g. ["http://spark-master:6066"]. + :param poll_interval: Seconds between REST API polls. Defaults to 10. + """ + + def __init__( + self, + driver_id: str, + master_urls: list[str], + poll_interval: int = 10, + ) -> None: + super().__init__() + self.driver_id = driver_id + self.master_urls = master_urls + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.apache.spark.triggers.spark_submit.SparkDriverTrigger", + { + "driver_id": self.driver_id, + "master_urls": self.master_urls, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Poll Spark REST API until driver reaches a terminal state.""" + while True: + status = await self._poll_driver_status() + if status is None: + yield TriggerEvent( + { + "status": "error", + "driver_id": self.driver_id, + "message": "All Spark masters unreachable", + } + ) + return + self.log.info("SparkDriverTrigger: driver=%s status=%s", self.driver_id, status) + upper = status.upper() Review Comment: I would change `upper` to `normalized_status` to better communicate the intent of using `upper()` ########## providers/apache/spark/src/airflow/providers/apache/spark/triggers/spark_submit.py: ########## @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio + +from collections.abc import AsyncIterator +from typing import Any + +import aiohttp + +from airflow.triggers.base import BaseTrigger, TriggerEvent + +_SPARK_TERMINAL_STATES = frozenset({"FINISHED", "FAILED", "KILLED", "ERROR"}) +_SPARK_ACTIVE_STATES = frozenset({"SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN"}) + + +class SparkDriverTrigger(BaseTrigger): + """ + Async trigger that polls the Spark standalone REST API until the driver + reaches a terminal state. Used when SparkSubmitOperator runs with deferrable=True. + :param driver_id: Spark driver submission ID returned by spark-submit --rest. + :param master_urls: List of Spark master REST base URLs e.g. ["http://spark-master:6066"]. + :param poll_interval: Seconds between REST API polls. Defaults to 10. + """ + + def __init__( + self, + driver_id: str, + master_urls: list[str], + poll_interval: int = 10, + ) -> None: + super().__init__() + self.driver_id = driver_id + self.master_urls = master_urls + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.apache.spark.triggers.spark_submit.SparkDriverTrigger", + { + "driver_id": self.driver_id, + "master_urls": self.master_urls, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Poll Spark REST API until driver reaches a terminal state.""" + while True: + status = await self._poll_driver_status() + if status is None: + yield TriggerEvent( + { + "status": "error", + "driver_id": self.driver_id, + "message": "All Spark masters unreachable", + } + ) + return + self.log.info("SparkDriverTrigger: driver=%s status=%s", self.driver_id, status) + upper = status.upper() + if upper in _SPARK_TERMINAL_STATES: + success = upper == "FINISHED" + yield TriggerEvent( + { + "status": "success" if success else "error", + "driver_id": self.driver_id, + "driver_state": upper, + "message": f"Driver {self.driver_id} reached state {upper}", + } + ) + return + await asyncio.sleep(self.poll_interval) + + async def _poll_driver_status(self) -> str | None: + """Try each master URL; return driverState str or None if all fail.""" + for url in self.master_urls: + status_url = f"{url.rstrip('/')}/v1/submissions/status/{self.driver_id}" + try: + async with aiohttp.ClientSession() as session: + async with session.get(status_url, timeout=aiohttp.ClientTimeout(total=30)) as resp: + resp.raise_for_status() + data = await resp.json() + if not data.get("success"): + self.log.warning( + "Spark REST API returned success=false for %s: %s", + self.driver_id, + data.get("message", "unknown"), + ) + return "UNKNOWN" Review Comment: I am not sure why you are defaulting to `UNKNOWN` when the Spark driver status is not success? Why should the driver status always be `UNKNOWN` if not success? Also, with your current logic, the spark driver state not being retrievable will just result in the trigger polling continuously until a success state or error is encountered. ########## providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py: ########## @@ -920,3 +920,156 @@ def test_execute_falls_through_to_plain_submit_when_flag_off(self): hook.submit.assert_called_once_with("test.jar") hook._poll_k8s_driver_via_api.assert_not_called() + + +class TestSparkSubmitOperatorDeferrable: + """Tests for SparkSubmitOperator deferrable=True mode.""" + + def setup_method(self): + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + self.dag = DAG("test_deferrable_dag", schedule=None, default_args=args) + + def _make_operator(self, **kwargs): + return SparkSubmitOperator( + task_id="test_deferrable", + dag=self.dag, + application="test.jar", + deferrable=True, + **kwargs, + ) + + def _make_hook(self): + hook = MagicMock() + hook._should_track_driver_status = False + hook._should_track_driver_via_k8s_api.return_value = False + hook._connection = {"master": "spark://myhost:7077"} + hook.submit.return_value = "driver-001" + return hook + + def test_deferrable_defaults_to_false(self): Review Comment: This test is not needed as you are testing native python behaviour. ########## providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py: ########## @@ -920,3 +920,156 @@ def test_execute_falls_through_to_plain_submit_when_flag_off(self): hook.submit.assert_called_once_with("test.jar") hook._poll_k8s_driver_via_api.assert_not_called() + + +class TestSparkSubmitOperatorDeferrable: + """Tests for SparkSubmitOperator deferrable=True mode.""" + + def setup_method(self): + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + self.dag = DAG("test_deferrable_dag", schedule=None, default_args=args) + + def _make_operator(self, **kwargs): + return SparkSubmitOperator( + task_id="test_deferrable", + dag=self.dag, + application="test.jar", + deferrable=True, + **kwargs, + ) + + def _make_hook(self): + hook = MagicMock() + hook._should_track_driver_status = False + hook._should_track_driver_via_k8s_api.return_value = False + hook._connection = {"master": "spark://myhost:7077"} + hook.submit.return_value = "driver-001" + return hook + + def test_deferrable_defaults_to_false(self): + """deferrable must default to False — existing behaviour unchanged.""" + op = SparkSubmitOperator(task_id="t", dag=self.dag, application="app.jar") + assert op.deferrable is False + + def test_deferrable_stored_on_operator(self): Review Comment: Same here. Not needed. ########## providers/apache/spark/src/airflow/providers/apache/spark/triggers/spark_submit.py: ########## @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio + +from collections.abc import AsyncIterator +from typing import Any + +import aiohttp + +from airflow.triggers.base import BaseTrigger, TriggerEvent + +_SPARK_TERMINAL_STATES = frozenset({"FINISHED", "FAILED", "KILLED", "ERROR"}) +_SPARK_ACTIVE_STATES = frozenset({"SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN"}) Review Comment: I don't see `_SPARK_ACTIVE_STATES` being used anywhere. Why is this needed? ########## providers/apache/spark/src/airflow/providers/apache/spark/triggers/spark_submit.py: ########## @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio + +from collections.abc import AsyncIterator +from typing import Any + +import aiohttp + +from airflow.triggers.base import BaseTrigger, TriggerEvent + +_SPARK_TERMINAL_STATES = frozenset({"FINISHED", "FAILED", "KILLED", "ERROR"}) Review Comment: nit: no need for frozenset as this set is not being modified anywhere. ########## providers/apache/spark/provider.yaml: ########## @@ -241,3 +241,8 @@ connection-types: task-decorators: - class-name: airflow.providers.apache.spark.decorators.pyspark.pyspark_task name: pyspark + +triggers: Review Comment: This should come after hooks. ########## providers/apache/spark/src/airflow/providers/apache/spark/triggers/spark_submit.py: ########## @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio + +from collections.abc import AsyncIterator +from typing import Any + +import aiohttp + +from airflow.triggers.base import BaseTrigger, TriggerEvent + +_SPARK_TERMINAL_STATES = frozenset({"FINISHED", "FAILED", "KILLED", "ERROR"}) +_SPARK_ACTIVE_STATES = frozenset({"SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN"}) + + +class SparkDriverTrigger(BaseTrigger): + """ + Async trigger that polls the Spark standalone REST API until the driver + reaches a terminal state. Used when SparkSubmitOperator runs with deferrable=True. + :param driver_id: Spark driver submission ID returned by spark-submit --rest. + :param master_urls: List of Spark master REST base URLs e.g. ["http://spark-master:6066"]. + :param poll_interval: Seconds between REST API polls. Defaults to 10. + """ + + def __init__( + self, + driver_id: str, + master_urls: list[str], + poll_interval: int = 10, + ) -> None: + super().__init__() + self.driver_id = driver_id + self.master_urls = master_urls + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.apache.spark.triggers.spark_submit.SparkDriverTrigger", + { + "driver_id": self.driver_id, + "master_urls": self.master_urls, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Poll Spark REST API until driver reaches a terminal state.""" + while True: + status = await self._poll_driver_status() + if status is None: + yield TriggerEvent( + { + "status": "error", + "driver_id": self.driver_id, + "message": "All Spark masters unreachable", + } + ) + return + self.log.info("SparkDriverTrigger: driver=%s status=%s", self.driver_id, status) + upper = status.upper() + if upper in _SPARK_TERMINAL_STATES: + success = upper == "FINISHED" + yield TriggerEvent( + { + "status": "success" if success else "error", + "driver_id": self.driver_id, + "driver_state": upper, + "message": f"Driver {self.driver_id} reached state {upper}", + } + ) + return + await asyncio.sleep(self.poll_interval) + + async def _poll_driver_status(self) -> str | None: + """Try each master URL; return driverState str or None if all fail.""" + for url in self.master_urls: + status_url = f"{url.rstrip('/')}/v1/submissions/status/{self.driver_id}" + try: + async with aiohttp.ClientSession() as session: + async with session.get(status_url, timeout=aiohttp.ClientTimeout(total=30)) as resp: + resp.raise_for_status() + data = await resp.json() + if not data.get("success"): + self.log.warning( + "Spark REST API returned success=false for %s: %s", + self.driver_id, + data.get("message", "unknown"), + ) + return "UNKNOWN" + return data["driverState"] + except Exception as exc: Review Comment: This Exception is too broad. What if there is a `ClientError/TimeoutError` in the loop? It seems like you are collapsing all of these different errors into a single log message which could be misleading. Also, if the error is not recoverable, why continue the loop? ########## providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py: ########## @@ -920,3 +920,156 @@ def test_execute_falls_through_to_plain_submit_when_flag_off(self): hook.submit.assert_called_once_with("test.jar") hook._poll_k8s_driver_via_api.assert_not_called() + + +class TestSparkSubmitOperatorDeferrable: + """Tests for SparkSubmitOperator deferrable=True mode.""" + + def setup_method(self): + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + self.dag = DAG("test_deferrable_dag", schedule=None, default_args=args) + + def _make_operator(self, **kwargs): + return SparkSubmitOperator( + task_id="test_deferrable", + dag=self.dag, + application="test.jar", + deferrable=True, + **kwargs, + ) + + def _make_hook(self): + hook = MagicMock() + hook._should_track_driver_status = False + hook._should_track_driver_via_k8s_api.return_value = False + hook._connection = {"master": "spark://myhost:7077"} + hook.submit.return_value = "driver-001" + return hook + + def test_deferrable_defaults_to_false(self): + """deferrable must default to False — existing behaviour unchanged.""" + op = SparkSubmitOperator(task_id="t", dag=self.dag, application="app.jar") + assert op.deferrable is False + + def test_deferrable_stored_on_operator(self): + """deferrable=True must be stored as self.deferrable.""" + op = self._make_operator() + assert op.deferrable is True + + def test_execute_calls_defer_when_deferrable_true(self): + """execute() must call self.defer() when deferrable=True.""" + op = self._make_operator() + hook = self._make_hook() + op._hook = hook + + with mock.patch.object(op, "submit_job", return_value="driver-001"), \ + mock.patch.object(op, "_build_master_rest_urls", return_value=["http://myhost:6066"]), \ + mock.patch.object(op, "defer") as mock_defer: + op.execute(context={}) + + mock_defer.assert_called_once() + call_kwargs = mock_defer.call_args.kwargs + assert call_kwargs["method_name"] == "execute_complete" + + def test_execute_passes_correct_args_to_trigger(self): + """execute() must pass driver_id and master_urls to SparkDriverTrigger.""" + from airflow.providers.apache.spark.triggers.spark_submit import SparkDriverTrigger + + op = self._make_operator(status_poll_interval=15) + hook = self._make_hook() + op._hook = hook + + with mock.patch.object(op, "submit_job", return_value="driver-xyz"), \ + mock.patch.object(op, "_build_master_rest_urls", return_value=["http://m1:6066"]), \ + mock.patch.object(op, "defer") as mock_defer: + op.execute(context={}) + + trigger = mock_defer.call_args.kwargs["trigger"] + assert isinstance(trigger, SparkDriverTrigger) + assert trigger.driver_id == "driver-xyz" + assert trigger.master_urls == ["http://m1:6066"] + assert trigger.poll_interval == 15 + + def test_execute_does_not_call_hook_submit_directly(self): + """execute() in deferrable mode must use submit_job(), not hook.submit().""" + op = self._make_operator() + hook = self._make_hook() + op._hook = hook + + with mock.patch.object(op, "submit_job", return_value="driver-001"), \ + mock.patch.object(op, "_build_master_rest_urls", return_value=["http://myhost:6066"]), \ + mock.patch.object(op, "defer"): + op.execute(context={}) + + hook.submit.assert_not_called() + + def test_execute_complete_succeeds_on_success_event(self): + """execute_complete() must not raise when status=success.""" + op = self._make_operator() + event = { + "status": "success", + "driver_id": "driver-001", + "driver_state": "FINISHED", + "message": "Driver reached FINISHED", + } + op.execute_complete(context={}, event=event) # must not raise + + def test_execute_complete_raises_on_error_event(self): + """execute_complete() must raise AirflowException when status=error.""" + from airflow.providers.common.compat.sdk import AirflowException + + op = self._make_operator() + event = { + "status": "error", + "driver_id": "driver-001", + "driver_state": "FAILED", + "message": "Driver reached FAILED", + } + with pytest.raises(AirflowException, match="driver-001"): + op.execute_complete(context={}, event=event) + Review Comment: Can you add another test for malformed event payloads i.e. missing `driver_id` and `status`? ########## providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py: ########## @@ -920,3 +920,156 @@ def test_execute_falls_through_to_plain_submit_when_flag_off(self): hook.submit.assert_called_once_with("test.jar") hook._poll_k8s_driver_via_api.assert_not_called() + + +class TestSparkSubmitOperatorDeferrable: + """Tests for SparkSubmitOperator deferrable=True mode.""" + + def setup_method(self): + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + self.dag = DAG("test_deferrable_dag", schedule=None, default_args=args) + + def _make_operator(self, **kwargs): + return SparkSubmitOperator( + task_id="test_deferrable", + dag=self.dag, + application="test.jar", + deferrable=True, + **kwargs, + ) + + def _make_hook(self): + hook = MagicMock() + hook._should_track_driver_status = False + hook._should_track_driver_via_k8s_api.return_value = False + hook._connection = {"master": "spark://myhost:7077"} + hook.submit.return_value = "driver-001" + return hook + + def test_deferrable_defaults_to_false(self): + """deferrable must default to False — existing behaviour unchanged.""" + op = SparkSubmitOperator(task_id="t", dag=self.dag, application="app.jar") + assert op.deferrable is False + + def test_deferrable_stored_on_operator(self): + """deferrable=True must be stored as self.deferrable.""" + op = self._make_operator() + assert op.deferrable is True + + def test_execute_calls_defer_when_deferrable_true(self): + """execute() must call self.defer() when deferrable=True.""" + op = self._make_operator() + hook = self._make_hook() + op._hook = hook + + with mock.patch.object(op, "submit_job", return_value="driver-001"), \ + mock.patch.object(op, "_build_master_rest_urls", return_value=["http://myhost:6066"]), \ + mock.patch.object(op, "defer") as mock_defer: + op.execute(context={}) + + mock_defer.assert_called_once() + call_kwargs = mock_defer.call_args.kwargs + assert call_kwargs["method_name"] == "execute_complete" + + def test_execute_passes_correct_args_to_trigger(self): + """execute() must pass driver_id and master_urls to SparkDriverTrigger.""" + from airflow.providers.apache.spark.triggers.spark_submit import SparkDriverTrigger + + op = self._make_operator(status_poll_interval=15) + hook = self._make_hook() + op._hook = hook + + with mock.patch.object(op, "submit_job", return_value="driver-xyz"), \ + mock.patch.object(op, "_build_master_rest_urls", return_value=["http://m1:6066"]), \ + mock.patch.object(op, "defer") as mock_defer: + op.execute(context={}) + + trigger = mock_defer.call_args.kwargs["trigger"] + assert isinstance(trigger, SparkDriverTrigger) + assert trigger.driver_id == "driver-xyz" + assert trigger.master_urls == ["http://m1:6066"] + assert trigger.poll_interval == 15 + + def test_execute_does_not_call_hook_submit_directly(self): + """execute() in deferrable mode must use submit_job(), not hook.submit().""" + op = self._make_operator() + hook = self._make_hook() + op._hook = hook + + with mock.patch.object(op, "submit_job", return_value="driver-001"), \ + mock.patch.object(op, "_build_master_rest_urls", return_value=["http://myhost:6066"]), \ + mock.patch.object(op, "defer"): + op.execute(context={}) + + hook.submit.assert_not_called() + + def test_execute_complete_succeeds_on_success_event(self): + """execute_complete() must not raise when status=success.""" + op = self._make_operator() + event = { + "status": "success", + "driver_id": "driver-001", + "driver_state": "FINISHED", + "message": "Driver reached FINISHED", + } + op.execute_complete(context={}, event=event) # must not raise + + def test_execute_complete_raises_on_error_event(self): + """execute_complete() must raise AirflowException when status=error.""" + from airflow.providers.common.compat.sdk import AirflowException Review Comment: Why is `AirflowException` still here? You switched to `RuntimeError` in your implementation? ########## providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py: ########## @@ -920,3 +920,156 @@ def test_execute_falls_through_to_plain_submit_when_flag_off(self): hook.submit.assert_called_once_with("test.jar") hook._poll_k8s_driver_via_api.assert_not_called() + + +class TestSparkSubmitOperatorDeferrable: + """Tests for SparkSubmitOperator deferrable=True mode.""" + + def setup_method(self): + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + self.dag = DAG("test_deferrable_dag", schedule=None, default_args=args) + + def _make_operator(self, **kwargs): + return SparkSubmitOperator( + task_id="test_deferrable", + dag=self.dag, + application="test.jar", + deferrable=True, + **kwargs, + ) + + def _make_hook(self): + hook = MagicMock() + hook._should_track_driver_status = False + hook._should_track_driver_via_k8s_api.return_value = False + hook._connection = {"master": "spark://myhost:7077"} + hook.submit.return_value = "driver-001" + return hook + + def test_deferrable_defaults_to_false(self): + """deferrable must default to False — existing behaviour unchanged.""" + op = SparkSubmitOperator(task_id="t", dag=self.dag, application="app.jar") + assert op.deferrable is False + + def test_deferrable_stored_on_operator(self): + """deferrable=True must be stored as self.deferrable.""" + op = self._make_operator() + assert op.deferrable is True + + def test_execute_calls_defer_when_deferrable_true(self): + """execute() must call self.defer() when deferrable=True.""" + op = self._make_operator() + hook = self._make_hook() + op._hook = hook + + with mock.patch.object(op, "submit_job", return_value="driver-001"), \ + mock.patch.object(op, "_build_master_rest_urls", return_value=["http://myhost:6066"]), \ + mock.patch.object(op, "defer") as mock_defer: + op.execute(context={}) + + mock_defer.assert_called_once() + call_kwargs = mock_defer.call_args.kwargs + assert call_kwargs["method_name"] == "execute_complete" + + def test_execute_passes_correct_args_to_trigger(self): + """execute() must pass driver_id and master_urls to SparkDriverTrigger.""" + from airflow.providers.apache.spark.triggers.spark_submit import SparkDriverTrigger + + op = self._make_operator(status_poll_interval=15) + hook = self._make_hook() + op._hook = hook + + with mock.patch.object(op, "submit_job", return_value="driver-xyz"), \ + mock.patch.object(op, "_build_master_rest_urls", return_value=["http://m1:6066"]), \ + mock.patch.object(op, "defer") as mock_defer: + op.execute(context={}) + + trigger = mock_defer.call_args.kwargs["trigger"] + assert isinstance(trigger, SparkDriverTrigger) + assert trigger.driver_id == "driver-xyz" + assert trigger.master_urls == ["http://m1:6066"] + assert trigger.poll_interval == 15 + + def test_execute_does_not_call_hook_submit_directly(self): + """execute() in deferrable mode must use submit_job(), not hook.submit().""" + op = self._make_operator() + hook = self._make_hook() + op._hook = hook + + with mock.patch.object(op, "submit_job", return_value="driver-001"), \ + mock.patch.object(op, "_build_master_rest_urls", return_value=["http://myhost:6066"]), \ + mock.patch.object(op, "defer"): + op.execute(context={}) + + hook.submit.assert_not_called() + + def test_execute_complete_succeeds_on_success_event(self): + """execute_complete() must not raise when status=success.""" + op = self._make_operator() + event = { + "status": "success", + "driver_id": "driver-001", + "driver_state": "FINISHED", + "message": "Driver reached FINISHED", + } + op.execute_complete(context={}, event=event) # must not raise + + def test_execute_complete_raises_on_error_event(self): + """execute_complete() must raise AirflowException when status=error.""" + from airflow.providers.common.compat.sdk import AirflowException + + op = self._make_operator() + event = { + "status": "error", + "driver_id": "driver-001", + "driver_state": "FAILED", + "message": "Driver reached FAILED", + } + with pytest.raises(AirflowException, match="driver-001"): + op.execute_complete(context={}, event=event) + + def test_build_master_rest_urls_single_master(self): + """_build_master_rest_urls must return correct URL for a single master.""" + op = self._make_operator() + hook = self._make_hook() + hook._connection = { + "master": "spark://myhost:7077", + "rest_scheme": "http", + "rest_port": 6066, + } + op._hook = hook + + urls = op._build_master_rest_urls() + + assert urls == ["http://myhost:6066"] + + def test_build_master_rest_urls_ha_multiple_masters(self): + """_build_master_rest_urls must return a URL per master in HA mode.""" + op = self._make_operator() + hook = self._make_hook() + hook._connection = { + "master": "spark://m1:7077,m2:7077", + "rest_scheme": "https", + "rest_port": 6066, + } + op._hook = hook + + urls = op._build_master_rest_urls() + + assert urls == ["https://m1:6066", "https://m2:6066"] + + def test_deferrable_false_uses_sync_path(self): Review Comment: Is this not covered by the existing non-deferrable tests? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
