This is an automated email from the ASF dual-hosted git repository.

srichter pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 2ed858318f3 [FLINK-32345][state] Improve parallel download of RocksDB 
incremental state. (#22788)
2ed858318f3 is described below

commit 2ed858318f38eb9913e5f4b3019be6b6d0a8e6fb
Author: Stefan Richter <srich...@apache.org>
AuthorDate: Wed Jun 21 12:17:08 2023 +0200

    [FLINK-32345][state] Improve parallel download of RocksDB incremental 
state. (#22788)
    
    * [FLINK-32345] Improve parallel download of RocksDB incremental state.
    
    This commit improves RocksDBStateDownloader to support parallelized state 
download across multiple state types and across multiple state handles. This 
can improve our download times for scale-in.
---
 .../java/org/apache/flink/util/CollectionUtil.java |  81 +++++++++++
 .../org/apache/flink/util/CollectionUtilTest.java  |  47 ++++++
 .../state/RocksDBIncrementalCheckpointUtils.java   |   2 +-
 .../streaming/state/RocksDBStateDownloader.java    | 121 +++++++++-------
 .../streaming/state/StateHandleDownloadSpec.java   |  49 +++++++
 .../RocksDBIncrementalRestoreOperation.java        | 157 ++++++++++++---------
 .../state/RocksDBStateDownloaderTest.java          | 132 +++++++++++++----
 7 files changed, 438 insertions(+), 151 deletions(-)

diff --git a/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java 
b/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java
index 8c96e3e3554..18f4c4313c4 100644
--- a/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java
+++ b/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java
@@ -19,6 +19,7 @@
 package org.apache.flink.util;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
 
 import javax.annotation.Nullable;
 
@@ -27,7 +28,10 @@ import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -45,6 +49,9 @@ public final class CollectionUtil {
     /** A safe maximum size for arrays in the JVM. */
     public static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8;
 
+    /** The default load factor for hash maps create with this util class. */
+    static final float HASH_MAP_DEFAULT_LOAD_FACTOR = 0.75f;
+
     private CollectionUtil() {
         throw new AssertionError();
     }
@@ -133,4 +140,78 @@ public final class CollectionUtil {
         }
         return Collections.unmodifiableMap(map);
     }
+
+    /**
+     * Creates a new {@link HashMap} of the expected size, i.e. a hash map 
that will not rehash if
+     * expectedSize many keys are inserted, considering the load factor.
+     *
+     * @param expectedSize the expected size of the created hash map.
+     * @return a new hash map instance with enough capacity for the expected 
size.
+     * @param <K> the type of keys maintained by this map.
+     * @param <V> the type of mapped values.
+     */
+    public static <K, V> HashMap<K, V> newHashMapWithExpectedSize(int 
expectedSize) {
+        return new HashMap<>(
+                computeRequiredCapacity(expectedSize, 
HASH_MAP_DEFAULT_LOAD_FACTOR),
+                HASH_MAP_DEFAULT_LOAD_FACTOR);
+    }
+
+    /**
+     * Creates a new {@link LinkedHashMap} of the expected size, i.e. a hash 
map that will not
+     * rehash if expectedSize many keys are inserted, considering the load 
factor.
+     *
+     * @param expectedSize the expected size of the created hash map.
+     * @return a new hash map instance with enough capacity for the expected 
size.
+     * @param <K> the type of keys maintained by this map.
+     * @param <V> the type of mapped values.
+     */
+    public static <K, V> LinkedHashMap<K, V> 
newLinkedHashMapWithExpectedSize(int expectedSize) {
+        return new LinkedHashMap<>(
+                computeRequiredCapacity(expectedSize, 
HASH_MAP_DEFAULT_LOAD_FACTOR),
+                HASH_MAP_DEFAULT_LOAD_FACTOR);
+    }
+
+    /**
+     * Creates a new {@link HashSet} of the expected size, i.e. a hash set 
that will not rehash if
+     * expectedSize many unique elements are inserted, considering the load 
factor.
+     *
+     * @param expectedSize the expected size of the created hash map.
+     * @return a new hash map instance with enough capacity for the expected 
size.
+     * @param <E> the type of elements stored by this set.
+     */
+    public static <E> HashSet<E> newHashSetWithExpectedSize(int expectedSize) {
+        return new HashSet<>(
+                computeRequiredCapacity(expectedSize, 
HASH_MAP_DEFAULT_LOAD_FACTOR),
+                HASH_MAP_DEFAULT_LOAD_FACTOR);
+    }
+
+    /**
+     * Creates a new {@link LinkedHashSet} of the expected size, i.e. a hash 
set that will not
+     * rehash if expectedSize many unique elements are inserted, considering 
the load factor.
+     *
+     * @param expectedSize the expected size of the created hash map.
+     * @return a new hash map instance with enough capacity for the expected 
size.
+     * @param <E> the type of elements stored by this set.
+     */
+    public static <E> LinkedHashSet<E> newLinkedHashSetWithExpectedSize(int 
expectedSize) {
+        return new LinkedHashSet<>(
+                computeRequiredCapacity(expectedSize, 
HASH_MAP_DEFAULT_LOAD_FACTOR),
+                HASH_MAP_DEFAULT_LOAD_FACTOR);
+    }
+
+    /**
+     * Helper method to compute the right capacity for a hash map with load 
factor
+     * HASH_MAP_DEFAULT_LOAD_FACTOR.
+     */
+    @VisibleForTesting
+    static int computeRequiredCapacity(int expectedSize, float loadFactor) {
+        Preconditions.checkArgument(expectedSize >= 0);
+        Preconditions.checkArgument(loadFactor > 0f);
+        if (expectedSize <= 2) {
+            return expectedSize + 1;
+        }
+        return expectedSize < (Integer.MAX_VALUE / 2 + 1)
+                ? (int) ((float) expectedSize / loadFactor + 1.0f)
+                : Integer.MAX_VALUE;
+    }
 }
diff --git 
a/flink-core/src/test/java/org/apache/flink/util/CollectionUtilTest.java 
b/flink-core/src/test/java/org/apache/flink/util/CollectionUtilTest.java
index abeec238879..de749f9aadb 100644
--- a/flink-core/src/test/java/org/apache/flink/util/CollectionUtilTest.java
+++ b/flink-core/src/test/java/org/apache/flink/util/CollectionUtilTest.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.util;
 
+import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
 
@@ -25,6 +26,7 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
 
+import static 
org.apache.flink.util.CollectionUtil.HASH_MAP_DEFAULT_LOAD_FACTOR;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /** Tests for java collection utilities. */
@@ -52,4 +54,49 @@ public class CollectionUtilTest {
         final Object element = new Object();
         
assertThat(CollectionUtil.ofNullable(element)).singleElement().isEqualTo(element);
     }
+
+    @Test
+    public void testComputeCapacity() {
+        Assertions.assertEquals(
+                1, CollectionUtil.computeRequiredCapacity(0, 
HASH_MAP_DEFAULT_LOAD_FACTOR));
+        Assertions.assertEquals(
+                2, CollectionUtil.computeRequiredCapacity(1, 
HASH_MAP_DEFAULT_LOAD_FACTOR));
+        Assertions.assertEquals(
+                3, CollectionUtil.computeRequiredCapacity(2, 
HASH_MAP_DEFAULT_LOAD_FACTOR));
+        Assertions.assertEquals(
+                5, CollectionUtil.computeRequiredCapacity(3, 
HASH_MAP_DEFAULT_LOAD_FACTOR));
+        Assertions.assertEquals(
+                6, CollectionUtil.computeRequiredCapacity(4, 
HASH_MAP_DEFAULT_LOAD_FACTOR));
+        Assertions.assertEquals(
+                134, CollectionUtil.computeRequiredCapacity(100, 
HASH_MAP_DEFAULT_LOAD_FACTOR));
+        Assertions.assertEquals(
+                1334, CollectionUtil.computeRequiredCapacity(1000, 
HASH_MAP_DEFAULT_LOAD_FACTOR));
+        Assertions.assertEquals(
+                13334, CollectionUtil.computeRequiredCapacity(10000, 
HASH_MAP_DEFAULT_LOAD_FACTOR));
+
+        Assertions.assertEquals(20001, 
CollectionUtil.computeRequiredCapacity(10000, 0.5f));
+
+        Assertions.assertEquals(100001, 
CollectionUtil.computeRequiredCapacity(10000, 0.1f));
+
+        Assertions.assertEquals(
+                1431655808,
+                CollectionUtil.computeRequiredCapacity(
+                        Integer.MAX_VALUE / 2, HASH_MAP_DEFAULT_LOAD_FACTOR));
+        Assertions.assertEquals(
+                Integer.MAX_VALUE,
+                CollectionUtil.computeRequiredCapacity(
+                        1 + Integer.MAX_VALUE / 2, 
HASH_MAP_DEFAULT_LOAD_FACTOR));
+
+        try {
+            CollectionUtil.computeRequiredCapacity(-1, 
HASH_MAP_DEFAULT_LOAD_FACTOR);
+            Assertions.fail();
+        } catch (IllegalArgumentException expected) {
+        }
+
+        try {
+            CollectionUtil.computeRequiredCapacity(Integer.MIN_VALUE, 
HASH_MAP_DEFAULT_LOAD_FACTOR);
+            Assertions.fail();
+        } catch (IllegalArgumentException expected) {
+        }
+    }
 }
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java
index 23c78675068..54121709876 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java
@@ -186,7 +186,7 @@ public class RocksDBIncrementalCheckpointUtils {
             Score handleScore =
                     stateHandleEvaluator(
                             rawStateHandle, targetKeyGroupRange, 
overlapFractionThreshold);
-            if (handleScore.compareTo(bestScore) > 0) {
+            if (bestStateHandle == null || handleScore.compareTo(bestScore) > 
0) {
                 bestStateHandle = rawStateHandle;
                 bestScore = handleScore;
             }
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloader.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloader.java
index d06790ced8a..0a1e43e9700 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloader.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloader.java
@@ -19,23 +19,27 @@ package org.apache.flink.contrib.streaming.state;
 
 import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.runtime.state.IncrementalRemoteKeyedStateHandle;
 import org.apache.flink.runtime.state.StateHandleID;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.FileUtils;
 import org.apache.flink.util.FlinkRuntimeException;
+import org.apache.flink.util.IOUtils;
 import org.apache.flink.util.concurrent.FutureUtils;
 import org.apache.flink.util.function.ThrowingRunnable;
 
+import org.apache.flink.shaded.guava30.com.google.common.collect.Streams;
+
 import java.io.IOException;
 import java.io.OutputStream;
 import java.nio.file.Files;
 import java.nio.file.Path;
-import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
-import java.util.Map;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutionException;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 /** Help class for downloading RocksDB state files. */
 public class RocksDBStateDownloader extends RocksDBStateDataTransfer {
@@ -44,45 +48,43 @@ public class RocksDBStateDownloader extends 
RocksDBStateDataTransfer {
     }
 
     /**
-     * Transfer all state data to the target directory using specified number 
of threads.
+     * Transfer all state data to the target directory, as specified in the 
download requests.
      *
-     * @param restoreStateHandle Handles used to retrieve the state data.
-     * @param dest The target directory which the state data will be stored.
-     * @throws Exception Thrown if can not transfer all the state data.
+     * @param downloadRequests the list of downloads.
+     * @throws Exception If anything about the download goes wrong.
      */
     public void transferAllStateDataToDirectory(
-            IncrementalRemoteKeyedStateHandle restoreStateHandle,
-            Path dest,
-            CloseableRegistry closeableRegistry)
-            throws Exception {
-
-        final Map<StateHandleID, StreamStateHandle> sstFiles = 
restoreStateHandle.getSharedState();
-        final Map<StateHandleID, StreamStateHandle> miscFiles =
-                restoreStateHandle.getPrivateState();
-
-        downloadDataForAllStateHandles(sstFiles, dest, closeableRegistry);
-        downloadDataForAllStateHandles(miscFiles, dest, closeableRegistry);
-    }
-
-    /**
-     * Copies all the files from the given stream state handles to the given 
path, renaming the
-     * files w.r.t. their {@link StateHandleID}.
-     */
-    private void downloadDataForAllStateHandles(
-            Map<StateHandleID, StreamStateHandle> stateHandleMap,
-            Path restoreInstancePath,
+            Collection<StateHandleDownloadSpec> downloadRequests,
             CloseableRegistry closeableRegistry)
             throws Exception {
 
+        // We use this closer for fine-grained shutdown of all parallel 
downloading.
+        CloseableRegistry internalCloser = new CloseableRegistry();
+        // Make sure we also react to external close signals.
+        closeableRegistry.registerCloseable(internalCloser);
+        List<CompletableFuture<Void>> futures = Collections.emptyList();
         try {
-            List<Runnable> runnables =
-                    createDownloadRunnables(stateHandleMap, 
restoreInstancePath, closeableRegistry);
-            List<CompletableFuture<Void>> futures = new 
ArrayList<>(runnables.size());
-            for (Runnable runnable : runnables) {
-                futures.add(CompletableFuture.runAsync(runnable, 
executorService));
+            try {
+                futures =
+                        transferAllStateDataToDirectoryAsync(downloadRequests, 
internalCloser)
+                                .collect(Collectors.toList());
+                // Wait until either all futures completed successfully or one 
failed exceptionally.
+                FutureUtils.waitForAll(futures).get();
+            } finally {
+                // Unregister and close the internal closer. In a failure 
case, this should
+                // interrupt ongoing downloads.
+                if (closeableRegistry.unregisterCloseable(internalCloser)) {
+                    IOUtils.closeQuietly(internalCloser);
+                }
             }
-            FutureUtils.waitForAll(futures).get();
-        } catch (ExecutionException e) {
+        } catch (Exception e) {
+            // Cleanup on exception: cancel all tasks and delete the created 
directories
+            futures.forEach(future -> future.cancel(true));
+            downloadRequests.stream()
+                    .map(StateHandleDownloadSpec::getDownloadDestination)
+                    .map(Path::toFile)
+                    .forEach(FileUtils::deleteDirectoryQuietly);
+            // Error reporting
             Throwable throwable = ExceptionUtils.stripExecutionException(e);
             throwable = ExceptionUtils.stripException(throwable, 
RuntimeException.class);
             if (throwable instanceof IOException) {
@@ -93,24 +95,39 @@ public class RocksDBStateDownloader extends 
RocksDBStateDataTransfer {
         }
     }
 
-    private List<Runnable> createDownloadRunnables(
-            Map<StateHandleID, StreamStateHandle> stateHandleMap,
-            Path restoreInstancePath,
+    /** Asynchronously runs the specified download requests on 
executorService. */
+    private Stream<CompletableFuture<Void>> 
transferAllStateDataToDirectoryAsync(
+            Collection<StateHandleDownloadSpec> handleWithPaths,
             CloseableRegistry closeableRegistry) {
-        List<Runnable> runnables = new ArrayList<>(stateHandleMap.size());
-        for (Map.Entry<StateHandleID, StreamStateHandle> entry : 
stateHandleMap.entrySet()) {
-            StateHandleID stateHandleID = entry.getKey();
-            StreamStateHandle remoteFileHandle = entry.getValue();
-
-            Path path = restoreInstancePath.resolve(stateHandleID.toString());
-
-            runnables.add(
-                    ThrowingRunnable.unchecked(
-                            () ->
-                                    downloadDataForStateHandle(
-                                            path, remoteFileHandle, 
closeableRegistry)));
-        }
-        return runnables;
+        return handleWithPaths.stream()
+                .flatMap(
+                        downloadRequest ->
+                                // Take all files from shared and private 
state.
+                                Streams.concat(
+                                                
downloadRequest.getStateHandle().getSharedState()
+                                                        .entrySet().stream(),
+                                                
downloadRequest.getStateHandle().getPrivateState()
+                                                        .entrySet().stream())
+                                        .map(
+                                                // Create one runnable for 
each StreamStateHandle
+                                                entry -> {
+                                                    StateHandleID 
stateHandleID = entry.getKey();
+                                                    StreamStateHandle 
remoteFileHandle =
+                                                            entry.getValue();
+                                                    Path downloadDest =
+                                                            downloadRequest
+                                                                    
.getDownloadDestination()
+                                                                    .resolve(
+                                                                            
stateHandleID
+                                                                               
     .toString());
+                                                    return 
ThrowingRunnable.unchecked(
+                                                            () ->
+                                                                    
downloadDataForStateHandle(
+                                                                            
downloadDest,
+                                                                            
remoteFileHandle,
+                                                                            
closeableRegistry));
+                                                }))
+                .map(runnable -> CompletableFuture.runAsync(runnable, 
executorService));
     }
 
     /** Copies the file from a single state handle to the given path. */
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/StateHandleDownloadSpec.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/StateHandleDownloadSpec.java
new file mode 100644
index 00000000000..93a33fdc6fa
--- /dev/null
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/StateHandleDownloadSpec.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import org.apache.flink.runtime.state.IncrementalRemoteKeyedStateHandle;
+
+import java.nio.file.Path;
+
+/**
+ * This class represents a download specification for the content of one {@link
+ * IncrementalRemoteKeyedStateHandle} to a target {@link Path}.
+ */
+public class StateHandleDownloadSpec {
+    /** The state handle to download. */
+    private final IncrementalRemoteKeyedStateHandle stateHandle;
+
+    /** The path to which the content of the state handle shall be downloaded. 
*/
+    private final Path downloadDestination;
+
+    public StateHandleDownloadSpec(
+            IncrementalRemoteKeyedStateHandle stateHandle, Path 
downloadDestination) {
+        this.stateHandle = stateHandle;
+        this.downloadDestination = downloadDestination;
+    }
+
+    public IncrementalRemoteKeyedStateHandle getStateHandle() {
+        return stateHandle;
+    }
+
+    public Path getDownloadDestination() {
+        return downloadDestination;
+    }
+}
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java
index d6ec9ae6055..89998b8768a 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java
@@ -27,6 +27,7 @@ import 
org.apache.flink.contrib.streaming.state.RocksDBOperationUtils;
 import org.apache.flink.contrib.streaming.state.RocksDBStateDownloader;
 import org.apache.flink.contrib.streaming.state.RocksDBWriteBatchWrapper;
 import org.apache.flink.contrib.streaming.state.RocksIteratorWrapper;
+import org.apache.flink.contrib.streaming.state.StateHandleDownloadSpec;
 import 
org.apache.flink.contrib.streaming.state.ttl.RocksDbTtlCompactFiltersManager;
 import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.core.memory.DataInputView;
@@ -46,8 +47,10 @@ import org.apache.flink.runtime.state.StateHandleID;
 import org.apache.flink.runtime.state.StateSerializerProvider;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
+import org.apache.flink.util.CollectionUtil;
 import org.apache.flink.util.FileUtils;
 import org.apache.flink.util.IOUtils;
+import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.StateMigrationException;
 
 import org.rocksdb.ColumnFamilyDescriptor;
@@ -69,6 +72,7 @@ import java.io.InputStream;
 import java.nio.file.Path;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -186,12 +190,12 @@ public class RocksDBIncrementalRestoreOperation<K> 
implements RocksDBRestoreOper
             IncrementalRemoteKeyedStateHandle 
incrementalRemoteKeyedStateHandle =
                     (IncrementalRemoteKeyedStateHandle) keyedStateHandle;
             
restorePreviousIncrementalFilesStatus(incrementalRemoteKeyedStateHandle);
-            restoreFromRemoteState(incrementalRemoteKeyedStateHandle);
+            restoreBaseDBFromRemoteState(incrementalRemoteKeyedStateHandle);
         } else if (keyedStateHandle instanceof 
IncrementalLocalKeyedStateHandle) {
             IncrementalLocalKeyedStateHandle incrementalLocalKeyedStateHandle =
                     (IncrementalLocalKeyedStateHandle) keyedStateHandle;
             
restorePreviousIncrementalFilesStatus(incrementalLocalKeyedStateHandle);
-            restoreFromLocalState(incrementalLocalKeyedStateHandle);
+            restoreBaseDBFromLocalState(incrementalLocalKeyedStateHandle);
         } else {
             throw unexpectedStateHandleException(
                     new Class[] {
@@ -213,20 +217,46 @@ public class RocksDBIncrementalRestoreOperation<K> 
implements RocksDBRestoreOper
         lastCompletedCheckpointId = localKeyedStateHandle.getCheckpointId();
     }
 
-    private void restoreFromRemoteState(IncrementalRemoteKeyedStateHandle 
stateHandle)
+    private void 
restoreBaseDBFromRemoteState(IncrementalRemoteKeyedStateHandle stateHandle)
             throws Exception {
         // used as restore source for IncrementalRemoteKeyedStateHandle
         final Path tmpRestoreInstancePath =
                 
instanceBasePath.getAbsoluteFile().toPath().resolve(UUID.randomUUID().toString());
+        final StateHandleDownloadSpec downloadRequest =
+                new StateHandleDownloadSpec(stateHandle, 
tmpRestoreInstancePath);
         try {
-            restoreFromLocalState(
-                    
transferRemoteStateToLocalDirectory(tmpRestoreInstancePath, stateHandle));
+            
transferRemoteStateToLocalDirectory(Collections.singletonList(downloadRequest));
+            restoreBaseDBFromDownloadedState(downloadRequest);
         } finally {
-            cleanUpPathQuietly(tmpRestoreInstancePath);
+            cleanUpPathQuietly(downloadRequest.getDownloadDestination());
         }
     }
 
-    private void restoreFromLocalState(IncrementalLocalKeyedStateHandle 
localKeyedStateHandle)
+    /**
+     * This helper method creates a {@link IncrementalLocalKeyedStateHandle} 
for state that was
+     * previously downloaded for a {@link IncrementalRemoteKeyedStateHandle} 
and then invokes the
+     * restore procedure for local state on the downloaded state.
+     *
+     * @param downloadedState the specification of a completed state download.
+     * @throws Exception for restore problems.
+     */
+    private void restoreBaseDBFromDownloadedState(StateHandleDownloadSpec 
downloadedState)
+            throws Exception {
+        // since we transferred all remote state to a local directory, we can 
use the same code
+        // as for local recovery.
+        IncrementalRemoteKeyedStateHandle stateHandle = 
downloadedState.getStateHandle();
+        restoreBaseDBFromLocalState(
+                new IncrementalLocalKeyedStateHandle(
+                        stateHandle.getBackendIdentifier(),
+                        stateHandle.getCheckpointId(),
+                        new 
DirectoryStateHandle(downloadedState.getDownloadDestination()),
+                        stateHandle.getKeyGroupRange(),
+                        stateHandle.getMetaStateHandle(),
+                        stateHandle.getSharedState()));
+    }
+
+    /** Restores RocksDB instance from local state. */
+    private void restoreBaseDBFromLocalState(IncrementalLocalKeyedStateHandle 
localKeyedStateHandle)
             throws Exception {
         KeyedBackendSerializationProxy<K> serializationProxy =
                 readMetaData(localKeyedStateHandle.getMetaDataState());
@@ -246,26 +276,13 @@ public class RocksDBIncrementalRestoreOperation<K> 
implements RocksDBRestoreOper
                 restoreSourcePath);
     }
 
-    private IncrementalLocalKeyedStateHandle 
transferRemoteStateToLocalDirectory(
-            Path temporaryRestoreInstancePath, 
IncrementalRemoteKeyedStateHandle restoreStateHandle)
-            throws Exception {
-
+    private void transferRemoteStateToLocalDirectory(
+            Collection<StateHandleDownloadSpec> downloadRequests) throws 
Exception {
         try (RocksDBStateDownloader rocksDBStateDownloader =
                 new RocksDBStateDownloader(numberOfTransferringThreads)) {
             rocksDBStateDownloader.transferAllStateDataToDirectory(
-                    restoreStateHandle, temporaryRestoreInstancePath, 
cancelStreamRegistry);
+                    downloadRequests, cancelStreamRegistry);
         }
-
-        // since we transferred all remote state to a local directory, we can 
use the same code as
-        // for
-        // local recovery.
-        return new IncrementalLocalKeyedStateHandle(
-                restoreStateHandle.getBackendIdentifier(),
-                restoreStateHandle.getCheckpointId(),
-                new DirectoryStateHandle(temporaryRestoreInstancePath),
-                restoreStateHandle.getKeyGroupRange(),
-                restoreStateHandle.getMetaStateHandle(),
-                restoreStateHandle.getSharedState());
     }
 
     private void cleanUpPathQuietly(@Nonnull Path path) {
@@ -284,19 +301,39 @@ public class RocksDBIncrementalRestoreOperation<K> 
implements RocksDBRestoreOper
     private void restoreWithRescaling(Collection<KeyedStateHandle> 
restoreStateHandles)
             throws Exception {
 
-        // Prepare for restore with rescaling
-        KeyedStateHandle initialHandle =
+        Preconditions.checkArgument(restoreStateHandles != null && 
!restoreStateHandles.isEmpty());
+
+        Map<StateHandleID, StateHandleDownloadSpec> allDownloadSpecs =
+                
CollectionUtil.newHashMapWithExpectedSize(restoreStateHandles.size());
+
+        // Choose the best state handle for the initial DB
+        final KeyedStateHandle selectedInitialHandle =
                 
RocksDBIncrementalCheckpointUtils.chooseTheBestStateHandleForInitial(
                         restoreStateHandles, keyGroupRange, 
overlapFractionThreshold);
 
-        // Init base DB instance
-        if (initialHandle != null) {
-            restoreStateHandles.remove(initialHandle);
-            initDBWithRescaling(initialHandle);
-        } else {
-            this.rocksHandle.openDB();
+        Preconditions.checkNotNull(selectedInitialHandle);
+
+        final Path absolutInstanceBasePath = 
instanceBasePath.getAbsoluteFile().toPath();
+
+        // Prepare and collect all the download request to pull remote state 
to a local directory
+        for (KeyedStateHandle stateHandle : restoreStateHandles) {
+            if (!(stateHandle instanceof IncrementalRemoteKeyedStateHandle)) {
+                throw unexpectedStateHandleException(
+                        IncrementalRemoteKeyedStateHandle.class, 
stateHandle.getClass());
+            }
+            StateHandleDownloadSpec downloadRequest =
+                    new StateHandleDownloadSpec(
+                            (IncrementalRemoteKeyedStateHandle) stateHandle,
+                            
absolutInstanceBasePath.resolve(UUID.randomUUID().toString()));
+            allDownloadSpecs.put(stateHandle.getStateHandleId(), 
downloadRequest);
         }
 
+        // Process all state downloads
+        transferRemoteStateToLocalDirectory(allDownloadSpecs.values());
+
+        // Init the base DB instance with the initial state
+        
initBaseDBForRescaling(allDownloadSpecs.remove(selectedInitialHandle.getStateHandleId()));
+
         // Transfer remaining key-groups from temporary instance into base DB
         byte[] startKeyGroupPrefixBytes = new byte[keyGroupPrefixBytes];
         CompositeKeySerializationUtils.serializeKeyGroup(
@@ -306,24 +343,14 @@ public class RocksDBIncrementalRestoreOperation<K> 
implements RocksDBRestoreOper
         CompositeKeySerializationUtils.serializeKeyGroup(
                 keyGroupRange.getEndKeyGroup() + 1, stopKeyGroupPrefixBytes);
 
-        for (KeyedStateHandle rawStateHandle : restoreStateHandles) {
-
-            if (!(rawStateHandle instanceof 
IncrementalRemoteKeyedStateHandle)) {
-                throw unexpectedStateHandleException(
-                        IncrementalRemoteKeyedStateHandle.class, 
rawStateHandle.getClass());
-            }
-
+        // Insert all remaining state through creating temporary RocksDB 
instances
+        for (StateHandleDownloadSpec downloadRequest : 
allDownloadSpecs.values()) {
             logger.info(
-                    "Starting to restore from state handle: {} with 
rescaling.", rawStateHandle);
-            Path temporaryRestoreInstancePath =
-                    instanceBasePath
-                            .getAbsoluteFile()
-                            .toPath()
-                            .resolve(UUID.randomUUID().toString());
+                    "Starting to restore from state handle: {} with 
rescaling.",
+                    downloadRequest.getStateHandle());
+
             try (RestoredDBInstance tmpRestoreDBInfo =
-                            restoreDBInstanceFromStateHandle(
-                                    (IncrementalRemoteKeyedStateHandle) 
rawStateHandle,
-                                    temporaryRestoreInstancePath);
+                            
restoreTempDBInstanceFromDownloadedState(downloadRequest);
                     RocksDBWriteBatchWrapper writeBatchWrapper =
                             new RocksDBWriteBatchWrapper(
                                     this.rocksHandle.getDb(), writeBatchSize)) 
{
@@ -335,12 +362,13 @@ public class RocksDBIncrementalRestoreOperation<K> 
implements RocksDBRestoreOper
 
                 // iterating only the requested descriptors automatically 
skips the default column
                 // family handle
-                for (int i = 0; i < tmpColumnFamilyDescriptors.size(); ++i) {
-                    ColumnFamilyHandle tmpColumnFamilyHandle = 
tmpColumnFamilyHandles.get(i);
+                for (int descIdx = 0; descIdx < 
tmpColumnFamilyDescriptors.size(); ++descIdx) {
+                    ColumnFamilyHandle tmpColumnFamilyHandle = 
tmpColumnFamilyHandles.get(descIdx);
 
                     ColumnFamilyHandle targetColumnFamilyHandle =
                             
this.rocksHandle.getOrRegisterStateColumnFamilyHandle(
-                                            null, 
tmpRestoreDBInfo.stateMetaInfoSnapshots.get(i))
+                                            null,
+                                            
tmpRestoreDBInfo.stateMetaInfoSnapshots.get(descIdx))
                                     .columnFamilyHandle;
 
                     try (RocksIteratorWrapper iterator =
@@ -369,19 +397,19 @@ public class RocksDBIncrementalRestoreOperation<K> 
implements RocksDBRestoreOper
                     } // releases native iterator resources
                 }
                 logger.info(
-                        "Finished restoring from state handle: {} with 
rescaling.", rawStateHandle);
+                        "Finished restoring from state handle: {} with 
rescaling.",
+                        downloadRequest.getStateHandle());
             } finally {
-                cleanUpPathQuietly(temporaryRestoreInstancePath);
+                cleanUpPathQuietly(downloadRequest.getDownloadDestination());
             }
         }
     }
 
-    private void initDBWithRescaling(KeyedStateHandle initialHandle) throws 
Exception {
-
-        assert (initialHandle instanceof IncrementalRemoteKeyedStateHandle);
+    private void initBaseDBForRescaling(StateHandleDownloadSpec 
downloadedInitialState)
+            throws Exception {
 
         // 1. Restore base DB from selected initial handle
-        restoreFromRemoteState((IncrementalRemoteKeyedStateHandle) 
initialHandle);
+        restoreBaseDBFromDownloadedState(downloadedInitialState);
 
         // 2. Clip the base DB instance
         try {
@@ -389,7 +417,7 @@ public class RocksDBIncrementalRestoreOperation<K> 
implements RocksDBRestoreOper
                     this.rocksHandle.getDb(),
                     this.rocksHandle.getColumnFamilyHandles(),
                     keyGroupRange,
-                    initialHandle.getKeyGroupRange(),
+                    downloadedInitialState.getStateHandle().getKeyGroupRange(),
                     keyGroupPrefixBytes);
         } catch (RocksDBException e) {
             String errMsg = "Failed to clip DB after initialization.";
@@ -441,18 +469,11 @@ public class RocksDBIncrementalRestoreOperation<K> 
implements RocksDBRestoreOper
         }
     }
 
-    private RestoredDBInstance restoreDBInstanceFromStateHandle(
-            IncrementalRemoteKeyedStateHandle restoreStateHandle, Path 
temporaryRestoreInstancePath)
-            throws Exception {
-
-        try (RocksDBStateDownloader rocksDBStateDownloader =
-                new RocksDBStateDownloader(numberOfTransferringThreads)) {
-            rocksDBStateDownloader.transferAllStateDataToDirectory(
-                    restoreStateHandle, temporaryRestoreInstancePath, 
cancelStreamRegistry);
-        }
+    private RestoredDBInstance restoreTempDBInstanceFromDownloadedState(
+            StateHandleDownloadSpec downloadRequest) throws Exception {
 
         KeyedBackendSerializationProxy<K> serializationProxy =
-                readMetaData(restoreStateHandle.getMetaStateHandle());
+                
readMetaData(downloadRequest.getStateHandle().getMetaStateHandle());
         // read meta data
         List<StateMetaInfoSnapshot> stateMetaInfoSnapshots =
                 serializationProxy.getStateMetaInfoSnapshots();
@@ -465,7 +486,7 @@ public class RocksDBIncrementalRestoreOperation<K> 
implements RocksDBRestoreOper
 
         RocksDB restoreDb =
                 RocksDBOperationUtils.openDB(
-                        temporaryRestoreInstancePath.toString(),
+                        downloadRequest.getDownloadDestination().toString(),
                         columnFamilyDescriptors,
                         columnFamilyHandles,
                         RocksDBOperationUtils.createColumnFamilyOptions(
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloaderTest.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloaderTest.java
index fcce8674887..2f903644797 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloaderTest.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloaderTest.java
@@ -28,6 +28,7 @@ import org.apache.flink.runtime.state.TestStreamStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.util.TestLogger;
 
+import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
@@ -37,6 +38,7 @@ import java.nio.file.Files;
 import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -73,8 +75,10 @@ public class RocksDBStateDownloaderTest extends TestLogger {
 
         try (RocksDBStateDownloader rocksDBStateDownloader = new 
RocksDBStateDownloader(5)) {
             rocksDBStateDownloader.transferAllStateDataToDirectory(
-                    incrementalKeyedStateHandle,
-                    temporaryFolder.newFolder().toPath(),
+                    Collections.singletonList(
+                            new StateHandleDownloadSpec(
+                                    incrementalKeyedStateHandle,
+                                    temporaryFolder.newFolder().toPath())),
                     new CloseableRegistry());
             fail();
         } catch (Exception e) {
@@ -85,46 +89,69 @@ public class RocksDBStateDownloaderTest extends TestLogger {
     /** Tests that download files with multi-thread correctly. */
     @Test
     public void testMultiThreadRestoreCorrectly() throws Exception {
-        Random random = new Random();
-        int contentNum = 6;
-        byte[][] contents = new byte[contentNum][];
-        for (int i = 0; i < contentNum; ++i) {
-            contents[i] = new byte[random.nextInt(100000) + 1];
-            random.nextBytes(contents[i]);
+        int numRemoteHandles = 3;
+        int numSubHandles = 6;
+        byte[][][] contents = createContents(numRemoteHandles, numSubHandles);
+        List<StateHandleDownloadSpec> downloadRequests = new 
ArrayList<>(numRemoteHandles);
+        for (int i = 0; i < numRemoteHandles; ++i) {
+            downloadRequests.add(
+                    createDownloadRequestForContent(
+                            temporaryFolder.newFolder().toPath(), contents[i], 
i));
         }
 
-        List<StreamStateHandle> handles = new ArrayList<>(contentNum);
-        for (int i = 0; i < contentNum; ++i) {
-            handles.add(new ByteStreamStateHandle(String.format("state%d", i), 
contents[i]));
+        try (RocksDBStateDownloader rocksDBStateDownloader = new 
RocksDBStateDownloader(4)) {
+            rocksDBStateDownloader.transferAllStateDataToDirectory(
+                    downloadRequests, new CloseableRegistry());
         }
 
-        Map<StateHandleID, StreamStateHandle> sharedStates = new 
HashMap<>(contentNum);
-        Map<StateHandleID, StreamStateHandle> privateStates = new 
HashMap<>(contentNum);
-        for (int i = 0; i < contentNum; ++i) {
-            sharedStates.put(new StateHandleID(String.format("sharedState%d", 
i)), handles.get(i));
-            privateStates.put(
-                    new StateHandleID(String.format("privateState%d", i)), 
handles.get(i));
+        for (int i = 0; i < numRemoteHandles; ++i) {
+            StateHandleDownloadSpec downloadRequest = downloadRequests.get(i);
+            Path dstPath = downloadRequest.getDownloadDestination();
+            Assert.assertTrue(dstPath.toFile().exists());
+            for (int j = 0; j < numSubHandles; ++j) {
+                assertStateContentEqual(
+                        contents[i][j], 
dstPath.resolve(String.format("sharedState-%d-%d", i, j)));
+            }
         }
+    }
 
-        IncrementalRemoteKeyedStateHandle incrementalKeyedStateHandle =
-                new IncrementalRemoteKeyedStateHandle(
-                        UUID.randomUUID(),
-                        KeyGroupRange.of(0, 1),
-                        1,
-                        sharedStates,
-                        privateStates,
-                        handles.get(0));
+    /** Tests cleanup on download failures. */
+    @Test
+    public void testMultiThreadCleanupOnFailure() throws Exception {
+        int numRemoteHandles = 3;
+        int numSubHandles = 6;
+        byte[][][] contents = createContents(numRemoteHandles, numSubHandles);
+        List<StateHandleDownloadSpec> downloadRequests = new 
ArrayList<>(numRemoteHandles);
+        for (int i = 0; i < numRemoteHandles; ++i) {
+            downloadRequests.add(
+                    createDownloadRequestForContent(
+                            temporaryFolder.newFolder().toPath(), contents[i], 
i));
+        }
 
-        Path dstPath = temporaryFolder.newFolder().toPath();
+        IncrementalRemoteKeyedStateHandle stateHandle =
+                downloadRequests.get(downloadRequests.size() - 
1).getStateHandle();
+
+        // Add a state handle that induces an exception
+        stateHandle
+                .getSharedState()
+                .put(
+                        new StateHandleID("error-handle"),
+                        new ThrowingStateHandle(new IOException("Test 
exception.")));
+
+        CloseableRegistry closeableRegistry = new CloseableRegistry();
         try (RocksDBStateDownloader rocksDBStateDownloader = new 
RocksDBStateDownloader(5)) {
             rocksDBStateDownloader.transferAllStateDataToDirectory(
-                    incrementalKeyedStateHandle, dstPath, new 
CloseableRegistry());
+                    downloadRequests, closeableRegistry);
+            fail("Exception is expected");
+        } catch (IOException ignore) {
         }
 
-        for (int i = 0; i < contentNum; ++i) {
-            assertStateContentEqual(
-                    contents[i], 
dstPath.resolve(String.format("sharedState%d", i)));
+        // Check that all download directories have been deleted
+        for (StateHandleDownloadSpec downloadRequest : downloadRequests) {
+            
Assert.assertFalse(downloadRequest.getDownloadDestination().toFile().exists());
         }
+        // The passed in closable registry should not be closed by us on 
failure.
+        Assert.assertFalse(closeableRegistry.isClosed());
     }
 
     private void assertStateContentEqual(byte[] expected, Path path) throws 
IOException {
@@ -165,4 +192,49 @@ public class RocksDBStateDownloaderTest extends TestLogger 
{
             return 0;
         }
     }
+
+    private byte[][][] createContents(int numRemoteHandles, int numSubHandles) 
{
+        Random random = new Random();
+        byte[][][] contents = new byte[numRemoteHandles][numSubHandles][];
+        for (int i = 0; i < numRemoteHandles; ++i) {
+            for (int j = 0; j < numSubHandles; ++j) {
+                contents[i][j] = new byte[random.nextInt(100000) + 1];
+                random.nextBytes(contents[i][j]);
+            }
+        }
+        return contents;
+    }
+
+    private StateHandleDownloadSpec createDownloadRequestForContent(
+            Path dstPath, byte[][] content, int remoteHandleId) {
+        int numSubHandles = content.length;
+        List<StreamStateHandle> handles = new ArrayList<>(numSubHandles);
+        for (int i = 0; i < numSubHandles; ++i) {
+            handles.add(
+                    new ByteStreamStateHandle(
+                            String.format("state-%d-%d", remoteHandleId, i), 
content[i]));
+        }
+
+        Map<StateHandleID, StreamStateHandle> sharedStates = new 
HashMap<>(numSubHandles);
+        Map<StateHandleID, StreamStateHandle> privateStates = new 
HashMap<>(numSubHandles);
+        for (int i = 0; i < numSubHandles; ++i) {
+            sharedStates.put(
+                    new StateHandleID(String.format("sharedState-%d-%d", 
remoteHandleId, i)),
+                    handles.get(i));
+            privateStates.put(
+                    new StateHandleID(String.format("privateState-%d-%d", 
remoteHandleId, i)),
+                    handles.get(i));
+        }
+
+        IncrementalRemoteKeyedStateHandle incrementalKeyedStateHandle =
+                new IncrementalRemoteKeyedStateHandle(
+                        UUID.randomUUID(),
+                        KeyGroupRange.of(0, 1),
+                        1,
+                        sharedStates,
+                        privateStates,
+                        handles.get(0));
+
+        return new StateHandleDownloadSpec(incrementalKeyedStateHandle, 
dstPath);
+    }
 }


Reply via email to