This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new bf5ab8fc46 Fix bug in LivyOperator when its trigger times out (#38916)
bf5ab8fc46 is described below
commit bf5ab8fc462b1f35a45b5cfc8940d06fb0e698dd
Author: mateuslatrova <[email protected]>
AuthorDate: Sun Apr 14 08:54:16 2024 -0300
Fix bug in LivyOperator when its trigger times out (#38916)
When a LivyOperator was instantiated with deferrable=True and its batch job
ran for more time than the set execution_timeout, airflow would detect this
timeout and would cancel the trigger and then try to kill the task with the
'on_kill' method.
But that would fail raising an AttributeError because the batch_id
attribute wouldn't be defined in the constructor method.
From now on, the LivyTrigger will timeout itself before airflow does it,
and it will send an event to the LivyOperator signaling that a timeout
happened. This way, the operator can stop the running Livy batch job, and can
fail the task instance gracefully.
---
airflow/providers/apache/livy/operators/livy.py | 9 ++++--
airflow/providers/apache/livy/triggers/livy.py | 27 +++++++++++++++-
tests/providers/apache/livy/operators/test_livy.py | 37 ++++++++++++++++++++++
tests/providers/apache/livy/triggers/test_livy.py | 30 ++++++++++++++++++
4 files changed, 100 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/apache/livy/operators/livy.py
b/airflow/providers/apache/livy/operators/livy.py
index edaa4e15d8..b761118a5c 100644
--- a/airflow/providers/apache/livy/operators/livy.py
+++ b/airflow/providers/apache/livy/operators/livy.py
@@ -124,7 +124,7 @@ class LivyOperator(BaseOperator):
self._extra_options = extra_options or {}
self._extra_headers = extra_headers or {}
- self._batch_id: int | str
+ self._batch_id: int | str | None = None
self.retry_args = retry_args
self.deferrable = deferrable
@@ -170,6 +170,7 @@ class LivyOperator(BaseOperator):
polling_interval=self._polling_interval,
extra_options=self._extra_options,
extra_headers=self._extra_headers,
+ execution_timeout=self.execution_timeout,
),
method_name="execute_complete",
)
@@ -217,8 +218,12 @@ class LivyOperator(BaseOperator):
for log_line in event["log_lines"]:
self.log.info(log_line)
- if event["status"] == "error":
+ if event["status"] == "timeout":
+ self.hook.delete_batch(event["batch_id"])
+
+ if event["status"] in ["error", "timeout"]:
raise AirflowException(event["response"])
+
self.log.info(
"%s completed with response %s",
self.task_id,
diff --git a/airflow/providers/apache/livy/triggers/livy.py
b/airflow/providers/apache/livy/triggers/livy.py
index d6203b4324..298d1e5f87 100644
--- a/airflow/providers/apache/livy/triggers/livy.py
+++ b/airflow/providers/apache/livy/triggers/livy.py
@@ -20,6 +20,7 @@
from __future__ import annotations
import asyncio
+from datetime import datetime, timedelta, timezone
from typing import Any, AsyncIterator
from airflow.providers.apache.livy.hooks.livy import BatchState, LivyAsyncHook
@@ -54,6 +55,7 @@ class LivyTrigger(BaseTrigger):
extra_options: dict[str, Any] | None = None,
extra_headers: dict[str, Any] | None = None,
livy_hook_async: LivyAsyncHook | None = None,
+ execution_timeout: timedelta | None = None,
):
super().__init__()
self._batch_id = batch_id
@@ -63,6 +65,7 @@ class LivyTrigger(BaseTrigger):
self._extra_options = extra_options
self._extra_headers = extra_headers
self._livy_hook_async = livy_hook_async
+ self._execution_timeout = execution_timeout
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize LivyTrigger arguments and classpath."""
@@ -76,6 +79,7 @@ class LivyTrigger(BaseTrigger):
"extra_options": self._extra_options,
"extra_headers": self._extra_headers,
"livy_hook_async": self._livy_hook_async,
+ "execution_timeout": self._execution_timeout,
},
)
@@ -113,16 +117,37 @@ class LivyTrigger(BaseTrigger):
:param batch_id: id of the batch session to monitor.
"""
+ if self._execution_timeout is not None:
+ timeout_datetime = datetime.now(timezone.utc) +
self._execution_timeout
+ else:
+ timeout_datetime = None
+ batch_execution_timed_out = False
hook = self._get_async_hook()
state = await hook.get_batch_state(batch_id)
self.log.info("Batch with id %s is in state: %s", batch_id,
state["batch_state"].value)
while state["batch_state"] not in hook.TERMINAL_STATES:
self.log.info("Batch with id %s is in state: %s", batch_id,
state["batch_state"].value)
+ batch_execution_timed_out = (
+ timeout_datetime is not None and datetime.now(timezone.utc) >
timeout_datetime
+ )
+ if batch_execution_timed_out:
+ break
self.log.info("Sleeping for %s seconds", self._polling_interval)
await asyncio.sleep(self._polling_interval)
state = await hook.get_batch_state(batch_id)
- self.log.info("Batch with id %s terminated with state: %s", batch_id,
state["batch_state"].value)
log_lines = await hook.dump_batch_logs(batch_id)
+ if batch_execution_timed_out:
+ self.log.info(
+ "Batch with id %s did not terminate, but it reached execution
timeout.",
+ batch_id,
+ )
+ return {
+ "status": "timeout",
+ "batch_id": batch_id,
+ "response": f"Batch {batch_id} timed out",
+ "log_lines": log_lines,
+ }
+ self.log.info("Batch with id %s terminated with state: %s", batch_id,
state["batch_state"].value)
if state["batch_state"] != BatchState.SUCCESS:
return {
"status": "error",
diff --git a/tests/providers/apache/livy/operators/test_livy.py
b/tests/providers/apache/livy/operators/test_livy.py
index 02e8231eb2..4e128cbec8 100644
--- a/tests/providers/apache/livy/operators/test_livy.py
+++ b/tests/providers/apache/livy/operators/test_livy.py
@@ -280,6 +280,19 @@ class TestLivyOperator:
task.execute(context=self.mock_context)
assert task.hook.extra_options == extra_options
+
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch")
+ def
test_when_kill_is_called_right_after_construction_it_should_not_raise_attribute_error(
+ self, mock_delete_batch
+ ):
+ task = LivyOperator(
+ livy_conn_id="livyunittest",
+ file="sparkapp",
+ dag=self.dag,
+ task_id="livy_example",
+ )
+ task.kill()
+ mock_delete_batch.assert_not_called()
+
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch")
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch",
return_value=BATCH_ID)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch",
return_value=GET_BATCH)
@@ -380,6 +393,30 @@ class TestLivyOperator:
)
self.mock_context["ti"].xcom_push.assert_not_called()
+ @patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch",
return_value=BATCH_ID)
+
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch")
+ def test_execute_complete_timeout(self, mock_delete, mock_post):
+ task = LivyOperator(
+ livy_conn_id="livyunittest",
+ file="sparkapp",
+ dag=self.dag,
+ task_id="livy_example",
+ polling_interval=1,
+ deferrable=True,
+ )
+ with pytest.raises(AirflowException):
+ task.execute_complete(
+ context=self.mock_context,
+ event={
+ "status": "timeout",
+ "log_lines": ["mock log"],
+ "batch_id": BATCH_ID,
+ "response": "mock timeout",
+ },
+ )
+ mock_delete.assert_called_once_with(BATCH_ID)
+ self.mock_context["ti"].xcom_push.assert_not_called()
+
@pytest.mark.db_test
def test_spark_params_templating(create_task_instance_of_operator):
diff --git a/tests/providers/apache/livy/triggers/test_livy.py
b/tests/providers/apache/livy/triggers/test_livy.py
index ac1464ffd4..df85a84bac 100644
--- a/tests/providers/apache/livy/triggers/test_livy.py
+++ b/tests/providers/apache/livy/triggers/test_livy.py
@@ -17,6 +17,7 @@
from __future__ import annotations
import asyncio
+from datetime import timedelta
from unittest import mock
import pytest
@@ -46,6 +47,7 @@ class TestLivyTrigger:
"extra_options": None,
"extra_headers": None,
"livy_hook_async": None,
+ "execution_timeout": None,
}
@pytest.mark.asyncio
@@ -195,3 +197,31 @@ class TestLivyTrigger:
# TriggerEvent was not returned
assert task.done() is False
asyncio.get_event_loop().stop()
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_batch_state")
+
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.dump_batch_logs")
+ async def test_livy_trigger_poll_for_termination_timeout(
+ self, mock_dump_batch_logs, mock_get_batch_state
+ ):
+ """
+ Test if poll_for_termination() returns timeout response when execution
times out.
+ """
+ mock_get_batch_state.return_value = {"batch_state": BatchState.RUNNING}
+ mock_dump_batch_logs.return_value = ["mock_log"]
+ trigger = LivyTrigger(
+ batch_id=1,
+ spark_params={},
+ livy_conn_id=LivyHook.default_conn_name,
+ polling_interval=1,
+ execution_timeout=timedelta(seconds=0),
+ )
+
+ task = await trigger.poll_for_termination(1)
+
+ assert task == {
+ "status": "timeout",
+ "batch_id": 1,
+ "response": "Batch 1 timed out",
+ "log_lines": ["mock_log"],
+ }