This is an automated email from the ASF dual-hosted git repository.

collado pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/polaris.git


The following commit(s) were added to refs/heads/main by this push:
     new 42af733d Update DefaultOAuth2ApiService to support multiple token 
types and client secret without id (#952)
42af733d is described below

commit 42af733d63acf5b2f81917055fa295550be4fd6b
Author: Michael Collado <[email protected]>
AuthorDate: Thu Feb 13 20:55:26 2025 -0800

    Update DefaultOAuth2ApiService to support multiple token types and client 
secret without id (#952)
    
    * Update DefaultOAuth2ApiService to support multiple token types and client 
secret without id
    
    * Address PR comments
    
    * Update TokenBroker interface to accept requestedTokenType
    
    * Add check for requested token type
---
 .../service/quarkus/auth/JWTRSAKeyPairTest.java    |   7 +-
 .../quarkus/auth/JWTSymmetricKeyGeneratorTest.java |   7 +-
 .../service/auth/DefaultOAuth2ApiService.java      |  55 ++--
 .../org/apache/polaris/service/auth/JWTBroker.java |  17 +-
 .../service/auth/NoneTokenBrokerFactory.java       |  12 +-
 .../apache/polaris/service/auth/TokenBroker.java   |  70 ++++-
 .../service/auth/DefaultOAuth2ApiServiceTest.java  | 328 +++++++++++++++++++++
 7 files changed, 457 insertions(+), 39 deletions(-)

diff --git 
a/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/JWTRSAKeyPairTest.java
 
b/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/JWTRSAKeyPairTest.java
index 6627fbf9..8eb89fb4 100644
--- 
a/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/JWTRSAKeyPairTest.java
+++ 
b/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/JWTRSAKeyPairTest.java
@@ -41,6 +41,7 @@ import org.apache.polaris.service.auth.PemUtils;
 import org.apache.polaris.service.auth.TokenBroker;
 import org.apache.polaris.service.auth.TokenRequestValidator;
 import org.apache.polaris.service.auth.TokenResponse;
+import org.apache.polaris.service.types.TokenType;
 import org.junit.jupiter.api.Test;
 import org.mockito.Mockito;
 
@@ -76,7 +77,11 @@ public class JWTRSAKeyPairTest {
         new JWTRSAKeyPair(metastoreManager, session, 420, publicFileLocation, 
privateFileLocation);
     TokenResponse token =
         tokenBroker.generateFromClientSecrets(
-            clientId, mainSecret, TokenRequestValidator.CLIENT_CREDENTIALS, 
scope);
+            clientId,
+            mainSecret,
+            TokenRequestValidator.CLIENT_CREDENTIALS,
+            scope,
+            TokenType.ACCESS_TOKEN);
     assertThat(token).isNotNull();
     assertThat(token.getExpiresIn()).isEqualTo(420);
 
diff --git 
a/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/JWTSymmetricKeyGeneratorTest.java
 
b/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/JWTSymmetricKeyGeneratorTest.java
index 039e575c..787e329e 100644
--- 
a/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/JWTSymmetricKeyGeneratorTest.java
+++ 
b/quarkus/service/src/test/java/org/apache/polaris/service/quarkus/auth/JWTSymmetricKeyGeneratorTest.java
@@ -34,6 +34,7 @@ import org.apache.polaris.service.auth.JWTSymmetricKeyBroker;
 import org.apache.polaris.service.auth.TokenBroker;
 import org.apache.polaris.service.auth.TokenRequestValidator;
 import org.apache.polaris.service.auth.TokenResponse;
+import org.apache.polaris.service.types.TokenType;
 import org.junit.jupiter.api.Test;
 import org.mockito.Mockito;
 
@@ -64,7 +65,11 @@ public class JWTSymmetricKeyGeneratorTest {
         new JWTSymmetricKeyBroker(metastoreManager, metaStoreSession, 666, () 
-> "polaris");
     TokenResponse token =
         generator.generateFromClientSecrets(
-            clientId, mainSecret, TokenRequestValidator.CLIENT_CREDENTIALS, 
"PRINCIPAL_ROLE:TEST");
+            clientId,
+            mainSecret,
+            TokenRequestValidator.CLIENT_CREDENTIALS,
+            "PRINCIPAL_ROLE:TEST",
+            TokenType.ACCESS_TOKEN);
     assertThat(token).isNotNull();
 
     JWTVerifier verifier = 
JWT.require(Algorithm.HMAC256("polaris")).withIssuer("polaris").build();
diff --git 
a/service/common/src/main/java/org/apache/polaris/service/auth/DefaultOAuth2ApiService.java
 
b/service/common/src/main/java/org/apache/polaris/service/auth/DefaultOAuth2ApiService.java
index f35dd370..4055d6e5 100644
--- 
a/service/common/src/main/java/org/apache/polaris/service/auth/DefaultOAuth2ApiService.java
+++ 
b/service/common/src/main/java/org/apache/polaris/service/auth/DefaultOAuth2ApiService.java
@@ -43,7 +43,6 @@ public class DefaultOAuth2ApiService implements 
IcebergRestOAuth2ApiService {
 
   private static final Logger LOGGER = 
LoggerFactory.getLogger(DefaultOAuth2ApiService.class);
 
-  private static final String CLIENT_CREDENTIALS = "client_credentials";
   private static final String BEARER = "bearer";
 
   private final TokenBrokerFactory tokenBrokerFactory;
@@ -75,43 +74,39 @@ public class DefaultOAuth2ApiService implements 
IcebergRestOAuth2ApiService {
     if (!tokenBroker.supportsRequestedTokenType(requestedTokenType)) {
       return 
OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.invalid_request);
     }
-    if (authHeader == null && clientId == null) {
+    if (authHeader == null && clientSecret == null) {
       return 
OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.invalid_client);
     }
-    if (authHeader != null && clientId == null && authHeader.startsWith("Basic 
")) {
+    // token exchange with client id and client secret in the authorization 
header means the client
+    // has previously attempted to refresh an access token, but refreshing was 
not supported by the
+    // token broker. Accept the client id and secret and treat it as a new 
token request
+    if (authHeader != null && clientSecret == null && 
authHeader.startsWith("Basic ")) {
       String credentials = new 
String(Base64.decodeBase64(authHeader.substring(6)), UTF_8);
       if (!credentials.contains(":")) {
-        return 
OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.invalid_client);
+        return 
OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.invalid_request);
       }
       LOGGER.debug("Found credentials in auth header - treating as 
client_credentials");
       String[] parts = credentials.split(":", 2);
-      clientId = parts[0];
-      clientSecret = parts[1];
+      if (parts.length == 2) {
+        clientId = parts[0];
+        clientSecret = parts[1];
+      } else {
+        LOGGER.debug("Don't know how to parse Basic auth header");
+        return 
OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.invalid_request);
+      }
+    }
+    TokenResponse tokenResponse;
+    if (clientSecret != null) {
+      tokenResponse =
+          tokenBroker.generateFromClientSecrets(
+              clientId, clientSecret, grantType, scope, requestedTokenType);
+    } else if (subjectToken != null) {
+      tokenResponse =
+          tokenBroker.generateFromToken(
+              subjectTokenType, subjectToken, grantType, scope, 
requestedTokenType);
+    } else {
+      return 
OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.invalid_request);
     }
-    TokenResponse tokenResponse =
-        switch (subjectTokenType) {
-          case TokenType.ID_TOKEN,
-                  TokenType.REFRESH_TOKEN,
-                  TokenType.JWT,
-                  TokenType.SAML1,
-                  TokenType.SAML2 ->
-              new TokenResponse(OAuthTokenErrorResponse.Error.invalid_request);
-          case TokenType.ACCESS_TOKEN -> {
-            // token exchange with client id and client secret means the 
client has previously
-            // attempted to refresh
-            // an access token, but refreshing was not supported by the token 
broker. Accept the
-            // client id and
-            // secret and treat it as a new token request
-            if (clientId != null && clientSecret != null) {
-              yield tokenBroker.generateFromClientSecrets(
-                  clientId, clientSecret, CLIENT_CREDENTIALS, scope);
-            } else {
-              yield tokenBroker.generateFromToken(subjectTokenType, 
subjectToken, grantType, scope);
-            }
-          }
-          case null ->
-              tokenBroker.generateFromClientSecrets(clientId, clientSecret, 
grantType, scope);
-        };
     if (tokenResponse == null) {
       return 
OAuthUtils.getResponseFromError(OAuthTokenErrorResponse.Error.unsupported_grant_type);
     }
diff --git 
a/service/common/src/main/java/org/apache/polaris/service/auth/JWTBroker.java 
b/service/common/src/main/java/org/apache/polaris/service/auth/JWTBroker.java
index cd4b24d1..17a87a31 100644
--- 
a/service/common/src/main/java/org/apache/polaris/service/auth/JWTBroker.java
+++ 
b/service/common/src/main/java/org/apache/polaris/service/auth/JWTBroker.java
@@ -98,8 +98,15 @@ public abstract class JWTBroker implements TokenBroker {
 
   @Override
   public TokenResponse generateFromToken(
-      TokenType tokenType, String subjectToken, String grantType, String 
scope) {
-    if (!TokenType.ACCESS_TOKEN.equals(tokenType)) {
+      TokenType subjectTokenType,
+      String subjectToken,
+      String grantType,
+      String scope,
+      TokenType requestedTokenType) {
+    if (!TokenType.ACCESS_TOKEN.equals(requestedTokenType)) {
+      return new TokenResponse(OAuthTokenErrorResponse.Error.invalid_request);
+    }
+    if (!TokenType.ACCESS_TOKEN.equals(subjectTokenType)) {
       return new TokenResponse(OAuthTokenErrorResponse.Error.invalid_request);
     }
     if (StringUtils.isBlank(subjectToken)) {
@@ -121,7 +128,11 @@ public abstract class JWTBroker implements TokenBroker {
 
   @Override
   public TokenResponse generateFromClientSecrets(
-      String clientId, String clientSecret, String grantType, String scope) {
+      String clientId,
+      String clientSecret,
+      String grantType,
+      String scope,
+      TokenType requestedTokenType) {
     // Initial sanity checks
     TokenRequestValidator validator = new TokenRequestValidator();
     Optional<OAuthTokenErrorResponse.Error> initialValidationResponse =
diff --git 
a/service/common/src/main/java/org/apache/polaris/service/auth/NoneTokenBrokerFactory.java
 
b/service/common/src/main/java/org/apache/polaris/service/auth/NoneTokenBrokerFactory.java
index 9f642a2b..175c2afa 100644
--- 
a/service/common/src/main/java/org/apache/polaris/service/auth/NoneTokenBrokerFactory.java
+++ 
b/service/common/src/main/java/org/apache/polaris/service/auth/NoneTokenBrokerFactory.java
@@ -42,13 +42,21 @@ public class NoneTokenBrokerFactory implements 
TokenBrokerFactory {
 
         @Override
         public TokenResponse generateFromClientSecrets(
-            String clientId, String clientSecret, String grantType, String 
scope) {
+            String clientId,
+            String clientSecret,
+            String grantType,
+            String scope,
+            TokenType requestedTokenType) {
           return null;
         }
 
         @Override
         public TokenResponse generateFromToken(
-            TokenType tokenType, String subjectToken, String grantType, String 
scope) {
+            TokenType subjectTokenType,
+            String subjectToken,
+            String grantType,
+            String scope,
+            TokenType requestedTokenType) {
           return null;
         }
 
diff --git 
a/service/common/src/main/java/org/apache/polaris/service/auth/TokenBroker.java 
b/service/common/src/main/java/org/apache/polaris/service/auth/TokenBroker.java
index 31b6c400..62969938 100644
--- 
a/service/common/src/main/java/org/apache/polaris/service/auth/TokenBroker.java
+++ 
b/service/common/src/main/java/org/apache/polaris/service/auth/TokenBroker.java
@@ -34,11 +34,77 @@ public interface TokenBroker {
 
   boolean supportsRequestedTokenType(TokenType tokenType);
 
+  /**
+   * Generate a token from client secrets without specifying the requested 
token type
+   *
+   * @param clientId
+   * @param clientSecret
+   * @param grantType
+   * @param scope
+   * @return the response indicating an error or the requested token
+   * @deprecated - use the method with the requested token type
+   */
+  @Deprecated
+  default TokenResponse generateFromClientSecrets(
+      final String clientId,
+      final String clientSecret,
+      final String grantType,
+      final String scope) {
+    return generateFromClientSecrets(
+        clientId, clientSecret, grantType, scope, TokenType.ACCESS_TOKEN);
+  }
+
+  /**
+   * Generate a token from client secrets
+   *
+   * @param clientId
+   * @param clientSecret
+   * @param grantType
+   * @param scope
+   * @param requestedTokenType
+   * @return the response indicating an error or the requested token
+   */
   TokenResponse generateFromClientSecrets(
-      final String clientId, final String clientSecret, final String 
grantType, final String scope);
+      final String clientId,
+      final String clientSecret,
+      final String grantType,
+      final String scope,
+      TokenType requestedTokenType);
+
+  /**
+   * Generate a token from an existing token of a specified type without 
specifying the requested
+   * token type
+   *
+   * @param subjectTokenType
+   * @param subjectToken
+   * @param grantType
+   * @param scope
+   * @return the response indicating an error or the requested token
+   * @deprecated - use the method with the requested token type
+   */
+  @Deprecated
+  default TokenResponse generateFromToken(
+      TokenType subjectTokenType, String subjectToken, final String grantType, 
final String scope) {
+    return generateFromToken(
+        subjectTokenType, subjectToken, grantType, scope, 
TokenType.ACCESS_TOKEN);
+  }
 
+  /**
+   * Generate a token from an existing token of a specified type
+   *
+   * @param subjectTokenType
+   * @param subjectToken
+   * @param grantType
+   * @param scope
+   * @param requestedTokenType
+   * @return the response indicating an error or the requested token
+   */
   TokenResponse generateFromToken(
-      TokenType tokenType, String subjectToken, final String grantType, final 
String scope);
+      TokenType subjectTokenType,
+      String subjectToken,
+      final String grantType,
+      final String scope,
+      TokenType requestedTokenType);
 
   DecodedToken verify(String token);
 
diff --git 
a/service/common/src/test/java/org/apache/polaris/service/auth/DefaultOAuth2ApiServiceTest.java
 
b/service/common/src/test/java/org/apache/polaris/service/auth/DefaultOAuth2ApiServiceTest.java
new file mode 100644
index 00000000..ea119e5c
--- /dev/null
+++ 
b/service/common/src/test/java/org/apache/polaris/service/auth/DefaultOAuth2ApiServiceTest.java
@@ -0,0 +1,328 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.polaris.service.auth;
+
+import static org.mockito.Mockito.when;
+
+import jakarta.ws.rs.core.Response;
+import jakarta.ws.rs.core.SecurityContext;
+import java.nio.charset.Charset;
+import java.util.Base64;
+import org.apache.iceberg.rest.responses.OAuthTokenResponse;
+import org.apache.polaris.core.context.RealmContext;
+import org.apache.polaris.service.types.TokenType;
+import org.assertj.core.api.Assertions;
+import org.assertj.core.api.InstanceOfAssertFactories;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+
+class DefaultOAuth2ApiServiceTest {
+  private static final String CLIENT_CREDENTIALS = "client_credentials";
+
+  @Test
+  public void testNoSupportGrantType() {
+    TokenBrokerFactory tokenBrokerFactory = Mockito.mock();
+    RealmContext realmContext = () -> "realm";
+    TokenBroker tokenBroker = Mockito.mock();
+    when(tokenBrokerFactory.apply(realmContext)).thenReturn(tokenBroker);
+    when(tokenBroker.supportsGrantType(CLIENT_CREDENTIALS)).thenReturn(false);
+    
when(tokenBroker.supportsRequestedTokenType(TokenType.ACCESS_TOKEN)).thenReturn(true);
+    when(tokenBroker.generateFromClientSecrets(
+            "client", "secret", CLIENT_CREDENTIALS, "scope", 
TokenType.ACCESS_TOKEN))
+        .thenReturn(new TokenResponse("token", 
TokenType.ACCESS_TOKEN.getValue(), 3600));
+    Response response =
+        new InvocationBuilder()
+            .scope("scope")
+            .clientId("client")
+            .clientSecret("secret")
+            .grantType(CLIENT_CREDENTIALS)
+            .requestedTokenType(TokenType.ACCESS_TOKEN)
+            .realmContext(realmContext)
+            .invoke(new DefaultOAuth2ApiService(tokenBrokerFactory));
+    Assertions.assertThat(response.getEntity())
+        .isInstanceOf(OAuthTokenErrorResponse.class)
+        
.asInstanceOf(InstanceOfAssertFactories.type(OAuthTokenErrorResponse.class))
+        .returns(
+            OAuthTokenErrorResponse.Error.unsupported_grant_type.name(),
+            OAuthTokenErrorResponse::getError);
+  }
+
+  @Test
+  public void testNoSupportRequestedTokenType() {
+    TokenBrokerFactory tokenBrokerFactory = Mockito.mock();
+    RealmContext realmContext = () -> "realm";
+    TokenBroker tokenBroker = Mockito.mock();
+    when(tokenBrokerFactory.apply(realmContext)).thenReturn(tokenBroker);
+    when(tokenBroker.supportsGrantType(CLIENT_CREDENTIALS)).thenReturn(true);
+    
when(tokenBroker.supportsRequestedTokenType(TokenType.ACCESS_TOKEN)).thenReturn(false);
+    when(tokenBroker.generateFromClientSecrets(
+            "client", "secret", CLIENT_CREDENTIALS, "scope", 
TokenType.ACCESS_TOKEN))
+        .thenReturn(new TokenResponse("token", 
TokenType.ACCESS_TOKEN.getValue(), 3600));
+    Response response =
+        new InvocationBuilder()
+            .scope("scope")
+            .clientId("client")
+            .clientSecret("secret")
+            .grantType(CLIENT_CREDENTIALS)
+            .requestedTokenType(TokenType.ACCESS_TOKEN)
+            .realmContext(realmContext)
+            .invoke(new DefaultOAuth2ApiService(tokenBrokerFactory));
+    Assertions.assertThat(response.getEntity())
+        .isInstanceOf(OAuthTokenErrorResponse.class)
+        
.asInstanceOf(InstanceOfAssertFactories.type(OAuthTokenErrorResponse.class))
+        .returns(
+            OAuthTokenErrorResponse.Error.invalid_request.name(),
+            OAuthTokenErrorResponse::getError);
+  }
+
+  @Test
+  public void testSupportClientIdNoSecret() {
+    TokenBrokerFactory tokenBrokerFactory = Mockito.mock();
+    RealmContext realmContext = () -> "realm";
+    TokenBroker tokenBroker = Mockito.mock();
+    when(tokenBrokerFactory.apply(realmContext)).thenReturn(tokenBroker);
+    when(tokenBroker.supportsGrantType(CLIENT_CREDENTIALS)).thenReturn(true);
+    
when(tokenBroker.supportsRequestedTokenType(TokenType.ACCESS_TOKEN)).thenReturn(true);
+    when(tokenBroker.generateFromClientSecrets(
+            null, "secret", CLIENT_CREDENTIALS, "scope", 
TokenType.ACCESS_TOKEN))
+        .thenReturn(new TokenResponse("token", 
TokenType.ACCESS_TOKEN.getValue(), 3600));
+    Response response =
+        new InvocationBuilder()
+            .scope("scope")
+            .clientSecret("secret")
+            .grantType(CLIENT_CREDENTIALS)
+            .requestedTokenType(TokenType.ACCESS_TOKEN)
+            .realmContext(realmContext)
+            .invoke(new DefaultOAuth2ApiService(tokenBrokerFactory));
+    Assertions.assertThat(response.getEntity())
+        .isInstanceOf(OAuthTokenResponse.class)
+        .asInstanceOf(InstanceOfAssertFactories.type(OAuthTokenResponse.class))
+        .returns("token", OAuthTokenResponse::token);
+  }
+
+  @Test
+  public void testSupportClientIdAndSecret() {
+    TokenBrokerFactory tokenBrokerFactory = Mockito.mock();
+    RealmContext realmContext = () -> "realm";
+    TokenBroker tokenBroker = Mockito.mock();
+    when(tokenBrokerFactory.apply(realmContext)).thenReturn(tokenBroker);
+    when(tokenBroker.supportsGrantType(CLIENT_CREDENTIALS)).thenReturn(true);
+    
when(tokenBroker.supportsRequestedTokenType(TokenType.ACCESS_TOKEN)).thenReturn(true);
+    when(tokenBroker.generateFromClientSecrets(
+            "client", "secret", CLIENT_CREDENTIALS, "scope", 
TokenType.ACCESS_TOKEN))
+        .thenReturn(new TokenResponse("token", 
TokenType.ACCESS_TOKEN.getValue(), 3600));
+    Response response =
+        new InvocationBuilder()
+            .scope("scope")
+            .clientId("client")
+            .clientSecret("secret")
+            .grantType(CLIENT_CREDENTIALS)
+            .requestedTokenType(TokenType.ACCESS_TOKEN)
+            .realmContext(realmContext)
+            .invoke(new DefaultOAuth2ApiService(tokenBrokerFactory));
+    Assertions.assertThat(response.getEntity())
+        .isInstanceOf(OAuthTokenResponse.class)
+        .asInstanceOf(InstanceOfAssertFactories.type(OAuthTokenResponse.class))
+        .returns("token", OAuthTokenResponse::token);
+  }
+
+  @Test
+  public void testReadClientCredentialsFromAuthHeader() {
+    TokenBrokerFactory tokenBrokerFactory = Mockito.mock();
+    RealmContext realmContext = () -> "realm";
+    TokenBroker tokenBroker = Mockito.mock();
+    when(tokenBrokerFactory.apply(realmContext)).thenReturn(tokenBroker);
+    
when(tokenBroker.supportsGrantType(TokenRequestValidator.TOKEN_EXCHANGE)).thenReturn(true);
+    
when(tokenBroker.supportsRequestedTokenType(TokenType.ACCESS_TOKEN)).thenReturn(true);
+    when(tokenBroker.generateFromClientSecrets(
+            "client",
+            "secret",
+            TokenRequestValidator.TOKEN_EXCHANGE,
+            "scope",
+            TokenType.ACCESS_TOKEN))
+        .thenReturn(new TokenResponse("token", 
TokenType.ACCESS_TOKEN.getValue(), 3600));
+    Response response =
+        new InvocationBuilder()
+            .authHeader(
+                "Basic "
+                    + Base64.getEncoder()
+                        
.encodeToString("client:secret".getBytes(Charset.defaultCharset())))
+            .scope("scope")
+            .grantType(TokenRequestValidator.TOKEN_EXCHANGE)
+            .requestedTokenType(TokenType.ACCESS_TOKEN)
+            .realmContext(realmContext)
+            .invoke(new DefaultOAuth2ApiService(tokenBrokerFactory));
+    Assertions.assertThat(response.getEntity())
+        .isInstanceOf(OAuthTokenResponse.class)
+        .asInstanceOf(InstanceOfAssertFactories.type(OAuthTokenResponse.class))
+        .returns("token", OAuthTokenResponse::token);
+  }
+
+  @Test
+  public void testAuthHeaderRequiresValidCredentialPair() {
+    TokenBrokerFactory tokenBrokerFactory = Mockito.mock();
+    RealmContext realmContext = () -> "realm";
+    TokenBroker tokenBroker = Mockito.mock();
+    when(tokenBrokerFactory.apply(realmContext)).thenReturn(tokenBroker);
+    
when(tokenBroker.supportsGrantType(TokenRequestValidator.TOKEN_EXCHANGE)).thenReturn(true);
+    
when(tokenBroker.supportsRequestedTokenType(TokenType.ACCESS_TOKEN)).thenReturn(true);
+    when(tokenBroker.generateFromClientSecrets(
+            null, "secret", TokenRequestValidator.TOKEN_EXCHANGE, "scope", 
TokenType.ACCESS_TOKEN))
+        .thenReturn(new TokenResponse("token", 
TokenType.ACCESS_TOKEN.getValue(), 3600));
+    Response response =
+        new InvocationBuilder()
+            .authHeader(
+                "Basic "
+                    + Base64.getEncoder()
+                        
.encodeToString("secret".getBytes(Charset.defaultCharset())))
+            .scope("scope")
+            .grantType(TokenRequestValidator.TOKEN_EXCHANGE)
+            .requestedTokenType(TokenType.ACCESS_TOKEN)
+            .realmContext(realmContext)
+            .invoke(new DefaultOAuth2ApiService(tokenBrokerFactory));
+    Assertions.assertThat(response.getEntity())
+        .isInstanceOf(OAuthTokenErrorResponse.class)
+        
.asInstanceOf(InstanceOfAssertFactories.type(OAuthTokenErrorResponse.class))
+        .returns(
+            OAuthTokenErrorResponse.Error.invalid_request.name(),
+            OAuthTokenErrorResponse::getError);
+  }
+
+  @Test
+  public void testReadClientSecretFromAuthHeader() {
+    TokenBrokerFactory tokenBrokerFactory = Mockito.mock();
+    RealmContext realmContext = () -> "realm";
+    TokenBroker tokenBroker = Mockito.mock();
+    when(tokenBrokerFactory.apply(realmContext)).thenReturn(tokenBroker);
+    
when(tokenBroker.supportsGrantType(TokenRequestValidator.TOKEN_EXCHANGE)).thenReturn(true);
+    
when(tokenBroker.supportsRequestedTokenType(TokenType.ACCESS_TOKEN)).thenReturn(true);
+
+    when(tokenBroker.generateFromClientSecrets(
+            "", "secret", TokenRequestValidator.TOKEN_EXCHANGE, "scope", 
TokenType.ACCESS_TOKEN))
+        .thenReturn(new TokenResponse("token", 
TokenType.ACCESS_TOKEN.getValue(), 3600));
+    Response response =
+        new InvocationBuilder()
+
+            // here the auth header has a blank client id, providing a blank, 
but not null client id
+            .authHeader(
+                "Basic "
+                    + Base64.getEncoder()
+                        
.encodeToString(":secret".getBytes(Charset.defaultCharset())))
+            .scope("scope")
+            .grantType(TokenRequestValidator.TOKEN_EXCHANGE)
+            .requestedTokenType(TokenType.ACCESS_TOKEN)
+            .realmContext(realmContext)
+            .invoke(new DefaultOAuth2ApiService(tokenBrokerFactory));
+    Assertions.assertThat(response.getEntity())
+        .isInstanceOf(OAuthTokenResponse.class)
+        .asInstanceOf(InstanceOfAssertFactories.type(OAuthTokenResponse.class))
+        .returns("token", OAuthTokenResponse::token);
+  }
+
+  private static final class InvocationBuilder {
+    private String authHeader;
+    private String grantType;
+    private String scope;
+    private String clientId;
+    private String clientSecret;
+    private TokenType requestedTokenType;
+    private String subjectToken;
+    private TokenType subjectTokenType;
+    private String actorToken;
+    private TokenType actorTokenType;
+    private RealmContext realmContext;
+    private SecurityContext securityContext;
+
+    public InvocationBuilder authHeader(String authHeader) {
+      this.authHeader = authHeader;
+      return this;
+    }
+
+    public InvocationBuilder grantType(String grantType) {
+      this.grantType = grantType;
+      return this;
+    }
+
+    public InvocationBuilder scope(String scope) {
+      this.scope = scope;
+      return this;
+    }
+
+    public InvocationBuilder clientId(String clientId) {
+      this.clientId = clientId;
+      return this;
+    }
+
+    public InvocationBuilder clientSecret(String clientSecret) {
+      this.clientSecret = clientSecret;
+      return this;
+    }
+
+    public InvocationBuilder requestedTokenType(TokenType requestedTokenType) {
+      this.requestedTokenType = requestedTokenType;
+      return this;
+    }
+
+    public InvocationBuilder subjectToken(String subjectToken) {
+      this.subjectToken = subjectToken;
+      return this;
+    }
+
+    public InvocationBuilder subjectTokenType(TokenType subjectTokenType) {
+      this.subjectTokenType = subjectTokenType;
+      return this;
+    }
+
+    public InvocationBuilder actorToken(String actorToken) {
+      this.actorToken = actorToken;
+      return this;
+    }
+
+    public InvocationBuilder actorTokenType(TokenType actorTokenType) {
+      this.actorTokenType = actorTokenType;
+      return this;
+    }
+
+    public InvocationBuilder realmContext(RealmContext realmContext) {
+      this.realmContext = realmContext;
+      return this;
+    }
+
+    public InvocationBuilder securityContext(SecurityContext securityContext) {
+      this.securityContext = securityContext;
+      return this;
+    }
+
+    public Response invoke(DefaultOAuth2ApiService instance) {
+      return instance.getToken(
+          authHeader,
+          grantType,
+          scope,
+          clientId,
+          clientSecret,
+          requestedTokenType,
+          subjectToken,
+          subjectTokenType,
+          actorToken,
+          actorTokenType,
+          realmContext,
+          securityContext);
+    }
+  }
+}

Reply via email to