johnbrandborg commented on code in PR #33005:
URL: https://github.com/apache/airflow/pull/33005#discussion_r1285800524


##########
airflow/providers/databricks/hooks/databricks_base.py:
##########
@@ -210,6 +213,100 @@ def _a_get_retry_object(self) -> AsyncRetrying:
         """
         return AsyncRetrying(**self.retry_args)
 
+    def _get_sp_token(self) -> str:
+        """Function to get Service Principal token."""
+        if self.sp_token and self._is_sp_token_valid(self.sp_token):
+            return self.sp_token["token"]
+
+        self.log.info("Existing Service Principal token is expired, or going 
to expire soon. Refreshing...")
+        try:
+            for attempt in self._get_retry_object():
+                with attempt:
+                    resp = requests.post(
+                        
OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host),
+                        auth=HTTPBasicAuth(self.databricks_conn.login, 
self.databricks_conn.password),
+                        data="grant_type=client_credentials&scope=all-apis",
+                        headers={
+                            **self.user_agent_header,
+                            "Content-Type": 
"application/x-www-form-urlencoded",
+                        },
+                        timeout=self.token_timeout_seconds,
+                    )
+
+                    resp.raise_for_status()
+                    jsn = resp.json()
+                    if (
+                        "access_token" not in jsn
+                        or jsn.get("token_type") != "Bearer"
+                        or "expires_in" not in jsn
+                    ):
+                        raise AirflowException(
+                            f"Can't get necessary data from Service Principal 
token: {jsn}"
+                        )
+
+                    token = jsn["access_token"]
+                    self.sp_token = {"token": token, "expires_on": 
int(time.time() + jsn["expires_in"])}
+                    break
+        except RetryError:
+            raise AirflowException(f"API requests to Databricks failed 
{self.retry_limit} times. Giving up.")
+        except requests_exceptions.HTTPError as e:
+            raise AirflowException(f"Response: {e.response.content}, Status 
Code: {e.response.status_code}")
+
+        return token
+
+    async def _a_get_sp_token(self) -> str:
+        """Async version of `_get_sp_token()`."""
+        if self.sp_token and self._is_sp_token_valid(self.sp_token):
+            return self.sp_token["token"]
+
+        self.log.info("Existing Service Principal token is expired, or going 
to expire soon. Refreshing...")
+        try:
+            async for attempt in self._a_get_retry_object():
+                with attempt:
+                    async with self._session.post(
+                        
OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host),
+                        auth=HTTPBasicAuth(self.databricks_conn.login, 
self.databricks_conn.password),
+                        data="grant_type=client_credentials&scope=all-apis",
+                        headers={
+                            **self.user_agent_header,
+                            "Content-Type": 
"application/x-www-form-urlencoded",
+                        },
+                        timeout=self.token_timeout_seconds,
+                    ) as resp:
+                        resp.raise_for_status()
+                        jsn = await resp.json()
+                    if (
+                        "access_token" not in jsn
+                        or jsn.get("token_type") != "Bearer"
+                        or "expires_in" not in jsn
+                    ):
+                        raise AirflowException(
+                            f"Can't get necessary data from Service Principal 
token: {jsn}"
+                        )
+
+                    token = jsn["access_token"]
+                    self.sp_token = {"token": token, "expires_on": 
int(time.time() + jsn["expires_in"])}
+                    break
+        except RetryError:
+            raise AirflowException(f"API requests to Databricks failed 
{self.retry_limit} times. Giving up.")
+        except requests_exceptions.HTTPError as e:
+            raise AirflowException(f"Response: {e.response.content}, Status 
Code: {e.response.status_code}")
+
+        return token
+
+    @staticmethod
+    def _is_sp_token_valid(sp_token: dict) -> bool:
+        """
+        Utility function to check Service Principal token hasn't expired yet.
+
+        :param aad_token: dict with properties of AAD token

Review Comment:
   Done.  I also remove `sp_token`, and `aad_tokens`.   OAuth tokens are placed 
into `oauth_tokens`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to