jerqi commented on code in PR #7883:
URL: https://github.com/apache/gravitino/pull/7883#discussion_r2246813856


##########
server-common/src/main/java/org/apache/gravitino/server/authentication/JwksTokenValidator.java:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.gravitino.server.authentication;
+
+import com.nimbusds.jose.JWSAlgorithm;
+import com.nimbusds.jose.jwk.source.JWKSource;
+import com.nimbusds.jose.proc.JWSKeySelector;
+import com.nimbusds.jose.proc.JWSVerificationKeySelector;
+import com.nimbusds.jose.proc.SecurityContext;
+import com.nimbusds.jwt.JWTClaimsSet;
+import com.nimbusds.jwt.SignedJWT;
+import com.nimbusds.jwt.proc.DefaultJWTProcessor;
+import java.security.Principal;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.gravitino.Config;
+import org.apache.gravitino.UserPrincipal;
+import org.apache.gravitino.exceptions.UnauthorizedException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Generic JWKS-based OAuth token validator that uses Nimbus JOSE + JWT 
library to validate JWT
+ * tokens from any OAuth provider that exposes a JWKS endpoint.
+ *
+ * <p>This validator supports OAuth providers that publish their public keys 
via JWKS (JSON Web Key
+ * Set) endpoints
+ */
+public class JwksTokenValidator implements OAuthTokenValidator {
+  private static final Logger LOG = 
LoggerFactory.getLogger(JwksTokenValidator.class);
+
+  private String jwksUri;
+  private String expectedIssuer;
+  private String principalField;
+  private long allowSkewSeconds;
+
+  @Override
+  public void initialize(Config config) {
+    this.jwksUri = config.get(OAuthConfig.JWKS_URI);
+    this.expectedIssuer = config.get(OAuthConfig.AUTHORITY);
+    this.principalField = config.get(OAuthConfig.PRINCIPAL_FIELD);
+    this.allowSkewSeconds = config.get(OAuthConfig.ALLOW_SKEW_SECONDS);
+
+    LOG.info("Initializing JWKS token validator");
+
+    if (StringUtils.isBlank(jwksUri)) {
+      throw new IllegalArgumentException(
+          "JWKS URI must be configured when using JWKS-based OAuth providers");
+    }
+
+    // Validate JWKS URI format
+    try {
+      new java.net.URL(jwksUri);
+    } catch (Exception e) {
+      LOG.error("Invalid JWKS URI format: {}", jwksUri);
+      throw new IllegalArgumentException("Invalid JWKS URI format: " + 
jwksUri, e);
+    }
+  }
+
+  @Override
+  public Principal validateToken(String token, String serviceAudience) {
+    if (token == null || token.trim().isEmpty()) {
+      LOG.error("Token is null or empty");
+      throw new UnauthorizedException("Token cannot be null or empty");
+    }
+
+    if (serviceAudience == null || serviceAudience.trim().isEmpty()) {
+      LOG.error("Service audience is null or empty");
+      throw new UnauthorizedException("Service audience cannot be null or 
empty");
+    }
+
+    try {
+      SignedJWT signedJWT = SignedJWT.parse(token);
+
+      // Set up JWKS source and processor
+      JWKSource<SecurityContext> jwkSource = createJwkSource();
+      JWSAlgorithm algorithm = 
JWSAlgorithm.parse(signedJWT.getHeader().getAlgorithm().getName());
+      JWSKeySelector<SecurityContext> keySelector =
+          new JWSVerificationKeySelector<>(algorithm, jwkSource);
+
+      DefaultJWTProcessor<SecurityContext> jwtProcessor = new 
DefaultJWTProcessor<>();
+      jwtProcessor.setJWSKeySelector(keySelector);
+
+      // Configure claims verification
+      JWTClaimsSet.Builder expectedClaimsBuilder = new JWTClaimsSet.Builder();
+
+      // Set expected issuer if configured
+      if (StringUtils.isNotBlank(expectedIssuer)) {
+        expectedClaimsBuilder.issuer(expectedIssuer);
+      }
+
+      // Set expected audience if provided
+      if (StringUtils.isNotBlank(serviceAudience)) {
+        expectedClaimsBuilder.audience(serviceAudience);
+      }
+
+      com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier<SecurityContext> 
claimsVerifier =

Review Comment:
   Could we avoid using the long package name?



##########
server-common/src/main/java/org/apache/gravitino/server/authentication/JwksTokenValidator.java:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.gravitino.server.authentication;
+
+import com.nimbusds.jose.JWSAlgorithm;
+import com.nimbusds.jose.jwk.source.JWKSource;
+import com.nimbusds.jose.proc.JWSKeySelector;
+import com.nimbusds.jose.proc.JWSVerificationKeySelector;
+import com.nimbusds.jose.proc.SecurityContext;
+import com.nimbusds.jwt.JWTClaimsSet;
+import com.nimbusds.jwt.SignedJWT;
+import com.nimbusds.jwt.proc.DefaultJWTProcessor;
+import java.security.Principal;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.gravitino.Config;
+import org.apache.gravitino.UserPrincipal;
+import org.apache.gravitino.exceptions.UnauthorizedException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Generic JWKS-based OAuth token validator that uses Nimbus JOSE + JWT 
library to validate JWT
+ * tokens from any OAuth provider that exposes a JWKS endpoint.
+ *
+ * <p>This validator supports OAuth providers that publish their public keys 
via JWKS (JSON Web Key
+ * Set) endpoints
+ */
+public class JwksTokenValidator implements OAuthTokenValidator {

Review Comment:
   Could we give a better name for this? I can't get the difference from the 
default validator by the name.



##########
server-common/src/main/java/org/apache/gravitino/server/authentication/JwksTokenValidator.java:
##########
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.gravitino.server.authentication;
+
+import com.nimbusds.jose.JWSAlgorithm;
+import com.nimbusds.jose.jwk.source.JWKSource;
+import com.nimbusds.jose.proc.JWSKeySelector;
+import com.nimbusds.jose.proc.JWSVerificationKeySelector;
+import com.nimbusds.jose.proc.SecurityContext;
+import com.nimbusds.jwt.JWTClaimsSet;
+import com.nimbusds.jwt.SignedJWT;
+import com.nimbusds.jwt.proc.DefaultJWTProcessor;
+import java.security.Principal;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.gravitino.Config;
+import org.apache.gravitino.UserPrincipal;
+import org.apache.gravitino.exceptions.UnauthorizedException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Generic JWKS-based OAuth token validator that uses Nimbus JOSE + JWT 
library to validate JWT
+ * tokens from any OAuth provider that exposes a JWKS endpoint.
+ *
+ * <p>This validator supports OAuth providers that publish their public keys 
via JWKS (JSON Web Key
+ * Set) endpoints
+ */
+public class JwksTokenValidator implements OAuthTokenValidator {
+  private static final Logger LOG = 
LoggerFactory.getLogger(JwksTokenValidator.class);
+
+  private String jwksUri;
+  private String expectedIssuer;
+  private String principalField;
+  private long allowSkewSeconds;
+
+  @Override
+  public void initialize(Config config) {
+    this.jwksUri = config.get(OAuthConfig.JWKS_URI);
+    this.expectedIssuer = config.get(OAuthConfig.AUTHORITY);
+    this.principalField = config.get(OAuthConfig.PRINCIPAL_FIELD);
+    this.allowSkewSeconds = config.get(OAuthConfig.ALLOW_SKEW_SECONDS);
+
+    LOG.info("Initializing JWKS token validator");
+
+    if (StringUtils.isBlank(jwksUri)) {
+      throw new IllegalArgumentException(
+          "JWKS URI must be configured when using JWKS-based OAuth providers");
+    }
+
+    // Validate JWKS URI format
+    try {
+      new java.net.URL(jwksUri);
+    } catch (Exception e) {
+      LOG.error("Invalid JWKS URI format: {}", jwksUri);
+      throw new IllegalArgumentException("Invalid JWKS URI format: " + 
jwksUri, e);
+    }
+  }
+
+  @Override
+  public Principal validateToken(String token, String serviceAudience) {
+    if (token == null || token.trim().isEmpty()) {
+      LOG.error("Token is null or empty");
+      throw new UnauthorizedException("Token cannot be null or empty");
+    }
+
+    if (serviceAudience == null || serviceAudience.trim().isEmpty()) {
+      LOG.error("Service audience is null or empty");
+      throw new UnauthorizedException("Service audience cannot be null or 
empty");
+    }
+
+    try {
+      SignedJWT signedJWT = SignedJWT.parse(token);
+
+      // Set up JWKS source and processor
+      JWKSource<SecurityContext> jwkSource = createJwkSource();
+      JWSAlgorithm algorithm = 
JWSAlgorithm.parse(signedJWT.getHeader().getAlgorithm().getName());
+      JWSKeySelector<SecurityContext> keySelector =
+          new JWSVerificationKeySelector<>(algorithm, jwkSource);
+
+      DefaultJWTProcessor<SecurityContext> jwtProcessor = new 
DefaultJWTProcessor<>();
+      jwtProcessor.setJWSKeySelector(keySelector);
+
+      // Configure claims verification
+      JWTClaimsSet.Builder expectedClaimsBuilder = new JWTClaimsSet.Builder();
+
+      // Set expected issuer if configured
+      if (StringUtils.isNotBlank(expectedIssuer)) {
+        expectedClaimsBuilder.issuer(expectedIssuer);
+      }
+
+      // Set expected audience if provided
+      if (StringUtils.isNotBlank(serviceAudience)) {
+        expectedClaimsBuilder.audience(serviceAudience);
+      }
+
+      com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier<SecurityContext> 
claimsVerifier =
+          new com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier<SecurityContext>(
+              expectedClaimsBuilder.build(), null);
+
+      // Set clock skew tolerance
+      claimsVerifier.setMaxClockSkew((int) allowSkewSeconds);
+      jwtProcessor.setJWTClaimsSetVerifier(claimsVerifier);
+
+      // Process and validate the token
+      JWTClaimsSet validatedClaims = jwtProcessor.process(signedJWT, null);
+
+      String principal = extractPrincipal(validatedClaims);
+      if (principal == null) {
+        LOG.error("No valid principal found in token");
+        throw new UnauthorizedException("No valid principal found in token");
+      }
+
+      return new UserPrincipal(principal);
+
+    } catch (Exception e) {
+      LOG.error("JWKS JWT validation error: {}", e.getMessage());
+      throw new UnauthorizedException(e, "JWKS JWT validation error");
+    }
+  }
+
+  /** Creates a JWK source from the configured JWKS URI. */
+  private JWKSource<SecurityContext> createJwkSource() throws Exception {
+    try {
+      return com.nimbusds.jose.jwk.source.JWKSourceBuilder.create(new 
java.net.URL(jwksUri))
+          .build();
+    } catch (Exception e) {
+      LOG.error("Failed to create JWKS source from URI: {}", jwksUri, e);
+      throw new Exception("Failed to create JWKS source: " + e.getMessage(), 
e);
+    }
+  }
+
+  /** Extracts the principal from the validated JWT claims using configured 
field or fallbacks. */
+  private String extractPrincipal(JWTClaimsSet validatedClaims) {
+    // Try the configured principal field first
+    if (StringUtils.isNotBlank(principalField)) {
+      String principal = (String) validatedClaims.getClaim(principalField);
+      if (principal != null) {
+        return principal;
+      }
+    }
+
+    // Fallback to common user identity fields
+    String[] fallbackFields = {"unique_name", "upn", "email", 
"preferred_username", "sub"};

Review Comment:
   I'm not sure whether to make these configurable.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to