This is an automated email from the ASF dual-hosted git repository.

jshao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gravitino.git


The following commit(s) were added to refs/heads/main by this push:
     new 213bcc9f2 [#3755] improvement(client-python): Support 
OAuth2TokenProvider for Python client (#4011)
213bcc9f2 is described below

commit 213bcc9f28102a3b472a8b2d9629525e9d00d269
Author: noidname01 <[email protected]>
AuthorDate: Fri Jul 19 10:53:57 2024 +0800

    [#3755] improvement(client-python): Support OAuth2TokenProvider for Python 
client (#4011)
    
    ### What changes were proposed in this pull request?
    
    * Add `OAuth2TokenProvider` and `DefaultOAuth2TokenProvider` in
    `client-python`
    * There are some components and tests missing because it would be a big
    code change if they were also done in this PR, they will be added in the
    following PRs
            - [ ] Error Handling: #4173
            - [ ] Integration Test: #4208
    * Modify test file structure, and found issue #4136, solve it by reset
    environment variable.
    
    ### Why are the changes needed?
    
    Fix: #3755, #4136
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Add UT and tested by `./gradlew clients:client-python:unittest`
    
    ---------
    
    Co-authored-by: TimWang <[email protected]>
---
 .../client-python/gravitino/auth/auth_constants.py |   2 +
 .../auth/default_oauth2_token_provider.py          | 133 +++++++++++++++++++
 .../gravitino/auth/oauth2_token_provider.py        |  75 +++++++++++
 .../gravitino/auth/simple_auth_provider.py         |   4 +-
 .../requests/oauth2_client_credential_request.py}  |  15 ++-
 .../dto/responses/oauth2_token_response.py         |  55 ++++++++
 .../client-python/gravitino/utils/http_client.py   |  36 ++++--
 clients/client-python/requirements-dev.txt         |   3 +-
 .../tests/integration/test_simple_auth_client.py   |   2 +
 .../unittests/auth/__init__.py}                    |   6 -
 .../tests/unittests/auth/mock_base.py              | 144 +++++++++++++++++++++
 .../unittests/auth/test_oauth2_token_provider.py   |  93 +++++++++++++
 .../{ => auth}/test_simple_auth_provider.py        |   4 +
 13 files changed, 551 insertions(+), 21 deletions(-)

diff --git a/clients/client-python/gravitino/auth/auth_constants.py 
b/clients/client-python/gravitino/auth/auth_constants.py
index 2494030fc..247abcaaa 100644
--- a/clients/client-python/gravitino/auth/auth_constants.py
+++ b/clients/client-python/gravitino/auth/auth_constants.py
@@ -21,4 +21,6 @@ under the License.
 class AuthConstants:
     HTTP_HEADER_AUTHORIZATION: str = "Authorization"
 
+    AUTHORIZATION_BEARER_HEADER: str = "Bearer "
+
     AUTHORIZATION_BASIC_HEADER: str = "Basic "
diff --git 
a/clients/client-python/gravitino/auth/default_oauth2_token_provider.py 
b/clients/client-python/gravitino/auth/default_oauth2_token_provider.py
new file mode 100644
index 000000000..3fb730395
--- /dev/null
+++ b/clients/client-python/gravitino/auth/default_oauth2_token_provider.py
@@ -0,0 +1,133 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+"""
+
+import time
+import json
+import base64
+from typing import Optional
+from gravitino.auth.oauth2_token_provider import OAuth2TokenProvider
+from gravitino.dto.responses.oauth2_token_response import OAuth2TokenResponse
+from gravitino.dto.requests.oauth2_client_credential_request import (
+    OAuth2ClientCredentialRequest,
+)
+from gravitino.exceptions.base import GravitinoRuntimeException
+
+CLIENT_CREDENTIALS = "client_credentials"
+CREDENTIAL_SPLITTER = ":"
+TOKEN_SPLITTER = "."
+JWT_EXPIRE = "exp"
+
+
+class DefaultOAuth2TokenProvider(OAuth2TokenProvider):
+    """This class is the default implement of OAuth2TokenProvider."""
+
+    _credential: Optional[str]
+    _scope: Optional[str]
+    _path: Optional[str]
+    _token: Optional[str]
+
+    def __init__(
+        self,
+        uri: str = None,
+        credential: str = None,
+        scope: str = None,
+        path: str = None,
+    ):
+        super().__init__(uri)
+
+        self._credential = credential
+        self._scope = scope
+        self._path = path
+
+        self.validate()
+
+        self._token = self._fetch_token()
+
+    def validate(self):
+        assert (
+            self._credential and self._credential.strip()
+        ), "OAuth2TokenProvider must set credential"
+        assert self._scope and self._scope.strip(), "OAuth2TokenProvider must 
set scope"
+        assert self._path and self._path.strip(), "OAuth2TokenProvider must 
set path"
+
+    def _get_access_token(self) -> Optional[str]:
+
+        expires = self._expires_at_millis()
+
+        if expires is None:
+            return None
+
+        if expires > time.time() * 1000:
+            return self._token
+
+        self._token = self._fetch_token()
+        return self._token
+
+    def _parse_credential(self):
+        assert self._credential is not None, "Invalid credential: None"
+
+        credential_info = self._credential.split(CREDENTIAL_SPLITTER, 
maxsplit=1)
+        client_id = None
+        client_secret = None
+
+        if len(credential_info) == 2:
+            client_id, client_secret = credential_info
+        elif len(credential_info) == 1:
+            client_secret = credential_info[0]
+        else:
+            raise GravitinoRuntimeException(f"Invalid credential: 
{self._credential}")
+
+        return client_id, client_secret
+
+    def _fetch_token(self) -> str:
+
+        client_id, client_secret = self._parse_credential()
+
+        client_credential_request = OAuth2ClientCredentialRequest(
+            grant_type=CLIENT_CREDENTIALS,
+            client_id=client_id,
+            client_secret=client_secret,
+            scope=self._scope,
+        )
+
+        resp = self._client.post_form(
+            self._path, data=client_credential_request.to_dict()
+        )
+        oauth2_resp = OAuth2TokenResponse.from_json(resp.body, 
infer_missing=True)
+        oauth2_resp.validate()
+
+        return oauth2_resp.access_token()
+
+    def _expires_at_millis(self) -> int:
+        if self._token is None:
+            return None
+
+        parts = self._token.split(TOKEN_SPLITTER)
+
+        if len(parts) != 3:
+            return None
+
+        jwt = json.loads(
+            base64.b64decode(parts[1] + "=" * (-len(parts[1]) % 
4)).decode("utf-8")
+        )
+
+        if JWT_EXPIRE not in jwt or not isinstance(jwt[JWT_EXPIRE], int):
+            return None
+
+        return jwt[JWT_EXPIRE] * 1000
diff --git a/clients/client-python/gravitino/auth/oauth2_token_provider.py 
b/clients/client-python/gravitino/auth/oauth2_token_provider.py
new file mode 100644
index 000000000..5d243053f
--- /dev/null
+++ b/clients/client-python/gravitino/auth/oauth2_token_provider.py
@@ -0,0 +1,75 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+"""
+
+from abc import abstractmethod
+from typing import Optional
+
+from gravitino.utils.http_client import HTTPClient
+from gravitino.auth.auth_data_provider import AuthDataProvider
+from gravitino.auth.auth_constants import AuthConstants
+
+
+class OAuth2TokenProvider(AuthDataProvider):
+    """OAuth2TokenProvider will request the access token from the 
authorization server and then provide
+    the access token for every request.
+    """
+
+    # The HTTP client used to request the access token from the authorization 
server.
+    _client: HTTPClient
+
+    def __init__(self, uri: str):
+        self._client = HTTPClient(uri)
+
+    def has_token_data(self) -> bool:
+        """Judge whether AuthDataProvider can provide token data.
+
+        Returns:
+            true if the AuthDataProvider can provide token data otherwise 
false.
+        """
+        return True
+
+    def get_token_data(self) -> Optional[bytes]:
+        """Acquire the data of token for authentication. The client will set 
the token data as HTTP header
+        Authorization directly. So the return value should ensure token data 
contain the token header
+        (eg: Bearer, Basic) if necessary.
+
+        Returns:
+            the token data is used for authentication.
+        """
+        access_token = self._get_access_token()
+
+        if access_token is None:
+            return None
+
+        return (AuthConstants.AUTHORIZATION_BEARER_HEADER + 
access_token).encode(
+            "utf-8"
+        )
+
+    def close(self):
+        """Closes the OAuth2TokenProvider and releases any underlying 
resources."""
+        if self._client is not None:
+            self._client.close()
+
+    @abstractmethod
+    def _get_access_token(self) -> Optional[str]:
+        """Get the access token from the authorization server."""
+
+    @abstractmethod
+    def validate(self):
+        """Validate the OAuth2TokenProvider"""
diff --git a/clients/client-python/gravitino/auth/simple_auth_provider.py 
b/clients/client-python/gravitino/auth/simple_auth_provider.py
index ef013a7fe..96aae06a0 100644
--- a/clients/client-python/gravitino/auth/simple_auth_provider.py
+++ b/clients/client-python/gravitino/auth/simple_auth_provider.py
@@ -20,8 +20,8 @@ under the License.
 import base64
 import os
 
-from .auth_constants import AuthConstants
-from .auth_data_provider import AuthDataProvider
+from gravitino.auth.auth_constants import AuthConstants
+from gravitino.auth.auth_data_provider import AuthDataProvider
 
 
 class SimpleAuthProvider(AuthDataProvider):
diff --git a/clients/client-python/gravitino/auth/auth_constants.py 
b/clients/client-python/gravitino/dto/requests/oauth2_client_credential_request.py
similarity index 71%
copy from clients/client-python/gravitino/auth/auth_constants.py
copy to 
clients/client-python/gravitino/dto/requests/oauth2_client_credential_request.py
index 2494030fc..4d4de57a4 100644
--- a/clients/client-python/gravitino/auth/auth_constants.py
+++ 
b/clients/client-python/gravitino/dto/requests/oauth2_client_credential_request.py
@@ -17,8 +17,17 @@ specific language governing permissions and limitations
 under the License.
 """
 
+from typing import Optional
+from dataclasses import dataclass
 
-class AuthConstants:
-    HTTP_HEADER_AUTHORIZATION: str = "Authorization"
 
-    AUTHORIZATION_BASIC_HEADER: str = "Basic "
+@dataclass
+class OAuth2ClientCredentialRequest:
+
+    grant_type: str
+    client_id: Optional[str]
+    client_secret: str
+    scope: str
+
+    def to_dict(self, **kwarg):
+        return {k: v for k, v in self.__dict__.items() if v is not None}
diff --git 
a/clients/client-python/gravitino/dto/responses/oauth2_token_response.py 
b/clients/client-python/gravitino/dto/responses/oauth2_token_response.py
new file mode 100644
index 000000000..07869ec03
--- /dev/null
+++ b/clients/client-python/gravitino/dto/responses/oauth2_token_response.py
@@ -0,0 +1,55 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+"""
+
+from typing import Optional
+from dataclasses import dataclass, field
+from dataclasses_json import config
+
+from gravitino.dto.responses.base_response import BaseResponse
+from gravitino.auth.auth_constants import AuthConstants
+
+
+@dataclass
+class OAuth2TokenResponse(BaseResponse):
+
+    _access_token: str = field(metadata=config(field_name="access_token"))
+    _issue_token_type: Optional[str] = field(
+        metadata=config(field_name="issued_token_type")
+    )
+    _token_type: str = field(metadata=config(field_name="token_type"))
+    _expires_in: int = field(metadata=config(field_name="expires_in"))
+    _scope: str = field(metadata=config(field_name="scope"))
+    _refresh_token: Optional[str] = 
field(metadata=config(field_name="refresh_token"))
+
+    def validate(self):
+        """Validates the response.
+
+        Raise:
+            IllegalArgumentException If the response is invalid, this 
exception is thrown.
+        """
+        super().validate()
+
+        assert self._access_token is not None, "Invalid access token: None"
+        assert (
+            AuthConstants.AUTHORIZATION_BEARER_HEADER.strip().lower()
+            == self._token_type.lower()
+        ), f'Unsupported token type: {self._token_type} (must be "bearer")'
+
+    def access_token(self) -> str:
+        return self._access_token
diff --git a/clients/client-python/gravitino/utils/http_client.py 
b/clients/client-python/gravitino/utils/http_client.py
index 67504f12d..89b75d641 100644
--- a/clients/client-python/gravitino/utils/http_client.py
+++ b/clients/client-python/gravitino/utils/http_client.py
@@ -78,6 +78,17 @@ class Response:
 
 
 class HTTPClient:
+
+    FORMDATA_HEADER = {
+        "Content-Type": "application/x-www-form-urlencoded",
+        "Accept": "application/vnd.gravitino.v1+json",
+    }
+
+    JSON_HEADER = {
+        "Content-Type": "application/json",
+        "Accept": "application/vnd.gravitino.v1+json",
+    }
+
     def __init__(
         self,
         host,
@@ -139,12 +150,14 @@ class HTTPClient:
 
             return (False, err_resp)
 
+    # pylint: disable=too-many-locals
     def _request(
         self,
         method,
         endpoint,
         params=None,
         json=None,
+        data=None,
         headers=None,
         timeout=None,
         error_handler: ErrorHandler = None,
@@ -152,17 +165,17 @@ class HTTPClient:
         method = method.upper()
         request_data = None
 
-        if headers:
-            self._update_headers(headers)
+        if data:
+            request_data = urlencode(data.to_dict()).encode()
+            self._update_headers(self.FORMDATA_HEADER)
         else:
-            headers = {
-                "Content-Type": "application/json",
-                "Accept": "application/vnd.gravitino.v1+json",
-            }
-            self._update_headers(headers)
+            if json:
+                request_data = json.to_json().encode("utf-8")
 
-        if json:
-            request_data = json.to_json().encode("utf-8")
+            self._update_headers(self.JSON_HEADER)
+
+        if headers:
+            self._update_headers(headers)
 
         opener = build_opener()
         request = Request(self._build_url(endpoint, params), data=request_data)
@@ -213,6 +226,11 @@ class HTTPClient:
             "put", endpoint, json=json, error_handler=error_handler, **kwargs
         )
 
+    def post_form(self, endpoint, data=None, error_handler=None, **kwargs):
+        return self._request(
+            "post", endpoint, data=data, error_handler=error_handler**kwargs
+        )
+
     def close(self):
         self._request("close", "/")
         if self.auth_data_provider is not None:
diff --git a/clients/client-python/requirements-dev.txt 
b/clients/client-python/requirements-dev.txt
index 06f634358..e91d966a4 100644
--- a/clients/client-python/requirements-dev.txt
+++ b/clients/client-python/requirements-dev.txt
@@ -27,4 +27,5 @@ llama-index==0.10.40
 tenacity==8.3.0
 cachetools==5.3.3
 readerwriterlock==1.0.9
-docker==7.1.0
\ No newline at end of file
+docker==7.1.0
+pyjwt[crypto]==2.8.0
diff --git a/clients/client-python/tests/integration/test_simple_auth_client.py 
b/clients/client-python/tests/integration/test_simple_auth_client.py
index a4ed77fe1..5dd8a553b 100644
--- a/clients/client-python/tests/integration/test_simple_auth_client.py
+++ b/clients/client-python/tests/integration/test_simple_auth_client.py
@@ -100,6 +100,8 @@ class TestSimpleAuthClient(IntegrationTestEnv):
             )
         except Exception as e:
             logger.error("Clean test data failed: %s", e)
+        finally:
+            os.environ["GRAVITINO_USER"] = ""
 
     def init_test_env(self):
         self.gravitino_admin_client.create_metalake(
diff --git a/clients/client-python/gravitino/auth/auth_constants.py 
b/clients/client-python/tests/unittests/auth/__init__.py
similarity index 86%
copy from clients/client-python/gravitino/auth/auth_constants.py
copy to clients/client-python/tests/unittests/auth/__init__.py
index 2494030fc..c206137f1 100644
--- a/clients/client-python/gravitino/auth/auth_constants.py
+++ b/clients/client-python/tests/unittests/auth/__init__.py
@@ -16,9 +16,3 @@ KIND, either express or implied.  See the License for the
 specific language governing permissions and limitations
 under the License.
 """
-
-
-class AuthConstants:
-    HTTP_HEADER_AUTHORIZATION: str = "Authorization"
-
-    AUTHORIZATION_BASIC_HEADER: str = "Basic "
diff --git a/clients/client-python/tests/unittests/auth/mock_base.py 
b/clients/client-python/tests/unittests/auth/mock_base.py
new file mode 100644
index 000000000..f7b66c6b3
--- /dev/null
+++ b/clients/client-python/tests/unittests/auth/mock_base.py
@@ -0,0 +1,144 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+"""
+
+import time
+import json
+from dataclasses import dataclass
+from http import HTTPStatus
+
+from dataclasses_json import dataclass_json
+import jwt
+from cryptography.hazmat.primitives import serialization as 
crypto_serialization
+from cryptography.hazmat.primitives.asymmetric import rsa
+from cryptography.hazmat.backends import default_backend as 
crypto_default_backend
+
+
+@dataclass
+class TestResponse:
+    body: bytes
+    status_code: int
+
+
+@dataclass_json
+@dataclass
+class TestJWT:
+    sub: str
+    exp: int
+    aud: str
+
+
+def generate_private_key():
+    key = rsa.generate_private_key(
+        backend=crypto_default_backend(), public_exponent=65537, key_size=2048
+    )
+
+    private_key = key.private_bytes(
+        crypto_serialization.Encoding.PEM,
+        crypto_serialization.PrivateFormat.PKCS8,
+        crypto_serialization.NoEncryption(),
+    )
+
+    return private_key
+
+
+JWT_PRIVATE_KEY = generate_private_key()
+GENERATED_TIME = int(time.time())
+
+
+def mock_authentication_with_error_authentication_type():
+    return TestResponse(
+        body=json.dumps(
+            {
+                "code": 0,
+                "access_token": "1",
+                "issued_token_type": "2",
+                "token_type": "3",
+                "expires_in": 1,
+                "scope": "test",
+                "refresh_token": None,
+            }
+        ).encode("utf-8"),
+        status_code=HTTPStatus.OK.value,
+    )
+
+
+def mock_authentication_with_non_jwt():
+    return TestResponse(
+        body=json.dumps(
+            {
+                "code": 0,
+                "access_token": "1",
+                "issued_token_type": "2",
+                "token_type": "bearer",
+                "expires_in": 1,
+                "scope": "test",
+                "refresh_token": None,
+            }
+        ),
+        status_code=HTTPStatus.OK.value,
+    )
+
+
+def mock_jwt(sub, exp, aud):
+    return jwt.encode(
+        TestJWT(sub, exp, aud).to_dict(),
+        JWT_PRIVATE_KEY,
+        algorithm="RS256",
+    )
+
+
+def mock_old_new_jwt():
+    return [
+        mock_jwt(sub="gravitino", exp=GENERATED_TIME - 10000, aud="service1"),
+        mock_jwt(sub="gravitino", exp=GENERATED_TIME + 10000, aud="service1"),
+    ]
+
+
+def mock_authentication_with_jwt():
+    old_access_token, new_access_token = mock_old_new_jwt()
+    return [
+        TestResponse(
+            body=json.dumps(
+                {
+                    "code": 0,
+                    "access_token": old_access_token,
+                    "issued_token_type": "2",
+                    "token_type": "bearer",
+                    "expires_in": 1,
+                    "scope": "test",
+                    "refresh_token": None,
+                }
+            ),
+            status_code=HTTPStatus.OK.value,
+        ),
+        TestResponse(
+            body=json.dumps(
+                {
+                    "code": 0,
+                    "access_token": new_access_token,
+                    "issued_token_type": "2",
+                    "token_type": "bearer",
+                    "expires_in": 1,
+                    "scope": "test",
+                    "refresh_token": None,
+                }
+            ),
+            status_code=HTTPStatus.OK.value,
+        ),
+    ]
diff --git 
a/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py 
b/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py
new file mode 100644
index 000000000..b60efbf04
--- /dev/null
+++ b/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py
@@ -0,0 +1,93 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+"""
+
+import unittest
+from unittest.mock import patch
+
+from gravitino.auth.auth_constants import AuthConstants
+from gravitino.auth.default_oauth2_token_provider import 
DefaultOAuth2TokenProvider
+from tests.unittests.auth import mock_base
+
+OAUTH_PORT = 1082
+
+
+class TestOAuth2TokenProvider(unittest.TestCase):
+
+    def test_provider_init_exception(self):
+
+        with self.assertRaises(AssertionError):
+            _ = DefaultOAuth2TokenProvider(uri="test")
+
+        with self.assertRaises(AssertionError):
+            _ = DefaultOAuth2TokenProvider(uri="test", credential="xx")
+
+        with self.assertRaises(AssertionError):
+            _ = DefaultOAuth2TokenProvider(uri="test", credential="xx", 
scope="test")
+
+    # TODO
+    # Error Test
+
+    @patch(
+        "gravitino.utils.http_client.HTTPClient.post_form",
+        
return_value=mock_base.mock_authentication_with_error_authentication_type(),
+    )
+    def test_authentication_with_error_authentication_type(self, 
*mock_methods):
+
+        with self.assertRaises(AssertionError):
+            _ = DefaultOAuth2TokenProvider(
+                uri=f"http://127.0.0.1:{OAUTH_PORT}";,
+                credential="yy:xx",
+                path="oauth/token",
+                scope="test",
+            )
+
+    @patch(
+        "gravitino.utils.http_client.HTTPClient.post_form",
+        return_value=mock_base.mock_authentication_with_non_jwt(),
+    )
+    def test_authentication_with_non_jwt(self, *mock_methods):
+        token_provider = DefaultOAuth2TokenProvider(
+            uri=f"http://127.0.0.1:{OAUTH_PORT}";,
+            credential="yy:xx",
+            path="oauth/token",
+            scope="test",
+        )
+
+        self.assertTrue(token_provider.has_token_data())
+        self.assertIsNone(token_provider.get_token_data())
+
+    @patch(
+        "gravitino.utils.http_client.HTTPClient.post_form",
+        side_effect=mock_base.mock_authentication_with_jwt(),
+    )
+    def test_authentication_with_jwt(self, *mock_methods):
+        old_access_token, new_access_token = mock_base.mock_old_new_jwt()
+
+        token_provider = DefaultOAuth2TokenProvider(
+            uri=f"http://127.0.0.1:{OAUTH_PORT}";,
+            credential="yy:xx",
+            path="oauth/token",
+            scope="test",
+        )
+
+        self.assertNotEqual(old_access_token, new_access_token)
+        self.assertEqual(
+            token_provider.get_token_data().decode("utf-8"),
+            AuthConstants.AUTHORIZATION_BEARER_HEADER + new_access_token,
+        )
diff --git a/clients/client-python/tests/unittests/test_simple_auth_provider.py 
b/clients/client-python/tests/unittests/auth/test_simple_auth_provider.py
similarity index 91%
rename from clients/client-python/tests/unittests/test_simple_auth_provider.py
rename to 
clients/client-python/tests/unittests/auth/test_simple_auth_provider.py
index d8c10e467..c7e7fdc39 100644
--- a/clients/client-python/tests/unittests/test_simple_auth_provider.py
+++ b/clients/client-python/tests/unittests/auth/test_simple_auth_provider.py
@@ -40,6 +40,9 @@ class TestSimpleAuthProvider(unittest.TestCase):
         ).decode("utf-8")
         self.assertEqual(f"{user}:dummy", token_string)
 
+        original_gravitino_user = (
+            os.environ["GRAVITINO_USER"] if "GRAVITINO_USER" in os.environ 
else ""
+        )
         os.environ["GRAVITINO_USER"] = "test_auth2"
         provider: AuthDataProvider = SimpleAuthProvider()
         self.assertTrue(provider.has_token_data())
@@ -50,3 +53,4 @@ class TestSimpleAuthProvider(unittest.TestCase):
             token[len(AuthConstants.AUTHORIZATION_BASIC_HEADER) :]
         ).decode("utf-8")
         self.assertEqual(f"{user}:dummy", token_string)
+        os.environ["GRAVITINO_USER"] = original_gravitino_user

Reply via email to