johnbrandborg commented on code in PR #33005:
URL: https://github.com/apache/airflow/pull/33005#discussion_r1285798401


##########
airflow/providers/databricks/hooks/databricks_base.py:
##########
@@ -210,6 +213,100 @@ def _a_get_retry_object(self) -> AsyncRetrying:
         """
         return AsyncRetrying(**self.retry_args)
 
+    def _get_sp_token(self) -> str:
+        """Function to get Service Principal token."""
+        if self.sp_token and self._is_sp_token_valid(self.sp_token):
+            return self.sp_token["token"]
+
+        self.log.info("Existing Service Principal token is expired, or going 
to expire soon. Refreshing...")
+        try:
+            for attempt in self._get_retry_object():
+                with attempt:
+                    resp = requests.post(
+                        
OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host),
+                        auth=HTTPBasicAuth(self.databricks_conn.login, 
self.databricks_conn.password),
+                        data="grant_type=client_credentials&scope=all-apis",
+                        headers={
+                            **self.user_agent_header,
+                            "Content-Type": 
"application/x-www-form-urlencoded",
+                        },
+                        timeout=self.token_timeout_seconds,
+                    )
+
+                    resp.raise_for_status()
+                    jsn = resp.json()
+                    if (
+                        "access_token" not in jsn
+                        or jsn.get("token_type") != "Bearer"
+                        or "expires_in" not in jsn
+                    ):
+                        raise AirflowException(
+                            f"Can't get necessary data from Service Principal 
token: {jsn}"
+                        )
+
+                    token = jsn["access_token"]
+                    self.sp_token = {"token": token, "expires_on": 
int(time.time() + jsn["expires_in"])}
+                    break
+        except RetryError:
+            raise AirflowException(f"API requests to Databricks failed 
{self.retry_limit} times. Giving up.")
+        except requests_exceptions.HTTPError as e:
+            raise AirflowException(f"Response: {e.response.content}, Status 
Code: {e.response.status_code}")
+
+        return token
+
+    async def _a_get_sp_token(self) -> str:
+        """Async version of `_get_sp_token()`."""
+        if self.sp_token and self._is_sp_token_valid(self.sp_token):
+            return self.sp_token["token"]
+
+        self.log.info("Existing Service Principal token is expired, or going 
to expire soon. Refreshing...")
+        try:
+            async for attempt in self._a_get_retry_object():
+                with attempt:
+                    async with self._session.post(
+                        
OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host),
+                        auth=HTTPBasicAuth(self.databricks_conn.login, 
self.databricks_conn.password),
+                        data="grant_type=client_credentials&scope=all-apis",
+                        headers={
+                            **self.user_agent_header,
+                            "Content-Type": 
"application/x-www-form-urlencoded",
+                        },
+                        timeout=self.token_timeout_seconds,
+                    ) as resp:
+                        resp.raise_for_status()
+                        jsn = await resp.json()
+                    if (
+                        "access_token" not in jsn
+                        or jsn.get("token_type") != "Bearer"
+                        or "expires_in" not in jsn
+                    ):
+                        raise AirflowException(
+                            f"Can't get necessary data from Service Principal 
token: {jsn}"
+                        )
+
+                    token = jsn["access_token"]
+                    self.sp_token = {"token": token, "expires_on": 
int(time.time() + jsn["expires_in"])}
+                    break
+        except RetryError:
+            raise AirflowException(f"API requests to Databricks failed 
{self.retry_limit} times. Giving up.")
+        except requests_exceptions.HTTPError as e:
+            raise AirflowException(f"Response: {e.response.content}, Status 
Code: {e.response.status_code}")
+
+        return token
+
+    @staticmethod
+    def _is_sp_token_valid(sp_token: dict) -> bool:
+        """
+        Utility function to check Service Principal token hasn't expired yet.
+
+        :param aad_token: dict with properties of AAD token
+        :return: true if token is valid, false otherwise
+        """
+        now = int(time.time())
+        if sp_token["expires_on"] > (now + TOKEN_REFRESH_LEAD_TIME):
+            return True
+        return False

Review Comment:
   Return statement is cleaned up now, and I have moved the token validation in 
the get token methods into the `is_oauth_token_valid` static method.   So now 
it not only checks if the token hasn't expired, but also validates the token 
dict.  Tests have been added to check it will catch the exceptions.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to