chinwobble commented on issue #18999:
URL: https://github.com/apache/airflow/issues/18999#issuecomment-953359692
@eskarimov I have implemented a prototype like this:
I'm sure many improvements could be made but this should work.
```python
# pylint: disable=abstract-method
class DatabricksHookAsync(DatabricksHook):
"""Async version of the databricks hook"""
async def get_run_state_async(
self, run_id: str, session: ClientSession
) -> RunState:
json = {"run_id": run_id}
response = await self._do_api_call_async(GET_RUN_ENDPOINT, json,
session)
state = response["state"]
life_cycle_state = state["life_cycle_state"]
# result_state may not be in the state if not terminal
result_state = state.get("result_state", None)
state_message = state["state_message"]
return RunState(life_cycle_state, result_state, state_message)
async def _do_api_call_async(self, endpoint_info, json, session:
ClientSession):
"""
Utility function to perform an API call with retries
:param endpoint_info: Tuple of method and endpoint
:type endpoint_info: tuple[string, string]
:param json: Parameters for this API call.
:type json: dict
:return: If the api call returns a OK status code,
this function returns the response in JSON. Otherwise,
we throw an AirflowException.
:rtype: dict
"""
method, endpoint = endpoint_info
self.databricks_conn = self.get_connection(self.databricks_conn_id)
if "token" in self.databricks_conn.extra_dejson:
self.log.info("Using token auth. ")
auth = {
"Authorization": "Bearer " +
self.databricks_conn.extra_dejson["token"]
}
if "host" in self.databricks_conn.extra_dejson:
host =
self._parse_host(self.databricks_conn.extra_dejson["host"])
else:
host = self.databricks_conn.host
else:
raise AirflowException("DatabricksHookAsync only supports token
Auth")
url = f"https://{self._parse_host(host)}/{endpoint}" # type: ignore
if method == "GET":
request_func = session.get
elif method == "POST":
request_func = session.post
elif method == "PATCH":
request_func = session.patch
else:
raise AirflowException("Unexpected HTTP Method: " + method)
attempt_num = 1
while True:
try:
response = await request_func(
url,
json=json if method in ("POST", "PATCH") else None,
params=json if method == "GET" else None,
headers=auth,
timeout=self.timeout_seconds,
)
response.raise_for_status()
return await response.json()
except ClientResponseError as err:
if err.status < 500:
# In this case, the user probably made a mistake.
# Don't retry.
# pylint: disable=raise-missing-from
raise AirflowException(
f"Response: {err.message}, Status Code: {err.status}"
)
if attempt_num == self.retry_limit:
raise AirflowException(
(
"API requests to Databricks failed {} times. " +
"Giving up."
).format(self.retry_limit)
)
attempt_num += 1
await asyncio.sleep(self.retry_delay)
class DatabricksJobTrigger(BaseTrigger):
"""A trigger that checks every 15 seconds whether the databricks job is
finished"""
def __init__(self, run_id: str, databricks_conn_id):
super().__init__()
self.run_id = run_id
self.databricks_conn_id = databricks_conn_id
def serialize(self) -> typing.Tuple[str, typing.Dict[str, typing.Any]]:
return (
"operators.submit_to_databricks_operator.DatabricksJobTrigger",
{
"run_id": self.run_id,
"databricks_conn_id": self.databricks_conn_id,
},
)
async def run(self):
hook = DatabricksHookAsync(self.databricks_conn_id)
async with aiohttp.ClientSession() as session:
while True:
run_state = await hook.get_run_state_async(self.run_id,
session)
if run_state.is_terminal:
if run_state.is_successful:
self.log.info("Run id: %s completed successfully.",
self.run_id)
else:
self.log.info("Run id: %s completed and failed.",
self.run_id)
yield TriggerEvent((self.run_id, run_state.result_state))
await asyncio.sleep(15)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]