This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8c9e0b2a8ec Add `endpoint_prefix` to `LivyHook` (#45811)
8c9e0b2a8ec is described below

commit 8c9e0b2a8ec0e6f2842883898370741e71c0e802
Author: gpathak128 <[email protected]>
AuthorDate: Fri Jan 31 20:23:47 2025 -0500

    Add `endpoint_prefix` to `LivyHook` (#45811)
    
    * revert manual changes to the Changelog file
    
    * adding tests for livy db hook
    
    * adding tests for async livy hook
    
    * reformat comment, and fix endpoint_prefix type for livy trigger
    
    * formatting changes
    
    * remove extra default args from expectation
    
    * adding mock patch for mock request
    
    ---------
    
    Co-authored-by: Giridhar Pathak <[email protected]>
---
 .../airflow/providers/apache/livy/hooks/livy.py    | 36 ++++++--
 .../providers/apache/livy/operators/livy.py        |  3 +
 .../airflow/providers/apache/livy/sensors/livy.py  |  3 +
 .../airflow/providers/apache/livy/triggers/livy.py |  3 +
 .../provider_tests/apache/livy/hooks/test_livy.py  | 99 ++++++++++++++++++++++
 5 files changed, 137 insertions(+), 7 deletions(-)

diff --git 
a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py 
b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
index 934985befbc..f185ccdfbe3 100644
--- a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
+++ b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py
@@ -52,6 +52,11 @@ class BatchState(Enum):
     SUCCESS = "success"
 
 
+def sanitize_endpoint_prefix(endpoint_prefix: str | None) -> str:
+    """Ensure that the endpoint prefix is prefixed with a slash."""
+    return f"/{endpoint_prefix.strip('/')}" if endpoint_prefix else ""
+
+
 class LivyHook(HttpHook):
     """
     Hook for Apache Livy through the REST API.
@@ -86,12 +91,14 @@ class LivyHook(HttpHook):
         extra_options: dict[str, Any] | None = None,
         extra_headers: dict[str, Any] | None = None,
         auth_type: Any | None = None,
+        endpoint_prefix: str | None = None,
     ) -> None:
         super().__init__()
         self.method = "POST"
         self.http_conn_id = livy_conn_id
         self.extra_headers = extra_headers or {}
         self.extra_options = extra_options or {}
+        self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix)
         if auth_type:
             self.auth_type = auth_type
 
@@ -163,7 +170,10 @@ class LivyHook(HttpHook):
         self.log.info("Submitting job %s to %s", batch_submit_body, 
self.base_url)
 
         response = self.run_method(
-            method="POST", endpoint="/batches", data=batch_submit_body, 
headers=self.extra_headers
+            method="POST",
+            endpoint=f"{self.endpoint_prefix}/batches",
+            data=batch_submit_body,
+            headers=self.extra_headers,
         )
         self.log.debug("Got response: %s", response.text)
 
@@ -192,7 +202,9 @@ class LivyHook(HttpHook):
         self._validate_session_id(session_id)
 
         self.log.debug("Fetching info for batch session %s", session_id)
-        response = self.run_method(endpoint=f"/batches/{session_id}", 
headers=self.extra_headers)
+        response = self.run_method(
+            endpoint=f"{self.endpoint_prefix}/batches/{session_id}", 
headers=self.extra_headers
+        )
 
         try:
             response.raise_for_status()
@@ -217,7 +229,9 @@ class LivyHook(HttpHook):
 
         self.log.debug("Fetching info for batch session %s", session_id)
         response = self.run_method(
-            endpoint=f"/batches/{session_id}/state", retry_args=retry_args, 
headers=self.extra_headers
+            endpoint=f"{self.endpoint_prefix}/batches/{session_id}/state",
+            retry_args=retry_args,
+            headers=self.extra_headers,
         )
 
         try:
@@ -244,7 +258,9 @@ class LivyHook(HttpHook):
 
         self.log.info("Deleting batch session %s", session_id)
         response = self.run_method(
-            method="DELETE", endpoint=f"/batches/{session_id}", 
headers=self.extra_headers
+            method="DELETE",
+            endpoint=f"{self.endpoint_prefix}/batches/{session_id}",
+            headers=self.extra_headers,
         )
 
         try:
@@ -270,7 +286,9 @@ class LivyHook(HttpHook):
         self._validate_session_id(session_id)
         log_params = {"from": log_start_position, "size": log_batch_size}
         response = self.run_method(
-            endpoint=f"/batches/{session_id}/log", data=log_params, 
headers=self.extra_headers
+            endpoint=f"{self.endpoint_prefix}/batches/{session_id}/log",
+            data=log_params,
+            headers=self.extra_headers,
         )
         try:
             response.raise_for_status()
@@ -490,12 +508,14 @@ class LivyAsyncHook(HttpAsyncHook):
         livy_conn_id: str = default_conn_name,
         extra_options: dict[str, Any] | None = None,
         extra_headers: dict[str, Any] | None = None,
+        endpoint_prefix: str | None = None,
     ) -> None:
         super().__init__()
         self.method = "POST"
         self.http_conn_id = livy_conn_id
         self.extra_headers = extra_headers or {}
         self.extra_options = extra_options or {}
+        self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix)
 
     async def _do_api_call_async(
         self,
@@ -624,7 +644,7 @@ class LivyAsyncHook(HttpAsyncHook):
         """
         self._validate_session_id(session_id)
         self.log.info("Fetching info for batch session %s", session_id)
-        result = await self.run_method(endpoint=f"/batches/{session_id}/state")
+        result = await 
self.run_method(endpoint=f"{self.endpoint_prefix}/batches/{session_id}/state")
         if result["status"] == "error":
             self.log.info(result)
             return {"batch_state": "error", "response": result, "status": 
"error"}
@@ -659,7 +679,9 @@ class LivyAsyncHook(HttpAsyncHook):
         """
         self._validate_session_id(session_id)
         log_params = {"from": log_start_position, "size": log_batch_size}
-        result = await self.run_method(endpoint=f"/batches/{session_id}/log", 
data=log_params)
+        result = await self.run_method(
+            endpoint=f"{self.endpoint_prefix}/batches/{session_id}/log", 
data=log_params
+        )
         if result["status"] == "error":
             self.log.info(result)
             return {"response": result["response"], "status": "error"}
diff --git 
a/providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py 
b/providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py
index b74e52d5e61..746ea55ccee 100644
--- a/providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py
+++ b/providers/apache/livy/src/airflow/providers/apache/livy/operators/livy.py
@@ -88,6 +88,7 @@ class LivyOperator(BaseOperator):
         proxy_user: str | None = None,
         livy_conn_id: str = "livy_default",
         livy_conn_auth_type: Any | None = None,
+        livy_endpoint_prefix: str | None = None,
         polling_interval: int = 0,
         extra_options: dict[str, Any] | None = None,
         extra_headers: dict[str, Any] | None = None,
@@ -119,6 +120,7 @@ class LivyOperator(BaseOperator):
         self.spark_params = spark_params
         self._livy_conn_id = livy_conn_id
         self._livy_conn_auth_type = livy_conn_auth_type
+        self._livy_endpoint_prefix = livy_endpoint_prefix
         self._polling_interval = polling_interval
         self._extra_options = extra_options or {}
         self._extra_headers = extra_headers or {}
@@ -139,6 +141,7 @@ class LivyOperator(BaseOperator):
             extra_headers=self._extra_headers,
             extra_options=self._extra_options,
             auth_type=self._livy_conn_auth_type,
+            endpoint_prefix=self._livy_endpoint_prefix,
         )
 
     def execute(self, context: Context) -> Any:
diff --git 
a/providers/apache/livy/src/airflow/providers/apache/livy/sensors/livy.py 
b/providers/apache/livy/src/airflow/providers/apache/livy/sensors/livy.py
index 3c1a50255ad..d0b011e2518 100644
--- a/providers/apache/livy/src/airflow/providers/apache/livy/sensors/livy.py
+++ b/providers/apache/livy/src/airflow/providers/apache/livy/sensors/livy.py
@@ -46,6 +46,7 @@ class LivySensor(BaseSensorOperator):
         livy_conn_id: str = "livy_default",
         livy_conn_auth_type: Any | None = None,
         extra_options: dict[str, Any] | None = None,
+        endpoint_prefix: str | None = None,
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
@@ -54,6 +55,7 @@ class LivySensor(BaseSensorOperator):
         self._livy_conn_auth_type = livy_conn_auth_type
         self._livy_hook: LivyHook | None = None
         self._extra_options = extra_options or {}
+        self._endpoint_prefix = endpoint_prefix
 
     def get_hook(self) -> LivyHook:
         """
@@ -66,6 +68,7 @@ class LivySensor(BaseSensorOperator):
                 livy_conn_id=self._livy_conn_id,
                 extra_options=self._extra_options,
                 auth_type=self._livy_conn_auth_type,
+                endpoint_prefix=self._endpoint_prefix,
             )
         return self._livy_hook
 
diff --git 
a/providers/apache/livy/src/airflow/providers/apache/livy/triggers/livy.py 
b/providers/apache/livy/src/airflow/providers/apache/livy/triggers/livy.py
index 2e40b26113e..9c47706b737 100644
--- a/providers/apache/livy/src/airflow/providers/apache/livy/triggers/livy.py
+++ b/providers/apache/livy/src/airflow/providers/apache/livy/triggers/livy.py
@@ -57,6 +57,7 @@ class LivyTrigger(BaseTrigger):
         extra_headers: dict[str, Any] | None = None,
         livy_hook_async: LivyAsyncHook | None = None,
         execution_timeout: timedelta | None = None,
+        endpoint_prefix: str | None = None,
     ):
         super().__init__()
         self._batch_id = batch_id
@@ -67,6 +68,7 @@ class LivyTrigger(BaseTrigger):
         self._extra_headers = extra_headers
         self._livy_hook_async = livy_hook_async
         self._execution_timeout = execution_timeout
+        self._endpoint_prefix = endpoint_prefix
 
     def serialize(self) -> tuple[str, dict[str, Any]]:
         """Serialize LivyTrigger arguments and classpath."""
@@ -170,5 +172,6 @@ class LivyTrigger(BaseTrigger):
                 livy_conn_id=self._livy_conn_id,
                 extra_headers=self._extra_headers,
                 extra_options=self._extra_options,
+                endpoint_prefix=self._endpoint_prefix,
             )
         return self._livy_hook_async
diff --git 
a/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py 
b/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py
index 6cb6dd91160..2a1101ab44e 100644
--- a/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py
+++ b/providers/apache/livy/tests/provider_tests/apache/livy/hooks/test_livy.py
@@ -428,6 +428,77 @@ class TestLivyDbHook:
 
         auth_type.assert_called_once_with("login", "secret")
 
+    @patch("airflow.providers.apache.livy.hooks.livy.LivyHook.run_method")
+    def test_post_batch_with_endpoint_prefix(self, mock_request):
+        mock_request.return_value.status_code = 201
+        mock_request.return_value.json.return_value = {
+            "id": BATCH_ID,
+            "state": BatchState.STARTING.value,
+            "log": [],
+        }
+
+        resp = LivyHook(endpoint_prefix="/livy").post_batch(file="sparkapp")
+
+        mock_request.assert_called_once_with(
+            method="POST", endpoint="/livy/batches", data=json.dumps({"file": 
"sparkapp"}), headers={}
+        )
+
+        request_args = mock_request.call_args.kwargs
+        assert "data" in request_args
+        assert isinstance(request_args["data"], str)
+
+        assert isinstance(resp, int)
+        assert resp == BATCH_ID
+
+    def test_get_batch_with_endpoint_prefix(self, requests_mock):
+        requests_mock.register_uri(
+            "GET", f"{MATCH_URL}/livy/batches/{BATCH_ID}", json={"id": 
BATCH_ID}, status_code=200
+        )
+        resp = LivyHook(endpoint_prefix="/livy").get_batch(BATCH_ID)
+        assert isinstance(resp, dict)
+        assert "id" in resp
+
+    def test_get_batch_state_with_endpoint_prefix(self, requests_mock):
+        running = BatchState.RUNNING
+
+        requests_mock.register_uri(
+            "GET",
+            f"{MATCH_URL}/livy/batches/{BATCH_ID}/state",
+            json={"id": BATCH_ID, "state": running.value},
+            status_code=200,
+        )
+
+        state = LivyHook(endpoint_prefix="/livy").get_batch_state(BATCH_ID)
+        assert isinstance(state, BatchState)
+        assert state == running
+
+    def test_delete_batch_with_endpoint_prefix(self, requests_mock):
+        requests_mock.register_uri(
+            "DELETE", f"{MATCH_URL}/livy/batches/{BATCH_ID}", json={"msg": 
"deleted"}, status_code=200
+        )
+        assert LivyHook(endpoint_prefix="/livy").delete_batch(BATCH_ID) == 
{"msg": "deleted"}
+
+    @pytest.mark.parametrize(
+        "prefix",
+        ["/livy/", "livy", "/livy", "livy/"],
+        ids=["leading_and_trailing_slashes", "no_slashes", "leading_slash", 
"trailing_slash"],
+    )
+    def test_endpoint_prefix_is_sanitized_simple(self, requests_mock, prefix):
+        requests_mock.register_uri(
+            "GET", f"{MATCH_URL}/livy/batches/{BATCH_ID}", json={"id": 
BATCH_ID}, status_code=200
+        )
+        resp = LivyHook(endpoint_prefix=prefix).get_batch(BATCH_ID)
+        assert isinstance(resp, dict)
+        assert "id" in resp
+
+    def test_endpoint_prefix_is_sanitized_multiple_path_elements(self, 
requests_mock):
+        requests_mock.register_uri(
+            "GET", f"{MATCH_URL}/livy/foo/bar/batches/{BATCH_ID}", json={"id": 
BATCH_ID}, status_code=200
+        )
+        resp = LivyHook(endpoint_prefix="/livy/foo/bar/").get_batch(BATCH_ID)
+        assert isinstance(resp, dict)
+        assert "id" in resp
+
 
 class TestLivyAsyncHook:
     @pytest.mark.asyncio
@@ -815,3 +886,31 @@ class TestLivyAsyncHook:
     def test_check_session_id_failure(self, conn_id):
         with pytest.raises(TypeError):
             LivyAsyncHook._validate_session_id(None)
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method")
+    async def test_get_batch_state_with_endpoint_prefix(self, mock_run_method):
+        mock_run_method.return_value = {"status": "success", "response": 
{"state": BatchState.RUNNING}}
+        hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID, 
endpoint_prefix="/livy")
+        state = await hook.get_batch_state(BATCH_ID)
+        assert state == {
+            "batch_state": BatchState.RUNNING,
+            "response": "successfully fetched the batch state.",
+            "status": "success",
+        }
+        mock_run_method.assert_called_once_with(
+            endpoint=f"/livy/batches/{BATCH_ID}/state",
+        )
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.run_method")
+    async def test_get_batch_logs_with_endpoint_prefix(self, mock_run_method):
+        mock_run_method.return_value = {"status": "success", "response": {}}
+        hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID, 
endpoint_prefix="/livy")
+        state = await hook.get_batch_logs(BATCH_ID, 0, 100)
+        assert state["status"] == "success"
+
+        mock_run_method.assert_called_once_with(
+            endpoint=f"/livy/batches/{BATCH_ID}/log",
+            data={"from": 0, "size": 100},
+        )

Reply via email to