This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8c9e0b2a8ec Add `endpoint_prefix` to `LivyHook` (#45811)
8c9e0b2a8ec is described below
commit 8c9e0b2a8ec0e6f2842883898370741e71c0e802
Author: gpathak128 <[email protected]>
AuthorDate: Fri Jan 31 20:23:47 2025 -0500
Add `endpoint_prefix` to `LivyHook` (#45811)
* revert manual changes to the Changelog file
* adding tests for livy db hook
* adding tests for async livy hook
* reformat comment, and fix endpoint_prefix type for livy trigger
* formatting changes
* remove extra default args from expectation
* adding mock patch for mock request
---------
Co-authored-by: Giridhar Pathak <[email protected]>
---
.../airflow/providers/apache/livy/hooks/livy.py | 36 ++++++--
.../providers/apache/livy/operators/livy.py | 3 +
.../airflow/providers/apache/livy/sensors/livy.py | 3 +
.../airflow/providers/apache/livy/triggers/livy.py | 3 +
.../provider_tests/apache/livy/hooks/test_livy.py | 99 ++++++++++++++++++++++
5 files changed, 137 insertions(+), 7 deletions(-)
diff --git
a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
index 934985befbc..f185ccdfbe3 100644
--- a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
+++ b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
@@ -52,6 +52,11 @@ class BatchState(Enum):
SUCCESS = "success"
+def sanitize_endpoint_prefix(endpoint_prefix: str | None) -> str:
+ """Ensure that the endpoint prefix is prefixed with a slash."""
+ return f"/{endpoint_prefix.strip('/')}" if endpoint_prefix else ""
+
+
class LivyHook(HttpHook):
"""
Hook for Apache Livy through the REST API.
@@ -86,12 +91,14 @@ class LivyHook(HttpHook):
extra_options: dict[str, Any] | None = None,
extra_headers: dict[str, Any] | None = None,
auth_type: Any | None = None,
+ endpoint_prefix: str | None = None,
) -> None:
super().__init__()
self.method = "POST"
self.http_conn_id = livy_conn_id
self.extra_headers = extra_headers or {}
self.extra_options = extra_options or {}
+ self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix)
if auth_type:
self.auth_type = auth_type
@@ -163,7 +170,10 @@ class LivyHook(HttpHook):
self.log.info("Submitting job %s to %s", batch_submit_body,
self.base_url)
response = self.run_method(
- method="POST", endpoint="/batches", data=batch_submit_body,
headers=self.extra_headers
+ method="POST",
+ endpoint=f"{self.endpoint_prefix}/batches",
+ data=batch_submit_body,
+ headers=self.extra_headers,
)
self.log.debug("Got response: %s", response.text)
@@ -192,7 +202,9 @@ class LivyHook(HttpHook):
self._validate_session_id(session_id)
self.log.debug("Fetching info for batch session %s", session_id)
- response = self.run_method(endpoint=f"/batches/{session_id}",
headers=self.extra_headers)
+ response = self.run_method(
+ endpoint=f"{self.endpoint_prefix}/batches/{session_id}",
headers=self.extra_headers
+ )
try:
response.raise_for_status()
@@ -217,7 +229,9 @@ class LivyHook(HttpHook):
self.log.debug("Fetching info for batch session %s", session_id)
response = self.run_method(
- endpoint=f"/batches/{session_id}/state", retry_args=retry_args,
headers=self.extra_headers
+ endpoint=f"{self.endpoint_prefix}/batches/{session_id}/state",
+ retry_args=retry_args,
+ headers=self.extra_headers,
)
try:
@@ -244,7 +258,9 @@ class LivyHook(HttpHook):
self.log.info("Deleting batch session %s", session_id)
response = self.run_method(
- method="DELETE", endpoint=f"/batches/{session_id}",
headers=self.extra_headers
+ method="DELETE",
+ endpoint=f"{self.endpoint_prefix}/batches/{session_id}",
+ headers=self.extra_headers,
)
try:
@@ -270,7 +286,9 @@ class LivyHook(HttpHook):
self._validate_session_id(session_id)
log_params = {"from": log_start_position, "size": log_batch_size}
response = self.run_method(
- endpoint=f"/batches/{session_id}/log", data=log_params,
headers=self.extra_headers
+ endpoint=f"{self.endpoint_prefix}/batches/{session_id}/log",
+ data=log_params,
+ headers=self.extra_headers,
)
try:
response.raise_for_status()
@@ -490,12 +508,14 @@ class LivyAsyncHook(HttpAsyncHook):
livy_conn_id: str = default_conn_name,
extra_options: dict[str, Any] | None = None,
extra_headers: dict[str, Any] | None = None,
+ endpoint_prefix: str | None = None,
) -> None:
super().__init__()
self.method = "POST"
self.http_conn_id = livy_conn_id
self.extra_headers = extra_headers or {}
self.extra_options = extra_options or {}
+ self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix)
async def _do_api_call_async(
self,
@@ -624,7 +644,7 @@ class LivyAsyncHook(HttpAsyncHook):
"""
self._validate_session_id(session_id)
self.log.info("Fetching info for batch session %s", session_id)
- result = await self.run_method(endpoint=f"/batches/{session_id}/state")
+ result = await
self.run_method(endpoint=f"{self.endpoint_prefix}/batches/{session_id}/state")
if result["status"] == "error":
self.log.info(result)
return {"batch_state": "error", "response": result, "status":
"error"}
@@ -659,7 +679,9 @@ class LivyAsyncHook(HttpAsyncHook):
"""
self._validate_session_id(session_id)
log_params = {"from": log_start_position, "size": log_batch_size}
- result = await self.run_method(endpoint=f"/batches/{session_id}/log",
data=log_params)
+ result = await self.run_method(
+ endpoint=f"{self.endpoint_prefix}/batches/{session_id}/log",
data=log_params
+ )
if result["status"] == "error":
self.log.info(result)
return {"response": result["response"], "status": "error"}
diff --git
a/providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py
b/providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py
index b74e52d5e61..746ea55ccee 100644
--- a/providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py
+++ b/providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py
@@ -88,6 +88,7 @@ class LivyOperator(BaseOperator):
proxy_user: str | None = None,
livy_conn_id: str = "livy_default",
livy_conn_auth_type: Any | None = None,
+ livy_endpoint_prefix: str | None = None,
polling_interval: int = 0,
extra_options: dict[str, Any] | None = None,
extra_headers: dict[str, Any] | None = None,
@@ -119,6 +120,7 @@ class LivyOperator(BaseOperator):
self.spark_params = spark_params
self._livy_conn_id = livy_conn_id
self._livy_conn_auth_type = livy_conn_auth_type
+ self._livy_endpoint_prefix = livy_endpoint_prefix
self._polling_interval = polling_interval
self._extra_options = extra_options or {}
self._extra_headers = extra_headers or {}
@@ -139,6 +141,7 @@ class LivyOperator(BaseOperator):
extra_headers=self._extra_headers,
extra_options=self._extra_options,
auth_type=self._livy_conn_auth_type,
+ endpoint_prefix=self._livy_endpoint_prefix,
)
def execute(self, context: Context) -> Any:
diff --git
a/providers/apache/livy/src/airflow/providers/apache/livy/sensors/livy.py
b/providers/apache/livy/src/airflow/providers/apache/livy/sensors/livy.py
index 3c1a50255ad..d0b011e2518 100644
--- a/providers/apache/livy/src/airflow/providers/apache/livy/sensors/livy.py
+++ b/providers/apache/livy/src/airflow/providers/apache/livy/sensors/livy.py
@@ -46,6 +46,7 @@ class LivySensor(BaseSensorOperator):
livy_conn_id: str = "livy_default",
livy_conn_auth_type: Any | None = None,
extra_options: dict[str, Any] | None = None,
+ endpoint_prefix: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -54,6 +55,7 @@ class LivySensor(BaseSensorOperator):
self._livy_conn_auth_type = livy_conn_auth_type
self._livy_hook: LivyHook | None = None
self._extra_options = extra_options or {}
+ self._endpoint_prefix = endpoint_prefix
def get_hook(self) -> LivyHook:
"""
@@ -66,6 +68,7 @@ class LivySensor(BaseSensorOperator):
livy_conn_id=self._livy_conn_id,
extra_options=self._extra_options,
auth_type=self._livy_conn_auth_type,
+ endpoint_prefix=self._endpoint_prefix,
)
return self._livy_hook
diff --git
a/providers/apache/livy/src/airflow/providers/apache/livy/triggers/livy.py
b/providers/apache/livy/src/airflow/providers/apache/livy/triggers/livy.py
index 2e40b26113e..9c47706b737 100644
--- a/providers/apache/livy/src/airflow/providers/apache/livy/triggers/livy.py
+++ b/providers/apache/livy/src/airflow/providers/apache/livy/triggers/livy.py
@@ -57,6 +57,7 @@ class LivyTrigger(BaseTrigger):
extra_headers: dict[str, Any] | None = None,
livy_hook_async: LivyAsyncHook | None = None,
execution_timeout: timedelta | None = None,
+ endpoint_prefix: str | None = None,
):
super().__init__()
self._batch_id = batch_id
@@ -67,6 +68,7 @@ class LivyTrigger(BaseTrigger):
self._extra_headers = extra_headers
self._livy_hook_async = livy_hook_async
self._execution_timeout = execution_timeout
+ self._endpoint_prefix = endpoint_prefix
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize LivyTrigger arguments and classpath."""
@@ -170,5 +172,6 @@ class LivyTrigger(BaseTrigger):
livy_conn_id=self._livy_conn_id,
extra_headers=self._extra_headers,
extra_options=self._extra_options,
+ endpoint_prefix=self._endpoint_prefix,
)
return self._livy_hook_async
diff --git
a/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py
b/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py
index 6cb6dd91160..2a1101ab44e 100644
--- a/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py
+++ b/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py
@@ -428,6 +428,77 @@ class TestLivyDbHook:
auth_type.assert_called_once_with("login", "secret")
+ @patch("airflow.providers.apache.livy.hooks.livy.LivyHook.run_method")
+ def test_post_batch_with_endpoint_prefix(self, mock_request):
+ mock_request.return_value.status_code = 201
+ mock_request.return_value.json.return_value = {
+ "id": BATCH_ID,
+ "state": BatchState.STARTING.value,
+ "log": [],
+ }
+
+ resp = LivyHook(endpoint_prefix="/livy").post_batch(file="sparkapp")
+
+ mock_request.assert_called_once_with(
+ method="POST", endpoint="/livy/batches", data=json.dumps({"file":
"sparkapp"}), headers={}
+ )
+
+ request_args = mock_request.call_args.kwargs
+ assert "data" in request_args
+ assert isinstance(request_args["data"], str)
+
+ assert isinstance(resp, int)
+ assert resp == BATCH_ID
+
+ def test_get_batch_with_endpoint_prefix(self, requests_mock):
+ requests_mock.register_uri(
+ "GET", f"{MATCH_URL}/livy/batches/{BATCH_ID}", json={"id":
BATCH_ID}, status_code=200
+ )
+ resp = LivyHook(endpoint_prefix="/livy").get_batch(BATCH_ID)
+ assert isinstance(resp, dict)
+ assert "id" in resp
+
+ def test_get_batch_state_with_endpoint_prefix(self, requests_mock):
+ running = BatchState.RUNNING
+
+ requests_mock.register_uri(
+ "GET",
+ f"{MATCH_URL}/livy/batches/{BATCH_ID}/state",
+ json={"id": BATCH_ID, "state": running.value},
+ status_code=200,
+ )
+
+ state = LivyHook(endpoint_prefix="/livy").get_batch_state(BATCH_ID)
+ assert isinstance(state, BatchState)
+ assert state == running
+
+ def test_delete_batch_with_endpoint_prefix(self, requests_mock):
+ requests_mock.register_uri(
+ "DELETE", f"{MATCH_URL}/livy/batches/{BATCH_ID}", json={"msg":
"deleted"}, status_code=200
+ )
+ assert LivyHook(endpoint_prefix="/livy").delete_batch(BATCH_ID) ==
{"msg": "deleted"}
+
+ @pytest.mark.parametrize(
+ "prefix",
+ ["/livy/", "livy", "/livy", "livy/"],
+ ids=["leading_and_trailing_slashes", "no_slashes", "leading_slash",
"trailing_slash"],
+ )
+ def test_endpoint_prefix_is_sanitized_simple(self, requests_mock, prefix):
+ requests_mock.register_uri(
+ "GET", f"{MATCH_URL}/livy/batches/{BATCH_ID}", json={"id":
BATCH_ID}, status_code=200
+ )
+ resp = LivyHook(endpoint_prefix=prefix).get_batch(BATCH_ID)
+ assert isinstance(resp, dict)
+ assert "id" in resp
+
+ def test_endpoint_prefix_is_sanitized_multiple_path_elements(self,
requests_mock):
+ requests_mock.register_uri(
+ "GET", f"{MATCH_URL}/livy/foo/bar/batches/{BATCH_ID}", json={"id":
BATCH_ID}, status_code=200
+ )
+ resp = LivyHook(endpoint_prefix="/livy/foo/bar/").get_batch(BATCH_ID)
+ assert isinstance(resp, dict)
+ assert "id" in resp
+
class TestLivyAsyncHook:
@pytest.mark.asyncio
@@ -815,3 +886,31 @@ class TestLivyAsyncHook:
def test_check_session_id_failure(self, conn_id):
with pytest.raises(TypeError):
LivyAsyncHook._validate_session_id(None)
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method")
+ async def test_get_batch_state_with_endpoint_prefix(self, mock_run_method):
+ mock_run_method.return_value = {"status": "success", "response":
{"state": BatchState.RUNNING}}
+ hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID,
endpoint_prefix="/livy")
+ state = await hook.get_batch_state(BATCH_ID)
+ assert state == {
+ "batch_state": BatchState.RUNNING,
+ "response": "successfully fetched the batch state.",
+ "status": "success",
+ }
+ mock_run_method.assert_called_once_with(
+ endpoint=f"/livy/batches/{BATCH_ID}/state",
+ )
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method")
+ async def test_get_batch_logs_with_endpoint_prefix(self, mock_run_method):
+ mock_run_method.return_value = {"status": "success", "response": {}}
+ hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID,
endpoint_prefix="/livy")
+ state = await hook.get_batch_logs(BATCH_ID, 0, 100)
+ assert state["status"] == "success"
+
+ mock_run_method.assert_called_once_with(
+ endpoint=f"/livy/batches/{BATCH_ID}/log",
+ data={"from": 0, "size": 100},
+ )