This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ecb9a9ea78 Add retry param in databrics async operator (#30744)
ecb9a9ea78 is described below
commit ecb9a9ea78203bd1ce2f2d645d554409651ba8c1
Author: Pankaj Singh <[email protected]>
AuthorDate: Mon Apr 24 03:05:49 2023 +0530
Add retry param in databrics async operator (#30744)
* Add retry param in databrics async operator
* Apply review suggestions
---
.../providers/databricks/operators/databricks.py | 8 +++++
.../providers/databricks/triggers/databricks.py | 41 +++++++++++++++++++---
.../databricks/triggers/test_databricks.py | 10 ++++--
3 files changed, 51 insertions(+), 8 deletions(-)
diff --git a/airflow/providers/databricks/operators/databricks.py
b/airflow/providers/databricks/operators/databricks.py
index 61384c8015..006da0edae 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -36,6 +36,7 @@ if TYPE_CHECKING:
DEFER_METHOD_NAME = "execute_complete"
XCOM_RUN_ID_KEY = "run_id"
+XCOM_JOB_ID_KEY = "job_id"
XCOM_RUN_PAGE_URL_KEY = "run_page_url"
@@ -104,6 +105,9 @@ def
_handle_deferrable_databricks_operator_execution(operator, hook, log, contex
:param operator: Databricks async operator being handled
:param context: Airflow context
"""
+ job_id = hook.get_job_id(operator.run_id)
+ if operator.do_xcom_push and context is not None:
+ context["ti"].xcom_push(key=XCOM_JOB_ID_KEY, value=job_id)
if operator.do_xcom_push and context is not None:
context["ti"].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id)
log.info("Run submitted with run_id: %s", operator.run_id)
@@ -119,6 +123,10 @@ def
_handle_deferrable_databricks_operator_execution(operator, hook, log, contex
run_id=operator.run_id,
databricks_conn_id=operator.databricks_conn_id,
polling_period_seconds=operator.polling_period_seconds,
+ retry_limit=operator.databricks_retry_limit,
+ retry_delay=operator.databricks_retry_delay,
+ retry_args=operator.databricks_retry_args,
+ run_page_url=run_page_url,
),
method_name=DEFER_METHOD_NAME,
)
diff --git a/airflow/providers/databricks/triggers/databricks.py
b/airflow/providers/databricks/triggers/databricks.py
index cd2421c376..e5e56cc0ff 100644
--- a/airflow/providers/databricks/triggers/databricks.py
+++ b/airflow/providers/databricks/triggers/databricks.py
@@ -32,14 +32,36 @@ class DatabricksExecutionTrigger(BaseTrigger):
:param databricks_conn_id: Reference to the :ref:`Databricks connection
<howto/connection:databricks>`.
:param polling_period_seconds: Controls the rate of the poll for the
result of this run.
By default, the trigger will poll every 30 seconds.
+ :param retry_limit: The number of times to retry the connection in case of
service outages.
+ :param retry_delay: The number of seconds to wait between retries.
+ :param retry_args: An optional dictionary with arguments passed to
``tenacity.Retrying`` class.
+ :param run_page_url: The run page url.
"""
- def __init__(self, run_id: int, databricks_conn_id: str,
polling_period_seconds: int = 30) -> None:
+ def __init__(
+ self,
+ run_id: int,
+ databricks_conn_id: str,
+ polling_period_seconds: int = 30,
+ retry_limit: int = 3,
+ retry_delay: int = 10,
+ retry_args: dict[Any, Any] | None = None,
+ run_page_url: str | None = None,
+ ) -> None:
super().__init__()
self.run_id = run_id
self.databricks_conn_id = databricks_conn_id
self.polling_period_seconds = polling_period_seconds
- self.hook = DatabricksHook(databricks_conn_id)
+ self.retry_limit = retry_limit
+ self.retry_delay = retry_delay
+ self.retry_args = retry_args
+ self.run_page_url = run_page_url
+ self.hook = DatabricksHook(
+ databricks_conn_id,
+ retry_limit=self.retry_limit,
+ retry_delay=self.retry_delay,
+ retry_args=retry_args,
+ )
def serialize(self) -> tuple[str, dict[str, Any]]:
return (
@@ -48,22 +70,31 @@ class DatabricksExecutionTrigger(BaseTrigger):
"run_id": self.run_id,
"databricks_conn_id": self.databricks_conn_id,
"polling_period_seconds": self.polling_period_seconds,
+ "retry_limit": self.retry_limit,
+ "retry_delay": self.retry_delay,
+ "retry_args": self.retry_args,
+ "run_page_url": self.run_page_url,
},
)
async def run(self):
async with self.hook:
- run_page_url = await self.hook.a_get_run_page_url(self.run_id)
while True:
run_state = await self.hook.a_get_run_state(self.run_id)
if run_state.is_terminal:
yield TriggerEvent(
{
"run_id": self.run_id,
+ "run_page_url": self.run_page_url,
"run_state": run_state.to_json(),
- "run_page_url": run_page_url,
}
)
- break
+ return
else:
+ self.log.info(
+ "run-id %s in run state %s. sleeping for %s seconds",
+ self.run_id,
+ run_state,
+ self.polling_period_seconds,
+ )
await asyncio.sleep(self.polling_period_seconds)
diff --git a/tests/providers/databricks/triggers/test_databricks.py
b/tests/providers/databricks/triggers/test_databricks.py
index 4e5da213f5..675995beb9 100644
--- a/tests/providers/databricks/triggers/test_databricks.py
+++ b/tests/providers/databricks/triggers/test_databricks.py
@@ -84,6 +84,7 @@ class TestDatabricksExecutionTrigger:
run_id=RUN_ID,
databricks_conn_id=DEFAULT_CONN_ID,
polling_period_seconds=POLLING_INTERVAL_SECONDS,
+ run_page_url=RUN_PAGE_URL,
)
def test_serialize(self):
@@ -93,6 +94,10 @@ class TestDatabricksExecutionTrigger:
"run_id": RUN_ID,
"databricks_conn_id": DEFAULT_CONN_ID,
"polling_period_seconds": POLLING_INTERVAL_SECONDS,
+ "retry_delay": 10,
+ "retry_limit": 3,
+ "retry_args": None,
+ "run_page_url": RUN_PAGE_URL,
},
)
@@ -121,10 +126,9 @@ class TestDatabricksExecutionTrigger:
@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.triggers.databricks.asyncio.sleep")
-
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url")
@mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state")
- async def test_sleep_between_retries(self, mock_get_run_state,
mock_get_run_page_url, mock_sleep):
- mock_get_run_page_url.return_value = RUN_PAGE_URL
+ async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep):
+
mock_get_run_state.side_effect = [
RunState(
life_cycle_state=LIFE_CYCLE_STATE_PENDING,