This is an automated email from the ASF dual-hosted git repository. dweeks pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push: new 5e3f919 AWS: Add progressive multipart upload to S3FileIO (#1767) 5e3f919 is described below commit 5e3f9198e5675a852df4f0e1c28b4e3cf6630f86 Author: Daniel Weeks <dwe...@apache.org> AuthorDate: Sat Nov 21 10:01:26 2020 -0800 AWS: Add progressive multipart upload to S3FileIO (#1767) * AWS: Add progressive multipart upload to S3FileIO * Fix test after rebase * Simply the complete and ensure parts are in order * Add sort to stream * Checkstyle * Add abort attempt back to complete * Initialize only once * Add executor service for async tasks per errorprone * Address comments * Refactor setting encryption for requests * Fix defaults to no-arg AwsProperties * Address some failure cases and add more testing * Checkstyle --- .../java/org/apache/iceberg/aws/AwsProperties.java | 70 ++++++ .../java/org/apache/iceberg/aws/s3/S3FileIO.java | 1 - .../org/apache/iceberg/aws/s3/S3InputStream.java | 7 +- .../org/apache/iceberg/aws/s3/S3OutputStream.java | 269 ++++++++++++++++++--- .../org/apache/iceberg/aws/s3/S3RequestUtil.java | 124 ++++++++++ .../apache/iceberg/aws/s3/S3OutputStreamTest.java | 151 +++++++++--- .../main/java/org/apache/iceberg/GuavaClasses.java | 2 + .../java/org/apache/iceberg/util/PropertyUtil.java | 9 + 8 files changed, 548 insertions(+), 85 deletions(-) diff --git a/aws/src/main/java/org/apache/iceberg/aws/AwsProperties.java b/aws/src/main/java/org/apache/iceberg/aws/AwsProperties.java index 358a8c9..c2f58be 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/AwsProperties.java +++ b/aws/src/main/java/org/apache/iceberg/aws/AwsProperties.java @@ -82,9 +82,40 @@ public class AwsProperties { public static final String GLUE_CATALOG_SKIP_ARCHIVE = "gluecatalog.skip-archive"; public static final boolean GLUE_CATALOG_SKIP_ARCHIVE_DEFAULT = false; + /** + * Number of threads to use for uploading parts to S3 (shared pool across all output streams). + */ + public static final String S3FILEIO_MULTIPART_UPLOAD_THREADS = "s3fileio.multipart.num-threads"; + + /** + * The size of a single part for multipart upload requests (default: 32MB). + */ + public static final String S3FILEIO_MULTIPART_SIZE = "s3fileio.multipart.part.size"; + + /** + * The threshold expressed as a factor times the multipart size at which to + * switch from uploading using a single put object request to uploading using multipart upload + * (default: 1.5). + */ + public static final String S3FILEIO_MULTIPART_THRESHOLD_FACTOR = "s3fileio.multipart.threshold"; + + /** + * Location to put staging files for upload to S3. + */ + public static final String S3FILEIO_STAGING_DIRECTORY = "s3fileio.staging.dir"; + + + static final int MIN_MULTIPART_UPLOAD_SIZE = 5 * 1024 * 1024; + static final int DEFAULT_MULTIPART_SIZE = 32 * 1024 * 1024; + static final double DEFAULT_MULTIPART_THRESHOLD = 1.5; + private String s3FileIoSseType; private String s3FileIoSseKey; private String s3FileIoSseMd5; + private int s3FileIoMultipartUploadThreads; + private int s3FileIoMultiPartSize; + private double s3FileIoMultipartThresholdFactor; + private String s3fileIoStagingDirectory; private String glueCatalogId; private boolean glueCatalogSkipArchive; @@ -94,6 +125,11 @@ public class AwsProperties { this.s3FileIoSseKey = null; this.s3FileIoSseMd5 = null; + this.s3FileIoMultipartUploadThreads = Runtime.getRuntime().availableProcessors(); + this.s3FileIoMultiPartSize = DEFAULT_MULTIPART_SIZE; + this.s3FileIoMultipartThresholdFactor = DEFAULT_MULTIPART_THRESHOLD; + this.s3fileIoStagingDirectory = System.getProperty("java.io.tmpdir"); + this.glueCatalogId = null; this.glueCatalogSkipArchive = GLUE_CATALOG_SKIP_ARCHIVE_DEFAULT; } @@ -111,6 +147,24 @@ public class AwsProperties { this.glueCatalogId = properties.get(GLUE_CATALOG_ID); this.glueCatalogSkipArchive = PropertyUtil.propertyAsBoolean(properties, AwsProperties.GLUE_CATALOG_SKIP_ARCHIVE, AwsProperties.GLUE_CATALOG_SKIP_ARCHIVE_DEFAULT); + + this.s3FileIoMultipartUploadThreads = PropertyUtil.propertyAsInt(properties, S3FILEIO_MULTIPART_UPLOAD_THREADS, + Runtime.getRuntime().availableProcessors()); + + this.s3FileIoMultiPartSize = PropertyUtil.propertyAsInt(properties, S3FILEIO_MULTIPART_SIZE, + DEFAULT_MULTIPART_SIZE); + + this.s3FileIoMultipartThresholdFactor = PropertyUtil.propertyAsDouble(properties, + S3FILEIO_MULTIPART_THRESHOLD_FACTOR, DEFAULT_MULTIPART_THRESHOLD); + + Preconditions.checkArgument(s3FileIoMultipartThresholdFactor >= 1.0, + "Multipart threshold factor must be >= to 1.0"); + + Preconditions.checkArgument(s3FileIoMultiPartSize >= MIN_MULTIPART_UPLOAD_SIZE, + "Minimum multipart upload object size must be larger than 5 MB."); + + this.s3fileIoStagingDirectory = PropertyUtil.propertyAsString(properties, S3FILEIO_STAGING_DIRECTORY, + System.getProperty("java.io.tmpdir")); } public String s3FileIoSseType() { @@ -152,4 +206,20 @@ public class AwsProperties { public void setGlueCatalogSkipArchive(boolean skipArchive) { this.glueCatalogSkipArchive = skipArchive; } + + public int s3FileIoMultipartUploadThreads() { + return s3FileIoMultipartUploadThreads; + } + + public int s3FileIoMultiPartSize() { + return s3FileIoMultiPartSize; + } + + public double s3FileIOMultipartThresholdFactor() { + return s3FileIoMultipartThresholdFactor; + } + + public String getS3fileIoStagingDirectory() { + return s3fileIoStagingDirectory; + } } diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/S3FileIO.java b/aws/src/main/java/org/apache/iceberg/aws/s3/S3FileIO.java index 443c58f..edcd87a 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/s3/S3FileIO.java +++ b/aws/src/main/java/org/apache/iceberg/aws/s3/S3FileIO.java @@ -39,7 +39,6 @@ import software.amazon.awssdk.services.s3.model.ObjectIdentifier; public class S3FileIO implements FileIO { private final SerializableSupplier<S3Client> s3; private AwsProperties awsProperties; - private transient S3Client client; public S3FileIO() { diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java b/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java index 2d3962b..7d58fc4 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java +++ b/aws/src/main/java/org/apache/iceberg/aws/s3/S3InputStream.java @@ -32,7 +32,6 @@ import org.slf4j.LoggerFactory; import software.amazon.awssdk.core.sync.ResponseTransformer; import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.model.GetObjectRequest; -import software.amazon.awssdk.services.s3.model.ServerSideEncryption; class S3InputStream extends SeekableInputStream { private static final Logger LOG = LoggerFactory.getLogger(S3InputStream.class); @@ -139,11 +138,7 @@ class S3InputStream extends SeekableInputStream { .key(location.key()) .range(String.format("bytes=%s-", pos)); - if (AwsProperties.S3FILEIO_SSE_TYPE_CUSTOM.equals(awsProperties.s3FileIoSseType())) { - requestBuilder.sseCustomerAlgorithm(ServerSideEncryption.AES256.name()); - requestBuilder.sseCustomerKey(awsProperties.s3FileIoSseKey()); - requestBuilder.sseCustomerKeyMD5(awsProperties.s3FileIoSseMd5()); - } + S3RequestUtil.configureEncryption(awsProperties, requestBuilder); closeStream(); stream = s3.getObject(requestBuilder.build(), ResponseTransformer.toInputStream()); diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/S3OutputStream.java b/aws/src/main/java/org/apache/iceberg/aws/s3/S3OutputStream.java index 42b9608..15fd19f 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/s3/S3OutputStream.java +++ b/aws/src/main/java/org/apache/iceberg/aws/s3/S3OutputStream.java @@ -19,51 +19,98 @@ package org.apache.iceberg.aws.s3; +import java.io.BufferedInputStream; import java.io.BufferedOutputStream; +import java.io.ByteArrayInputStream; import java.io.File; +import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; -import java.io.OutputStream; +import java.io.InputStream; +import java.io.SequenceInputStream; +import java.io.UncheckedIOException; +import java.nio.file.Files; import java.util.Arrays; -import java.util.Locale; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.stream.Collectors; import org.apache.iceberg.aws.AwsProperties; import org.apache.iceberg.io.PositionOutputStream; import org.apache.iceberg.relocated.com.google.common.base.Joiner; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.base.Predicates; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.io.CountingOutputStream; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.relocated.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.iceberg.util.Tasks; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.model.ServerSideEncryption; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.model.UploadPartResponse; class S3OutputStream extends PositionOutputStream { private static final Logger LOG = LoggerFactory.getLogger(S3OutputStream.class); + private static volatile ExecutorService executorService; + private final StackTraceElement[] createStack; private final S3Client s3; private final S3URI location; private final AwsProperties awsProperties; - private final OutputStream stream; - private final File stagingFile; - private long pos = 0; + private CountingOutputStream stream; + private final List<File> stagingFiles = Lists.newArrayList(); + private final File stagingDirectory; + private File currentStagingFile; + private String multipartUploadId; + private final Map<File, CompletableFuture<CompletedPart>> multiPartMap = Maps.newHashMap(); + private final int multiPartSize; + private final int multiPartThresholdSize; + private long pos = 0; private boolean closed = false; - S3OutputStream(S3Client s3, S3URI location) throws IOException { - this(s3, location, new AwsProperties()); - } - S3OutputStream(S3Client s3, S3URI location, AwsProperties awsProperties) throws IOException { + if (executorService == null) { + synchronized (this) { + if (executorService == null) { + executorService = MoreExecutors.getExitingExecutorService( + (ThreadPoolExecutor) Executors.newFixedThreadPool( + awsProperties.s3FileIoMultipartUploadThreads(), + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("iceberg-s3fileio-upload-%d") + .build())); + } + } + } + this.s3 = s3; this.location = location; this.awsProperties = awsProperties; createStack = Thread.currentThread().getStackTrace(); - stagingFile = File.createTempFile("s3fileio-", ".tmp"); - stream = new BufferedOutputStream(new FileOutputStream(stagingFile)); - stagingFile.deleteOnExit(); + multiPartSize = awsProperties.s3FileIoMultiPartSize(); + multiPartThresholdSize = (int) (multiPartSize * awsProperties.s3FileIOMultipartThresholdFactor()); + stagingDirectory = new File(awsProperties.getS3fileIoStagingDirectory()); + + newStream(); } @Override @@ -78,14 +125,60 @@ class S3OutputStream extends PositionOutputStream { @Override public void write(int b) throws IOException { + if (stream.getCount() >= multiPartSize) { + newStream(); + uploadParts(); + } + stream.write(b); pos += 1; + + // switch to multipart upload + if (multipartUploadId == null && pos >= multiPartThresholdSize) { + initializeMultiPartUpload(); + uploadParts(); + } } @Override public void write(byte[] b, int off, int len) throws IOException { - stream.write(b, off, len); + int remaining = len; + int relativeOffset = off; + + // Write the remainder of the part size to the staging file + // and continue to write new staging files if the write is + // larger than the part size. + while (stream.getCount() + remaining > multiPartSize) { + int writeSize = multiPartSize - (int) stream.getCount(); + + stream.write(b, relativeOffset, writeSize); + remaining -= writeSize; + relativeOffset += writeSize; + + newStream(); + uploadParts(); + } + + stream.write(b, relativeOffset, remaining); pos += len; + + // switch to multipart upload + if (multipartUploadId == null && pos >= multiPartThresholdSize) { + initializeMultiPartUpload(); + uploadParts(); + } + } + + private void newStream() throws IOException { + if (stream != null) { + stream.close(); + } + + currentStagingFile = File.createTempFile("s3fileio-", ".tmp", stagingDirectory); + currentStagingFile.deleteOnExit(); + stagingFiles.add(currentStagingFile); + + stream = new CountingOutputStream(new BufferedOutputStream(new FileOutputStream(currentStagingFile))); } @Override @@ -100,42 +193,138 @@ class S3OutputStream extends PositionOutputStream { try { stream.close(); - PutObjectRequest.Builder requestBuilder = PutObjectRequest.builder() - .bucket(location.bucket()) - .key(location.key()); + completeUploads(); + } finally { + cleanUpStagingFiles(); + } + } + + private void initializeMultiPartUpload() { + CreateMultipartUploadRequest.Builder requestBuilder = CreateMultipartUploadRequest.builder() + .bucket(location.bucket()).key(location.key()); + S3RequestUtil.configureEncryption(awsProperties, requestBuilder); + + multipartUploadId = s3.createMultipartUpload(requestBuilder.build()).uploadId(); + } - switch (awsProperties.s3FileIoSseType().toLowerCase(Locale.ENGLISH)) { - case AwsProperties.S3FILEIO_SSE_TYPE_NONE: - break; + private void uploadParts() { + // exit if multipart has not been initiated + if (multipartUploadId == null) { + return; + } - case AwsProperties.S3FILEIO_SSE_TYPE_KMS: - requestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); - requestBuilder.ssekmsKeyId(awsProperties.s3FileIoSseKey()); - break; + stagingFiles.stream() + // do not upload the file currently being written + .filter(f -> closed || !f.equals(currentStagingFile)) + // do not upload any files that have already been processed + .filter(Predicates.not(multiPartMap::containsKey)) + .forEach(f -> { + UploadPartRequest.Builder requestBuilder = UploadPartRequest.builder() + .bucket(location.bucket()) + .key(location.key()) + .uploadId(multipartUploadId) + .partNumber(stagingFiles.indexOf(f) + 1) + .contentLength(f.length()); - case AwsProperties.S3FILEIO_SSE_TYPE_S3: - requestBuilder.serverSideEncryption(ServerSideEncryption.AES256); - break; + S3RequestUtil.configureEncryption(awsProperties, requestBuilder); - case AwsProperties.S3FILEIO_SSE_TYPE_CUSTOM: - requestBuilder.sseCustomerAlgorithm(ServerSideEncryption.AES256.name()); - requestBuilder.sseCustomerKey(awsProperties.s3FileIoSseKey()); - requestBuilder.sseCustomerKeyMD5(awsProperties.s3FileIoSseMd5()); - break; + UploadPartRequest uploadRequest = requestBuilder.build(); - default: - throw new IllegalArgumentException( - "Cannot support given S3 encryption type: " + awsProperties.s3FileIoSseType()); - } + CompletableFuture<CompletedPart> future = CompletableFuture.supplyAsync( + () -> { + UploadPartResponse response = s3.uploadPart(uploadRequest, RequestBody.fromFile(f)); + return CompletedPart.builder().eTag(response.eTag()).partNumber(uploadRequest.partNumber()).build(); + }, + executorService + ).whenComplete((result, thrown) -> { + try { + Files.deleteIfExists(f.toPath()); + } catch (IOException e) { + LOG.warn("Failed to delete staging file: {}", f, e); + } - s3.putObject(requestBuilder.build(), RequestBody.fromFile(stagingFile)); - } finally { - if (!stagingFile.delete()) { - LOG.warn("Could not delete temporary file: {}", stagingFile); + if (thrown != null) { + LOG.error("Failed to upload part: {}", uploadRequest, thrown); + abortUpload(); + } + }); + + multiPartMap.put(f, future); + }); + } + + private void completeMultiPartUpload() { + Preconditions.checkState(closed, "Complete upload called on open stream: " + location); + + List<CompletedPart> completedParts = + multiPartMap.values() + .stream() + .map(CompletableFuture::join) + .sorted(Comparator.comparing(CompletedPart::partNumber)) + .collect(Collectors.toList()); + + CompleteMultipartUploadRequest request = CompleteMultipartUploadRequest.builder() + .bucket(location.bucket()).key(location.key()) + .uploadId(multipartUploadId) + .multipartUpload(CompletedMultipartUpload.builder().parts(completedParts).build()).build(); + + Tasks.foreach(request) + .noRetry() + .onFailure((r, thrown) -> { + LOG.error("Failed to complete multipart upload request: {}", r, thrown); + abortUpload(); + }) + .throwFailureWhenFinished() + .run(s3::completeMultipartUpload); + } + + private void abortUpload() { + if (multipartUploadId != null) { + try { + s3.abortMultipartUpload(AbortMultipartUploadRequest.builder() + .bucket(location.bucket()).key(location.key()).uploadId(multipartUploadId).build()); + } finally { + cleanUpStagingFiles(); } } } + private void cleanUpStagingFiles() { + Tasks.foreach(stagingFiles) + .suppressFailureWhenFinished() + .onFailure((file, thrown) -> LOG.warn("Failed to delete staging file: {}", file, thrown)) + .run(File::delete); + } + + private void completeUploads() { + if (multipartUploadId == null) { + long contentLength = stagingFiles.stream().mapToLong(File::length).sum(); + InputStream contentStream = new BufferedInputStream(stagingFiles.stream() + .map(S3OutputStream::uncheckedInputStream) + .reduce(SequenceInputStream::new) + .orElseGet(() -> new ByteArrayInputStream(new byte[0]))); + + PutObjectRequest.Builder requestBuilder = PutObjectRequest.builder() + .bucket(location.bucket()) + .key(location.key()); + + S3RequestUtil.configureEncryption(awsProperties, requestBuilder); + + s3.putObject(requestBuilder.build(), RequestBody.fromInputStream(contentStream, contentLength)); + } else { + uploadParts(); + completeMultiPartUpload(); + } + } + + private static InputStream uncheckedInputStream(File file) { + try { + return new FileInputStream(file); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + @SuppressWarnings("checkstyle:NoFinalizer") @Override protected void finalize() throws Throwable { diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/S3RequestUtil.java b/aws/src/main/java/org/apache/iceberg/aws/s3/S3RequestUtil.java new file mode 100644 index 0000000..5f9309e --- /dev/null +++ b/aws/src/main/java/org/apache/iceberg/aws/s3/S3RequestUtil.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.aws.s3; + +import java.util.Locale; +import org.apache.iceberg.aws.AwsProperties; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.ServerSideEncryption; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; + +public class S3RequestUtil { + + private S3RequestUtil() { + } + + static void configureEncryption(AwsProperties awsProperties, PutObjectRequest.Builder requestBuilder) { + switch (awsProperties.s3FileIoSseType().toLowerCase(Locale.ENGLISH)) { + case AwsProperties.S3FILEIO_SSE_TYPE_NONE: + break; + + case AwsProperties.S3FILEIO_SSE_TYPE_KMS: + requestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); + requestBuilder.ssekmsKeyId(awsProperties.s3FileIoSseKey()); + break; + + case AwsProperties.S3FILEIO_SSE_TYPE_S3: + requestBuilder.serverSideEncryption(ServerSideEncryption.AES256); + break; + + case AwsProperties.S3FILEIO_SSE_TYPE_CUSTOM: + requestBuilder.sseCustomerAlgorithm(ServerSideEncryption.AES256.name()); + requestBuilder.sseCustomerKey(awsProperties.s3FileIoSseKey()); + requestBuilder.sseCustomerKeyMD5(awsProperties.s3FileIoSseMd5()); + break; + + default: + throw new IllegalArgumentException( + "Cannot support given S3 encryption type: " + awsProperties.s3FileIoSseType()); + } + } + + static void configureEncryption(AwsProperties awsProperties, CreateMultipartUploadRequest.Builder requestBuilder) { + switch (awsProperties.s3FileIoSseType().toLowerCase(Locale.ENGLISH)) { + case AwsProperties.S3FILEIO_SSE_TYPE_NONE: + break; + + case AwsProperties.S3FILEIO_SSE_TYPE_KMS: + requestBuilder.serverSideEncryption(ServerSideEncryption.AWS_KMS); + requestBuilder.ssekmsKeyId(awsProperties.s3FileIoSseKey()); + break; + + case AwsProperties.S3FILEIO_SSE_TYPE_S3: + requestBuilder.serverSideEncryption(ServerSideEncryption.AES256); + break; + + case AwsProperties.S3FILEIO_SSE_TYPE_CUSTOM: + requestBuilder.sseCustomerAlgorithm(ServerSideEncryption.AES256.name()); + requestBuilder.sseCustomerKey(awsProperties.s3FileIoSseKey()); + requestBuilder.sseCustomerKeyMD5(awsProperties.s3FileIoSseMd5()); + break; + + default: + throw new IllegalArgumentException( + "Cannot support given S3 encryption type: " + awsProperties.s3FileIoSseType()); + } + } + + static void configureEncryption(AwsProperties awsProperties, UploadPartRequest.Builder requestBuilder) { + switch (awsProperties.s3FileIoSseType().toLowerCase(Locale.ENGLISH)) { + case AwsProperties.S3FILEIO_SSE_TYPE_NONE: + case AwsProperties.S3FILEIO_SSE_TYPE_KMS: + case AwsProperties.S3FILEIO_SSE_TYPE_S3: + break; + + case AwsProperties.S3FILEIO_SSE_TYPE_CUSTOM: + requestBuilder.sseCustomerAlgorithm(ServerSideEncryption.AES256.name()); + requestBuilder.sseCustomerKey(awsProperties.s3FileIoSseKey()); + requestBuilder.sseCustomerKeyMD5(awsProperties.s3FileIoSseMd5()); + break; + + default: + throw new IllegalArgumentException( + "Cannot support given S3 encryption type: " + awsProperties.s3FileIoSseType()); + } + } + + static void configureEncryption(AwsProperties awsProperties, GetObjectRequest.Builder requestBuilder) { + switch (awsProperties.s3FileIoSseType().toLowerCase(Locale.ENGLISH)) { + case AwsProperties.S3FILEIO_SSE_TYPE_NONE: + case AwsProperties.S3FILEIO_SSE_TYPE_KMS: + case AwsProperties.S3FILEIO_SSE_TYPE_S3: + break; + + case AwsProperties.S3FILEIO_SSE_TYPE_CUSTOM: + requestBuilder.sseCustomerAlgorithm(ServerSideEncryption.AES256.name()); + requestBuilder.sseCustomerKey(awsProperties.s3FileIoSseKey()); + requestBuilder.sseCustomerKeyMD5(awsProperties.s3FileIoSseMd5()); + break; + + default: + throw new IllegalArgumentException( + "Cannot support given S3 encryption type: " + awsProperties.s3FileIoSseType()); + } + } +} diff --git a/aws/src/test/java/org/apache/iceberg/aws/s3/S3OutputStreamTest.java b/aws/src/test/java/org/apache/iceberg/aws/s3/S3OutputStreamTest.java index 3e1b355..b4dc1ec 100644 --- a/aws/src/test/java/org/apache/iceberg/aws/s3/S3OutputStreamTest.java +++ b/aws/src/test/java/org/apache/iceberg/aws/s3/S3OutputStreamTest.java @@ -21,86 +21,151 @@ package org.apache.iceberg.aws.s3; import com.adobe.testing.s3mock.junit4.S3MockRule; import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.Random; +import java.util.UUID; +import java.util.stream.Stream; +import org.apache.iceberg.aws.AwsProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.junit.Before; import org.junit.ClassRule; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.runners.MockitoJUnitRunner; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.sync.RequestBody; import software.amazon.awssdk.core.sync.ResponseTransformer; import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CreateBucketRequest; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.mockito.AdditionalAnswers.delegatesTo; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +@RunWith(MockitoJUnitRunner.class) public class S3OutputStreamTest { + private static final Logger LOG = LoggerFactory.getLogger(S3OutputStreamTest.class); + private static final String BUCKET = "test-bucket"; + @ClassRule public static final S3MockRule S3_MOCK_RULE = S3MockRule.builder().silent().build(); private final S3Client s3 = S3_MOCK_RULE.createS3ClientV2(); + private final S3Client s3mock = mock(S3Client.class, delegatesTo(s3)); private final Random random = new Random(1); + private final Path tmpDir = Files.createTempDirectory("s3fileio-test-"); + + private final AwsProperties properties = new AwsProperties(ImmutableMap.of( + AwsProperties.S3FILEIO_MULTIPART_SIZE, Integer.toString(5 * 1024 * 1024), + AwsProperties.S3FILEIO_STAGING_DIRECTORY, tmpDir.toString())); + + public S3OutputStreamTest() throws IOException { + } @Before public void before() { - s3.createBucket(CreateBucketRequest.builder().bucket("bucket").build()); + s3.createBucket(CreateBucketRequest.builder().bucket(BUCKET).build()); } @Test - public void getPos() throws IOException { - S3URI uri = new S3URI("s3://bucket/path/to/pos.dat"); - int writeSize = 1024; - - try (S3OutputStream stream = new S3OutputStream(s3, uri)) { - stream.write(new byte[writeSize]); - assertEquals(writeSize, stream.getPos()); - } + public void testWrite() { + // Run tests for both byte and array write paths + Stream.of(true, false).forEach(arrayWrite -> { + // Test small file write (less than multipart threshold) + writeAndVerify(s3mock, randomURI(), randomData(1024), arrayWrite); + verify(s3mock, times(1)).putObject((PutObjectRequest) any(), (RequestBody) any()); + reset(s3mock); + + // Test file larger than part size but less than multipart threshold + writeAndVerify(s3mock, randomURI(), randomData(6 * 1024 * 1024), arrayWrite); + verify(s3mock, times(1)).putObject((PutObjectRequest) any(), (RequestBody) any()); + reset(s3mock); + + // Test file large enough to trigger multipart upload + writeAndVerify(s3mock, randomURI(), randomData(10 * 1024 * 1024), arrayWrite); + verify(s3mock, times(2)).uploadPart((UploadPartRequest) any(), (RequestBody) any()); + reset(s3mock); + + // Test uploading many parts + writeAndVerify(s3mock, randomURI(), randomData(22 * 1024 * 1024), arrayWrite); + verify(s3mock, times(5)).uploadPart((UploadPartRequest) any(), (RequestBody) any()); + reset(s3mock); + }); } @Test - public void testWrite() throws IOException { - S3URI uri = new S3URI("s3://bucket/path/to/out.dat"); - int size = 5 * 1024 * 1024; - byte [] expected = new byte[size]; - random.nextBytes(expected); - - try (S3OutputStream stream = new S3OutputStream(s3, uri)) { - for (int i = 0; i < size; i++) { - stream.write(expected[i]); - assertEquals(i + 1, stream.getPos()); - } - } - - byte [] actual = readS3Data(uri); + public void testAbortAfterFailedPartUpload() { + doThrow(new RuntimeException()).when(s3mock).uploadPart((UploadPartRequest) any(), (RequestBody) any()); - assertArrayEquals(expected, actual); + try (S3OutputStream stream = new S3OutputStream(s3mock, randomURI(), properties)) { + stream.write(randomData(10 * 1024 * 1024)); + } catch (Exception e) { + verify(s3mock, atLeastOnce()).abortMultipartUpload((AbortMultipartUploadRequest) any()); + } } @Test - public void testWriteArray() throws IOException { - S3URI uri = new S3URI("s3://bucket/path/to/array-out.dat"); - byte [] expected = new byte[5 * 1024 * 1024]; - random.nextBytes(expected); - - try (S3OutputStream stream = new S3OutputStream(s3, uri)) { - stream.write(expected); - assertEquals(expected.length, stream.getPos()); - } - - byte [] actual = readS3Data(uri); + public void testAbortMultipart() { + doThrow(new RuntimeException()).when(s3mock).completeMultipartUpload((CompleteMultipartUploadRequest) any()); - assertArrayEquals(expected, actual); + try (S3OutputStream stream = new S3OutputStream(s3mock, randomURI(), properties)) { + stream.write(randomData(10 * 1024 * 1024)); + } catch (Exception e) { + verify(s3mock).abortMultipartUpload((AbortMultipartUploadRequest) any()); + } } @Test public void testMultipleClose() throws IOException { - S3URI uri = new S3URI("s3://bucket/path/to/array-out.dat"); - S3OutputStream stream = new S3OutputStream(s3, uri); + S3OutputStream stream = new S3OutputStream(s3, randomURI(), properties); stream.close(); stream.close(); } + private void writeAndVerify(S3Client client, S3URI uri, byte [] data, boolean arrayWrite) { + try (S3OutputStream stream = new S3OutputStream(client, uri, properties)) { + if (arrayWrite) { + stream.write(data); + assertEquals(data.length, stream.getPos()); + } else { + for (int i = 0; i < data.length; i++) { + stream.write(data[i]); + assertEquals(i + 1, stream.getPos()); + } + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + byte[] actual = readS3Data(uri); + assertArrayEquals(data, actual); + + // Verify all staging files are cleaned up + try { + assertEquals(0, Files.list(tmpDir).count()); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + private byte[] readS3Data(S3URI uri) { ResponseBytes<GetObjectResponse> data = s3.getObject(GetObjectRequest.builder().bucket(uri.bucket()).key(uri.key()).build(), @@ -108,4 +173,14 @@ public class S3OutputStreamTest { return data.asByteArray(); } + + private byte[] randomData(int size) { + byte [] result = new byte[size]; + random.nextBytes(result); + return result; + } + + private S3URI randomURI() { + return new S3URI(String.format("s3://%s/data/%s.dat", BUCKET, UUID.randomUUID())); + } } diff --git a/bundled-guava/src/main/java/org/apache/iceberg/GuavaClasses.java b/bundled-guava/src/main/java/org/apache/iceberg/GuavaClasses.java index 107e420..81c9351 100644 --- a/bundled-guava/src/main/java/org/apache/iceberg/GuavaClasses.java +++ b/bundled-guava/src/main/java/org/apache/iceberg/GuavaClasses.java @@ -46,6 +46,7 @@ import com.google.common.collect.Streams; import com.google.common.hash.HashFunction; import com.google.common.hash.Hasher; import com.google.common.hash.Hashing; +import com.google.common.io.CountingOutputStream; import com.google.common.io.Files; import com.google.common.primitives.Bytes; import com.google.common.util.concurrent.MoreExecutors; @@ -90,6 +91,7 @@ public class GuavaClasses { MoreExecutors.class.getName(); ThreadFactoryBuilder.class.getName(); Iterables.class.getName(); + CountingOutputStream.class.getName(); } } diff --git a/core/src/main/java/org/apache/iceberg/util/PropertyUtil.java b/core/src/main/java/org/apache/iceberg/util/PropertyUtil.java index 2df88c0..e47eb7b 100644 --- a/core/src/main/java/org/apache/iceberg/util/PropertyUtil.java +++ b/core/src/main/java/org/apache/iceberg/util/PropertyUtil.java @@ -35,6 +35,15 @@ public class PropertyUtil { return defaultValue; } + public static double propertyAsDouble(Map<String, String> properties, + String property, double defaultValue) { + String value = properties.get(property); + if (value != null) { + return Double.parseDouble(properties.get(property)); + } + return defaultValue; + } + public static int propertyAsInt(Map<String, String> properties, String property, int defaultValue) { String value = properties.get(property);