This is an automated email from the ASF dual-hosted git repository.

collado pushed a commit to branch mcollado-support-tokenexchange
in repository https://gitbox.apache.org/repos/asf/polaris.git

commit cda2a620139d36ed578b1e27ddc987f43370eba9
Author: Michael Collado <[email protected]>
AuthorDate: Fri Feb 7 11:21:48 2025 -0800

    Address PR comments
---
 .../service/auth/DefaultOAuth2ApiService.java      | 16 +++++-----
 .../service/auth/DefaultOAuth2ApiServiceTest.java  | 35 +++++++++++++++++++++-
 2 files changed, 41 insertions(+), 10 deletions(-)

diff --git 
a/service/common/src/main/java/org/apache/polaris/service/auth/DefaultOAuth2ApiService.java
 
b/service/common/src/main/java/org/apache/polaris/service/auth/DefaultOAuth2ApiService.java
index ad0bd2b6..5855e36a 100644
--- 
a/service/common/src/main/java/org/apache/polaris/service/auth/DefaultOAuth2ApiService.java
+++ 
b/service/common/src/main/java/org/apache/polaris/service/auth/DefaultOAuth2ApiService.java
@@ -77,28 +77,26 @@ public class DefaultOAuth2ApiService implements 
IcebergRestOAuth2ApiService {
     if (authHeader == null && clientSecret == null) {
       return 
OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.invalid_client);
     }
-    // token exchange with client id and client secret means the client has 
previously
-    // attempted to refresh an access token, but refreshing was not supported 
by the token broker.
-    // Accept the client id and secret and treat it as a new token request
+    // token exchange with client id and client secret in the authorization 
header means the client
+    // has previously attempted to refresh an access token, but refreshing was 
not supported by the
+    // token broker. Accept the client id and secret and treat it as a new 
token request
     if (authHeader != null && clientSecret == null && 
authHeader.startsWith("Basic ")) {
       String credentials = new 
String(Base64.decodeBase64(authHeader.substring(6)), UTF_8);
+      if (!credentials.contains(":")) {
+        return 
OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.invalid_request);
+      }
       LOGGER.debug("Found credentials in auth header - treating as 
client_credentials");
       String[] parts = credentials.split(":", 2);
       if (parts.length == 2) {
         clientId = parts[0];
         clientSecret = parts[1];
-      } else if (parts.length == 1) {
-        clientSecret = parts[0];
       } else {
         LOGGER.debug("Don't know how to parse Basic auth header");
-        
OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.invalid_request);
+        return 
OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.invalid_request);
       }
     }
     TokenResponse tokenResponse;
     if (subjectToken != null) {
-      if (!tokenBroker.supportsRequestedTokenType(subjectTokenType)) {
-        return 
OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.invalid_request);
-      }
       tokenResponse =
           tokenBroker.generateFromToken(subjectTokenType, subjectToken, 
grantType, scope);
     } else if (clientSecret != null) {
diff --git 
a/service/common/src/test/java/org/apache/polaris/service/auth/DefaultOAuth2ApiServiceTest.java
 
b/service/common/src/test/java/org/apache/polaris/service/auth/DefaultOAuth2ApiServiceTest.java
index 055937bb..097f99e9 100644
--- 
a/service/common/src/test/java/org/apache/polaris/service/auth/DefaultOAuth2ApiServiceTest.java
+++ 
b/service/common/src/test/java/org/apache/polaris/service/auth/DefaultOAuth2ApiServiceTest.java
@@ -167,7 +167,7 @@ class DefaultOAuth2ApiServiceTest {
   }
 
   @Test
-  public void testReadClientSecretFromAuthHeader() {
+  public void testAuthHeaderRequiresValidCredentialPair() {
     TokenBrokerFactory tokenBrokerFactory = Mockito.mock();
     RealmContext realmContext = () -> "realm";
     TokenBroker tokenBroker = Mockito.mock();
@@ -188,6 +188,39 @@ class DefaultOAuth2ApiServiceTest {
             .requestedTokenType(TokenType.ACCESS_TOKEN)
             .realmContext(realmContext)
             .invoke(new DefaultOAuth2ApiService(tokenBrokerFactory));
+    Assertions.assertThat(response.getEntity())
+        .isInstanceOf(OAuthTokenErrorResponse.class)
+        
.asInstanceOf(InstanceOfAssertFactories.type(OAuthTokenErrorResponse.class))
+        .returns(
+            OAuthTokenErrorResponse.Error.invalid_request.name(),
+            OAuthTokenErrorResponse::getError);
+  }
+
+  @Test
+  public void testReadClientSecretFromAuthHeader() {
+    TokenBrokerFactory tokenBrokerFactory = Mockito.mock();
+    RealmContext realmContext = () -> "realm";
+    TokenBroker tokenBroker = Mockito.mock();
+    when(tokenBrokerFactory.apply(realmContext)).thenReturn(tokenBroker);
+    
when(tokenBroker.supportsGrantType(TokenRequestValidator.TOKEN_EXCHANGE)).thenReturn(true);
+    
when(tokenBroker.supportsRequestedTokenType(TokenType.ACCESS_TOKEN)).thenReturn(true);
+
+    when(tokenBroker.generateFromClientSecrets(
+            "", "secret", TokenRequestValidator.TOKEN_EXCHANGE, "scope"))
+        .thenReturn(new TokenResponse("token", 
TokenType.ACCESS_TOKEN.getValue(), 3600));
+    Response response =
+        new InvocationBuilder()
+
+            // here the auth header has a blank client id, providing a blank, 
but not null client id
+            .authHeader(
+                "Basic "
+                    + Base64.getEncoder()
+                        
.encodeToString(":secret".getBytes(Charset.defaultCharset())))
+            .scope("scope")
+            .grantType(TokenRequestValidator.TOKEN_EXCHANGE)
+            .requestedTokenType(TokenType.ACCESS_TOKEN)
+            .realmContext(realmContext)
+            .invoke(new DefaultOAuth2ApiService(tokenBrokerFactory));
     Assertions.assertThat(response.getEntity())
         .isInstanceOf(OAuthTokenResponse.class)
         .asInstanceOf(InstanceOfAssertFactories.type(OAuthTokenResponse.class))

Reply via email to