michaeljmarshall commented on code in PR #19849:
URL: https://github.com/apache/pulsar/pull/19849#discussion_r1161925461


##########
pulsar-broker-auth-oidc/src/main/java/org/apache/pulsar/broker/authentication/oidc/AuthenticationProviderOpenID.java:
##########
@@ -0,0 +1,485 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pulsar.broker.authentication.oidc;
+
+import static 
org.apache.pulsar.broker.authentication.oidc.ConfigUtils.getConfigValueAsBoolean;
+import static 
org.apache.pulsar.broker.authentication.oidc.ConfigUtils.getConfigValueAsInt;
+import static 
org.apache.pulsar.broker.authentication.oidc.ConfigUtils.getConfigValueAsSet;
+import static 
org.apache.pulsar.broker.authentication.oidc.ConfigUtils.getConfigValueAsString;
+import com.auth0.jwk.InvalidPublicKeyException;
+import com.auth0.jwk.Jwk;
+import com.auth0.jwt.JWT;
+import com.auth0.jwt.JWTVerifier;
+import com.auth0.jwt.RegisteredClaims;
+import com.auth0.jwt.algorithms.Algorithm;
+import com.auth0.jwt.exceptions.AlgorithmMismatchException;
+import com.auth0.jwt.exceptions.InvalidClaimException;
+import com.auth0.jwt.exceptions.JWTDecodeException;
+import com.auth0.jwt.exceptions.JWTVerificationException;
+import com.auth0.jwt.exceptions.SignatureVerificationException;
+import com.auth0.jwt.exceptions.TokenExpiredException;
+import com.auth0.jwt.interfaces.Claim;
+import com.auth0.jwt.interfaces.DecodedJWT;
+import io.kubernetes.client.openapi.ApiClient;
+import io.kubernetes.client.util.Config;
+import io.netty.handler.ssl.SslContext;
+import io.netty.handler.ssl.SslContextBuilder;
+import java.io.File;
+import java.io.IOException;
+import java.net.SocketAddress;
+import java.security.PublicKey;
+import java.security.interfaces.ECPublicKey;
+import java.security.interfaces.RSAPublicKey;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import javax.naming.AuthenticationException;
+import javax.net.ssl.SSLSession;
+import org.apache.pulsar.broker.ServiceConfiguration;
+import org.apache.pulsar.broker.authentication.AuthenticationDataSource;
+import org.apache.pulsar.broker.authentication.AuthenticationProvider;
+import org.apache.pulsar.broker.authentication.AuthenticationProviderToken;
+import org.apache.pulsar.broker.authentication.AuthenticationState;
+import org.apache.pulsar.broker.authentication.metrics.AuthenticationMetrics;
+import org.apache.pulsar.common.api.AuthData;
+import org.asynchttpclient.AsyncHttpClient;
+import org.asynchttpclient.AsyncHttpClientConfig;
+import org.asynchttpclient.DefaultAsyncHttpClient;
+import org.asynchttpclient.DefaultAsyncHttpClientConfig;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * An {@link AuthenticationProvider} implementation that supports the usage of 
a JSON Web Token (JWT)
+ * for client authentication. This implementation retrieves the PublicKey from 
the JWT issuer (assuming the
+ * issuer is in the configured allowed list) and then uses that Public Key to 
verify the validity of the JWT's
+ * signature.
+ *
+ * The Public Keys for a given provider are cached based on certain configured 
parameters to improve performance.
+ * The tradeoff here is that the longer Public Keys are cached, the longer an 
invalidated token could be used. One way
+ * to ensure caches are cleared is to restart all brokers.
+ *
+ * Class is called from multiple threads. The implementation must be thread 
safe. This class expects to be loaded once
+ * and then called concurrently for each new connection. The cache is backed 
by a GuavaCachedJwkProvider, which is
+ * thread-safe.
+ *
+ * Supported algorithms are: RS256, RS384, RS512, ES256, ES384, ES512 where 
the naming conventions follow
+ * this RFC: https://datatracker.ietf.org/doc/html/rfc7518#section-3.1.
+ */
+public class AuthenticationProviderOpenID implements AuthenticationProvider {
+    private static final Logger log = 
LoggerFactory.getLogger(AuthenticationProviderOpenID.class);
+
+    private static final String SIMPLE_NAME = 
AuthenticationProviderOpenID.class.getSimpleName();
+
+    // Must match the value used by the OAuth2 Client Plugin.
+    private static final String AUTH_METHOD_NAME = "token";
+
+    // This is backed by an ObjectMapper, which is thread safe. It is an 
optimization
+    // to share this for decoding JWTs for all connections to this broker.
+    private final JWT jwtLibrary = new JWT();
+
+    private Set<String> issuers;
+
+    // This caches the map from Issuer URL to the jwks_uri served at the 
/.well-known/openid-configuration endpoint
+    private OpenIDProviderMetadataCache openIDProviderMetadataCache;
+
+    // A cache used to store the results of getting the JWKS from the jwks_uri 
for an issuer.
+    private JwksCache jwksCache;
+
+    private volatile AsyncHttpClient httpClient;
+
+    // A list of supported algorithms. This is the "alg" field on the JWT.
+    // Source for strings: 
https://datatracker.ietf.org/doc/html/rfc7518#section-3.1.
+    private static final String ALG_RS256 = "RS256";
+    private static final String ALG_RS384 = "RS384";
+    private static final String ALG_RS512 = "RS512";
+    private static final String ALG_ES256 = "ES256";
+    private static final String ALG_ES384 = "ES384";
+    private static final String ALG_ES512 = "ES512";
+
+    private long acceptedTimeLeewaySeconds;
+    private FallbackDiscoveryMode fallbackDiscoveryMode;
+    private String roleClaim;
+
+    static final String ALLOWED_TOKEN_ISSUERS = "openIDAllowedTokenIssuers";
+    static final String ISSUER_TRUST_CERTS_FILE_PATH = 
"openIDTokenIssuerTrustCertsFilePath";
+    static final String FALLBACK_DISCOVERY_MODE = 
"openIDFallbackDiscoveryMode";
+    static final String ALLOWED_AUDIENCES = "openIDAllowedAudiences";
+    static final String ROLE_CLAIM = "openIDRoleClaim";
+    static final String ROLE_CLAIM_DEFAULT = "sub";
+    static final String ACCEPTED_TIME_LEEWAY_SECONDS = 
"openIDAcceptedTimeLeewaySeconds";
+    static final int ACCEPTED_TIME_LEEWAY_SECONDS_DEFAULT = 0;
+    static final String CACHE_SIZE = "openIDCacheSize";
+    static final int CACHE_SIZE_DEFAULT = 5;
+    static final String CACHE_REFRESH_AFTER_WRITE_SECONDS = 
"openIDCacheRefreshAfterWriteSeconds";
+    static final int CACHE_REFRESH_AFTER_WRITE_SECONDS_DEFAULT = 18 * 60 * 60;
+    static final String CACHE_EXPIRATION_SECONDS = 
"openIDCacheExpirationSeconds";
+    static final int CACHE_EXPIRATION_SECONDS_DEFAULT = 24 * 60 * 60;
+    static final String HTTP_CONNECTION_TIMEOUT_MILLIS = 
"openIDHttpConnectionTimeoutMillis";
+    static final int HTTP_CONNECTION_TIMEOUT_MILLIS_DEFAULT = 10_000;
+    static final String HTTP_READ_TIMEOUT_MILLIS = 
"openIDHttpReadTimeoutMillis";
+    static final int HTTP_READ_TIMEOUT_MILLIS_DEFAULT = 10_000;
+    static final String REQUIRE_HTTPS = "openIDRequireIssuersUseHttps";
+    static final boolean REQUIRE_HTTPS_DEFAULT = true;
+
+    // The list of audiences that are allowed to connect to this broker. A 
valid JWT must contain one of the audiences.
+    private String[] allowedAudiences;
+
+    @Override
+    public void initialize(ServiceConfiguration config) throws IOException {
+        this.allowedAudiences = 
validateAllowedAudiences(getConfigValueAsSet(config, ALLOWED_AUDIENCES));
+        this.roleClaim = getConfigValueAsString(config, ROLE_CLAIM, 
ROLE_CLAIM_DEFAULT);
+        this.acceptedTimeLeewaySeconds = getConfigValueAsInt(config, 
ACCEPTED_TIME_LEEWAY_SECONDS,
+                ACCEPTED_TIME_LEEWAY_SECONDS_DEFAULT);
+        boolean requireHttps = getConfigValueAsBoolean(config, REQUIRE_HTTPS, 
REQUIRE_HTTPS_DEFAULT);
+        this.fallbackDiscoveryMode = 
FallbackDiscoveryMode.valueOf(getConfigValueAsString(config,
+                FALLBACK_DISCOVERY_MODE, 
FallbackDiscoveryMode.DISABLED.name()));
+        this.issuers = validateIssuers(getConfigValueAsSet(config, 
ALLOWED_TOKEN_ISSUERS), requireHttps,
+                fallbackDiscoveryMode != FallbackDiscoveryMode.DISABLED);
+
+        int connectionTimeout = getConfigValueAsInt(config, 
HTTP_CONNECTION_TIMEOUT_MILLIS,
+                HTTP_CONNECTION_TIMEOUT_MILLIS_DEFAULT);
+        int readTimeout = getConfigValueAsInt(config, 
HTTP_READ_TIMEOUT_MILLIS, HTTP_READ_TIMEOUT_MILLIS_DEFAULT);
+        String trustCertsFilePath = getConfigValueAsString(config, 
ISSUER_TRUST_CERTS_FILE_PATH, null);
+        SslContext sslContext = null;
+        if (trustCertsFilePath != null) {
+            // Use default settings for everything but the trust store.
+            sslContext = SslContextBuilder.forClient()
+                    .trustManager(new File(trustCertsFilePath))
+                    .build();
+        }
+        AsyncHttpClientConfig clientConfig = new 
DefaultAsyncHttpClientConfig.Builder()
+                .setConnectTimeout(connectionTimeout)
+                .setReadTimeout(readTimeout)
+                .setSslContext(sslContext)
+                .build();
+        httpClient = new DefaultAsyncHttpClient(clientConfig);
+        ApiClient k8sApiClient =
+                fallbackDiscoveryMode != FallbackDiscoveryMode.DISABLED ? 
Config.defaultClient() : null;
+        this.openIDProviderMetadataCache = new 
OpenIDProviderMetadataCache(config, httpClient, k8sApiClient);
+        this.jwksCache = new JwksCache(config, httpClient, k8sApiClient);
+    }
+
+    @Override
+    public String getAuthMethodName() {
+        return AUTH_METHOD_NAME;
+    }
+
+    /**
+     * Authenticate the parameterized {@link AuthenticationDataSource} by 
verifying the issuer is an allowed issuer,
+     * then retrieving the JWKS URI from the issuer, then retrieving the 
Public key from the JWKS URI, and finally
+     * verifying the JWT signature and claims.
+     *
+     * @param authData - the authData passed by the Pulsar Broker containing 
the token.
+     * @return the role, if the JWT is authenticated, otherwise a failed 
future.
+     */
+    @Override
+    public CompletableFuture<String> 
authenticateAsync(AuthenticationDataSource authData) {
+        return authenticateTokenAsync(authData).thenApply(this::getRole);
+    }
+
+    /**
+     * Authenticate the parameterized {@link AuthenticationDataSource} and 
return the decoded JWT.
+     * @param authData - the authData containing the token.
+     * @return a completed future with the decoded JWT, if the JWT is 
authenticated. Otherwise, a failed future.
+     */
+    CompletableFuture<DecodedJWT> 
authenticateTokenAsync(AuthenticationDataSource authData) {
+        String token;
+        try {
+            token = AuthenticationProviderToken.getToken(authData);
+        } catch (AuthenticationException e) {
+            
incrementFailureMetric(AuthenticationExceptionCode.ERROR_DECODING_JWT);
+            return CompletableFuture.failedFuture(e);
+        }
+        return authenticateToken(token)
+                .whenComplete((jwt, e) -> {
+                    if (jwt != null) {
+                        
AuthenticationMetrics.authenticateSuccess(getClass().getSimpleName(), 
getAuthMethodName());
+                    }
+                    // Failure metrics are incremented within methods above
+                });
+    }
+
+    /**
+     * Get the role from a JWT at the configured role claim field.
+     * NOTE: does not do any verification of the JWT
+     * @param jwt - token to get the role from
+     * @return the role, or null, if it is not set on the JWT
+     */
+    String getRole(DecodedJWT jwt) {
+        try {
+            Claim roleClaim = jwt.getClaim(this.roleClaim);
+            if (roleClaim.isNull()) {
+                // The claim was not present in the JWT
+                return null;
+            }
+
+            String role = roleClaim.asString();
+            if (role != null) {
+                // The role is non null only if the JSON node is a text field
+                return role;
+            }

Review Comment:
   Great catch. I'll add this as a step to token validation.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to