This is an automated email from the ASF dual-hosted git repository.

mmarshall pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/pulsar.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 197d35e3673 [feat] OIDC: support JWKS refresh for missing Key ID 
(#20338)
197d35e3673 is described below

commit 197d35e3673a7e9fa50c20587d5824a4f3378c7a
Author: Michael Marshall <[email protected]>
AuthorDate: Wed May 17 11:28:32 2023 -0500

    [feat] OIDC: support JWKS refresh for missing Key ID (#20338)
    
    ### Motivation
    
    When the `AuthenticationProviderOpenID` encounters an unknown Key ID (known 
as the `kid` in the JWT) for a trusted token issuer, the token is rejected with 
the following error `java.lang.IllegalArgumentException: No JWK found for Key 
ID <kid>`. This behavior is technically valid, but it isn't ideal because it is 
possible that the Identity Provider has issued new signing keys and started 
using them before the Pulsar authentication provider has refreshed its cache. 
This PR introduces a  [...]
    
    This PR adds a configuration called `openIDKeyIdCacheMissRefreshSeconds` 
that represents the length of time that must pass before the provider will 
reload the JWKS from the Identity Provider. The 
`openIDKeyIdCacheMissRefreshSeconds` setting limits the impact of an attacker 
invalidating the JWKS cache.
    
    When `openIDKeyIdCacheMissRefreshSeconds <= 0`, the JWKS will be refreshed 
for any missing key id when the issuer is trusted. This is only meant for 
testing.
    
    ### Modifications
    
    * Add `openIDKeyIdCacheMissRefreshSeconds` setting and default it to 5 
minutes.
    * Add functionality to invalidate and refresh the cache when a token has an 
unknown `kid` for a trusted issuer.
    
    ### Verifying this change
    
    New tests are added.
    
    ### Does this pull request potentially affect one of the following parts:
    
    This adds a new configuration, but it is very minor.
    
    ### Documentation
    
    - [x] `doc-required`
    
    ### Matching PR in forked repository
    
    PR in forked repository: Skipping since tests passed already on my local 
machine
    
    (cherry picked from commit d92010c00e30199c5cd53021e20ba5d7bb36701c)
---
 .../oidc/AuthenticationProviderOpenID.java         |   2 +
 .../broker/authentication/oidc/JwksCache.java      |  48 ++++++++--
 ...uthenticationProviderOpenIDIntegrationTest.java | 100 +++++++++++++++++++++
 3 files changed, 144 insertions(+), 6 deletions(-)

diff --git 
a/pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenID.java
 
b/pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenID.java
index 00ec09bd181..2078666a08d 100644
--- 
a/pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenID.java
+++ 
b/pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenID.java
@@ -133,6 +133,8 @@ public class AuthenticationProviderOpenID implements 
AuthenticationProvider {
     static final int CACHE_REFRESH_AFTER_WRITE_SECONDS_DEFAULT = 18 * 60 * 60;
     static final String CACHE_EXPIRATION_SECONDS = 
"openIDCacheExpirationSeconds";
     static final int CACHE_EXPIRATION_SECONDS_DEFAULT = 24 * 60 * 60;
+    static final String KEY_ID_CACHE_MISS_REFRESH_SECONDS = 
"openIDKeyIdCacheMissRefreshSeconds";
+    static final int KEY_ID_CACHE_MISS_REFRESH_SECONDS_DEFAULT = 5 * 60;
     static final String HTTP_CONNECTION_TIMEOUT_MILLIS = 
"openIDHttpConnectionTimeoutMillis";
     static final int HTTP_CONNECTION_TIMEOUT_MILLIS_DEFAULT = 10_000;
     static final String HTTP_READ_TIMEOUT_MILLIS = 
"openIDHttpReadTimeoutMillis";
diff --git 
a/pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/JwksCache.java
 
b/pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/JwksCache.java
index b5e038342c2..73934e9c1e0 100644
--- 
a/pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/JwksCache.java
+++ 
b/pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/JwksCache.java
@@ -24,6 +24,8 @@ import static 
org.apache.pulsar.broker.authentication.oidc.AuthenticationProvide
 import static 
org.apache.pulsar.broker.authentication.oidc.AuthenticationProviderOpenID.CACHE_REFRESH_AFTER_WRITE_SECONDS_DEFAULT;
 import static 
org.apache.pulsar.broker.authentication.oidc.AuthenticationProviderOpenID.CACHE_SIZE;
 import static 
org.apache.pulsar.broker.authentication.oidc.AuthenticationProviderOpenID.CACHE_SIZE_DEFAULT;
+import static 
org.apache.pulsar.broker.authentication.oidc.AuthenticationProviderOpenID.KEY_ID_CACHE_MISS_REFRESH_SECONDS;
+import static 
org.apache.pulsar.broker.authentication.oidc.AuthenticationProviderOpenID.KEY_ID_CACHE_MISS_REFRESH_SECONDS_DEFAULT;
 import static 
org.apache.pulsar.broker.authentication.oidc.AuthenticationProviderOpenID.incrementFailureMetric;
 import static 
org.apache.pulsar.broker.authentication.oidc.ConfigUtils.getConfigValueAsInt;
 import com.auth0.jwk.Jwk;
@@ -43,6 +45,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.TimeUnit;
 import javax.naming.AuthenticationException;
 import org.apache.pulsar.broker.ServiceConfiguration;
@@ -52,7 +55,8 @@ public class JwksCache {
 
     // Map from an issuer's JWKS URI to its JWKS. When the Issuer is not 
empty, use the fallback client.
     private final AsyncLoadingCache<Optional<String>, List<Jwk>> cache;
-
+    private final ConcurrentHashMap<Optional<String>, Long> 
jwksLastRefreshTime = new ConcurrentHashMap<>();
+    private final long keyIdCacheMissRefreshNanos;
     private final ObjectReader reader = new 
ObjectMapper().readerFor(HashMap.class);
     private final AsyncHttpClient httpClient;
     private final OpenidApi openidApi;
@@ -61,7 +65,8 @@ public class JwksCache {
         // Store the clients
         this.httpClient = httpClient;
         this.openidApi = apiClient != null ? new OpenidApi(apiClient) : null;
-
+        keyIdCacheMissRefreshNanos = 
TimeUnit.SECONDS.toNanos(getConfigValueAsInt(config,
+                KEY_ID_CACHE_MISS_REFRESH_SECONDS, 
KEY_ID_CACHE_MISS_REFRESH_SECONDS_DEFAULT));
         // Configure the cache
         int maxSize = getConfigValueAsInt(config, CACHE_SIZE, 
CACHE_SIZE_DEFAULT);
         int refreshAfterWriteSeconds = getConfigValueAsInt(config, 
CACHE_REFRESH_AFTER_WRITE_SECONDS,
@@ -69,6 +74,8 @@ public class JwksCache {
         int expireAfterSeconds = getConfigValueAsInt(config, 
CACHE_EXPIRATION_SECONDS,
                 CACHE_EXPIRATION_SECONDS_DEFAULT);
         AsyncCacheLoader<Optional<String>, List<Jwk>> loader = (jwksUri, 
executor) -> {
+            // Store the time of the retrieval, even though it might be a 
little early or the call might fail.
+            jwksLastRefreshTime.put(jwksUri, System.nanoTime());
             if (jwksUri.isPresent()) {
                 return getJwksFromJwksUri(jwksUri.get());
             } else {
@@ -87,7 +94,37 @@ public class JwksCache {
             
incrementFailureMetric(AuthenticationExceptionCode.ERROR_RETRIEVING_PUBLIC_KEY);
             return CompletableFuture.failedFuture(new 
IllegalArgumentException("jwksUri must not be null."));
         }
-        return cache.get(Optional.of(jwksUri)).thenApply(jwks -> 
getJwkForKID(jwks, keyId));
+        return getJwkAndMaybeReload(Optional.of(jwksUri), keyId, false);
+    }
+
+    /**
+     * Retrieve the JWK for the given key ID from the given JWKS URI. If the 
key ID is not found, and failOnMissingKeyId
+     * is false, then the JWK will be reloaded from the JWKS URI and the key 
ID will be searched for again.
+     */
+    private CompletableFuture<Jwk> getJwkAndMaybeReload(Optional<String> 
maybeJwksUri,
+                                                        String keyId,
+                                                        boolean 
failOnMissingKeyId) {
+        return cache
+                .get(maybeJwksUri)
+                .thenCompose(jwks -> {
+                    try {
+                        return 
CompletableFuture.completedFuture(getJwkForKID(maybeJwksUri, jwks, keyId));
+                    } catch (IllegalArgumentException e) {
+                        if (failOnMissingKeyId) {
+                            throw e;
+                        } else {
+                            Long lastRefresh = 
jwksLastRefreshTime.get(maybeJwksUri);
+                            if (lastRefresh == null || System.nanoTime() - 
lastRefresh > keyIdCacheMissRefreshNanos) {
+                                // In this case, the key ID was not found, but 
we haven't refreshed the JWKS in a while,
+                                // so it is possible the key ID was added. 
Refresh the JWKS and try again.
+                                cache.synchronous().invalidate(maybeJwksUri);
+                            }
+                            // There is a small race condition where the JWKS 
could be refreshed by another thread,
+                            // so we retry getting the JWK, even though we 
might not have invalidated the cache.
+                            return getJwkAndMaybeReload(maybeJwksUri, keyId, 
true);
+                        }
+                    }
+                });
     }
 
     private CompletableFuture<List<Jwk>> getJwksFromJwksUri(String jwksUri) {
@@ -119,8 +156,7 @@ public class JwksCache {
             return CompletableFuture.failedFuture(new AuthenticationException(
                     "Failed to retrieve public key from Kubernetes API server: 
Kubernetes fallback is not enabled."));
         }
-        return cache.get(Optional.empty(), (__, executor) -> 
getJwksFromKubernetesApiServer())
-                .thenApply(jwks -> getJwkForKID(jwks, keyId));
+        return getJwkAndMaybeReload(Optional.empty(), keyId, false);
     }
 
     private CompletableFuture<List<Jwk>> getJwksFromKubernetesApiServer() {
@@ -170,7 +206,7 @@ public class JwksCache {
         return future;
     }
 
-    private Jwk getJwkForKID(List<Jwk> jwks, String keyId) {
+    private Jwk getJwkForKID(Optional<String> maybeJwksUri, List<Jwk> jwks, 
String keyId) {
         for (Jwk jwk : jwks) {
             if (jwk.getId().equals(keyId)) {
                 return jwk;
diff --git 
a/pulsar-broker-auth-oidc/src/test/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenIDIntegrationTest.java
 
b/pulsar-broker-auth-oidc/src/test/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenIDIntegrationTest.java
index 0075d70f599..d2d2de1a114 100644
--- 
a/pulsar-broker-auth-oidc/src/test/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenIDIntegrationTest.java
+++ 
b/pulsar-broker-auth-oidc/src/test/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenIDIntegrationTest.java
@@ -29,6 +29,7 @@ import static org.testng.Assert.assertNull;
 import static org.testng.Assert.assertTrue;
 import static org.testng.Assert.fail;
 import com.github.tomakehurst.wiremock.WireMockServer;
+import com.github.tomakehurst.wiremock.stubbing.Scenario;
 import io.jsonwebtoken.SignatureAlgorithm;
 import io.jsonwebtoken.impl.DefaultJwtBuilder;
 import io.jsonwebtoken.io.Decoders;
@@ -58,6 +59,7 @@ import 
org.apache.pulsar.broker.authentication.utils.AuthTokenUtils;
 import org.apache.pulsar.common.api.AuthData;
 import org.testng.annotations.AfterClass;
 import org.testng.annotations.BeforeClass;
+import org.testng.annotations.BeforeMethod;
 import org.testng.annotations.Test;
 
 /**
@@ -75,6 +77,7 @@ public class AuthenticationProviderOpenIDIntegrationTest {
     // The valid issuer
     String issuer;
     String issuerWithTrailingSlash;
+    String issuerWithMissingKid;
     // This issuer is configured to return an issuer in the 
openid-configuration
     // that does not match the issuer on the token
     String issuerThatFails;
@@ -90,6 +93,7 @@ public class AuthenticationProviderOpenIDIntegrationTest {
         server.start();
         issuer = server.baseUrl();
         issuerWithTrailingSlash = issuer + "/trailing-slash/";
+        issuerWithMissingKid = issuer + "/missing-kid";
         issuerThatFails = issuer + "/fail";
         issuerK8s = issuer + "/k8s";
 
@@ -183,6 +187,50 @@ public class AuthenticationProviderOpenIDIntegrationTest {
                                         }
                                         """.formatted(validJwk, n, e, 
invalidJwk))));
 
+        server.stubFor(
+                
get(urlEqualTo("/missing-kid/.well-known/openid-configuration"))
+                        .willReturn(aResponse()
+                                .withHeader("Content-Type", "application/json")
+                                .withBody("""
+                                        {
+                                          "issuer": "%s",
+                                          "jwks_uri": "%s/keys"
+                                        }
+                                        """.formatted(issuerWithMissingKid, 
issuerWithMissingKid))));
+
+        // Set up JWKS endpoint where it first responds without the KID, then 
with the KID. This is a stateful stub.
+        // Note that the state machine is circular to make it easier to verify 
the two code paths that rely on
+        // this logic.
+        server.stubFor(
+                get(urlMatching( "/missing-kid/keys"))
+                        .inScenario("Changing KIDs")
+                        .whenScenarioStateIs(Scenario.STARTED)
+                        .willSetStateTo("serve-kid")
+                        .willReturn(aResponse()
+                                .withHeader("Content-Type", "application/json")
+                                .withBody("{\"keys\":[]}")));
+        server.stubFor(
+                get(urlMatching( "/missing-kid/keys"))
+                        .inScenario("Changing KIDs")
+                        .whenScenarioStateIs("serve-kid")
+                        .willSetStateTo(Scenario.STARTED)
+                        .willReturn(aResponse()
+                                .withHeader("Content-Type", "application/json")
+                                .withBody(
+                                        """
+                                        {
+                                            "keys" : [
+                                                {
+                                                "kid":"%s",
+                                                "kty":"RSA",
+                                                "alg":"RS256",
+                                                "n":"%s",
+                                                "e":"%s"
+                                                }
+                                            ]
+                                        }
+                                        """.formatted(validJwk, n, e))));
+
         ServiceConfiguration conf = new ServiceConfiguration();
         conf.setAuthenticationEnabled(true);
         
conf.setAuthenticationProviders(Set.of(AuthenticationProviderOpenID.class.getName()));
@@ -207,6 +255,12 @@ public class AuthenticationProviderOpenIDIntegrationTest {
         server.stop();
     }
 
+    @BeforeMethod
+    public void beforeMethod() {
+        // Scenarios are stateful. Start each test with the correct state.
+        server.resetScenarios();
+    }
+
     @Test
     public void testTokenWithValidJWK() throws Exception {
         String role = "superuser";
@@ -268,6 +322,52 @@ public class AuthenticationProviderOpenIDIntegrationTest {
             assertTrue(e.getCause() instanceof AuthenticationException, "Found 
exception: " + e.getCause());
         }
     }
+    @Test
+    public void testKidCacheMissWhenRefreshConfigZero() throws Exception {
+        ServiceConfiguration conf = new ServiceConfiguration();
+        conf.setAuthenticationEnabled(true);
+        
conf.setAuthenticationProviders(Set.of(AuthenticationProviderOpenID.class.getName()));
+        Properties props = conf.getProperties();
+        props.setProperty(AuthenticationProviderOpenID.REQUIRE_HTTPS, "false");
+        // Allows us to retrieve the JWK immediately after the cache miss of 
the KID
+        
props.setProperty(AuthenticationProviderOpenID.KEY_ID_CACHE_MISS_REFRESH_SECONDS,
 "0");
+        props.setProperty(AuthenticationProviderOpenID.ALLOWED_AUDIENCES, 
"allowed-audience");
+        props.setProperty(AuthenticationProviderOpenID.ALLOWED_TOKEN_ISSUERS, 
issuerWithMissingKid);
+
+        AuthenticationProviderOpenID provider = new 
AuthenticationProviderOpenID();
+        provider.initialize(conf);
+
+        String role = "superuser";
+        String token = generateToken(validJwk, issuerWithMissingKid, role, 
"allowed-audience", 0L, 0L, 10000L);
+        assertEquals(role, provider.authenticateAsync(new 
AuthenticationDataCommand(token)).get());
+    }
+
+    @Test
+    public void testKidCacheMissWhenRefreshConfigLongerThanDelta() throws 
Exception {
+        ServiceConfiguration conf = new ServiceConfiguration();
+        conf.setAuthenticationEnabled(true);
+        
conf.setAuthenticationProviders(Set.of(AuthenticationProviderOpenID.class.getName()));
+        Properties props = conf.getProperties();
+        props.setProperty(AuthenticationProviderOpenID.REQUIRE_HTTPS, "false");
+        // This value is high enough that the provider will not refresh the JWK
+        
props.setProperty(AuthenticationProviderOpenID.KEY_ID_CACHE_MISS_REFRESH_SECONDS,
 "100");
+        props.setProperty(AuthenticationProviderOpenID.ALLOWED_AUDIENCES, 
"allowed-audience");
+        props.setProperty(AuthenticationProviderOpenID.ALLOWED_TOKEN_ISSUERS, 
issuerWithMissingKid);
+
+        AuthenticationProviderOpenID provider = new 
AuthenticationProviderOpenID();
+        provider.initialize(conf);
+
+        String role = "superuser";
+        String token = generateToken(validJwk, issuerWithMissingKid, role, 
"allowed-audience", 0L, 0L, 10000L);
+        try {
+            provider.authenticateAsync(new 
AuthenticationDataCommand(token)).get();
+            fail("Expected exception");
+        } catch (ExecutionException e) {
+            assertTrue(e.getCause() instanceof IllegalArgumentException, 
"Found exception: " + e.getCause());
+            assertTrue(e.getCause().getMessage().contains("No JWK found for 
Key ID valid"),
+                    "Found exception: " + e.getCause());
+        }
+    }
 
     @Test
     public void testKubernetesApiServerAsDiscoverTrustedIssuerSuccess() throws 
Exception {

Reply via email to