This is an automated email from the ASF dual-hosted git repository.
jshao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gravitino.git
The following commit(s) were added to refs/heads/main by this push:
new 44271cbb3 [#4173] improvement(client-python): Add OAuth Error Handler
and related exceptions, test cases in client-python (#4324)
44271cbb3 is described below
commit 44271cbb3bdaea4179761a5aa1dcb620196a8166
Author: noidname01 <[email protected]>
AuthorDate: Thu Aug 1 14:43:41 2024 +0800
[#4173] improvement(client-python): Add OAuth Error Handler and related
exceptions, test cases in client-python (#4324)
### What changes were proposed in this pull request?
* Add OAuth Error Handler and related exceptions, UT in `client-python`
based on `client-java`
### Why are the changes needed?
Fix: #4173
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT Added and test with `./gradlew client:client-python:test`
---------
Co-authored-by: TimWang <[email protected]>
---
.../auth/default_oauth2_token_provider.py | 5 +-
.../dto/responses/oauth2_error_response.py | 40 +++++++++++++++
clients/client-python/gravitino/exceptions/base.py | 8 +++
.../exceptions/handlers/oauth_error_handler.py | 58 ++++++++++++++++++++++
.../client-python/gravitino/utils/http_client.py | 13 ++++-
.../tests/unittests/auth/mock_base.py | 26 ++++++++++
.../unittests/auth/test_oauth2_token_provider.py | 30 ++++++++++-
7 files changed, 175 insertions(+), 5 deletions(-)
diff --git
a/clients/client-python/gravitino/auth/default_oauth2_token_provider.py
b/clients/client-python/gravitino/auth/default_oauth2_token_provider.py
index 3fb730395..beefc90c4 100644
--- a/clients/client-python/gravitino/auth/default_oauth2_token_provider.py
+++ b/clients/client-python/gravitino/auth/default_oauth2_token_provider.py
@@ -27,6 +27,7 @@ from gravitino.dto.requests.oauth2_client_credential_request
import (
OAuth2ClientCredentialRequest,
)
from gravitino.exceptions.base import GravitinoRuntimeException
+from gravitino.exceptions.handlers.oauth_error_handler import
OAUTH_ERROR_HANDLER
CLIENT_CREDENTIALS = "client_credentials"
CREDENTIAL_SPLITTER = ":"
@@ -107,7 +108,9 @@ class DefaultOAuth2TokenProvider(OAuth2TokenProvider):
)
resp = self._client.post_form(
- self._path, data=client_credential_request.to_dict()
+ self._path,
+ data=client_credential_request,
+ error_handler=OAUTH_ERROR_HANDLER,
)
oauth2_resp = OAuth2TokenResponse.from_json(resp.body,
infer_missing=True)
oauth2_resp.validate()
diff --git
a/clients/client-python/gravitino/dto/responses/oauth2_error_response.py
b/clients/client-python/gravitino/dto/responses/oauth2_error_response.py
new file mode 100644
index 000000000..f7e472c13
--- /dev/null
+++ b/clients/client-python/gravitino/dto/responses/oauth2_error_response.py
@@ -0,0 +1,40 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+"""
+
+from dataclasses import dataclass, field
+from dataclasses_json import config
+
+from gravitino.dto.responses.error_response import ErrorResponse
+
+
+@dataclass
+class OAuth2ErrorResponse(ErrorResponse):
+ """Represents the response of an OAuth2 error."""
+
+ _type: str = field(metadata=config(field_name="error"))
+ _message: str = field(metadata=config(field_name="error_description"))
+
+ def type(self):
+ return self._type
+
+ def message(self):
+ return self._message
+
+ def validate(self):
+ assert self._type is not None, "OAuthErrorResponse should contain type"
diff --git a/clients/client-python/gravitino/exceptions/base.py
b/clients/client-python/gravitino/exceptions/base.py
index 418304d7d..7700e151a 100644
--- a/clients/client-python/gravitino/exceptions/base.py
+++ b/clients/client-python/gravitino/exceptions/base.py
@@ -85,3 +85,11 @@ class
UnsupportedOperationException(GravitinoRuntimeException):
class UnknownError(RuntimeError):
"""An exception thrown when other unknown exception is thrown"""
+
+
+class UnauthorizedException(GravitinoRuntimeException):
+ """An exception thrown when a user is not authorized to perform an
action."""
+
+
+class BadRequestException(GravitinoRuntimeException):
+ """An exception thrown when the request is invalid."""
diff --git
a/clients/client-python/gravitino/exceptions/handlers/oauth_error_handler.py
b/clients/client-python/gravitino/exceptions/handlers/oauth_error_handler.py
new file mode 100644
index 000000000..ede4d58ae
--- /dev/null
+++ b/clients/client-python/gravitino/exceptions/handlers/oauth_error_handler.py
@@ -0,0 +1,58 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+"""
+
+from gravitino.exceptions.base import UnauthorizedException,
BadRequestException
+from gravitino.dto.responses.oauth2_error_response import OAuth2ErrorResponse
+from gravitino.exceptions.handlers.rest_error_handler import RestErrorHandler
+
+INVALID_CLIENT_ERROR = "invalid_client"
+INVALID_REQUEST_ERROR = "invalid_request"
+INVALID_GRANT_ERROR = "invalid_grant"
+UNAUTHORIZED_CLIENT_ERROR = "unauthorized_client"
+UNSUPPORTED_GRANT_TYPE_ERROR = "unsupported_grant_type"
+INVALID_SCOPE_ERROR = "invalid_scope"
+
+
+class OAuthErrorHandler(RestErrorHandler):
+
+ def handle(self, error_response: OAuth2ErrorResponse):
+
+ error_message = error_response.message()
+ exception_type = error_response.type()
+
+ if exception_type == INVALID_CLIENT_ERROR:
+ raise UnauthorizedException(
+ f"Not authorized: {exception_type}: {error_message}"
+ )
+
+ if exception_type in [
+ INVALID_REQUEST_ERROR,
+ INVALID_GRANT_ERROR,
+ UNAUTHORIZED_CLIENT_ERROR,
+ UNSUPPORTED_GRANT_TYPE_ERROR,
+ INVALID_SCOPE_ERROR,
+ ]:
+ raise BadRequestException(
+ f"Malformed request: {exception_type}: {error_message}"
+ )
+
+ super().handle(error_response)
+
+
+OAUTH_ERROR_HANDLER = OAuthErrorHandler()
diff --git a/clients/client-python/gravitino/utils/http_client.py
b/clients/client-python/gravitino/utils/http_client.py
index 89b75d641..678942bb4 100644
--- a/clients/client-python/gravitino/utils/http_client.py
+++ b/clients/client-python/gravitino/utils/http_client.py
@@ -37,6 +37,7 @@ from gravitino.typing import JSONType
from gravitino.constants.timeout import TIMEOUT
from gravitino.dto.responses.error_response import ErrorResponse
+from gravitino.dto.responses.oauth2_error_response import OAuth2ErrorResponse
from gravitino.exceptions.base import RESTException, UnknownError
from gravitino.exceptions.handlers.error_handler import ErrorHandler
@@ -145,11 +146,19 @@ class HTTPClient:
ErrorResponse.generate_error_response(RESTException,
err.reason),
)
- err_resp = ErrorResponse.from_json(err_body, infer_missing=True)
+ err_resp = self._parse_error_response(err_body)
err_resp.validate()
return (False, err_resp)
+ def _parse_error_response(self, err_body: bytes) -> ErrorResponse:
+ json_err_body = _json.loads(err_body)
+
+ if "code" in json_err_body:
+ return ErrorResponse.from_json(err_body, infer_missing=True)
+
+ return OAuth2ErrorResponse.from_json(err_body, infer_missing=True)
+
# pylint: disable=too-many-locals
def _request(
self,
@@ -228,7 +237,7 @@ class HTTPClient:
def post_form(self, endpoint, data=None, error_handler=None, **kwargs):
return self._request(
- "post", endpoint, data=data, error_handler=error_handler**kwargs
+ "post", endpoint, data=data, error_handler=error_handler, **kwargs
)
def close(self):
diff --git a/clients/client-python/tests/unittests/auth/mock_base.py
b/clients/client-python/tests/unittests/auth/mock_base.py
index f7b66c6b3..2becd5457 100644
--- a/clients/client-python/tests/unittests/auth/mock_base.py
+++ b/clients/client-python/tests/unittests/auth/mock_base.py
@@ -28,6 +28,12 @@ from cryptography.hazmat.primitives import serialization as
crypto_serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend as
crypto_default_backend
+from gravitino.dto.responses.oauth2_error_response import OAuth2ErrorResponse
+from gravitino.exceptions.handlers.oauth_error_handler import (
+ INVALID_CLIENT_ERROR,
+ INVALID_GRANT_ERROR,
+)
+
@dataclass
class TestResponse:
@@ -61,6 +67,26 @@ JWT_PRIVATE_KEY = generate_private_key()
GENERATED_TIME = int(time.time())
+def mock_authentication_invalid_client_error():
+ return (
+ False,
+ OAuth2ErrorResponse.from_json(
+ json.dumps({"error": INVALID_CLIENT_ERROR, "error_description":
"invalid"}),
+ infer_missing=True,
+ ),
+ )
+
+
+def mock_authentication_invalid_grant_error():
+ return (
+ False,
+ OAuth2ErrorResponse.from_json(
+ json.dumps({"error": INVALID_GRANT_ERROR, "error_description":
"invalid"}),
+ infer_missing=True,
+ ),
+ )
+
+
def mock_authentication_with_error_authentication_type():
return TestResponse(
body=json.dumps(
diff --git
a/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py
b/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py
index b60efbf04..7d9ef9e25 100644
--- a/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py
+++ b/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py
@@ -22,6 +22,7 @@ from unittest.mock import patch
from gravitino.auth.auth_constants import AuthConstants
from gravitino.auth.default_oauth2_token_provider import
DefaultOAuth2TokenProvider
+from gravitino.exceptions.base import BadRequestException,
UnauthorizedException
from tests.unittests.auth import mock_base
OAUTH_PORT = 1082
@@ -40,8 +41,33 @@ class TestOAuth2TokenProvider(unittest.TestCase):
with self.assertRaises(AssertionError):
_ = DefaultOAuth2TokenProvider(uri="test", credential="xx",
scope="test")
- # TODO
- # Error Test
+ @patch(
+ "gravitino.utils.http_client.HTTPClient._make_request",
+ return_value=mock_base.mock_authentication_invalid_client_error(),
+ )
+ def test_authertication_invalid_client_error(self, *mock_methods):
+
+ with self.assertRaises(UnauthorizedException):
+ _ = DefaultOAuth2TokenProvider(
+ uri=f"http://127.0.0.1:{OAUTH_PORT}",
+ credential="yy:xx",
+ path="oauth/token",
+ scope="test",
+ )
+
+ @patch(
+ "gravitino.utils.http_client.HTTPClient._make_request",
+ return_value=mock_base.mock_authentication_invalid_grant_error(),
+ )
+ def test_authertication_invalid_grant_error(self, *mock_methods):
+
+ with self.assertRaises(BadRequestException):
+ _ = DefaultOAuth2TokenProvider(
+ uri=f"http://127.0.0.1:{OAUTH_PORT}",
+ credential="yy:xx",
+ path="oauth/token",
+ scope="test",
+ )
@patch(
"gravitino.utils.http_client.HTTPClient.post_form",