johnbrandborg commented on code in PR #33005:
URL: https://github.com/apache/airflow/pull/33005#discussion_r1285798401
##########
airflow/providers/databricks/hooks/databricks_base.py:
##########
@@ -210,6 +213,100 @@ def _a_get_retry_object(self) -> AsyncRetrying:
"""
return AsyncRetrying(**self.retry_args)
+ def _get_sp_token(self) -> str:
+ """Function to get Service Principal token."""
+ if self.sp_token and self._is_sp_token_valid(self.sp_token):
+ return self.sp_token["token"]
+
+ self.log.info("Existing Service Principal token is expired, or going
to expire soon. Refreshing...")
+ try:
+ for attempt in self._get_retry_object():
+ with attempt:
+ resp = requests.post(
+
OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host),
+ auth=HTTPBasicAuth(self.databricks_conn.login,
self.databricks_conn.password),
+ data="grant_type=client_credentials&scope=all-apis",
+ headers={
+ **self.user_agent_header,
+ "Content-Type":
"application/x-www-form-urlencoded",
+ },
+ timeout=self.token_timeout_seconds,
+ )
+
+ resp.raise_for_status()
+ jsn = resp.json()
+ if (
+ "access_token" not in jsn
+ or jsn.get("token_type") != "Bearer"
+ or "expires_in" not in jsn
+ ):
+ raise AirflowException(
+ f"Can't get necessary data from Service Principal
token: {jsn}"
+ )
+
+ token = jsn["access_token"]
+ self.sp_token = {"token": token, "expires_on":
int(time.time() + jsn["expires_in"])}
+ break
+ except RetryError:
+ raise AirflowException(f"API requests to Databricks failed
{self.retry_limit} times. Giving up.")
+ except requests_exceptions.HTTPError as e:
+ raise AirflowException(f"Response: {e.response.content}, Status
Code: {e.response.status_code}")
+
+ return token
+
+ async def _a_get_sp_token(self) -> str:
+ """Async version of `_get_sp_token()`."""
+ if self.sp_token and self._is_sp_token_valid(self.sp_token):
+ return self.sp_token["token"]
+
+ self.log.info("Existing Service Principal token is expired, or going
to expire soon. Refreshing...")
+ try:
+ async for attempt in self._a_get_retry_object():
+ with attempt:
+ async with self._session.post(
+
OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host),
+ auth=HTTPBasicAuth(self.databricks_conn.login,
self.databricks_conn.password),
+ data="grant_type=client_credentials&scope=all-apis",
+ headers={
+ **self.user_agent_header,
+ "Content-Type":
"application/x-www-form-urlencoded",
+ },
+ timeout=self.token_timeout_seconds,
+ ) as resp:
+ resp.raise_for_status()
+ jsn = await resp.json()
+ if (
+ "access_token" not in jsn
+ or jsn.get("token_type") != "Bearer"
+ or "expires_in" not in jsn
+ ):
+ raise AirflowException(
+ f"Can't get necessary data from Service Principal
token: {jsn}"
+ )
+
+ token = jsn["access_token"]
+ self.sp_token = {"token": token, "expires_on":
int(time.time() + jsn["expires_in"])}
+ break
+ except RetryError:
+ raise AirflowException(f"API requests to Databricks failed
{self.retry_limit} times. Giving up.")
+ except requests_exceptions.HTTPError as e:
+ raise AirflowException(f"Response: {e.response.content}, Status
Code: {e.response.status_code}")
+
+ return token
+
+ @staticmethod
+ def _is_sp_token_valid(sp_token: dict) -> bool:
+ """
+ Utility function to check Service Principal token hasn't expired yet.
+
+ :param aad_token: dict with properties of AAD token
+ :return: true if token is valid, false otherwise
+ """
+ now = int(time.time())
+ if sp_token["expires_on"] > (now + TOKEN_REFRESH_LEAD_TIME):
+ return True
+ return False
Review Comment:
Return statement is cleaned up now, and I have taken the token validation in
the get token methods, and added it to the `is_oauth_token_valid` static
method. so it not only checks if the token hasn't expired, but also validates
the token dict. Tests have been added to check it will catch the exceptions.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]