This is an automated email from the ASF dual-hosted git repository.
mmarshall pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/pulsar.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 197d35e3673 [feat] OIDC: support JWKS refresh for missing Key ID
(#20338)
197d35e3673 is described below
commit 197d35e3673a7e9fa50c20587d5824a4f3378c7a
Author: Michael Marshall <[email protected]>
AuthorDate: Wed May 17 11:28:32 2023 -0500
[feat] OIDC: support JWKS refresh for missing Key ID (#20338)
### Motivation
When the `AuthenticationProviderOpenID` encounters an unknown Key ID (known
as the `kid` in the JWT) for a trusted token issuer, the token is rejected with
the following error `java.lang.IllegalArgumentException: No JWK found for Key
ID <kid>`. This behavior is technically valid, but it isn't ideal because it is
possible that the Identity Provider has issued new signing keys and started
using them before the Pulsar authentication provider has refreshed its cache.
This PR introduces a [...]
This PR adds a configuration called `openIDKeyIdCacheMissRefreshSeconds`
that represents the length of time that must pass before the provider will
reload the JWKS from the Identity Provider. The
`openIDKeyIdCacheMissRefreshSeconds` setting limits the impact of an attacker
invalidating the JWKS cache.
When `openIDKeyIdCacheMissRefreshSeconds <= 0`, the JWKS will be refreshed
for any missing key id when the issuer is trusted. This is only meant for
testing.
### Modifications
* Add `openIDKeyIdCacheMissRefreshSeconds` setting and default it to 5
minutes.
* Add functionality to invalidate and refresh the cache when a token has an
unknown `kid` for a trusted issuer.
### Verifying this change
New tests are added.
### Does this pull request potentially affect one of the following parts:
This adds a new configuration, but it is very minor.
### Documentation
- [x] `doc-required`
### Matching PR in forked repository
PR in forked repository: Skipping since tests passed already on my local
machine
(cherry picked from commit d92010c00e30199c5cd53021e20ba5d7bb36701c)
---
.../oidc/AuthenticationProviderOpenID.java | 2 +
.../broker/authentication/oidc/JwksCache.java | 48 ++++++++--
...uthenticationProviderOpenIDIntegrationTest.java | 100 +++++++++++++++++++++
3 files changed, 144 insertions(+), 6 deletions(-)
diff --git
a/pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenID.java
b/pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenID.java
index 00ec09bd181..2078666a08d 100644
---
a/pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenID.java
+++
b/pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenID.java
@@ -133,6 +133,8 @@ public class AuthenticationProviderOpenID implements
AuthenticationProvider {
static final int CACHE_REFRESH_AFTER_WRITE_SECONDS_DEFAULT = 18 * 60 * 60;
static final String CACHE_EXPIRATION_SECONDS =
"openIDCacheExpirationSeconds";
static final int CACHE_EXPIRATION_SECONDS_DEFAULT = 24 * 60 * 60;
+ static final String KEY_ID_CACHE_MISS_REFRESH_SECONDS =
"openIDKeyIdCacheMissRefreshSeconds";
+ static final int KEY_ID_CACHE_MISS_REFRESH_SECONDS_DEFAULT = 5 * 60;
static final String HTTP_CONNECTION_TIMEOUT_MILLIS =
"openIDHttpConnectionTimeoutMillis";
static final int HTTP_CONNECTION_TIMEOUT_MILLIS_DEFAULT = 10_000;
static final String HTTP_READ_TIMEOUT_MILLIS =
"openIDHttpReadTimeoutMillis";
diff --git
a/pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/JwksCache.java
b/pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/JwksCache.java
index b5e038342c2..73934e9c1e0 100644
---
a/pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/JwksCache.java
+++
b/pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/JwksCache.java
@@ -24,6 +24,8 @@ import static
org.apache.pulsar.broker.authentication.oidc.AuthenticationProvide
import static
org.apache.pulsar.broker.authentication.oidc.AuthenticationProviderOpenID.CACHE_REFRESH_AFTER_WRITE_SECONDS_DEFAULT;
import static
org.apache.pulsar.broker.authentication.oidc.AuthenticationProviderOpenID.CACHE_SIZE;
import static
org.apache.pulsar.broker.authentication.oidc.AuthenticationProviderOpenID.CACHE_SIZE_DEFAULT;
+import static
org.apache.pulsar.broker.authentication.oidc.AuthenticationProviderOpenID.KEY_ID_CACHE_MISS_REFRESH_SECONDS;
+import static
org.apache.pulsar.broker.authentication.oidc.AuthenticationProviderOpenID.KEY_ID_CACHE_MISS_REFRESH_SECONDS_DEFAULT;
import static
org.apache.pulsar.broker.authentication.oidc.AuthenticationProviderOpenID.incrementFailureMetric;
import static
org.apache.pulsar.broker.authentication.oidc.ConfigUtils.getConfigValueAsInt;
import com.auth0.jwk.Jwk;
@@ -43,6 +45,7 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import javax.naming.AuthenticationException;
import org.apache.pulsar.broker.ServiceConfiguration;
@@ -52,7 +55,8 @@ public class JwksCache {
// Map from an issuer's JWKS URI to its JWKS. When the Issuer is not
empty, use the fallback client.
private final AsyncLoadingCache<Optional<String>, List<Jwk>> cache;
-
+ private final ConcurrentHashMap<Optional<String>, Long>
jwksLastRefreshTime = new ConcurrentHashMap<>();
+ private final long keyIdCacheMissRefreshNanos;
private final ObjectReader reader = new
ObjectMapper().readerFor(HashMap.class);
private final AsyncHttpClient httpClient;
private final OpenidApi openidApi;
@@ -61,7 +65,8 @@ public class JwksCache {
// Store the clients
this.httpClient = httpClient;
this.openidApi = apiClient != null ? new OpenidApi(apiClient) : null;
-
+ keyIdCacheMissRefreshNanos =
TimeUnit.SECONDS.toNanos(getConfigValueAsInt(config,
+ KEY_ID_CACHE_MISS_REFRESH_SECONDS,
KEY_ID_CACHE_MISS_REFRESH_SECONDS_DEFAULT));
// Configure the cache
int maxSize = getConfigValueAsInt(config, CACHE_SIZE,
CACHE_SIZE_DEFAULT);
int refreshAfterWriteSeconds = getConfigValueAsInt(config,
CACHE_REFRESH_AFTER_WRITE_SECONDS,
@@ -69,6 +74,8 @@ public class JwksCache {
int expireAfterSeconds = getConfigValueAsInt(config,
CACHE_EXPIRATION_SECONDS,
CACHE_EXPIRATION_SECONDS_DEFAULT);
AsyncCacheLoader<Optional<String>, List<Jwk>> loader = (jwksUri,
executor) -> {
+ // Store the time of the retrieval, even though it might be a
little early or the call might fail.
+ jwksLastRefreshTime.put(jwksUri, System.nanoTime());
if (jwksUri.isPresent()) {
return getJwksFromJwksUri(jwksUri.get());
} else {
@@ -87,7 +94,37 @@ public class JwksCache {
incrementFailureMetric(AuthenticationExceptionCode.ERROR_RETRIEVING_PUBLIC_KEY);
return CompletableFuture.failedFuture(new
IllegalArgumentException("jwksUri must not be null."));
}
- return cache.get(Optional.of(jwksUri)).thenApply(jwks ->
getJwkForKID(jwks, keyId));
+ return getJwkAndMaybeReload(Optional.of(jwksUri), keyId, false);
+ }
+
+ /**
+ * Retrieve the JWK for the given key ID from the given JWKS URI. If the
key ID is not found, and failOnMissingKeyId
+ * is false, then the JWK will be reloaded from the JWKS URI and the key
ID will be searched for again.
+ */
+ private CompletableFuture<Jwk> getJwkAndMaybeReload(Optional<String>
maybeJwksUri,
+ String keyId,
+ boolean
failOnMissingKeyId) {
+ return cache
+ .get(maybeJwksUri)
+ .thenCompose(jwks -> {
+ try {
+ return
CompletableFuture.completedFuture(getJwkForKID(maybeJwksUri, jwks, keyId));
+ } catch (IllegalArgumentException e) {
+ if (failOnMissingKeyId) {
+ throw e;
+ } else {
+ Long lastRefresh =
jwksLastRefreshTime.get(maybeJwksUri);
+ if (lastRefresh == null || System.nanoTime() -
lastRefresh > keyIdCacheMissRefreshNanos) {
+ // In this case, the key ID was not found, but
we haven't refreshed the JWKS in a while,
+ // so it is possible the key ID was added.
Refresh the JWKS and try again.
+ cache.synchronous().invalidate(maybeJwksUri);
+ }
+ // There is a small race condition where the JWKS
could be refreshed by another thread,
+ // so we retry getting the JWK, even though we
might not have invalidated the cache.
+ return getJwkAndMaybeReload(maybeJwksUri, keyId,
true);
+ }
+ }
+ });
}
private CompletableFuture<List<Jwk>> getJwksFromJwksUri(String jwksUri) {
@@ -119,8 +156,7 @@ public class JwksCache {
return CompletableFuture.failedFuture(new AuthenticationException(
"Failed to retrieve public key from Kubernetes API server:
Kubernetes fallback is not enabled."));
}
- return cache.get(Optional.empty(), (__, executor) ->
getJwksFromKubernetesApiServer())
- .thenApply(jwks -> getJwkForKID(jwks, keyId));
+ return getJwkAndMaybeReload(Optional.empty(), keyId, false);
}
private CompletableFuture<List<Jwk>> getJwksFromKubernetesApiServer() {
@@ -170,7 +206,7 @@ public class JwksCache {
return future;
}
- private Jwk getJwkForKID(List<Jwk> jwks, String keyId) {
+ private Jwk getJwkForKID(Optional<String> maybeJwksUri, List<Jwk> jwks,
String keyId) {
for (Jwk jwk : jwks) {
if (jwk.getId().equals(keyId)) {
return jwk;
diff --git
a/pulsar-broker-auth-oidc/src/test/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenIDIntegrationTest.java
b/pulsar-broker-auth-oidc/src/test/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenIDIntegrationTest.java
index 0075d70f599..d2d2de1a114 100644
---
a/pulsar-broker-auth-oidc/src/test/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenIDIntegrationTest.java
+++
b/pulsar-broker-auth-oidc/src/test/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenIDIntegrationTest.java
@@ -29,6 +29,7 @@ import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;
import com.github.tomakehurst.wiremock.WireMockServer;
+import com.github.tomakehurst.wiremock.stubbing.Scenario;
import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.impl.DefaultJwtBuilder;
import io.jsonwebtoken.io.Decoders;
@@ -58,6 +59,7 @@ import
org.apache.pulsar.broker.authentication.utils.AuthTokenUtils;
import org.apache.pulsar.common.api.AuthData;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
+import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
/**
@@ -75,6 +77,7 @@ public class AuthenticationProviderOpenIDIntegrationTest {
// The valid issuer
String issuer;
String issuerWithTrailingSlash;
+ String issuerWithMissingKid;
// This issuer is configured to return an issuer in the
openid-configuration
// that does not match the issuer on the token
String issuerThatFails;
@@ -90,6 +93,7 @@ public class AuthenticationProviderOpenIDIntegrationTest {
server.start();
issuer = server.baseUrl();
issuerWithTrailingSlash = issuer + "/trailing-slash/";
+ issuerWithMissingKid = issuer + "/missing-kid";
issuerThatFails = issuer + "/fail";
issuerK8s = issuer + "/k8s";
@@ -183,6 +187,50 @@ public class AuthenticationProviderOpenIDIntegrationTest {
}
""".formatted(validJwk, n, e,
invalidJwk))));
+ server.stubFor(
+
get(urlEqualTo("/missing-kid/.well-known/openid-configuration"))
+ .willReturn(aResponse()
+ .withHeader("Content-Type", "application/json")
+ .withBody("""
+ {
+ "issuer": "%s",
+ "jwks_uri": "%s/keys"
+ }
+ """.formatted(issuerWithMissingKid,
issuerWithMissingKid))));
+
+ // Set up JWKS endpoint where it first responds without the KID, then
with the KID. This is a stateful stub.
+ // Note that the state machine is circular to make it easier to verify
the two code paths that rely on
+ // this logic.
+ server.stubFor(
+ get(urlMatching( "/missing-kid/keys"))
+ .inScenario("Changing KIDs")
+ .whenScenarioStateIs(Scenario.STARTED)
+ .willSetStateTo("serve-kid")
+ .willReturn(aResponse()
+ .withHeader("Content-Type", "application/json")
+ .withBody("{\"keys\":[]}")));
+ server.stubFor(
+ get(urlMatching( "/missing-kid/keys"))
+ .inScenario("Changing KIDs")
+ .whenScenarioStateIs("serve-kid")
+ .willSetStateTo(Scenario.STARTED)
+ .willReturn(aResponse()
+ .withHeader("Content-Type", "application/json")
+ .withBody(
+ """
+ {
+ "keys" : [
+ {
+ "kid":"%s",
+ "kty":"RSA",
+ "alg":"RS256",
+ "n":"%s",
+ "e":"%s"
+ }
+ ]
+ }
+ """.formatted(validJwk, n, e))));
+
ServiceConfiguration conf = new ServiceConfiguration();
conf.setAuthenticationEnabled(true);
conf.setAuthenticationProviders(Set.of(AuthenticationProviderOpenID.class.getName()));
@@ -207,6 +255,12 @@ public class AuthenticationProviderOpenIDIntegrationTest {
server.stop();
}
+ @BeforeMethod
+ public void beforeMethod() {
+ // Scenarios are stateful. Start each test with the correct state.
+ server.resetScenarios();
+ }
+
@Test
public void testTokenWithValidJWK() throws Exception {
String role = "superuser";
@@ -268,6 +322,52 @@ public class AuthenticationProviderOpenIDIntegrationTest {
assertTrue(e.getCause() instanceof AuthenticationException, "Found
exception: " + e.getCause());
}
}
+ @Test
+ public void testKidCacheMissWhenRefreshConfigZero() throws Exception {
+ ServiceConfiguration conf = new ServiceConfiguration();
+ conf.setAuthenticationEnabled(true);
+
conf.setAuthenticationProviders(Set.of(AuthenticationProviderOpenID.class.getName()));
+ Properties props = conf.getProperties();
+ props.setProperty(AuthenticationProviderOpenID.REQUIRE_HTTPS, "false");
+ // Allows us to retrieve the JWK immediately after the cache miss of
the KID
+
props.setProperty(AuthenticationProviderOpenID.KEY_ID_CACHE_MISS_REFRESH_SECONDS,
"0");
+ props.setProperty(AuthenticationProviderOpenID.ALLOWED_AUDIENCES,
"allowed-audience");
+ props.setProperty(AuthenticationProviderOpenID.ALLOWED_TOKEN_ISSUERS,
issuerWithMissingKid);
+
+ AuthenticationProviderOpenID provider = new
AuthenticationProviderOpenID();
+ provider.initialize(conf);
+
+ String role = "superuser";
+ String token = generateToken(validJwk, issuerWithMissingKid, role,
"allowed-audience", 0L, 0L, 10000L);
+ assertEquals(role, provider.authenticateAsync(new
AuthenticationDataCommand(token)).get());
+ }
+
+ @Test
+ public void testKidCacheMissWhenRefreshConfigLongerThanDelta() throws
Exception {
+ ServiceConfiguration conf = new ServiceConfiguration();
+ conf.setAuthenticationEnabled(true);
+
conf.setAuthenticationProviders(Set.of(AuthenticationProviderOpenID.class.getName()));
+ Properties props = conf.getProperties();
+ props.setProperty(AuthenticationProviderOpenID.REQUIRE_HTTPS, "false");
+ // This value is high enough that the provider will not refresh the JWK
+
props.setProperty(AuthenticationProviderOpenID.KEY_ID_CACHE_MISS_REFRESH_SECONDS,
"100");
+ props.setProperty(AuthenticationProviderOpenID.ALLOWED_AUDIENCES,
"allowed-audience");
+ props.setProperty(AuthenticationProviderOpenID.ALLOWED_TOKEN_ISSUERS,
issuerWithMissingKid);
+
+ AuthenticationProviderOpenID provider = new
AuthenticationProviderOpenID();
+ provider.initialize(conf);
+
+ String role = "superuser";
+ String token = generateToken(validJwk, issuerWithMissingKid, role,
"allowed-audience", 0L, 0L, 10000L);
+ try {
+ provider.authenticateAsync(new
AuthenticationDataCommand(token)).get();
+ fail("Expected exception");
+ } catch (ExecutionException e) {
+ assertTrue(e.getCause() instanceof IllegalArgumentException,
"Found exception: " + e.getCause());
+ assertTrue(e.getCause().getMessage().contains("No JWK found for
Key ID valid"),
+ "Found exception: " + e.getCause());
+ }
+ }
@Test
public void testKubernetesApiServerAsDiscoverTrustedIssuerSuccess() throws
Exception {