This is an automated email from the ASF dual-hosted git repository. collado pushed a commit to branch mcollado-support-tokenexchange in repository https://gitbox.apache.org/repos/asf/polaris.git
commit cda2a620139d36ed578b1e27ddc987f43370eba9 Author: Michael Collado <[email protected]> AuthorDate: Fri Feb 7 11:21:48 2025 -0800 Address PR comments --- .../service/auth/DefaultOAuth2ApiService.java | 16 +++++----- .../service/auth/DefaultOAuth2ApiServiceTest.java | 35 +++++++++++++++++++++- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/service/common/src/main/java/org/apache/polaris/service/auth/DefaultOAuth2ApiService.java b/service/common/src/main/java/org/apache/polaris/service/auth/DefaultOAuth2ApiService.java index ad0bd2b6..5855e36a 100644 --- a/service/common/src/main/java/org/apache/polaris/service/auth/DefaultOAuth2ApiService.java +++ b/service/common/src/main/java/org/apache/polaris/service/auth/DefaultOAuth2ApiService.java @@ -77,28 +77,26 @@ public class DefaultOAuth2ApiService implements IcebergRestOAuth2ApiService { if (authHeader == null && clientSecret == null) { return OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.invalid_client); } - // token exchange with client id and client secret means the client has previously - // attempted to refresh an access token, but refreshing was not supported by the token broker. - // Accept the client id and secret and treat it as a new token request + // token exchange with client id and client secret in the authorization header means the client + // has previously attempted to refresh an access token, but refreshing was not supported by the + // token broker. Accept the client id and secret and treat it as a new token request if (authHeader != null && clientSecret == null && authHeader.startsWith("Basic ")) { String credentials = new String(Base64.decodeBase64(authHeader.substring(6)), UTF_8); + if (!credentials.contains(":")) { + return OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.invalid_request); + } LOGGER.debug("Found credentials in auth header - treating as client_credentials"); String[] parts = credentials.split(":", 2); if (parts.length == 2) { clientId = parts[0]; clientSecret = parts[1]; - } else if (parts.length == 1) { - clientSecret = parts[0]; } else { LOGGER.debug("Don't know how to parse Basic auth header"); - OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.invalid_request); + return OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.invalid_request); } } TokenResponse tokenResponse; if (subjectToken != null) { - if (!tokenBroker.supportsRequestedTokenType(subjectTokenType)) { - return OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.invalid_request); - } tokenResponse = tokenBroker.generateFromToken(subjectTokenType, subjectToken, grantType, scope); } else if (clientSecret != null) { diff --git a/service/common/src/test/java/org/apache/polaris/service/auth/DefaultOAuth2ApiServiceTest.java b/service/common/src/test/java/org/apache/polaris/service/auth/DefaultOAuth2ApiServiceTest.java index 055937bb..097f99e9 100644 --- a/service/common/src/test/java/org/apache/polaris/service/auth/DefaultOAuth2ApiServiceTest.java +++ b/service/common/src/test/java/org/apache/polaris/service/auth/DefaultOAuth2ApiServiceTest.java @@ -167,7 +167,7 @@ class DefaultOAuth2ApiServiceTest { } @Test - public void testReadClientSecretFromAuthHeader() { + public void testAuthHeaderRequiresValidCredentialPair() { TokenBrokerFactory tokenBrokerFactory = Mockito.mock(); RealmContext realmContext = () -> "realm"; TokenBroker tokenBroker = Mockito.mock(); @@ -188,6 +188,39 @@ class DefaultOAuth2ApiServiceTest { .requestedTokenType(TokenType.ACCESS_TOKEN) .realmContext(realmContext) .invoke(new DefaultOAuth2ApiService(tokenBrokerFactory)); + Assertions.assertThat(response.getEntity()) + .isInstanceOf(OAuthTokenErrorResponse.class) + .asInstanceOf(InstanceOfAssertFactories.type(OAuthTokenErrorResponse.class)) + .returns( + OAuthTokenErrorResponse.Error.invalid_request.name(), + OAuthTokenErrorResponse::getError); + } + + @Test + public void testReadClientSecretFromAuthHeader() { + TokenBrokerFactory tokenBrokerFactory = Mockito.mock(); + RealmContext realmContext = () -> "realm"; + TokenBroker tokenBroker = Mockito.mock(); + when(tokenBrokerFactory.apply(realmContext)).thenReturn(tokenBroker); + when(tokenBroker.supportsGrantType(TokenRequestValidator.TOKEN_EXCHANGE)).thenReturn(true); + when(tokenBroker.supportsRequestedTokenType(TokenType.ACCESS_TOKEN)).thenReturn(true); + + when(tokenBroker.generateFromClientSecrets( + "", "secret", TokenRequestValidator.TOKEN_EXCHANGE, "scope")) + .thenReturn(new TokenResponse("token", TokenType.ACCESS_TOKEN.getValue(), 3600)); + Response response = + new InvocationBuilder() + + // here the auth header has a blank client id, providing a blank, but not null client id + .authHeader( + "Basic " + + Base64.getEncoder() + .encodeToString(":secret".getBytes(Charset.defaultCharset()))) + .scope("scope") + .grantType(TokenRequestValidator.TOKEN_EXCHANGE) + .requestedTokenType(TokenType.ACCESS_TOKEN) + .realmContext(realmContext) + .invoke(new DefaultOAuth2ApiService(tokenBrokerFactory)); Assertions.assertThat(response.getEntity()) .isInstanceOf(OAuthTokenResponse.class) .asInstanceOf(InstanceOfAssertFactories.type(OAuthTokenResponse.class))
