alexott commented on code in PR #33005:
URL: https://github.com/apache/airflow/pull/33005#discussion_r1285229486


##########
airflow/providers/databricks/hooks/databricks_base.py:
##########
@@ -210,6 +213,100 @@ def _a_get_retry_object(self) -> AsyncRetrying:
         """
         return AsyncRetrying(**self.retry_args)
 
+    def _get_sp_token(self) -> str:
+        """Function to get Service Principal token."""
+        if self.sp_token and self._is_sp_token_valid(self.sp_token):
+            return self.sp_token["token"]
+
+        self.log.info("Existing Service Principal token is expired, or going 
to expire soon. Refreshing...")
+        try:
+            for attempt in self._get_retry_object():
+                with attempt:
+                    resp = requests.post(
+                        
OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host),
+                        auth=HTTPBasicAuth(self.databricks_conn.login, 
self.databricks_conn.password),
+                        data="grant_type=client_credentials&scope=all-apis",
+                        headers={
+                            **self.user_agent_header,
+                            "Content-Type": 
"application/x-www-form-urlencoded",
+                        },
+                        timeout=self.token_timeout_seconds,
+                    )
+
+                    resp.raise_for_status()
+                    jsn = resp.json()
+                    if (
+                        "access_token" not in jsn
+                        or jsn.get("token_type") != "Bearer"
+                        or "expires_in" not in jsn
+                    ):
+                        raise AirflowException(
+                            f"Can't get necessary data from Service Principal 
token: {jsn}"
+                        )
+
+                    token = jsn["access_token"]
+                    self.sp_token = {"token": token, "expires_on": 
int(time.time() + jsn["expires_in"])}
+                    break
+        except RetryError:
+            raise AirflowException(f"API requests to Databricks failed 
{self.retry_limit} times. Giving up.")
+        except requests_exceptions.HTTPError as e:
+            raise AirflowException(f"Response: {e.response.content}, Status 
Code: {e.response.status_code}")
+
+        return token
+
+    async def _a_get_sp_token(self) -> str:
+        """Async version of `_get_sp_token()`."""
+        if self.sp_token and self._is_sp_token_valid(self.sp_token):
+            return self.sp_token["token"]
+
+        self.log.info("Existing Service Principal token is expired, or going 
to expire soon. Refreshing...")
+        try:
+            async for attempt in self._a_get_retry_object():
+                with attempt:
+                    async with self._session.post(
+                        
OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host),
+                        auth=HTTPBasicAuth(self.databricks_conn.login, 
self.databricks_conn.password),
+                        data="grant_type=client_credentials&scope=all-apis",
+                        headers={
+                            **self.user_agent_header,
+                            "Content-Type": 
"application/x-www-form-urlencoded",
+                        },
+                        timeout=self.token_timeout_seconds,
+                    ) as resp:
+                        resp.raise_for_status()
+                        jsn = await resp.json()
+                    if (
+                        "access_token" not in jsn
+                        or jsn.get("token_type") != "Bearer"
+                        or "expires_in" not in jsn
+                    ):
+                        raise AirflowException(
+                            f"Can't get necessary data from Service Principal 
token: {jsn}"
+                        )
+
+                    token = jsn["access_token"]
+                    self.sp_token = {"token": token, "expires_on": 
int(time.time() + jsn["expires_in"])}
+                    break
+        except RetryError:
+            raise AirflowException(f"API requests to Databricks failed 
{self.retry_limit} times. Giving up.")
+        except requests_exceptions.HTTPError as e:
+            raise AirflowException(f"Response: {e.response.content}, Status 
Code: {e.response.status_code}")
+
+        return token
+
+    @staticmethod
+    def _is_sp_token_valid(sp_token: dict) -> bool:

Review Comment:
   I don't see a difference with `_is_aad_token_valid`: 
https://github.com/apache/airflow/blob/main/airflow/providers/databricks/hooks/databricks_base.py#L378C9-L378C28
 - let use the single function for both token types, anyway AAD token is OAuth 
as well, so don't duplicate the code.



##########
airflow/providers/databricks/hooks/databricks_base.py:
##########
@@ -210,6 +213,100 @@ def _a_get_retry_object(self) -> AsyncRetrying:
         """
         return AsyncRetrying(**self.retry_args)
 
+    def _get_sp_token(self) -> str:
+        """Function to get Service Principal token."""
+        if self.sp_token and self._is_sp_token_valid(self.sp_token):
+            return self.sp_token["token"]
+
+        self.log.info("Existing Service Principal token is expired, or going 
to expire soon. Refreshing...")
+        try:
+            for attempt in self._get_retry_object():
+                with attempt:
+                    resp = requests.post(
+                        
OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host),
+                        auth=HTTPBasicAuth(self.databricks_conn.login, 
self.databricks_conn.password),
+                        data="grant_type=client_credentials&scope=all-apis",
+                        headers={
+                            **self.user_agent_header,
+                            "Content-Type": 
"application/x-www-form-urlencoded",
+                        },
+                        timeout=self.token_timeout_seconds,
+                    )
+
+                    resp.raise_for_status()
+                    jsn = resp.json()
+                    if (
+                        "access_token" not in jsn
+                        or jsn.get("token_type") != "Bearer"
+                        or "expires_in" not in jsn
+                    ):
+                        raise AirflowException(
+                            f"Can't get necessary data from Service Principal 
token: {jsn}"
+                        )
+
+                    token = jsn["access_token"]
+                    self.sp_token = {"token": token, "expires_on": 
int(time.time() + jsn["expires_in"])}
+                    break
+        except RetryError:
+            raise AirflowException(f"API requests to Databricks failed 
{self.retry_limit} times. Giving up.")
+        except requests_exceptions.HTTPError as e:
+            raise AirflowException(f"Response: {e.response.content}, Status 
Code: {e.response.status_code}")
+
+        return token
+
+    async def _a_get_sp_token(self) -> str:
+        """Async version of `_get_sp_token()`."""
+        if self.sp_token and self._is_sp_token_valid(self.sp_token):
+            return self.sp_token["token"]
+
+        self.log.info("Existing Service Principal token is expired, or going 
to expire soon. Refreshing...")
+        try:
+            async for attempt in self._a_get_retry_object():
+                with attempt:
+                    async with self._session.post(
+                        
OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host),
+                        auth=HTTPBasicAuth(self.databricks_conn.login, 
self.databricks_conn.password),
+                        data="grant_type=client_credentials&scope=all-apis",
+                        headers={
+                            **self.user_agent_header,
+                            "Content-Type": 
"application/x-www-form-urlencoded",
+                        },
+                        timeout=self.token_timeout_seconds,
+                    ) as resp:
+                        resp.raise_for_status()
+                        jsn = await resp.json()
+                    if (
+                        "access_token" not in jsn
+                        or jsn.get("token_type") != "Bearer"
+                        or "expires_in" not in jsn
+                    ):
+                        raise AirflowException(
+                            f"Can't get necessary data from Service Principal 
token: {jsn}"
+                        )
+
+                    token = jsn["access_token"]
+                    self.sp_token = {"token": token, "expires_on": 
int(time.time() + jsn["expires_in"])}
+                    break
+        except RetryError:
+            raise AirflowException(f"API requests to Databricks failed 
{self.retry_limit} times. Giving up.")
+        except requests_exceptions.HTTPError as e:
+            raise AirflowException(f"Response: {e.response.content}, Status 
Code: {e.response.status_code}")
+
+        return token
+
+    @staticmethod
+    def _is_sp_token_valid(sp_token: dict) -> bool:
+        """
+        Utility function to check Service Principal token hasn't expired yet.
+
+        :param aad_token: dict with properties of AAD token

Review Comment:
   If we talk about general OAuth token, then we need to remove mentions of AAD



##########
airflow/providers/databricks/hooks/databricks_base.py:
##########
@@ -210,6 +213,100 @@ def _a_get_retry_object(self) -> AsyncRetrying:
         """
         return AsyncRetrying(**self.retry_args)
 
+    def _get_sp_token(self) -> str:
+        """Function to get Service Principal token."""
+        if self.sp_token and self._is_sp_token_valid(self.sp_token):
+            return self.sp_token["token"]
+
+        self.log.info("Existing Service Principal token is expired, or going 
to expire soon. Refreshing...")
+        try:
+            for attempt in self._get_retry_object():
+                with attempt:
+                    resp = requests.post(
+                        
OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host),
+                        auth=HTTPBasicAuth(self.databricks_conn.login, 
self.databricks_conn.password),
+                        data="grant_type=client_credentials&scope=all-apis",
+                        headers={
+                            **self.user_agent_header,
+                            "Content-Type": 
"application/x-www-form-urlencoded",
+                        },
+                        timeout=self.token_timeout_seconds,
+                    )
+
+                    resp.raise_for_status()
+                    jsn = resp.json()
+                    if (
+                        "access_token" not in jsn
+                        or jsn.get("token_type") != "Bearer"
+                        or "expires_in" not in jsn
+                    ):
+                        raise AirflowException(
+                            f"Can't get necessary data from Service Principal 
token: {jsn}"
+                        )
+
+                    token = jsn["access_token"]
+                    self.sp_token = {"token": token, "expires_on": 
int(time.time() + jsn["expires_in"])}
+                    break
+        except RetryError:
+            raise AirflowException(f"API requests to Databricks failed 
{self.retry_limit} times. Giving up.")
+        except requests_exceptions.HTTPError as e:
+            raise AirflowException(f"Response: {e.response.content}, Status 
Code: {e.response.status_code}")
+
+        return token
+
+    async def _a_get_sp_token(self) -> str:
+        """Async version of `_get_sp_token()`."""
+        if self.sp_token and self._is_sp_token_valid(self.sp_token):
+            return self.sp_token["token"]
+
+        self.log.info("Existing Service Principal token is expired, or going 
to expire soon. Refreshing...")
+        try:
+            async for attempt in self._a_get_retry_object():
+                with attempt:
+                    async with self._session.post(
+                        
OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host),
+                        auth=HTTPBasicAuth(self.databricks_conn.login, 
self.databricks_conn.password),
+                        data="grant_type=client_credentials&scope=all-apis",
+                        headers={
+                            **self.user_agent_header,
+                            "Content-Type": 
"application/x-www-form-urlencoded",
+                        },
+                        timeout=self.token_timeout_seconds,
+                    ) as resp:
+                        resp.raise_for_status()

Review Comment:
   I'm not sure if we should have `await` here as well, but I'm not async expert



##########
airflow/providers/databricks/hooks/databricks_base.py:
##########
@@ -210,6 +213,100 @@ def _a_get_retry_object(self) -> AsyncRetrying:
         """
         return AsyncRetrying(**self.retry_args)
 
+    def _get_sp_token(self) -> str:
+        """Function to get Service Principal token."""
+        if self.sp_token and self._is_sp_token_valid(self.sp_token):
+            return self.sp_token["token"]
+
+        self.log.info("Existing Service Principal token is expired, or going 
to expire soon. Refreshing...")
+        try:
+            for attempt in self._get_retry_object():
+                with attempt:
+                    resp = requests.post(
+                        
OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host),
+                        auth=HTTPBasicAuth(self.databricks_conn.login, 
self.databricks_conn.password),
+                        data="grant_type=client_credentials&scope=all-apis",
+                        headers={
+                            **self.user_agent_header,
+                            "Content-Type": 
"application/x-www-form-urlencoded",
+                        },
+                        timeout=self.token_timeout_seconds,
+                    )
+
+                    resp.raise_for_status()
+                    jsn = resp.json()
+                    if (
+                        "access_token" not in jsn
+                        or jsn.get("token_type") != "Bearer"
+                        or "expires_in" not in jsn
+                    ):
+                        raise AirflowException(
+                            f"Can't get necessary data from Service Principal 
token: {jsn}"
+                        )
+
+                    token = jsn["access_token"]
+                    self.sp_token = {"token": token, "expires_on": 
int(time.time() + jsn["expires_in"])}
+                    break
+        except RetryError:
+            raise AirflowException(f"API requests to Databricks failed 
{self.retry_limit} times. Giving up.")
+        except requests_exceptions.HTTPError as e:
+            raise AirflowException(f"Response: {e.response.content}, Status 
Code: {e.response.status_code}")
+
+        return token
+
+    async def _a_get_sp_token(self) -> str:
+        """Async version of `_get_sp_token()`."""
+        if self.sp_token and self._is_sp_token_valid(self.sp_token):
+            return self.sp_token["token"]
+
+        self.log.info("Existing Service Principal token is expired, or going 
to expire soon. Refreshing...")
+        try:
+            async for attempt in self._a_get_retry_object():
+                with attempt:
+                    async with self._session.post(
+                        
OIDC_TOKEN_SERVICE_URL.format(self.databricks_conn.host),
+                        auth=HTTPBasicAuth(self.databricks_conn.login, 
self.databricks_conn.password),
+                        data="grant_type=client_credentials&scope=all-apis",
+                        headers={
+                            **self.user_agent_header,
+                            "Content-Type": 
"application/x-www-form-urlencoded",
+                        },
+                        timeout=self.token_timeout_seconds,
+                    ) as resp:
+                        resp.raise_for_status()
+                        jsn = await resp.json()
+                    if (
+                        "access_token" not in jsn
+                        or jsn.get("token_type") != "Bearer"
+                        or "expires_in" not in jsn
+                    ):
+                        raise AirflowException(
+                            f"Can't get necessary data from Service Principal 
token: {jsn}"
+                        )
+
+                    token = jsn["access_token"]
+                    self.sp_token = {"token": token, "expires_on": 
int(time.time() + jsn["expires_in"])}
+                    break
+        except RetryError:
+            raise AirflowException(f"API requests to Databricks failed 
{self.retry_limit} times. Giving up.")
+        except requests_exceptions.HTTPError as e:
+            raise AirflowException(f"Response: {e.response.content}, Status 
Code: {e.response.status_code}")
+
+        return token
+
+    @staticmethod
+    def _is_sp_token_valid(sp_token: dict) -> bool:
+        """
+        Utility function to check Service Principal token hasn't expired yet.
+
+        :param aad_token: dict with properties of AAD token
+        :return: true if token is valid, false otherwise
+        """
+        now = int(time.time())
+        if sp_token["expires_on"] > (now + TOKEN_REFRESH_LEAD_TIME):
+            return True
+        return False

Review Comment:
   just `return sp_token["expires_on"] > (now + TOKEN_REFRESH_LEAD_TIME) `?
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to