This is an automated email from the ASF dual-hosted git repository.
taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c95aa7863f Consolidate hook management in LivyOperator (#34431)
c95aa7863f is described below
commit c95aa7863f3efb5b9acebe3d2c5c3a146c6de1bf
Author: Hussein Awala <[email protected]>
AuthorDate: Mon Sep 18 21:12:43 2023 +0200
Consolidate hook management in LivyOperator (#34431)
* Consolidate hook management in LivyOperator
* use AirflowProviderDeprecationWarning
---
airflow/providers/apache/livy/operators/livy.py | 56 ++++++++++++----------
tests/providers/apache/livy/operators/test_livy.py | 28 ++---------
2 files changed, 34 insertions(+), 50 deletions(-)
diff --git a/airflow/providers/apache/livy/operators/livy.py
b/airflow/providers/apache/livy/operators/livy.py
index 046f33324e..bcf36b50ca 100644
--- a/airflow/providers/apache/livy/operators/livy.py
+++ b/airflow/providers/apache/livy/operators/livy.py
@@ -17,11 +17,14 @@
"""This module contains the Apache Livy operator."""
from __future__ import annotations
+from functools import cached_property
from time import sleep
from typing import TYPE_CHECKING, Any, Sequence
+from deprecated.classic import deprecated
+
from airflow.configuration import conf
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.apache.livy.hooks.livy import BatchState, LivyHook
from airflow.providers.apache.livy.triggers.livy import LivyTrigger
@@ -119,41 +122,43 @@ class LivyOperator(BaseOperator):
self._extra_options = extra_options or {}
self._extra_headers = extra_headers or {}
- self._livy_hook: LivyHook | None = None
self._batch_id: int | str
self.retry_args = retry_args
self.deferrable = deferrable
- def get_hook(self) -> LivyHook:
+ @cached_property
+ def hook(self) -> LivyHook:
"""
Get valid hook.
- :return: hook
+ :return: LivyHook
"""
- if self._livy_hook is None or not isinstance(self._livy_hook,
LivyHook):
- self._livy_hook = LivyHook(
- livy_conn_id=self._livy_conn_id,
- extra_headers=self._extra_headers,
- extra_options=self._extra_options,
- auth_type=self._livy_conn_auth_type,
- )
- return self._livy_hook
+ return LivyHook(
+ livy_conn_id=self._livy_conn_id,
+ extra_headers=self._extra_headers,
+ extra_options=self._extra_options,
+ auth_type=self._livy_conn_auth_type,
+ )
+
+ @deprecated(reason="use `hook` property instead.",
category=AirflowProviderDeprecationWarning)
+ def get_hook(self) -> LivyHook:
+ """Get valid hook."""
+ return self.hook
def execute(self, context: Context) -> Any:
- self._batch_id = self.get_hook().post_batch(**self.spark_params)
+ self._batch_id = self.hook.post_batch(**self.spark_params)
self.log.info("Generated batch-id is %s", self._batch_id)
# Wait for the job to complete
if not self.deferrable:
if self._polling_interval > 0:
self.poll_for_termination(self._batch_id)
- context["ti"].xcom_push(key="app_id",
value=self.get_hook().get_batch(self._batch_id)["appId"])
+ context["ti"].xcom_push(key="app_id",
value=self.hook.get_batch(self._batch_id)["appId"])
return self._batch_id
- hook = self.get_hook()
- state = hook.get_batch_state(self._batch_id,
retry_args=self.retry_args)
+ state = self.hook.get_batch_state(self._batch_id,
retry_args=self.retry_args)
self.log.debug("Batch with id %s is in state: %s", self._batch_id,
state.value)
- if state not in hook.TERMINAL_STATES:
+ if state not in self.hook.TERMINAL_STATES:
self.defer(
timeout=self.execution_timeout,
trigger=LivyTrigger(
@@ -168,11 +173,11 @@ class LivyOperator(BaseOperator):
)
else:
self.log.info("Batch with id %s terminated with state: %s",
self._batch_id, state.value)
- hook.dump_batch_logs(self._batch_id)
+ self.hook.dump_batch_logs(self._batch_id)
if state != BatchState.SUCCESS:
raise AirflowException(f"Batch {self._batch_id} did not
succeed")
- context["ti"].xcom_push(key="app_id",
value=self.get_hook().get_batch(self._batch_id)["appId"])
+ context["ti"].xcom_push(key="app_id",
value=self.hook.get_batch(self._batch_id)["appId"])
return self._batch_id
def poll_for_termination(self, batch_id: int | str) -> None:
@@ -181,14 +186,13 @@ class LivyOperator(BaseOperator):
:param batch_id: id of the batch session to monitor.
"""
- hook = self.get_hook()
- state = hook.get_batch_state(batch_id, retry_args=self.retry_args)
- while state not in hook.TERMINAL_STATES:
+ state = self.hook.get_batch_state(batch_id, retry_args=self.retry_args)
+ while state not in self.hook.TERMINAL_STATES:
self.log.debug("Batch with id %s is in state: %s", batch_id,
state.value)
sleep(self._polling_interval)
- state = hook.get_batch_state(batch_id, retry_args=self.retry_args)
+ state = self.hook.get_batch_state(batch_id,
retry_args=self.retry_args)
self.log.info("Batch with id %s terminated with state: %s", batch_id,
state.value)
- hook.dump_batch_logs(batch_id)
+ self.hook.dump_batch_logs(batch_id)
if state != BatchState.SUCCESS:
raise AirflowException(f"Batch {batch_id} did not succeed")
@@ -198,7 +202,7 @@ class LivyOperator(BaseOperator):
def kill(self) -> None:
"""Delete the current batch session."""
if self._batch_id is not None:
- self.get_hook().delete_batch(self._batch_id)
+ self.hook.delete_batch(self._batch_id)
def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
"""
@@ -218,5 +222,5 @@ class LivyOperator(BaseOperator):
self.task_id,
event["response"],
)
- context["ti"].xcom_push(key="app_id",
value=self.get_hook().get_batch(event["batch_id"])["appId"])
+ context["ti"].xcom_push(key="app_id",
value=self.hook.get_batch(event["batch_id"])["appId"])
return event["batch_id"]
diff --git a/tests/providers/apache/livy/operators/test_livy.py
b/tests/providers/apache/livy/operators/test_livy.py
index 452abe182a..04e796f092 100644
--- a/tests/providers/apache/livy/operators/test_livy.py
+++ b/tests/providers/apache/livy/operators/test_livy.py
@@ -24,7 +24,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.models.dag import DAG
-from airflow.providers.apache.livy.hooks.livy import BatchState, LivyHook
+from airflow.providers.apache.livy.hooks.livy import BatchState
from airflow.providers.apache.livy.operators.livy import LivyOperator
from airflow.utils import db, timezone
@@ -63,7 +63,6 @@ class TestLivyOperator:
mock_livy.side_effect = side_effect
task = LivyOperator(file="sparkapp", polling_interval=1, dag=self.dag,
task_id="livy_example")
- task._livy_hook = task.get_hook()
task.poll_for_termination(BATCH_ID)
mock_livy.assert_called_with(BATCH_ID, retry_args=None)
@@ -87,7 +86,6 @@ class TestLivyOperator:
mock_livy.side_effect = side_effect
task = LivyOperator(file="sparkapp", polling_interval=1, dag=self.dag,
task_id="livy_example")
- task._livy_hook = task.get_hook()
with pytest.raises(AirflowException):
task.poll_for_termination(BATCH_ID)
@@ -147,14 +145,6 @@ class TestLivyOperator:
mock_delete.assert_called_once_with(BATCH_ID)
- def test_injected_hook(self):
- def_hook = LivyHook(livy_conn_id="livyunittest")
-
- task = LivyOperator(file="sparkapp", dag=self.dag,
task_id="livy_example")
- task._livy_hook = def_hook
-
- assert task.get_hook() == def_hook
-
@patch(
"airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state",
return_value=BatchState.SUCCESS,
@@ -171,7 +161,7 @@ class TestLivyOperator:
polling_interval=1,
)
caplog.clear()
- with caplog.at_level(level=logging.INFO,
logger=task.get_hook().log.name):
+ with caplog.at_level(level=logging.INFO, logger=task.hook.log.name):
task.execute(context=self.mock_context)
assert "first_line" in caplog.messages
@@ -200,7 +190,6 @@ class TestLivyOperator:
task = LivyOperator(
file="sparkapp", polling_interval=1, dag=self.dag,
task_id="livy_example", deferrable=True
)
- task._livy_hook = task.get_hook()
task.poll_for_termination(BATCH_ID)
mock_livy.assert_called_with(BATCH_ID, retry_args=None)
@@ -226,7 +215,6 @@ class TestLivyOperator:
task = LivyOperator(
file="sparkapp", polling_interval=1, dag=self.dag,
task_id="livy_example", deferrable=True
)
- task._livy_hook = task.get_hook()
with pytest.raises(AirflowException):
task.poll_for_termination(BATCH_ID)
@@ -287,7 +275,7 @@ class TestLivyOperator:
)
task.execute(context=self.mock_context)
- assert task.get_hook().extra_options == extra_options
+ assert task.hook.extra_options == extra_options
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch")
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch",
return_value=BATCH_ID)
@@ -315,14 +303,6 @@ class TestLivyOperator:
mock_delete.assert_called_once_with(BATCH_ID)
- def test_injected_hook_deferrable(self):
- def_hook = LivyHook(livy_conn_id="livyunittest")
-
- task = LivyOperator(file="sparkapp", dag=self.dag,
task_id="livy_example", deferrable=True)
- task._livy_hook = def_hook
-
- assert task.get_hook() == def_hook
-
@patch(
"airflow.providers.apache.livy.operators.livy.LivyHook.get_batch_state",
return_value=BatchState.SUCCESS,
@@ -341,7 +321,7 @@ class TestLivyOperator:
)
caplog.clear()
- with caplog.at_level(level=logging.INFO,
logger=task.get_hook().log.name):
+ with caplog.at_level(level=logging.INFO, logger=task.hook.log.name):
task.execute(context=self.mock_context)
assert "first_line" in caplog.messages