This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ecb9a9ea78 Add retry param in databrics async operator (#30744)
ecb9a9ea78 is described below

commit ecb9a9ea78203bd1ce2f2d645d554409651ba8c1
Author: Pankaj Singh <[email protected]>
AuthorDate: Mon Apr 24 03:05:49 2023 +0530

    Add retry param in databrics async operator (#30744)
    
    * Add retry param in databrics async operator
    
    * Apply review suggestions
---
 .../providers/databricks/operators/databricks.py   |  8 +++++
 .../providers/databricks/triggers/databricks.py    | 41 +++++++++++++++++++---
 .../databricks/triggers/test_databricks.py         | 10 ++++--
 3 files changed, 51 insertions(+), 8 deletions(-)

diff --git a/airflow/providers/databricks/operators/databricks.py 
b/airflow/providers/databricks/operators/databricks.py
index 61384c8015..006da0edae 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -36,6 +36,7 @@ if TYPE_CHECKING:
 
 DEFER_METHOD_NAME = "execute_complete"
 XCOM_RUN_ID_KEY = "run_id"
+XCOM_JOB_ID_KEY = "job_id"
 XCOM_RUN_PAGE_URL_KEY = "run_page_url"
 
 
@@ -104,6 +105,9 @@ def 
_handle_deferrable_databricks_operator_execution(operator, hook, log, contex
     :param operator: Databricks async operator being handled
     :param context: Airflow context
     """
+    job_id = hook.get_job_id(operator.run_id)
+    if operator.do_xcom_push and context is not None:
+        context["ti"].xcom_push(key=XCOM_JOB_ID_KEY, value=job_id)
     if operator.do_xcom_push and context is not None:
         context["ti"].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id)
     log.info("Run submitted with run_id: %s", operator.run_id)
@@ -119,6 +123,10 @@ def 
_handle_deferrable_databricks_operator_execution(operator, hook, log, contex
                 run_id=operator.run_id,
                 databricks_conn_id=operator.databricks_conn_id,
                 polling_period_seconds=operator.polling_period_seconds,
+                retry_limit=operator.databricks_retry_limit,
+                retry_delay=operator.databricks_retry_delay,
+                retry_args=operator.databricks_retry_args,
+                run_page_url=run_page_url,
             ),
             method_name=DEFER_METHOD_NAME,
         )
diff --git a/airflow/providers/databricks/triggers/databricks.py 
b/airflow/providers/databricks/triggers/databricks.py
index cd2421c376..e5e56cc0ff 100644
--- a/airflow/providers/databricks/triggers/databricks.py
+++ b/airflow/providers/databricks/triggers/databricks.py
@@ -32,14 +32,36 @@ class DatabricksExecutionTrigger(BaseTrigger):
     :param databricks_conn_id: Reference to the :ref:`Databricks connection 
<howto/connection:databricks>`.
     :param polling_period_seconds: Controls the rate of the poll for the 
result of this run.
         By default, the trigger will poll every 30 seconds.
+    :param retry_limit: The number of times to retry the connection in case of 
service outages.
+    :param retry_delay: The number of seconds to wait between retries.
+    :param retry_args: An optional dictionary with arguments passed to 
``tenacity.Retrying`` class.
+    :param run_page_url: The run page url.
     """
 
-    def __init__(self, run_id: int, databricks_conn_id: str, 
polling_period_seconds: int = 30) -> None:
+    def __init__(
+        self,
+        run_id: int,
+        databricks_conn_id: str,
+        polling_period_seconds: int = 30,
+        retry_limit: int = 3,
+        retry_delay: int = 10,
+        retry_args: dict[Any, Any] | None = None,
+        run_page_url: str | None = None,
+    ) -> None:
         super().__init__()
         self.run_id = run_id
         self.databricks_conn_id = databricks_conn_id
         self.polling_period_seconds = polling_period_seconds
-        self.hook = DatabricksHook(databricks_conn_id)
+        self.retry_limit = retry_limit
+        self.retry_delay = retry_delay
+        self.retry_args = retry_args
+        self.run_page_url = run_page_url
+        self.hook = DatabricksHook(
+            databricks_conn_id,
+            retry_limit=self.retry_limit,
+            retry_delay=self.retry_delay,
+            retry_args=retry_args,
+        )
 
     def serialize(self) -> tuple[str, dict[str, Any]]:
         return (
@@ -48,22 +70,31 @@ class DatabricksExecutionTrigger(BaseTrigger):
                 "run_id": self.run_id,
                 "databricks_conn_id": self.databricks_conn_id,
                 "polling_period_seconds": self.polling_period_seconds,
+                "retry_limit": self.retry_limit,
+                "retry_delay": self.retry_delay,
+                "retry_args": self.retry_args,
+                "run_page_url": self.run_page_url,
             },
         )
 
     async def run(self):
         async with self.hook:
-            run_page_url = await self.hook.a_get_run_page_url(self.run_id)
             while True:
                 run_state = await self.hook.a_get_run_state(self.run_id)
                 if run_state.is_terminal:
                     yield TriggerEvent(
                         {
                             "run_id": self.run_id,
+                            "run_page_url": self.run_page_url,
                             "run_state": run_state.to_json(),
-                            "run_page_url": run_page_url,
                         }
                     )
-                    break
+                    return
                 else:
+                    self.log.info(
+                        "run-id %s in run state %s. sleeping for %s seconds",
+                        self.run_id,
+                        run_state,
+                        self.polling_period_seconds,
+                    )
                     await asyncio.sleep(self.polling_period_seconds)
diff --git a/tests/providers/databricks/triggers/test_databricks.py 
b/tests/providers/databricks/triggers/test_databricks.py
index 4e5da213f5..675995beb9 100644
--- a/tests/providers/databricks/triggers/test_databricks.py
+++ b/tests/providers/databricks/triggers/test_databricks.py
@@ -84,6 +84,7 @@ class TestDatabricksExecutionTrigger:
             run_id=RUN_ID,
             databricks_conn_id=DEFAULT_CONN_ID,
             polling_period_seconds=POLLING_INTERVAL_SECONDS,
+            run_page_url=RUN_PAGE_URL,
         )
 
     def test_serialize(self):
@@ -93,6 +94,10 @@ class TestDatabricksExecutionTrigger:
                 "run_id": RUN_ID,
                 "databricks_conn_id": DEFAULT_CONN_ID,
                 "polling_period_seconds": POLLING_INTERVAL_SECONDS,
+                "retry_delay": 10,
+                "retry_limit": 3,
+                "retry_args": None,
+                "run_page_url": RUN_PAGE_URL,
             },
         )
 
@@ -121,10 +126,9 @@ class TestDatabricksExecutionTrigger:
 
     @pytest.mark.asyncio
     
@mock.patch("airflow.providers.databricks.triggers.databricks.asyncio.sleep")
-    
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url")
     
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state")
-    async def test_sleep_between_retries(self, mock_get_run_state, 
mock_get_run_page_url, mock_sleep):
-        mock_get_run_page_url.return_value = RUN_PAGE_URL
+    async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep):
+
         mock_get_run_state.side_effect = [
             RunState(
                 life_cycle_state=LIFE_CYCLE_STATE_PENDING,

Reply via email to