michaeljmarshall commented on code in PR #19849:
URL: https://github.com/apache/pulsar/pull/19849#discussion_r1161925229


##########
pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenID.java:
##########
@@ -0,0 +1,485 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pulsar.broker.authentication.oidc;
+
+import static 
org.apache.pulsar.broker.authentication.oidc.ConfigUtils.getConfigValueAsBoolean;
+import static 
org.apache.pulsar.broker.authentication.oidc.ConfigUtils.getConfigValueAsInt;
+import static 
org.apache.pulsar.broker.authentication.oidc.ConfigUtils.getConfigValueAsSet;
+import static 
org.apache.pulsar.broker.authentication.oidc.ConfigUtils.getConfigValueAsString;
+import com.auth0.jwk.InvalidPublicKeyException;
+import com.auth0.jwk.Jwk;
+import com.auth0.jwt.JWT;
+import com.auth0.jwt.JWTVerifier;
+import com.auth0.jwt.RegisteredClaims;
+import com.auth0.jwt.algorithms.Algorithm;
+import com.auth0.jwt.exceptions.AlgorithmMismatchException;
+import com.auth0.jwt.exceptions.InvalidClaimException;
+import com.auth0.jwt.exceptions.JWTDecodeException;
+import com.auth0.jwt.exceptions.JWTVerificationException;
+import com.auth0.jwt.exceptions.SignatureVerificationException;
+import com.auth0.jwt.exceptions.TokenExpiredException;
+import com.auth0.jwt.interfaces.Claim;
+import com.auth0.jwt.interfaces.DecodedJWT;
+import io.kubernetes.client.openapi.ApiClient;
+import io.kubernetes.client.util.Config;
+import io.netty.handler.ssl.SslContext;
+import io.netty.handler.ssl.SslContextBuilder;
+import java.io.File;
+import java.io.IOException;
+import java.net.SocketAddress;
+import java.security.PublicKey;
+import java.security.interfaces.ECPublicKey;
+import java.security.interfaces.RSAPublicKey;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import javax.naming.AuthenticationException;
+import javax.net.ssl.SSLSession;
+import org.apache.pulsar.broker.ServiceConfiguration;
+import org.apache.pulsar.broker.authentication.AuthenticationDataSource;
+import org.apache.pulsar.broker.authentication.AuthenticationProvider;
+import org.apache.pulsar.broker.authentication.AuthenticationProviderToken;
+import org.apache.pulsar.broker.authentication.AuthenticationState;
+import org.apache.pulsar.broker.authentication.metrics.AuthenticationMetrics;
+import org.apache.pulsar.common.api.AuthData;
+import org.asynchttpclient.AsyncHttpClient;
+import org.asynchttpclient.AsyncHttpClientConfig;
+import org.asynchttpclient.DefaultAsyncHttpClient;
+import org.asynchttpclient.DefaultAsyncHttpClientConfig;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * An {@link AuthenticationProvider} implementation that supports the usage of 
a JSON Web Token (JWT)
+ * for client authentication. This implementation retrieves the PublicKey from 
the JWT issuer (assuming the
+ * issuer is in the configured allowed list) and then uses that Public Key to 
verify the validity of the JWT's
+ * signature.
+ *
+ * The Public Keys for a given provider are cached based on certain configured 
parameters to improve performance.
+ * The tradeoff here is that the longer Public Keys are cached, the longer an 
invalidated token could be used. One way
+ * to ensure caches are cleared is to restart all brokers.
+ *
+ * Class is called from multiple threads. The implementation must be thread 
safe. This class expects to be loaded once
+ * and then called concurrently for each new connection. The cache is backed 
by a GuavaCachedJwkProvider, which is
+ * thread-safe.
+ *
+ * Supported algorithms are: RS256, RS384, RS512, ES256, ES384, ES512 where 
the naming conventions follow
+ * this RFC: https://datatracker.ietf.org/doc/html/rfc7518#section-3.1.
+ */
+public class AuthenticationProviderOpenID implements AuthenticationProvider {
+    private static final Logger log = 
LoggerFactory.getLogger(AuthenticationProviderOpenID.class);
+
+    private static final String SIMPLE_NAME = 
AuthenticationProviderOpenID.class.getSimpleName();
+
+    // Must match the value used by the OAuth2 Client Plugin.
+    private static final String AUTH_METHOD_NAME = "token";
+
+    // This is backed by an ObjectMapper, which is thread safe. It is an 
optimization
+    // to share this for decoding JWTs for all connections to this broker.
+    private final JWT jwtLibrary = new JWT();
+
+    private Set<String> issuers;
+
+    // This caches the map from Issuer URL to the jwks_uri served at the 
/.well-known/openid-configuration endpoint
+    private OpenIDProviderMetadataCache openIDProviderMetadataCache;
+
+    // A cache used to store the results of getting the JWKS from the jwks_uri 
for an issuer.
+    private JwksCache jwksCache;
+
+    private volatile AsyncHttpClient httpClient;
+
+    // A list of supported algorithms. This is the "alg" field on the JWT.
+    // Source for strings: 
https://datatracker.ietf.org/doc/html/rfc7518#section-3.1.
+    private static final String ALG_RS256 = "RS256";
+    private static final String ALG_RS384 = "RS384";
+    private static final String ALG_RS512 = "RS512";
+    private static final String ALG_ES256 = "ES256";
+    private static final String ALG_ES384 = "ES384";
+    private static final String ALG_ES512 = "ES512";
+
+    private long acceptedTimeLeewaySeconds;
+    private FallbackDiscoveryMode fallbackDiscoveryMode;
+    private String roleClaim;
+
+    static final String ALLOWED_TOKEN_ISSUERS = "openIDAllowedTokenIssuers";
+    static final String ISSUER_TRUST_CERTS_FILE_PATH = 
"openIDTokenIssuerTrustCertsFilePath";
+    static final String FALLBACK_DISCOVERY_MODE = 
"openIDFallbackDiscoveryMode";
+    static final String ALLOWED_AUDIENCES = "openIDAllowedAudiences";
+    static final String ROLE_CLAIM = "openIDRoleClaim";
+    static final String ROLE_CLAIM_DEFAULT = "sub";
+    static final String ACCEPTED_TIME_LEEWAY_SECONDS = 
"openIDAcceptedTimeLeewaySeconds";
+    static final int ACCEPTED_TIME_LEEWAY_SECONDS_DEFAULT = 0;
+    static final String CACHE_SIZE = "openIDCacheSize";
+    static final int CACHE_SIZE_DEFAULT = 5;
+    static final String CACHE_REFRESH_AFTER_WRITE_SECONDS = 
"openIDCacheRefreshAfterWriteSeconds";
+    static final int CACHE_REFRESH_AFTER_WRITE_SECONDS_DEFAULT = 18 * 60 * 60;
+    static final String CACHE_EXPIRATION_SECONDS = 
"openIDCacheExpirationSeconds";
+    static final int CACHE_EXPIRATION_SECONDS_DEFAULT = 24 * 60 * 60;
+    static final String HTTP_CONNECTION_TIMEOUT_MILLIS = 
"openIDHttpConnectionTimeoutMillis";
+    static final int HTTP_CONNECTION_TIMEOUT_MILLIS_DEFAULT = 10_000;
+    static final String HTTP_READ_TIMEOUT_MILLIS = 
"openIDHttpReadTimeoutMillis";
+    static final int HTTP_READ_TIMEOUT_MILLIS_DEFAULT = 10_000;
+    static final String REQUIRE_HTTPS = "openIDRequireIssuersUseHttps";
+    static final boolean REQUIRE_HTTPS_DEFAULT = true;
+
+    // The list of audiences that are allowed to connect to this broker. A 
valid JWT must contain one of the audiences.
+    private String[] allowedAudiences;
+
+    @Override
+    public void initialize(ServiceConfiguration config) throws IOException {
+        this.allowedAudiences = 
validateAllowedAudiences(getConfigValueAsSet(config, ALLOWED_AUDIENCES));
+        this.roleClaim = getConfigValueAsString(config, ROLE_CLAIM, 
ROLE_CLAIM_DEFAULT);
+        this.acceptedTimeLeewaySeconds = getConfigValueAsInt(config, 
ACCEPTED_TIME_LEEWAY_SECONDS,
+                ACCEPTED_TIME_LEEWAY_SECONDS_DEFAULT);
+        boolean requireHttps = getConfigValueAsBoolean(config, REQUIRE_HTTPS, 
REQUIRE_HTTPS_DEFAULT);
+        this.fallbackDiscoveryMode = 
FallbackDiscoveryMode.valueOf(getConfigValueAsString(config,
+                FALLBACK_DISCOVERY_MODE, 
FallbackDiscoveryMode.DISABLED.name()));
+        this.issuers = validateIssuers(getConfigValueAsSet(config, 
ALLOWED_TOKEN_ISSUERS), requireHttps,
+                fallbackDiscoveryMode != FallbackDiscoveryMode.DISABLED);
+
+        int connectionTimeout = getConfigValueAsInt(config, 
HTTP_CONNECTION_TIMEOUT_MILLIS,
+                HTTP_CONNECTION_TIMEOUT_MILLIS_DEFAULT);
+        int readTimeout = getConfigValueAsInt(config, 
HTTP_READ_TIMEOUT_MILLIS, HTTP_READ_TIMEOUT_MILLIS_DEFAULT);
+        String trustCertsFilePath = getConfigValueAsString(config, 
ISSUER_TRUST_CERTS_FILE_PATH, null);
+        SslContext sslContext = null;
+        if (trustCertsFilePath != null) {
+            // Use default settings for everything but the trust store.
+            sslContext = SslContextBuilder.forClient()
+                    .trustManager(new File(trustCertsFilePath))
+                    .build();
+        }
+        AsyncHttpClientConfig clientConfig = new 
DefaultAsyncHttpClientConfig.Builder()
+                .setConnectTimeout(connectionTimeout)
+                .setReadTimeout(readTimeout)
+                .setSslContext(sslContext)
+                .build();
+        httpClient = new DefaultAsyncHttpClient(clientConfig);
+        ApiClient k8sApiClient =
+                fallbackDiscoveryMode != FallbackDiscoveryMode.DISABLED ? 
Config.defaultClient() : null;
+        this.openIDProviderMetadataCache = new 
OpenIDProviderMetadataCache(config, httpClient, k8sApiClient);
+        this.jwksCache = new JwksCache(config, httpClient, k8sApiClient);
+    }
+
+    @Override
+    public String getAuthMethodName() {
+        return AUTH_METHOD_NAME;
+    }
+
+    /**
+     * Authenticate the parameterized {@link AuthenticationDataSource} by 
verifying the issuer is an allowed issuer,
+     * then retrieving the JWKS URI from the issuer, then retrieving the 
Public key from the JWKS URI, and finally
+     * verifying the JWT signature and claims.
+     *
+     * @param authData - the authData passed by the Pulsar Broker containing 
the token.
+     * @return the role, if the JWT is authenticated, otherwise a failed 
future.
+     */
+    @Override
+    public CompletableFuture<String> 
authenticateAsync(AuthenticationDataSource authData) {
+        return authenticateTokenAsync(authData).thenApply(this::getRole);
+    }
+
+    /**
+     * Authenticate the parameterized {@link AuthenticationDataSource} and 
return the decoded JWT.
+     * @param authData - the authData containing the token.
+     * @return a completed future with the decoded JWT, if the JWT is 
authenticated. Otherwise, a failed future.
+     */
+    CompletableFuture<DecodedJWT> 
authenticateTokenAsync(AuthenticationDataSource authData) {
+        String token;
+        try {
+            token = AuthenticationProviderToken.getToken(authData);
+        } catch (AuthenticationException e) {
+            
incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
+            return CompletableFuture.failedFuture(e);
+        }
+        return authenticateToken(token)
+                .whenComplete((jwt, e) -> {
+                    if (jwt != null) {
+                        
AuthenticationMetrics.authenticateSuccess(getClass().getSimpleName(), 
getAuthMethodName());
+                    }
+                    // Failure metrics are incremented within methods above
+                });
+    }
+
+    /**
+     * Get the role from a JWT at the configured role claim field.
+     * NOTE: does not do any verification of the JWT
+     * @param jwt - token to get the role from
+     * @return the role, or null, if it is not set on the JWT
+     */
+    String getRole(DecodedJWT jwt) {
+        try {
+            Claim roleClaim = jwt.getClaim(this.roleClaim);
+            if (roleClaim.isNull()) {
+                // The claim was not present in the JWT
+                return null;
+            }
+
+            String role = roleClaim.asString();
+            if (role != null) {
+                // The role is non null only if the JSON node is a text field
+                return role;
+            }
+
+            List<String> roles = 
jwt.getClaim(this.roleClaim).asList(String.class);
+            if (roles == null || roles.size() == 0) {
+                return null;
+            } else if (roles.size() == 1) {
+                return roles.get(0);
+            } else {
+                log.debug("JWT for subject [{}] has multiple roles; using the 
first one.", jwt.getSubject());
+                return roles.get(0);
+            }
+        } catch (JWTDecodeException e) {
+            log.error("Exception while retrieving role from JWT", e);
+            return null;
+        }
+    }
+
+    /**
+     * Convert a JWT string into a {@link DecodedJWT}
+     * The benefit of using this method is that it utilizes the already 
instantiated {@link JWT} parser.
+     * WARNING: this method does not verify the authenticity of the token. It 
only decodes it.
+     *
+     * @param token - string JWT to be decoded
+     * @return a decoded JWT
+     * @throws AuthenticationException if the token string is null or if any 
part of the token contains
+     *         an invalid jwt or JSON format of each of the jwt parts.
+     */
+    DecodedJWT decodeJWT(String token) throws AuthenticationException {
+        if (token == null) {
+            
incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
+            throw new AuthenticationException("Invalid token: cannot be null");
+        }
+        try {
+            return jwtLibrary.decodeJwt(token);
+        } catch (JWTDecodeException e) {
+            
incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
+            throw new AuthenticationException("Unable to decode JWT: " + 
e.getMessage());
+        }
+    }
+
+    /**
+     * Authenticate the parameterized JWT.
+     *
+     * @param token - a nonnull JWT to authenticate
+     * @return a fully authenticated JWT, or AuthenticationException if the 
JWT is proven to be invalid in any way
+     */
+    private CompletableFuture<DecodedJWT> authenticateToken(String token) {
+        if (token == null) {
+            
incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
+            return CompletableFuture.failedFuture(new 
AuthenticationException("JWT cannot be null"));
+        }
+        final DecodedJWT jwt;
+        try {
+            jwt = decodeJWT(token);
+        } catch (AuthenticationException e) {
+            
incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
+            return CompletableFuture.failedFuture(e);
+        }
+        return verifyIssuerAndGetJwk(jwt)
+                .thenCompose(jwk -> {
+                    try {
+                        if (!jwt.getAlgorithm().equals(jwk.getAlgorithm())) {
+                            
incrementFailureMetric(AuthenticationExceptionCode.ALGORITHM_MISMATCH);
+                            return CompletableFuture.failedFuture(
+                                    new AuthenticationException("JWK's alg [" 
+ jwk.getAlgorithm()
+                                            + "] does not match JWT's alg [" + 
jwt.getAlgorithm() + "]"));
+                        }
+                        // Verify the JWT signature
+                        // Throws exception if any verification check fails
+                        return CompletableFuture
+                                .completedFuture(verifyJWT(jwk.getPublicKey(), 
jwt.getAlgorithm(), jwt));
+                    } catch (InvalidPublicKeyException e) {
+                        
incrementFailureMetric(AuthenticationExceptionCode.INVALID_PUBLIC_KEY);
+                        return CompletableFuture.failedFuture(
+                                new AuthenticationException("Invalid public 
key: " + e.getMessage()));
+                    } catch (AuthenticationException e) {
+                        return CompletableFuture.failedFuture(e);
+                    }
+                });
+    }
+
+    /**
+     * Verify the JWT's issuer (iss) claim is one of the allowed issuers and 
then retrieve the JWK from the issuer. If
+     * not, see {@link FallbackDiscoveryMode} for the fallback behavior.
+     * @param jwt - the token to use to discover the issuer's JWKS URI, which 
is then used to retrieve the issuer's
+     *            current public keys.
+     * @return a JWK that can be used to verify the JWT's signature
+     */
+    private CompletableFuture<Jwk> verifyIssuerAndGetJwk(DecodedJWT jwt) {
+        if (jwt.getIssuer() == null) {
+            
incrementFailureMetric(AuthenticationExceptionCode.UNSUPPORTED_ISSUER);
+            return CompletableFuture.failedFuture(new 
AuthenticationException("Issuer cannot be null"));
+        } else if (this.issuers.contains(jwt.getIssuer())) {
+            // Retrieve the metadata: 
https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
+            return 
openIDProviderMetadataCache.getOpenIDProviderMetadataForIssuer(jwt.getIssuer())
+                    .thenCompose(metadata -> 
jwksCache.getJwk(metadata.getJwksUri(), jwt.getKeyId()));
+        } else if (fallbackDiscoveryMode == 
FallbackDiscoveryMode.KUBERNETES_DISCOVER_TRUSTED_ISSUER) {
+            return 
openIDProviderMetadataCache.getOpenIDProviderMetadataForKubernetesApiServer(jwt.getIssuer())
+                    .thenCompose(metadata ->
+                            
openIDProviderMetadataCache.getOpenIDProviderMetadataForIssuer(metadata.getIssuer()))
+                    .thenCompose(metadata -> 
jwksCache.getJwk(metadata.getJwksUri(), jwt.getKeyId()));
+        } else if (fallbackDiscoveryMode == 
FallbackDiscoveryMode.KUBERNETES_DISCOVER_PUBLIC_KEYS) {
+            return 
openIDProviderMetadataCache.getOpenIDProviderMetadataForKubernetesApiServer(jwt.getIssuer())
+                    .thenCompose(__ -> 
jwksCache.getJwkFromKubernetesApiServer(jwt.getKeyId()));

Review Comment:
   > Why is the metadata fetched and then not used in the second case?
   
   The metadata is fetched to verify that the token's issuer is the issuer on 
the discovery doc provided by the k8s API Server. This verification helps 
identify if the token properly qualifies for the fallback issuer or if it 
should be rejected before verifying the token's signature.
   
   The subsequent call to get the public keys is to the API Server's 
`/openid/v1/jwks` endpoint. This works for AKS and GKE, but not for EKS. 
Because the call is to the API Server, we need to use authentication and a 
custom ca cert. Using the k8s client is the easiest way to discover those 
details without requiring the user to configure anything extra.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to