Copilot commented on code in PR #9237: URL: https://github.com/apache/gravitino/pull/9237#discussion_r2562816274
########## bundles/aliyun/src/main/java/org/apache/gravitino/oss/credential/OSSTokenGenerator.java: ########## @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.gravitino.oss.credential; + +import com.aliyun.credentials.Client; +import com.aliyun.credentials.models.Config; +import com.aliyun.credentials.models.CredentialModel; +import com.aliyun.credentials.utils.AuthConstant; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.net.URI; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Stream; +import org.apache.commons.lang3.StringUtils; +import org.apache.gravitino.credential.CredentialContext; +import org.apache.gravitino.credential.CredentialGenerator; +import org.apache.gravitino.credential.OSSTokenCredential; +import org.apache.gravitino.credential.PathBasedCredentialContext; +import org.apache.gravitino.credential.config.OSSCredentialConfig; +import org.apache.gravitino.oss.credential.policy.Condition; +import org.apache.gravitino.oss.credential.policy.Effect; +import org.apache.gravitino.oss.credential.policy.Policy; +import org.apache.gravitino.oss.credential.policy.Statement; +import org.apache.gravitino.oss.credential.policy.StringLike; + +/** Generates OSS token to access OSS data. */ +public class OSSTokenGenerator implements CredentialGenerator<OSSTokenCredential> { + + private final ObjectMapper objectMapper = new ObjectMapper(); + private String accessKeyId; + private String secretAccessKey; + private String roleArn; + private String externalID; + private int tokenExpireSecs; + private String region; + + @Override + public void initialize(Map<String, String> properties) { + OSSCredentialConfig credentialConfig = new OSSCredentialConfig(properties); + this.roleArn = credentialConfig.ossRoleArn(); + this.externalID = credentialConfig.externalID(); + this.tokenExpireSecs = credentialConfig.tokenExpireInSecs(); + this.accessKeyId = credentialConfig.accessKeyID(); + this.secretAccessKey = credentialConfig.secretAccessKey(); + this.region = credentialConfig.region(); + } + + @Override + public OSSTokenCredential generate(CredentialContext context) throws Exception { + if (!(context instanceof PathBasedCredentialContext)) { + return null; + } + + PathBasedCredentialContext pathContext = (PathBasedCredentialContext) context; + + CredentialModel credentialModel = + createOSSCredentialModel( + pathContext.getReadPaths(), pathContext.getWritePaths(), pathContext.getUserName()); + + return new OSSTokenCredential( + credentialModel.accessKeyId, + credentialModel.accessKeySecret, + credentialModel.securityToken, + credentialModel.expiration); + } + + private CredentialModel createOSSCredentialModel( + Set<String> readLocations, Set<String> writeLocations, String userName) { + Config config = new Config(); + config.setAccessKeyId(accessKeyId); + config.setAccessKeySecret(secretAccessKey); + config.setType(AuthConstant.RAM_ROLE_ARN); + config.setRoleArn(roleArn); + config.setRoleSessionName(getRoleName(userName)); + if (StringUtils.isNotBlank(externalID)) { + config.setExternalId(externalID); + } + config.setRoleSessionExpiration(tokenExpireSecs); + config.setPolicy(createPolicy(readLocations, writeLocations, region)); + // Local object and client is a simple proxy that does not require manual release + Client client = new Client(config); + return client.getCredential(); + } + + private String createPolicy( + Set<String> readLocations, Set<String> writeLocations, String region) { + Policy.Builder policyBuilder = Policy.builder().version("1"); + + Statement.Builder allowGetObjectStatementBuilder = + Statement.builder() + .effect(Effect.ALLOW) + .addAction("oss:GetObject") + .addAction("oss:GetObjectVersion"); + + Map<String, Statement.Builder> bucketListStatementBuilder = new HashMap<>(); + Map<String, Statement.Builder> bucketMetadataStatementBuilder = new HashMap<>(); + + String arnPrefix = getArnPrefix(region); + Stream.concat(readLocations.stream(), writeLocations.stream()) + .distinct() + .forEach( + location -> { + URI uri = URI.create(location); + allowGetObjectStatementBuilder.addResource(getOssUriWithArn(arnPrefix, uri)); + String bucketArn = arnPrefix + getBucketName(uri); + bucketListStatementBuilder.computeIfAbsent( + bucketArn, + key -> + Statement.builder() + .effect(Effect.ALLOW) + .addAction("oss:ListObjects") + .addResource(key) + .condition(getCondition(uri))); + bucketMetadataStatementBuilder.computeIfAbsent( + bucketArn, + key -> + Statement.builder() + .effect(Effect.ALLOW) + .addAction("oss:GetBucketLocation") + .addAction("oss:GetBucketInfo") + .addResource(key)); + }); + + if (!writeLocations.isEmpty()) { + Statement.Builder allowPutObjectStatementBuilder = + Statement.builder() + .effect(Effect.ALLOW) + .addAction("oss:PutObject") + .addAction("oss:DeleteObject"); + writeLocations.forEach( + location -> { + URI uri = URI.create(location); + allowPutObjectStatementBuilder.addResource(getOssUriWithArn(arnPrefix, uri)); + }); + policyBuilder.addStatement(allowPutObjectStatementBuilder.build()); + } + + if (!bucketListStatementBuilder.isEmpty()) { + bucketListStatementBuilder + .values() + .forEach(statementBuilder -> policyBuilder.addStatement(statementBuilder.build())); + } else { + policyBuilder.addStatement( + Statement.builder().effect(Effect.ALLOW).addAction("oss:ListBucket").build()); + } + bucketMetadataStatementBuilder + .values() + .forEach(statementBuilder -> policyBuilder.addStatement(statementBuilder.build())); + + policyBuilder.addStatement(allowGetObjectStatementBuilder.build()); + try { + return objectMapper.writeValueAsString(policyBuilder.build()); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + private Condition getCondition(URI uri) { + return Condition.builder() + .stringLike( + StringLike.builder() + .addPrefix(concatPathWithSep(trimLeadingSlash(uri.getPath()), "*", "/")) + .build()) + .build(); + } + + private String getArnPrefix(String region) { + if (StringUtils.isNotEmpty(region)) { + return "acs:oss:" + region + ":*:"; + } + return "acs:oss:*:*:"; + } + + private String getBucketName(URI uri) { + return uri.getHost(); + } + + private String getOssUriWithArn(String arnPrefix, URI uri) { + return arnPrefix + concatPathWithSep(removeSchemaFromOSSUri(uri), "*", "/"); + } + + private static String concatPathWithSep(String leftPath, String rightPath, String fileSep) { + if (leftPath.endsWith(fileSep) && rightPath.startsWith(fileSep)) { + return leftPath + rightPath.substring(1); + } else if (!leftPath.endsWith(fileSep) && !rightPath.startsWith(fileSep)) { + return leftPath + fileSep + rightPath; + } else { + return leftPath + rightPath; + } + } + + private String removeSchemaFromOSSUri(URI uri) { + String bucket = uri.getHost(); + String path = trimLeadingSlash(uri.getPath()); + return String.join( + "/", Stream.of(bucket, path).filter(Objects::nonNull).toArray(String[]::new)); + } + + private String trimLeadingSlash(String path) { + return path.startsWith("/") ? path.substring(1) : path; + } + + private String getRoleName(String userName) { + return "gravitino_" + userName; + } + + @Override + public void close() throws IOException {} +} Review Comment: Consider adding test coverage for the newly extracted `OSSTokenGenerator` class. The existing tests in `TestCredentialProvider` only verify the credential type, but don't test the core credential generation logic that was moved from `OSSTokenProvider` to this generator, including: - The `initialize()` method with various OSS configurations - The `generate()` method's credential model creation - RAM role policy generation for different read/write path combinations - Error handling for invalid configurations or credential failures This is important because the generator contains the main business logic for creating OSS credentials with proper RAM policies. ########## bundles/aws/src/main/java/org/apache/gravitino/s3/credential/AwsIrsaCredentialGenerator.java: ########## @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.gravitino.s3.credential; + +import java.io.IOException; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Stream; +import org.apache.commons.lang3.StringUtils; +import org.apache.gravitino.credential.AwsIrsaCredential; +import org.apache.gravitino.credential.CredentialContext; +import org.apache.gravitino.credential.CredentialGenerator; +import org.apache.gravitino.credential.PathBasedCredentialContext; +import org.apache.gravitino.credential.config.S3CredentialConfig; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; +import software.amazon.awssdk.auth.credentials.WebIdentityTokenFileCredentialsProvider; +import software.amazon.awssdk.policybuilder.iam.IamConditionOperator; +import software.amazon.awssdk.policybuilder.iam.IamEffect; +import software.amazon.awssdk.policybuilder.iam.IamPolicy; +import software.amazon.awssdk.policybuilder.iam.IamResource; +import software.amazon.awssdk.policybuilder.iam.IamStatement; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.services.sts.StsClientBuilder; +import software.amazon.awssdk.services.sts.model.AssumeRoleWithWebIdentityRequest; +import software.amazon.awssdk.services.sts.model.AssumeRoleWithWebIdentityResponse; +import software.amazon.awssdk.services.sts.model.Credentials; + +/** Generate AWS IRSA credentials according to the read and write paths. */ +public class AwsIrsaCredentialGenerator implements CredentialGenerator<AwsIrsaCredential> { + + private WebIdentityTokenFileCredentialsProvider baseCredentialsProvider; + private String roleArn; + private int tokenExpireSecs; + private String region; + private String stsEndpoint; + + @Override + public void initialize(Map<String, String> properties) { + // Use WebIdentityTokenFileCredentialsProvider for base IRSA configuration + this.baseCredentialsProvider = WebIdentityTokenFileCredentialsProvider.create(); + + S3CredentialConfig s3CredentialConfig = new S3CredentialConfig(properties); + this.roleArn = s3CredentialConfig.s3RoleArn(); + this.tokenExpireSecs = s3CredentialConfig.tokenExpireInSecs(); + this.region = s3CredentialConfig.region(); + this.stsEndpoint = s3CredentialConfig.stsEndpoint(); + } + + @Override + public AwsIrsaCredential generate(CredentialContext context) { + if (!(context instanceof PathBasedCredentialContext)) { + // Fallback to original behavior for non-path-based contexts + AwsCredentials creds = baseCredentialsProvider.resolveCredentials(); + if (creds instanceof AwsSessionCredentials) { + AwsSessionCredentials sessionCreds = (AwsSessionCredentials) creds; + long expiration = + sessionCreds.expirationTime().isPresent() + ? sessionCreds.expirationTime().get().toEpochMilli() + : 0L; + return new AwsIrsaCredential( + sessionCreds.accessKeyId(), + sessionCreds.secretAccessKey(), + sessionCreds.sessionToken(), + expiration); + } else { + throw new IllegalStateException( + "AWS IRSA credentials must be of type AwsSessionCredentials. " + + "Check your EKS/IRSA configuration. Got: " + + creds.getClass().getName()); + } + } + + PathBasedCredentialContext pathBasedCredentialContext = (PathBasedCredentialContext) context; + + Credentials s3Token = + createCredentialsWithSessionPolicy( + pathBasedCredentialContext.getReadPaths(), + pathBasedCredentialContext.getWritePaths(), + pathBasedCredentialContext.getUserName()); + return new AwsIrsaCredential( + s3Token.accessKeyId(), + s3Token.secretAccessKey(), + s3Token.sessionToken(), + s3Token.expiration().toEpochMilli()); + } + + private Credentials createCredentialsWithSessionPolicy( + Set<String> readLocations, Set<String> writeLocations, String userName) { + validateInputParameters(readLocations, writeLocations, userName); + + IamPolicy sessionPolicy = createSessionPolicy(readLocations, writeLocations, region); + String webIdentityTokenFile = getValidatedWebIdentityTokenFile(); + String effectiveRoleArn = getValidatedRoleArn(roleArn); + + try { + String tokenContent = + new String(Files.readAllBytes(Paths.get(webIdentityTokenFile)), StandardCharsets.UTF_8); + if (StringUtils.isBlank(tokenContent)) { + throw new IllegalStateException( + "Web identity token file is empty: " + webIdentityTokenFile); + } + + return assumeRoleWithSessionPolicy(effectiveRoleArn, userName, tokenContent, sessionPolicy); + } catch (Exception e) { + throw new RuntimeException( + "Failed to create credentials with session policy for user: " + userName, e); + } + } + + private IamPolicy createSessionPolicy( + Set<String> readLocations, Set<String> writeLocations, String region) { + IamPolicy.Builder policyBuilder = IamPolicy.builder(); + String arnPrefix = getArnPrefix(region); + + addReadPermissions(policyBuilder, readLocations, writeLocations, arnPrefix); + if (!writeLocations.isEmpty()) { + addWritePermissions(policyBuilder, writeLocations, arnPrefix); + } + addBucketPermissions(policyBuilder, readLocations, writeLocations, arnPrefix); + + return policyBuilder.build(); + } + + private void addReadPermissions( + IamPolicy.Builder policyBuilder, + Set<String> readLocations, + Set<String> writeLocations, + String arnPrefix) { + IamStatement.Builder allowGetObjectStatementBuilder = + IamStatement.builder() + .effect(IamEffect.ALLOW) + .addAction("s3:GetObject") + .addAction("s3:GetObjectVersion"); + + Stream.concat(readLocations.stream(), writeLocations.stream()) + .distinct() + .forEach( + location -> { + URI uri = URI.create(location); + allowGetObjectStatementBuilder.addResource( + IamResource.create(getS3UriWithArn(arnPrefix, uri))); + }); + + policyBuilder.addStatement(allowGetObjectStatementBuilder.build()); + } + + private void addWritePermissions( + IamPolicy.Builder policyBuilder, Set<String> writeLocations, String arnPrefix) { + IamStatement.Builder allowPutObjectStatementBuilder = + IamStatement.builder() + .effect(IamEffect.ALLOW) + .addAction("s3:PutObject") + .addAction("s3:DeleteObject"); + + writeLocations.forEach( + location -> { + URI uri = URI.create(location); + allowPutObjectStatementBuilder.addResource( + IamResource.create(getS3UriWithArn(arnPrefix, uri))); + }); + + policyBuilder.addStatement(allowPutObjectStatementBuilder.build()); + } + + private void addBucketPermissions( + IamPolicy.Builder policyBuilder, + Set<String> readLocations, + Set<String> writeLocations, + String arnPrefix) { + Map<String, IamStatement.Builder> bucketListStatementBuilder = new HashMap<>(); + Map<String, IamStatement.Builder> bucketGetLocationStatementBuilder = new HashMap<>(); + + Stream.concat(readLocations.stream(), writeLocations.stream()) + .distinct() + .forEach( + location -> { + URI uri = URI.create(location); + String bucketArn = arnPrefix + getBucketName(uri); + String rawPath = trimLeadingSlash(uri.getPath()); + + bucketListStatementBuilder + .computeIfAbsent( + bucketArn, + key -> + IamStatement.builder() + .effect(IamEffect.ALLOW) + .addAction("s3:ListBucket") + .addResource(key)) + .addConditions( + IamConditionOperator.STRING_LIKE, + "s3:prefix", + Arrays.asList(rawPath, addWildcardToPath(rawPath))); + + bucketGetLocationStatementBuilder.computeIfAbsent( + bucketArn, + key -> + IamStatement.builder() + .effect(IamEffect.ALLOW) + .addAction("s3:GetBucketLocation") + .addResource(key)); + }); + + addStatementsToPolicy(policyBuilder, bucketListStatementBuilder); + addStatementsToPolicy(policyBuilder, bucketGetLocationStatementBuilder); + } + + private void addStatementsToPolicy( + IamPolicy.Builder policyBuilder, Map<String, IamStatement.Builder> statementBuilders) { + statementBuilders.values().forEach(builder -> policyBuilder.addStatement(builder.build())); + } + + private String getS3UriWithArn(String arnPrefix, URI uri) { + return arnPrefix + addWildcardToPath(removeSchemaFromS3Uri(uri)); + } + + private String getArnPrefix(String region) { + if (StringUtils.isNotBlank(region)) { + if (region.contains("cn-")) { + return "arn:aws-cn:s3:::"; + } else if (region.contains("us-gov-")) { + return "arn:aws-us-gov:s3:::"; + } + } + return "arn:aws:s3:::"; + } + + private static String addWildcardToPath(String path) { + return path.endsWith("/") ? path + "*" : path + "/*"; + } + + private static String removeSchemaFromS3Uri(URI uri) { + String bucket = uri.getHost(); + String path = trimLeadingSlash(uri.getPath()); + return String.join( + "/", Stream.of(bucket, path).filter(Objects::nonNull).toArray(String[]::new)); + } + + private static String trimLeadingSlash(String path) { + return path.startsWith("/") ? path.substring(1) : path; + } + + private static String getBucketName(URI uri) { + return uri.getHost(); + } + + private void validateInputParameters( + Set<String> readLocations, Set<String> writeLocations, String userName) { + if (StringUtils.isBlank(userName)) { + throw new IllegalArgumentException("userName cannot be null or empty"); + } + if ((readLocations == null || readLocations.isEmpty()) + && (writeLocations == null || writeLocations.isEmpty())) { + throw new IllegalArgumentException("At least one read or write location must be specified"); + } + } + + private String getValidatedWebIdentityTokenFile() { + String webIdentityTokenFile = System.getenv("AWS_WEB_IDENTITY_TOKEN_FILE"); + if (StringUtils.isBlank(webIdentityTokenFile)) { + throw new IllegalStateException( + "AWS_WEB_IDENTITY_TOKEN_FILE environment variable is not set. " + + "Ensure IRSA is properly configured in your EKS cluster."); + } + if (!Files.exists(Paths.get(webIdentityTokenFile))) { + throw new IllegalStateException( + "Web identity token file does not exist: " + webIdentityTokenFile); + } + return webIdentityTokenFile; + } + + private String getValidatedRoleArn(String configRoleArn) { + String effectiveRoleArn = + StringUtils.isNotBlank(configRoleArn) ? configRoleArn : System.getenv("AWS_ROLE_ARN"); + if (StringUtils.isBlank(effectiveRoleArn)) { + throw new IllegalStateException( + "No role ARN available. Either configure s3-role-arn or ensure AWS_ROLE_ARN environment variable is set."); + } + if (!effectiveRoleArn.startsWith("arn:aws")) { + throw new IllegalArgumentException("Invalid role ARN format: " + effectiveRoleArn); + } + return effectiveRoleArn; + } + + private Credentials assumeRoleWithSessionPolicy( + String roleArn, String userName, String webIdentityToken, IamPolicy sessionPolicy) { + StsClientBuilder stsBuilder = StsClient.builder(); + if (StringUtils.isNotBlank(region)) { + stsBuilder.region(Region.of(region)); + } + if (StringUtils.isNotBlank(stsEndpoint)) { + stsBuilder.endpointOverride(URI.create(stsEndpoint)); + } + + try (StsClient stsClient = stsBuilder.build()) { + AssumeRoleWithWebIdentityRequest request = + AssumeRoleWithWebIdentityRequest.builder() + .roleArn(roleArn) + .roleSessionName("gravitino_irsa_session_" + userName) + .durationSeconds(tokenExpireSecs) + .webIdentityToken(webIdentityToken) + .policy(sessionPolicy.toJson()) + .build(); + + AssumeRoleWithWebIdentityResponse response = stsClient.assumeRoleWithWebIdentity(request); + return response.credentials(); + } + } + + @Override + public void close() throws IOException {} +} Review Comment: Consider adding test coverage for the newly extracted `AwsIrsaCredentialGenerator` class. The existing tests in `TestCredentialProvider` only verify the credential type, but don't test the complex credential generation logic that was moved from `AwsIrsaCredentialProvider` to this generator, including: - The dual-mode behavior (basic IRSA vs. fine-grained path-based) - Session policy creation with proper IAM permissions - Web identity token file validation - Role ARN validation - Error handling for various failure scenarios This is particularly important given the complexity of the IRSA implementation and its dual-mode behavior. ########## common/src/main/java/org/apache/gravitino/credential/CredentialProviderDelegator.java: ########## @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.gravitino.credential; + +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.util.Map; + +/** + * An abstract base class for {@link CredentialProvider} implementations that delegate the actual + * credential generation to a {@link CredentialGenerator}. It handles the lazy and reflective + * loading of the generator to isolate heavy dependencies. + * + * @param <T> The type of credential generated by this provider. + */ +public abstract class CredentialProviderDelegator<T extends Credential> + implements CredentialProvider { + + /** The properties used by the generator to generate the credential. */ + protected Map<String, String> properties; + + private volatile CredentialGenerator<T> generator; + + /** + * Initializes the provider by storing properties and loading the associated {@link + * CredentialGenerator}. + * + * @param properties A map of configuration properties for the provider. + */ + @Override + public void initialize(Map<String, String> properties) { + this.properties = properties; + this.generator = loadGenerator(); + generator.initialize(properties); + } + + /** + * Delegates the credential generation to the loaded {@link CredentialGenerator}. + * + * @param context The context containing information required for credential retrieval. + * @return A {@link Credential} object. + * @throws RuntimeException if credential generation fails. + */ + @Override + public Credential getCredential(CredentialContext context) { + try { + return generator.generate(context); + } catch (Exception e) { + throw new RuntimeException( + "Failed to generate credential using " + getGeneratorClassName(), e); + } + } + + @Override + public void close() throws IOException { + if (generator != null) { + generator.close(); + } + } + + /** + * Returns the fully qualified class name of the {@link CredentialGenerator} implementation. This + * generator will be loaded via reflection to perform the actual credential creation. + * + * @return The class name of the credential generator. + */ + protected abstract String getGeneratorClassName(); + + /** + * Loads and instantiates the {@link CredentialGenerator} using reflection. + * + * <p>This implementation uses a no-argument constructor. The constructor can be non-public. + * + * @return An instance of the credential generator. + * @throws RuntimeException if the generator cannot be loaded or instantiated. + */ + @SuppressWarnings("unchecked") + private CredentialGenerator<T> loadGenerator() { + try { + Class<?> generatorClass = Class.forName(getGeneratorClassName()); + Constructor<?> constructor = generatorClass.getDeclaredConstructor(); + constructor.setAccessible(true); + return (CredentialGenerator<T>) constructor.newInstance(); + } catch (Exception e) { + throw new RuntimeException( + "Failed to load or instantiate CredentialGenerator: " + getGeneratorClassName(), e); + } + } Review Comment: Consider adding test coverage for the reflective loading mechanism in `CredentialProviderDelegator`. The `loadGenerator()` method uses reflection to instantiate generators, and this critical path should be tested to ensure proper error handling when: - The generator class doesn't exist - The generator class doesn't have a no-arg constructor - The constructor throws an exception This is particularly important since these generators are loaded at runtime and failures here would only be caught during actual usage. ########## bundles/gcp/src/main/java/org/apache/gravitino/gcs/credential/GCSTokenGenerator.java: ########## @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.gravitino.gcs.credential; + +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.CredentialAccessBoundary; +import com.google.auth.oauth2.CredentialAccessBoundary.AccessBoundaryRule; +import com.google.auth.oauth2.DownscopedCredentials; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.common.base.Preconditions; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Stream; +import org.apache.commons.lang3.StringUtils; +import org.apache.gravitino.credential.CredentialContext; +import org.apache.gravitino.credential.CredentialGenerator; +import org.apache.gravitino.credential.GCSTokenCredential; +import org.apache.gravitino.credential.PathBasedCredentialContext; +import org.apache.gravitino.credential.config.GCSCredentialConfig; + +/** Generate GCS access token according to the read and write paths. */ +public class GCSTokenGenerator implements CredentialGenerator<GCSTokenCredential> { + + private static final String INITIAL_SCOPE = "https://www.googleapis.com/auth/cloud-platform"; + private GoogleCredentials sourceCredentials; + + @Override + public void initialize(Map<String, String> properties) { + GCSCredentialConfig gcsCredentialConfig = new GCSCredentialConfig(properties); + try { + this.sourceCredentials = + getSourceCredentials(gcsCredentialConfig).createScoped(INITIAL_SCOPE); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public GCSTokenCredential generate(CredentialContext context) throws IOException { + if (!(context instanceof PathBasedCredentialContext)) { + return null; + } + + PathBasedCredentialContext pathBasedCredentialContext = (PathBasedCredentialContext) context; + AccessToken accessToken = + getToken( + sourceCredentials, + pathBasedCredentialContext.getReadPaths(), + pathBasedCredentialContext.getWritePaths()); + + String tokenValue = accessToken.getTokenValue(); + long expireTime = accessToken.getExpirationTime().toInstant().toEpochMilli(); + return new GCSTokenCredential(tokenValue, expireTime); + } + + private AccessToken getToken( + GoogleCredentials sourceCredentials, Set<String> readLocations, Set<String> writeLocations) + throws IOException { + DownscopedCredentials downscopedCredentials = + DownscopedCredentials.newBuilder() + .setSourceCredential(sourceCredentials) + .setCredentialAccessBoundary(getAccessBoundary(readLocations, writeLocations)) + .build(); + return downscopedCredentials.refreshAccessToken(); + } + + private List<String> getReadExpressions(String bucketName, String resourcePath) { + List<String> readExpressions = new ArrayList<>(); + readExpressions.add( + String.format( + "resource.name.startsWith('projects/_/buckets/%s/objects/%s')", + bucketName, resourcePath)); + getAllResources(resourcePath) + .forEach( + parentResourcePath -> + readExpressions.add( + String.format( + "resource.name == 'projects/_/buckets/%s/objects/%s'", + bucketName, parentResourcePath))); + return readExpressions; + } + + // "a/b/c" will get ["a", "a/", "a/b", "a/b/", "a/b/c"] + static List<String> getAllResources(String resourcePath) { + if (resourcePath.endsWith("/")) { + resourcePath = resourcePath.substring(0, resourcePath.length() - 1); + } + if (resourcePath.isEmpty()) { + return Arrays.asList(""); + } + Preconditions.checkArgument( + !resourcePath.startsWith("/"), resourcePath + " should not start with /"); + List<String> parts = Arrays.asList(resourcePath.split("/")); + List<String> results = new ArrayList<>(); + String parent = ""; + for (int i = 0; i < parts.size() - 1; i++) { + results.add(parts.get(i)); + parent += parts.get(i) + "/"; + results.add(parent); + } + results.add(parent + parts.get(parts.size() - 1)); + return results; + } + + // Remove the first '/', and append `/` if the path does not end with '/'. + static String normalizeUriPath(String resourcePath) { + if (resourcePath.isEmpty() || "/".equals(resourcePath)) { + return ""; + } + if (resourcePath.startsWith("/")) { + resourcePath = resourcePath.substring(1); + } + if (resourcePath.endsWith("/")) { + return resourcePath; + } + return resourcePath + "/"; + } + + private CredentialAccessBoundary getAccessBoundary( + Set<String> readLocations, Set<String> writeLocations) { + Map<String, List<String>> readExpressions = new HashMap<>(); + Map<String, List<String>> writeExpressions = new HashMap<>(); + + HashSet<String> readBuckets = new HashSet<>(); + HashSet<String> writeBuckets = new HashSet<>(); + Stream.concat(readLocations.stream(), writeLocations.stream()) + .distinct() + .forEach( + location -> { + URI uri = URI.create(location); + String bucketName = getBucketName(uri); + readBuckets.add(bucketName); + String resourcePath = normalizeUriPath(uri.getPath()); + List<String> resourceExpressions = + readExpressions.computeIfAbsent(bucketName, key -> new ArrayList<>()); + resourceExpressions.addAll(getReadExpressions(bucketName, resourcePath)); + resourceExpressions.add( + String.format( + "api.getAttribute('storage.googleapis.com/objectListPrefix', '').startsWith('%s')", + resourcePath)); + if (writeLocations.contains(location)) { + writeBuckets.add(bucketName); + resourceExpressions = + writeExpressions.computeIfAbsent(bucketName, key -> new ArrayList<>()); + resourceExpressions.add( + String.format( + "resource.name.startsWith('projects/_/buckets/%s/objects/%s')", + bucketName, resourcePath)); + } + }); + + CredentialAccessBoundary.Builder credentialAccessBoundaryBuilder = + CredentialAccessBoundary.newBuilder(); + readBuckets.forEach( + bucket -> { + AccessBoundaryRule bucketInfoRule = + AccessBoundaryRule.newBuilder() + .setAvailableResource(toGCSBucketResource(bucket)) + .setAvailablePermissions( + Arrays.asList("inRole:roles/storage.insightsCollectorService")) + .build(); + credentialAccessBoundaryBuilder.addRule(bucketInfoRule); + List<String> readConditions = readExpressions.get(bucket); + AccessBoundaryRule rule = + getAccessBoundaryRule( + bucket, readConditions, Arrays.asList("inRole:roles/storage.objectViewer")); + if (rule != null) { + credentialAccessBoundaryBuilder.addRule(rule); + } + }); + + writeBuckets.forEach( + bucket -> { + List<String> writeConditions = writeExpressions.get(bucket); + AccessBoundaryRule rule = + getAccessBoundaryRule( + bucket, + writeConditions, + Arrays.asList("inRole:roles/storage.legacyBucketWriter")); + if (rule != null) { + credentialAccessBoundaryBuilder.addRule(rule); + } + }); + + return credentialAccessBoundaryBuilder.build(); + } + + private AccessBoundaryRule getAccessBoundaryRule( + String bucketName, List<String> resourceExpression, List<String> permissions) { + if (resourceExpression == null || resourceExpression.isEmpty()) { + return null; + } + return AccessBoundaryRule.newBuilder() + .setAvailableResource(toGCSBucketResource(bucketName)) + .setAvailabilityCondition( + AccessBoundaryRule.AvailabilityCondition.newBuilder() + .setExpression(String.join(" || ", resourceExpression)) + .build()) + .setAvailablePermissions(permissions) + .build(); + } + + private static String toGCSBucketResource(String bucketName) { + return "//storage.googleapis.com/projects/_/buckets/" + bucketName; + } + + private static String getBucketName(URI uri) { + return uri.getHost(); + } + + private GoogleCredentials getSourceCredentials(GCSCredentialConfig gcsCredentialConfig) + throws IOException { + String gcsCredentialFilePath = gcsCredentialConfig.gcsCredentialFilePath(); + if (StringUtils.isBlank(gcsCredentialFilePath)) { + return GoogleCredentials.getApplicationDefault(); + } + Path credentialsFilePath = Paths.get(gcsCredentialFilePath); + try (InputStream fileInputStream = Files.newInputStream(credentialsFilePath)) { + return GoogleCredentials.fromStream(fileInputStream); + } catch (NoSuchFileException e) { + throw new IOException("GCS credential file does not exist." + gcsCredentialFilePath, e); + } + } + + @Override + public void close() throws IOException {} +} Review Comment: Consider adding test coverage for the newly extracted `GCSTokenGenerator` class. While the existing tests in `TestGCSTokenProvider` verify the static utility methods (`getAllResources` and `normalizeUriPath`), there's no test coverage for the core credential generation logic that was moved to the generator, including: - The `initialize()` method with various configuration scenarios - The `generate()` method's credential creation logic - Error handling when credentials cannot be obtained This is important because the generator now contains the critical business logic that was previously tested as part of the provider. ########## bundles/gcp/src/main/java/org/apache/gravitino/gcs/credential/GCSTokenGenerator.java: ########## @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.gravitino.gcs.credential; + +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.CredentialAccessBoundary; +import com.google.auth.oauth2.CredentialAccessBoundary.AccessBoundaryRule; +import com.google.auth.oauth2.DownscopedCredentials; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.common.base.Preconditions; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Stream; +import org.apache.commons.lang3.StringUtils; +import org.apache.gravitino.credential.CredentialContext; +import org.apache.gravitino.credential.CredentialGenerator; +import org.apache.gravitino.credential.GCSTokenCredential; +import org.apache.gravitino.credential.PathBasedCredentialContext; +import org.apache.gravitino.credential.config.GCSCredentialConfig; + +/** Generate GCS access token according to the read and write paths. */ +public class GCSTokenGenerator implements CredentialGenerator<GCSTokenCredential> { + + private static final String INITIAL_SCOPE = "https://www.googleapis.com/auth/cloud-platform"; + private GoogleCredentials sourceCredentials; + + @Override + public void initialize(Map<String, String> properties) { + GCSCredentialConfig gcsCredentialConfig = new GCSCredentialConfig(properties); + try { + this.sourceCredentials = + getSourceCredentials(gcsCredentialConfig).createScoped(INITIAL_SCOPE); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public GCSTokenCredential generate(CredentialContext context) throws IOException { + if (!(context instanceof PathBasedCredentialContext)) { + return null; + } + + PathBasedCredentialContext pathBasedCredentialContext = (PathBasedCredentialContext) context; + AccessToken accessToken = + getToken( + sourceCredentials, + pathBasedCredentialContext.getReadPaths(), + pathBasedCredentialContext.getWritePaths()); + + String tokenValue = accessToken.getTokenValue(); + long expireTime = accessToken.getExpirationTime().toInstant().toEpochMilli(); + return new GCSTokenCredential(tokenValue, expireTime); + } + + private AccessToken getToken( + GoogleCredentials sourceCredentials, Set<String> readLocations, Set<String> writeLocations) + throws IOException { + DownscopedCredentials downscopedCredentials = + DownscopedCredentials.newBuilder() + .setSourceCredential(sourceCredentials) + .setCredentialAccessBoundary(getAccessBoundary(readLocations, writeLocations)) + .build(); + return downscopedCredentials.refreshAccessToken(); + } + + private List<String> getReadExpressions(String bucketName, String resourcePath) { + List<String> readExpressions = new ArrayList<>(); + readExpressions.add( + String.format( + "resource.name.startsWith('projects/_/buckets/%s/objects/%s')", + bucketName, resourcePath)); + getAllResources(resourcePath) + .forEach( + parentResourcePath -> + readExpressions.add( + String.format( + "resource.name == 'projects/_/buckets/%s/objects/%s'", + bucketName, parentResourcePath))); + return readExpressions; + } + + // "a/b/c" will get ["a", "a/", "a/b", "a/b/", "a/b/c"] + static List<String> getAllResources(String resourcePath) { + if (resourcePath.endsWith("/")) { + resourcePath = resourcePath.substring(0, resourcePath.length() - 1); + } + if (resourcePath.isEmpty()) { + return Arrays.asList(""); + } + Preconditions.checkArgument( + !resourcePath.startsWith("/"), resourcePath + " should not start with /"); + List<String> parts = Arrays.asList(resourcePath.split("/")); + List<String> results = new ArrayList<>(); + String parent = ""; + for (int i = 0; i < parts.size() - 1; i++) { + results.add(parts.get(i)); + parent += parts.get(i) + "/"; + results.add(parent); + } + results.add(parent + parts.get(parts.size() - 1)); + return results; + } + + // Remove the first '/', and append `/` if the path does not end with '/'. + static String normalizeUriPath(String resourcePath) { + if (resourcePath.isEmpty() || "/".equals(resourcePath)) { + return ""; + } + if (resourcePath.startsWith("/")) { + resourcePath = resourcePath.substring(1); + } + if (resourcePath.endsWith("/")) { + return resourcePath; + } + return resourcePath + "/"; + } + + private CredentialAccessBoundary getAccessBoundary( + Set<String> readLocations, Set<String> writeLocations) { + Map<String, List<String>> readExpressions = new HashMap<>(); + Map<String, List<String>> writeExpressions = new HashMap<>(); + + HashSet<String> readBuckets = new HashSet<>(); + HashSet<String> writeBuckets = new HashSet<>(); + Stream.concat(readLocations.stream(), writeLocations.stream()) + .distinct() + .forEach( + location -> { + URI uri = URI.create(location); + String bucketName = getBucketName(uri); + readBuckets.add(bucketName); + String resourcePath = normalizeUriPath(uri.getPath()); + List<String> resourceExpressions = + readExpressions.computeIfAbsent(bucketName, key -> new ArrayList<>()); + resourceExpressions.addAll(getReadExpressions(bucketName, resourcePath)); + resourceExpressions.add( + String.format( + "api.getAttribute('storage.googleapis.com/objectListPrefix', '').startsWith('%s')", + resourcePath)); + if (writeLocations.contains(location)) { + writeBuckets.add(bucketName); + resourceExpressions = + writeExpressions.computeIfAbsent(bucketName, key -> new ArrayList<>()); + resourceExpressions.add( + String.format( + "resource.name.startsWith('projects/_/buckets/%s/objects/%s')", + bucketName, resourcePath)); + } + }); + + CredentialAccessBoundary.Builder credentialAccessBoundaryBuilder = + CredentialAccessBoundary.newBuilder(); + readBuckets.forEach( + bucket -> { + AccessBoundaryRule bucketInfoRule = + AccessBoundaryRule.newBuilder() + .setAvailableResource(toGCSBucketResource(bucket)) + .setAvailablePermissions( + Arrays.asList("inRole:roles/storage.insightsCollectorService")) + .build(); + credentialAccessBoundaryBuilder.addRule(bucketInfoRule); + List<String> readConditions = readExpressions.get(bucket); + AccessBoundaryRule rule = + getAccessBoundaryRule( + bucket, readConditions, Arrays.asList("inRole:roles/storage.objectViewer")); + if (rule != null) { + credentialAccessBoundaryBuilder.addRule(rule); + } + }); + + writeBuckets.forEach( + bucket -> { + List<String> writeConditions = writeExpressions.get(bucket); + AccessBoundaryRule rule = + getAccessBoundaryRule( + bucket, + writeConditions, + Arrays.asList("inRole:roles/storage.legacyBucketWriter")); + if (rule != null) { + credentialAccessBoundaryBuilder.addRule(rule); + } + }); + + return credentialAccessBoundaryBuilder.build(); + } + + private AccessBoundaryRule getAccessBoundaryRule( + String bucketName, List<String> resourceExpression, List<String> permissions) { + if (resourceExpression == null || resourceExpression.isEmpty()) { + return null; + } + return AccessBoundaryRule.newBuilder() + .setAvailableResource(toGCSBucketResource(bucketName)) + .setAvailabilityCondition( + AccessBoundaryRule.AvailabilityCondition.newBuilder() + .setExpression(String.join(" || ", resourceExpression)) + .build()) + .setAvailablePermissions(permissions) + .build(); + } + + private static String toGCSBucketResource(String bucketName) { + return "//storage.googleapis.com/projects/_/buckets/" + bucketName; + } + + private static String getBucketName(URI uri) { + return uri.getHost(); + } + + private GoogleCredentials getSourceCredentials(GCSCredentialConfig gcsCredentialConfig) + throws IOException { + String gcsCredentialFilePath = gcsCredentialConfig.gcsCredentialFilePath(); + if (StringUtils.isBlank(gcsCredentialFilePath)) { + return GoogleCredentials.getApplicationDefault(); + } + Path credentialsFilePath = Paths.get(gcsCredentialFilePath); + try (InputStream fileInputStream = Files.newInputStream(credentialsFilePath)) { + return GoogleCredentials.fromStream(fileInputStream); + } catch (NoSuchFileException e) { + throw new IOException("GCS credential file does not exist." + gcsCredentialFilePath, e); Review Comment: Missing space in error message after period. The message should read: "GCS credential file does not exist. " (with a space after the period) instead of "GCS credential file does not exist." + gcsCredentialFilePath to ensure proper formatting. ```suggestion throw new IOException("GCS credential file does not exist. " + gcsCredentialFilePath, e); ``` ########## bundles/azure/src/main/java/org/apache/gravitino/abs/credential/ADLSTokenGenerator.java: ########## @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.gravitino.abs.credential; + +import com.azure.core.util.Context; +import com.azure.identity.ClientSecretCredential; +import com.azure.identity.ClientSecretCredentialBuilder; +import com.azure.storage.file.datalake.DataLakeServiceClient; +import com.azure.storage.file.datalake.DataLakeServiceClientBuilder; +import com.azure.storage.file.datalake.implementation.util.DataLakeSasImplUtil; +import com.azure.storage.file.datalake.models.UserDelegationKey; +import com.azure.storage.file.datalake.sas.DataLakeServiceSasSignatureValues; +import com.azure.storage.file.datalake.sas.PathSasPermission; +import java.io.IOException; +import java.time.OffsetDateTime; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import org.apache.gravitino.credential.ADLSTokenCredential; +import org.apache.gravitino.credential.CredentialContext; +import org.apache.gravitino.credential.CredentialGenerator; +import org.apache.gravitino.credential.PathBasedCredentialContext; +import org.apache.gravitino.credential.config.AzureCredentialConfig; + +/** Generates ADLS token to access ADLS data. */ +public class ADLSTokenGenerator implements CredentialGenerator<ADLSTokenCredential> { + + private String storageAccountName; + private String tenantId; + private String clientId; + private String clientSecret; + private String endpoint; + private Integer tokenExpireSecs; + + @Override + public void initialize(Map<String, String> properties) { + AzureCredentialConfig azureCredentialConfig = new AzureCredentialConfig(properties); + this.storageAccountName = azureCredentialConfig.storageAccountName(); + this.tenantId = azureCredentialConfig.tenantId(); + this.clientId = azureCredentialConfig.clientId(); + this.clientSecret = azureCredentialConfig.clientSecret(); + this.endpoint = + String.format("https://%s.%s", storageAccountName, ADLSTokenCredential.ADLS_DOMAIN); + this.tokenExpireSecs = azureCredentialConfig.adlsTokenExpireInSecs(); + } + + @Override + public ADLSTokenCredential generate(CredentialContext context) { + if (!(context instanceof PathBasedCredentialContext)) { + return null; + } + + PathBasedCredentialContext pathContext = (PathBasedCredentialContext) context; + + Set<String> writePaths = pathContext.getWritePaths(); + Set<String> readPaths = pathContext.getReadPaths(); + Set<String> combinedPaths = new HashSet<>(writePaths); + combinedPaths.addAll(readPaths); + + if (combinedPaths.size() != 1) { + throw new IllegalArgumentException( + "ADLS should contain exactly one unique path, but found: " + + combinedPaths.size() + + " paths: " + + combinedPaths); + } + String uniquePath = combinedPaths.iterator().next(); + + ClientSecretCredential clientSecretCredential = + new ClientSecretCredentialBuilder() + .tenantId(tenantId) + .clientId(clientId) + .clientSecret(clientSecret) + .build(); + + DataLakeServiceClient dataLakeServiceClient = + new DataLakeServiceClientBuilder() + .endpoint(endpoint) + .credential(clientSecretCredential) + .buildClient(); + + OffsetDateTime start = OffsetDateTime.now(); + OffsetDateTime expiry = OffsetDateTime.now().plusSeconds(tokenExpireSecs); + UserDelegationKey userDelegationKey = dataLakeServiceClient.getUserDelegationKey(start, expiry); + + PathSasPermission pathSasPermission = + new PathSasPermission().setReadPermission(true).setListPermission(true); + if (!writePaths.isEmpty()) { + pathSasPermission + .setWritePermission(true) + .setDeletePermission(true) + .setCreatePermission(true) + .setAddPermission(true); + } + + DataLakeServiceSasSignatureValues signatureValues = + new DataLakeServiceSasSignatureValues(expiry, pathSasPermission); + ADLSLocationUtils.ADLSLocationParts locationParts = ADLSLocationUtils.parseLocation(uniquePath); + String sasToken = + new DataLakeSasImplUtil( + signatureValues, + locationParts.getContainer(), + ADLSLocationUtils.trimSlashes(locationParts.getPath()), + true) + .generateUserDelegationSas( + userDelegationKey, locationParts.getAccountName(), Context.NONE); + + return new ADLSTokenCredential( + locationParts.getAccountName(), sasToken, expiry.toInstant().toEpochMilli()); + } + + @Override + public void close() throws IOException {} +} Review Comment: Consider adding test coverage for the newly extracted `ADLSTokenGenerator` class. The existing tests in `TestCredentialProvider` only verify the credential type, but don't test the core credential generation logic that was moved from `ADLSTokenProvider` to this generator, including: - The `initialize()` method with various Azure configurations - The `generate()` method's SAS token generation logic - Error handling for invalid or missing configurations This is important because the generator contains the main business logic for creating ADLS credentials. ########## bundles/aws/src/main/java/org/apache/gravitino/s3/credential/S3TokenGenerator.java: ########## @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.gravitino.s3.credential; + +import java.io.IOException; +import java.net.URI; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Stream; +import org.apache.commons.lang3.StringUtils; +import org.apache.gravitino.credential.CredentialContext; +import org.apache.gravitino.credential.CredentialGenerator; +import org.apache.gravitino.credential.PathBasedCredentialContext; +import org.apache.gravitino.credential.S3TokenCredential; +import org.apache.gravitino.credential.config.S3CredentialConfig; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.policybuilder.iam.IamConditionOperator; +import software.amazon.awssdk.policybuilder.iam.IamEffect; +import software.amazon.awssdk.policybuilder.iam.IamPolicy; +import software.amazon.awssdk.policybuilder.iam.IamResource; +import software.amazon.awssdk.policybuilder.iam.IamStatement; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sts.StsClient; +import software.amazon.awssdk.services.sts.StsClientBuilder; +import software.amazon.awssdk.services.sts.model.AssumeRoleRequest; +import software.amazon.awssdk.services.sts.model.AssumeRoleResponse; +import software.amazon.awssdk.services.sts.model.Credentials; + +/** Generate S3 token credentials according to the read and write paths. */ +public class S3TokenGenerator implements CredentialGenerator<S3TokenCredential> { + + private StsClient stsClient; + private String roleArn; + private String externalID; + private int tokenExpireSecs; + + @Override + public void initialize(Map<String, String> properties) { + S3CredentialConfig s3CredentialConfig = new S3CredentialConfig(properties); + this.roleArn = s3CredentialConfig.s3RoleArn(); + this.externalID = s3CredentialConfig.externalID(); + this.tokenExpireSecs = s3CredentialConfig.tokenExpireInSecs(); + this.stsClient = createStsClient(s3CredentialConfig); + } + + @Override + public S3TokenCredential generate(CredentialContext context) { + if (!(context instanceof PathBasedCredentialContext)) { + return null; + } + + PathBasedCredentialContext pathContext = (PathBasedCredentialContext) context; + + Credentials s3Token = + createS3Token( + pathContext.getReadPaths(), pathContext.getWritePaths(), pathContext.getUserName()); + + return new S3TokenCredential( + s3Token.accessKeyId(), + s3Token.secretAccessKey(), + s3Token.sessionToken(), + s3Token.expiration().toEpochMilli()); + } + + private StsClient createStsClient(S3CredentialConfig s3CredentialConfig) { + AwsCredentialsProvider credentialsProvider = + StaticCredentialsProvider.create( + AwsBasicCredentials.create( + s3CredentialConfig.accessKeyID(), s3CredentialConfig.secretAccessKey())); + StsClientBuilder builder = StsClient.builder().credentialsProvider(credentialsProvider); + + if (StringUtils.isNotBlank(s3CredentialConfig.region())) { + builder.region(Region.of(s3CredentialConfig.region())); + } + if (StringUtils.isNotBlank(s3CredentialConfig.stsEndpoint())) { + builder.endpointOverride(URI.create(s3CredentialConfig.stsEndpoint())); + } + return builder.build(); + } + + private Credentials createS3Token( + Set<String> readLocations, Set<String> writeLocations, String userName) { + IamPolicy policy = createPolicy(roleArn, readLocations, writeLocations); + AssumeRoleRequest.Builder builder = + AssumeRoleRequest.builder() + .roleArn(roleArn) + .roleSessionName("gravitino_" + userName) + .durationSeconds(tokenExpireSecs) + .policy(policy.toJson()); + + if (StringUtils.isNotBlank(externalID)) { + builder.externalId(externalID); + } + AssumeRoleResponse response = stsClient.assumeRole(builder.build()); + return response.credentials(); + } + + private IamPolicy createPolicy( + String roleArn, Set<String> readLocations, Set<String> writeLocations) { + IamPolicy.Builder policyBuilder = IamPolicy.builder(); + IamStatement.Builder allowGetObjectStatementBuilder = + IamStatement.builder() + .effect(IamEffect.ALLOW) + .addAction("s3:GetObject") + .addAction("s3:GetObjectVersion"); + Map<String, IamStatement.Builder> bucketListStatementBuilder = new HashMap<>(); + Map<String, IamStatement.Builder> bucketGetLocationStatementBuilder = new HashMap<>(); + + String arnPrefix = getArnPrefix(roleArn); + Stream.concat(readLocations.stream(), writeLocations.stream()) + .distinct() + .forEach( + location -> { + URI uri = URI.create(location); + allowGetObjectStatementBuilder.addResource( + IamResource.create(getS3UriWithArn(arnPrefix, uri))); + String bucketArn = arnPrefix + getBucketName(uri); + String rawPath = trimLeadingSlash(uri.getPath()); + bucketListStatementBuilder + .computeIfAbsent( + bucketArn, + key -> + IamStatement.builder() + .effect(IamEffect.ALLOW) + .addAction("s3:ListBucket") + .addResource(key)) + .addConditions( + IamConditionOperator.STRING_LIKE, + "s3:prefix", + Arrays.asList(rawPath, addWildcardToPath(rawPath))); + + bucketGetLocationStatementBuilder.computeIfAbsent( + bucketArn, + key -> + IamStatement.builder() + .effect(IamEffect.ALLOW) + .addAction("s3:GetBucketLocation") + .addResource(key)); + }); + + if (!writeLocations.isEmpty()) { + IamStatement.Builder allowPutObjectStatementBuilder = + IamStatement.builder() + .effect(IamEffect.ALLOW) + .addAction("s3:PutObject") + .addAction("s3:DeleteObject"); + writeLocations.forEach( + location -> { + URI uri = URI.create(location); + allowPutObjectStatementBuilder.addResource( + IamResource.create(getS3UriWithArn(arnPrefix, uri))); + }); + policyBuilder.addStatement(allowPutObjectStatementBuilder.build()); + } + + bucketListStatementBuilder + .values() + .forEach(builder -> policyBuilder.addStatement(builder.build())); + bucketGetLocationStatementBuilder + .values() + .forEach(builder -> policyBuilder.addStatement(builder.build())); + policyBuilder.addStatement(allowGetObjectStatementBuilder.build()); + + return policyBuilder.build(); + } + + private String getS3UriWithArn(String arnPrefix, URI uri) { + return arnPrefix + addWildcardToPath(removeSchemaFromS3Uri(uri)); + } + + private String getArnPrefix(String roleArn) { + if (roleArn.contains("aws-cn")) { + return "arn:aws-cn:s3:::"; + } else if (roleArn.contains("aws-us-gov")) { + return "arn:aws-us-gov:s3:::"; + } + return "arn:aws:s3:::"; + } + + private static String addWildcardToPath(String path) { + return path.endsWith("/") ? path + "*" : path + "/*"; + } + + private static String removeSchemaFromS3Uri(URI uri) { + String bucket = uri.getHost(); + String path = trimLeadingSlash(uri.getPath()); + return String.join( + "/", Stream.of(bucket, path).filter(Objects::nonNull).toArray(String[]::new)); + } + + private static String trimLeadingSlash(String path) { + return path.startsWith("/") ? path.substring(1) : path; + } + + private static String getBucketName(URI uri) { + return uri.getHost(); + } + + @Override + public void close() throws IOException { + if (stsClient != null) { + stsClient.close(); + } + } +} Review Comment: Consider adding test coverage for the newly extracted `S3TokenGenerator` class. The existing tests in `TestCredentialProvider` only verify the credential type, but don't test the core credential generation logic that was moved from `S3TokenProvider` to this generator, including: - The `initialize()` method with various AWS configurations - The `generate()` method's STS token creation logic - IAM policy generation for different read/write path combinations - Error handling for invalid configurations or STS failures This is important because the generator contains the critical business logic for creating S3 credentials with proper IAM policies. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
