This is an automated email from the ASF dual-hosted git repository.

jqin pushed a commit to branch release-1.11
in repository https://gitbox.apache.org/repos/asf/flink.git

commit cb057b45f67cebf1ecd35abd0e62524b8f188d6a
Author: Stephan Ewen <[email protected]>
AuthorDate: Sun Nov 22 14:08:46 2020 +0100

    [FLINK-20223][runtime] (part 2) Set user code classloader as context class 
loader for SplitEnumerator creation and thread
    
    Co-authored-by: Jiangjie (Becket) Qin <[email protected]>
    
      - This moves the instantiation of the SplitEnumerator out of the 
constructor to prevent double-instantiation
      - Add context class loaders to creation and coordinator thread
      - Class-loading and SplitEnumerator instantiation is purely handled in 
the SourceCoordinator and does not leak
        into RecreateOnResetOperatorCoordinator.
---
 .../source/coordinator/SourceCoordinator.java      |  70 +++++--
 .../coordinator/SourceCoordinatorContext.java      |   4 +
 .../coordinator/SourceCoordinatorProvider.java     |  11 +-
 .../MockOperatorCoordinatorContext.java            |  21 ++-
 .../coordinator/SourceCoordinatorContextTest.java  |   7 +-
 .../coordinator/SourceCoordinatorProviderTest.java |   3 +-
 .../source/coordinator/SourceCoordinatorTest.java  | 208 +++++++++++++++++++--
 .../coordinator/SourceCoordinatorTestBase.java     |  18 +-
 8 files changed, 296 insertions(+), 46 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinator.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinator.java
index 10baf02..2685ca6 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinator.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinator.java
@@ -33,6 +33,7 @@ import 
org.apache.flink.runtime.source.event.ReaderRegistrationEvent;
 import org.apache.flink.runtime.source.event.RequestSplitEvent;
 import org.apache.flink.runtime.source.event.SourceEventWrapper;
 import org.apache.flink.util.FlinkException;
+import org.apache.flink.util.TemporaryClassLoaderContext;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -52,6 +53,7 @@ import java.util.concurrent.TimeUnit;
 import static 
org.apache.flink.runtime.source.coordinator.SourceCoordinatorSerdeUtils.readAndVerifyCoordinatorSerdeVersion;
 import static 
org.apache.flink.runtime.source.coordinator.SourceCoordinatorSerdeUtils.readBytes;
 import static 
org.apache.flink.runtime.source.coordinator.SourceCoordinatorSerdeUtils.writeCoordinatorSerdeVersion;
+import static org.apache.flink.util.Preconditions.checkState;
 
 /**
  * The default implementation of the {@link OperatorCoordinator} for the 
{@link Source}.
@@ -68,7 +70,9 @@ import static 
org.apache.flink.runtime.source.coordinator.SourceCoordinatorSerde
  */
 @Internal
 public class SourceCoordinator<SplitT extends SourceSplit, EnumChkT> 
implements OperatorCoordinator {
+
        private static final Logger LOG = 
LoggerFactory.getLogger(SourceCoordinator.class);
+
        /** The name of the operator this SourceCoordinator is associated with. 
*/
        private final String operatorName;
        /** A single-thread executor to handle all the changes to the 
coordinator. */
@@ -81,7 +85,8 @@ public class SourceCoordinator<SplitT extends SourceSplit, 
EnumChkT> implements
        private final SimpleVersionedSerializer<SplitT> splitSerializer;
        /** The context containing the states of the coordinator. */
        private final SourceCoordinatorContext<SplitT> context;
-       /** The split enumerator created from the associated Source. */
+       /** The split enumerator created from the associated Source. This one 
is created either during resetting
+        * the coordinator to a checkpoint, or when the coordinator is started. 
*/
        private SplitEnumerator<SplitT, EnumChkT> enumerator;
        /** A flag marking whether the coordinator has started. */
        private boolean started;
@@ -90,20 +95,29 @@ public class SourceCoordinator<SplitT extends SourceSplit, 
EnumChkT> implements
                        String operatorName,
                        ExecutorService coordinatorExecutor,
                        Source<?, SplitT, EnumChkT> source,
-                       SourceCoordinatorContext<SplitT> context) throws 
Exception {
+                       SourceCoordinatorContext<SplitT> context) {
                this.operatorName = operatorName;
                this.coordinatorExecutor = coordinatorExecutor;
                this.source = source;
                this.enumCheckpointSerializer = 
source.getEnumeratorCheckpointSerializer();
                this.splitSerializer = source.getSplitSerializer();
                this.context = context;
-               this.enumerator = source.createEnumerator(context);
-               this.started = false;
        }
 
        @Override
        public void start() throws Exception {
                LOG.info("Starting split enumerator for source {}.", 
operatorName);
+
+               // there are two ways the coordinator can get created:
+               //  (1) Source.restoreEnumerator(), in which case the 
'resetToCheckpoint()' method creates it
+               //  (2) Source.createEnumerator, in which case it has not been 
created, yet, and we create it here
+               if (enumerator == null) {
+                       final ClassLoader userCodeClassLoader = 
context.getCoordinatorContext().getUserCodeClassloader();
+                       try (TemporaryClassLoaderContext ignored = 
TemporaryClassLoaderContext.of(userCodeClassLoader)) {
+                               enumerator = source.createEnumerator(context);
+                       }
+               }
+
                // The start sequence is the first task in the coordinator 
executor.
                // We rely on the single-threaded coordinator executor to 
guarantee
                // the other methods are invoked after the enumerator has 
started.
@@ -117,7 +131,9 @@ public class SourceCoordinator<SplitT extends SourceSplit, 
EnumChkT> implements
                try {
                        if (started) {
                                context.close();
-                               enumerator.close();
+                               if (enumerator != null) {
+                                       enumerator.close();
+                               }
                        }
                } finally {
                        coordinatorExecutor.shutdownNow();
@@ -217,13 +233,16 @@ public class SourceCoordinator<SplitT extends 
SourceSplit, EnumChkT> implements
 
        @Override
        public void resetToCheckpoint(byte[] checkpointData) throws Exception {
-               if (started) {
-                       throw new IllegalStateException(String.format(
-                                       "The coordinator for source %s has 
started. The source coordinator state can " +
-                                       "only be reset to a checkpoint before 
it starts.", operatorName));
+               checkState(!started, "The coordinator can only be reset if it 
was not yet started");
+               assert enumerator == null;
+
+               LOG.info("Restoring SplitEnumerator of source {} from 
checkpoint.", operatorName);
+
+               final ClassLoader userCodeClassLoader = 
context.getCoordinatorContext().getUserCodeClassloader();
+               try (TemporaryClassLoaderContext ignored = 
TemporaryClassLoaderContext.of(userCodeClassLoader)) {
+                       final EnumChkT enumeratorCheckpoint = 
deserializeCheckpointAndRestoreContext(checkpointData);
+                       enumerator = source.restoreEnumerator(context, 
enumeratorCheckpoint);
                }
-               LOG.info("Resetting coordinator of source {} from checkpoint.", 
operatorName);
-               fromBytes(checkpointData);
        }
 
        // ---------------------------------------------------
@@ -249,16 +268,30 @@ public class SourceCoordinator<SplitT extends 
SourceSplit, EnumChkT> implements
         * @throws Exception When something goes wrong in serialization.
         */
        private byte[] toBytes(long checkpointId) throws Exception {
-               EnumChkT enumCkpt = enumerator.snapshotState();
+               return writeCheckpointBytes(
+                               checkpointId,
+                               enumerator.snapshotState(),
+                               context,
+                               enumCheckpointSerializer,
+                               splitSerializer);
+       }
+
+       static <SplitT extends SourceSplit, EnumChkT> byte[] 
writeCheckpointBytes(
+                       final long checkpointId,
+                       final EnumChkT enumeratorCheckpoint,
+                       final SourceCoordinatorContext<SplitT> 
coordinatorContext,
+                       final SimpleVersionedSerializer<EnumChkT> 
checkpointSerializer,
+                       final SimpleVersionedSerializer<SplitT> 
splitSerializer) throws Exception {
 
                try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
                                DataOutputStream out = new 
DataOutputViewStreamWrapper(baos)) {
+
                        writeCoordinatorSerdeVersion(out);
-                       out.writeInt(enumCheckpointSerializer.getVersion());
-                       byte[] serialziedEnumChkpt = 
enumCheckpointSerializer.serialize(enumCkpt);
+                       out.writeInt(checkpointSerializer.getVersion());
+                       byte[] serialziedEnumChkpt = 
checkpointSerializer.serialize(enumeratorCheckpoint);
                        out.writeInt(serialziedEnumChkpt.length);
                        out.write(serialziedEnumChkpt);
-                       context.snapshotState(checkpointId, splitSerializer, 
out);
+                       coordinatorContext.snapshotState(checkpointId, 
splitSerializer, out);
                        out.flush();
                        return baos.toByteArray();
                }
@@ -270,16 +303,15 @@ public class SourceCoordinator<SplitT extends 
SourceSplit, EnumChkT> implements
         * @param bytes The checkpoint bytes that was returned from {@link 
#toBytes(long)}
         * @throws Exception When the deserialization failed.
         */
-       private void fromBytes(byte[] bytes) throws Exception {
+       private EnumChkT deserializeCheckpointAndRestoreContext(byte[] bytes) 
throws Exception {
                try (ByteArrayInputStream bais = new 
ByteArrayInputStream(bytes);
                                DataInputStream in = new 
DataInputViewStreamWrapper(bais)) {
                        readAndVerifyCoordinatorSerdeVersion(in);
                        int enumSerializerVersion = in.readInt();
                        int serializedEnumChkptSize = in.readInt();
                        byte[] serializedEnumChkpt = readBytes(in, 
serializedEnumChkptSize);
-                       EnumChkT enumChkpt = 
enumCheckpointSerializer.deserialize(enumSerializerVersion, 
serializedEnumChkpt);
                        context.restoreState(splitSerializer, in);
-                       enumerator = source.restoreEnumerator(context, 
enumChkpt);
+                       return 
enumCheckpointSerializer.deserialize(enumSerializerVersion, 
serializedEnumChkpt);
                }
        }
 
@@ -294,5 +326,7 @@ public class SourceCoordinator<SplitT extends SourceSplit, 
EnumChkT> implements
                if (!started) {
                        throw new IllegalStateException("The coordinator has 
not started yet.");
                }
+
+               assert enumerator != null;
        }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContext.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContext.java
index 6141688..3ed0ab5 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContext.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContext.java
@@ -303,6 +303,10 @@ public class SourceCoordinatorContext<SplitT extends 
SourceSplit>
                assignmentTracker.onCheckpointComplete(checkpointId);
        }
 
+       OperatorCoordinator.Context getCoordinatorContext() {
+               return operatorCoordinatorContext;
+       }
+
        // ---------------- private helper methods -----------------
 
        /**
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorProvider.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorProvider.java
index 4047bad..f13aa4e 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorProvider.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorProvider.java
@@ -66,8 +66,9 @@ public class SourceCoordinatorProvider<SplitT extends 
SourceSplit> extends Recre
        public OperatorCoordinator getCoordinator(OperatorCoordinator.Context 
context) throws Exception  {
                final String coordinatorThreadName = "SourceCoordinator-" + 
operatorName;
                CoordinatorExecutorThreadFactory coordinatorThreadFactory =
-                               new 
CoordinatorExecutorThreadFactory(coordinatorThreadName, context);
+                               new 
CoordinatorExecutorThreadFactory(coordinatorThreadName, context, 
context.getUserCodeClassloader());
                ExecutorService coordinatorExecutor = 
Executors.newSingleThreadExecutor(coordinatorThreadFactory);
+
                SimpleVersionedSerializer<SplitT> splitSerializer = 
source.getSplitSerializer();
                SourceCoordinatorContext<SplitT> sourceCoordinatorContext =
                                new 
SourceCoordinatorContext<>(coordinatorExecutor, coordinatorThreadFactory, 
numWorkerThreads,
@@ -81,14 +82,17 @@ public class SourceCoordinatorProvider<SplitT extends 
SourceSplit> extends Recre
        public static class CoordinatorExecutorThreadFactory implements 
ThreadFactory {
                private final String coordinatorThreadName;
                private final OperatorCoordinator.Context context;
+               private final ClassLoader cl;
                private Thread t;
 
                CoordinatorExecutorThreadFactory(
-                               String coordinatorThreadName,
-                               OperatorCoordinator.Context context) {
+                               final String coordinatorThreadName,
+                               final OperatorCoordinator.Context context,
+                               final ClassLoader contextClassLoader) {
                        this.coordinatorThreadName = coordinatorThreadName;
                        this.context = context;
                        this.t = null;
+                       this.cl = contextClassLoader;
                }
 
                @Override
@@ -98,6 +102,7 @@ public class SourceCoordinatorProvider<SplitT extends 
SourceSplit> extends Recre
                                                "SingleThreadExecutor.");
                        }
                        t = new Thread(r, coordinatorThreadName);
+                       t.setContextClassLoader(cl);
                        t.setUncaughtExceptionHandler((thread, throwable) -> 
context.failJob(throwable));
                        return t;
                }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/coordination/MockOperatorCoordinatorContext.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/coordination/MockOperatorCoordinatorContext.java
index ab8a56b..e7184c9 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/coordination/MockOperatorCoordinatorContext.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/coordination/MockOperatorCoordinatorContext.java
@@ -30,6 +30,7 @@ import java.util.concurrent.CompletableFuture;
 
 public class MockOperatorCoordinatorContext implements 
OperatorCoordinator.Context {
        private final OperatorID operatorID;
+       private final ClassLoader userCodeClassLoader;
        private final int numSubtasks;
        private final boolean failEventSending;
 
@@ -40,12 +41,28 @@ public class MockOperatorCoordinatorContext implements 
OperatorCoordinator.Conte
                this(operatorID, numSubtasks, true);
        }
 
-       public MockOperatorCoordinatorContext(OperatorID operatorID, int 
numSubtasks, boolean failEventSending) {
+       public MockOperatorCoordinatorContext(
+                       OperatorID operatorID,
+                       int numSubtasks,
+                       boolean failEventSending) {
+               this(operatorID, numSubtasks, failEventSending, 
MockOperatorCoordinatorContext.class.getClassLoader());
+       }
+
+       public MockOperatorCoordinatorContext(OperatorID operatorID, 
ClassLoader userCodeClassLoader) {
+               this(operatorID, 1, true, userCodeClassLoader);
+       }
+
+       public MockOperatorCoordinatorContext(
+                       OperatorID operatorID,
+                       int numSubtasks,
+                       boolean failEventSending,
+                       ClassLoader userCodeClassLoader) {
                this.operatorID = operatorID;
                this.numSubtasks = numSubtasks;
                this.eventsToOperator = new HashMap<>();
                this.jobFailed = false;
                this.failEventSending = failEventSending;
+               this.userCodeClassLoader = userCodeClassLoader;
        }
 
        @Override
@@ -79,7 +96,7 @@ public class MockOperatorCoordinatorContext implements 
OperatorCoordinator.Conte
 
        @Override
        public ClassLoader getUserCodeClassloader() {
-               return getClass().getClassLoader();
+               return userCodeClassLoader;
        }
 
        // -------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContextTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContextTest.java
index 280f6a4..70550ae 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContextTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContextTest.java
@@ -127,7 +127,7 @@ public class SourceCoordinatorContextTest extends 
SourceCoordinatorTestBase {
                                                
context.assignSplits(splitsAssignment);
                                        }
                                },
-                               "assignSpoits() should fail to assign the 
splits to a reader that is not registered.",
+                               "assignSplits() should fail to assign the 
splits to a reader that is not registered.",
                                "Cannot assign splits");
        }
 
@@ -145,7 +145,10 @@ public class SourceCoordinatorContextTest extends 
SourceCoordinatorTestBase {
                SplitAssignmentTracker<MockSourceSplit> restoredTracker = new 
SplitAssignmentTracker<>();
                SourceCoordinatorProvider.CoordinatorExecutorThreadFactory 
coordinatorThreadFactory =
                                new 
SourceCoordinatorProvider.CoordinatorExecutorThreadFactory(
-                                       TEST_OPERATOR_ID.toHexString(), 
operatorCoordinatorContext);
+                                               TEST_OPERATOR_ID.toHexString(),
+                                               operatorCoordinatorContext,
+                                               getClass().getClassLoader());
+
                try (ByteArrayInputStream bais = new 
ByteArrayInputStream(bytes);
                                DataInputStream in = new DataInputStream(bais)) 
{
                        restoredContext = new SourceCoordinatorContext<>(
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorProviderTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorProviderTest.java
index c0fc3a8..cb8bfd8 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorProviderTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorProviderTest.java
@@ -42,9 +42,11 @@ import static org.junit.Assert.assertTrue;
 /**
  * Unit tests for {@link SourceCoordinatorProvider}.
  */
+@SuppressWarnings("serial")
 public class SourceCoordinatorProviderTest {
        private static final OperatorID OPERATOR_ID = new OperatorID(1234L, 
5678L);
        private static final int NUM_SPLITS = 10;
+
        private SourceCoordinatorProvider<MockSourceSplit> provider;
 
        @Before
@@ -116,5 +118,4 @@ public class SourceCoordinatorProviderTest {
                        Duration.ofSeconds(10L),
                        "The job did not fail before timeout.");
        }
-
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTest.java
index 0ddf971..d8182d4 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTest.java
@@ -18,10 +18,22 @@
 
 package org.apache.flink.runtime.source.coordinator;
 
+import org.apache.flink.api.connector.source.Boundedness;
+import org.apache.flink.api.connector.source.Source;
 import org.apache.flink.api.connector.source.SourceEvent;
+import org.apache.flink.api.connector.source.SourceReader;
+import org.apache.flink.api.connector.source.SourceReaderContext;
+import org.apache.flink.api.connector.source.SplitEnumerator;
+import org.apache.flink.api.connector.source.SplitEnumeratorContext;
 import org.apache.flink.api.connector.source.mocks.MockSourceSplit;
 import org.apache.flink.api.connector.source.mocks.MockSourceSplitSerializer;
 import org.apache.flink.api.connector.source.mocks.MockSplitEnumerator;
+import 
org.apache.flink.api.connector.source.mocks.MockSplitEnumeratorCheckpointSerializer;
+import org.apache.flink.core.io.SimpleVersionedSerializer;
+import org.apache.flink.runtime.concurrent.Executors;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import 
org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext;
+import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
 import org.apache.flink.runtime.operators.coordination.OperatorEvent;
 import org.apache.flink.runtime.source.event.AddSplitEvent;
 import org.apache.flink.runtime.source.event.ReaderRegistrationEvent;
@@ -29,25 +41,33 @@ import 
org.apache.flink.runtime.source.event.SourceEventWrapper;
 
 import org.junit.Test;
 
+import javax.annotation.Nullable;
+
 import java.io.IOException;
+import java.net.URL;
+import java.net.URLClassLoader;
 import java.time.Duration;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 import java.util.concurrent.CompletableFuture;
+import java.util.function.Supplier;
 
 import static org.apache.flink.core.testutils.CommonTestUtils.waitUtil;
 import static 
org.apache.flink.runtime.source.coordinator.CoordinatorTestUtils.verifyAssignment;
 import static 
org.apache.flink.runtime.source.coordinator.CoordinatorTestUtils.verifyException;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertSame;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 /**
  * Unit tests for {@link SourceCoordinator}.
  */
+@SuppressWarnings("serial")
 public class SourceCoordinatorTest extends SourceCoordinatorTestBase {
 
        @Test
@@ -70,25 +90,22 @@ public class SourceCoordinatorTest extends 
SourceCoordinatorTestBase {
                sourceCoordinator.start();
                verifyException(() -> sourceCoordinator.resetToCheckpoint(null),
                                "Reset to checkpoint should fail after the 
coordinator has started",
-                               String.format("The coordinator for source %s 
has started. The source coordinator state can " +
-                                               "only be reset to a checkpoint 
before it starts.", OPERATOR_NAME));
+                               "The coordinator can only be reset if it was 
not yet started");
        }
 
        @Test(timeout = 10000L)
        public void testStart() throws Exception {
-               assertFalse(enumerator.started());
                sourceCoordinator.start();
-               while (!enumerator.started()) {
+               while (!getEnumerator().started()) {
                        Thread.sleep(1);
                }
        }
 
        @Test
        public void testClosed() throws Exception {
-               assertFalse(enumerator.closed());
                sourceCoordinator.start();
                sourceCoordinator.close();
-               assertTrue(enumerator.closed());
+               assertTrue(getEnumerator().closed());
        }
 
        @Test
@@ -98,9 +115,9 @@ public class SourceCoordinatorTest extends 
SourceCoordinatorTestBase {
                                0, new ReaderRegistrationEvent(0, 
"location_0"));
                check(() -> {
                        assertEquals("2 splits should have been assigned to 
reader 0",
-                                       4, 
enumerator.getUnassignedSplits().size());
+                                       4, 
getEnumerator().getUnassignedSplits().size());
                        assertTrue(context.registeredReaders().containsKey(0));
-                       
assertTrue(enumerator.getHandledSourceEvent().isEmpty());
+                       
assertTrue(getEnumerator().getHandledSourceEvent().isEmpty());
                        verifyAssignment(Arrays.asList("0", "3"), 
splitSplitAssignmentTracker.uncheckpointedAssignments().get(0));
                });
        }
@@ -111,8 +128,8 @@ public class SourceCoordinatorTest extends 
SourceCoordinatorTestBase {
                SourceEvent sourceEvent = new SourceEvent() {};
                sourceCoordinator.handleEventFromOperator(0, new 
SourceEventWrapper(sourceEvent));
                check(() -> {
-                       assertEquals(1, 
enumerator.getHandledSourceEvent().size());
-                       assertEquals(sourceEvent, 
enumerator.getHandledSourceEvent().get(0));
+                       assertEquals(1, 
getEnumerator().getHandledSourceEvent().size());
+                       assertEquals(sourceEvent, 
getEnumerator().getHandledSourceEvent().get(0));
                });
        }
 
@@ -152,7 +169,7 @@ public class SourceCoordinatorTest extends 
SourceCoordinatorTestBase {
                checkpointFuture1.get();
 
                // Add split 6, assign it to reader 0 and take another snapshot 
101.
-               enumerator.addNewSplits(Collections.singletonList(new 
MockSourceSplit(6)));
+               getEnumerator().addNewSplits(Collections.singletonList(new 
MockSourceSplit(6)));
 
                final CompletableFuture<byte[]> checkpointFuture2 = new 
CompletableFuture<>();
                sourceCoordinator.checkpointCoordinator(101L, 
checkpointFuture2);
@@ -161,7 +178,7 @@ public class SourceCoordinatorTest extends 
SourceCoordinatorTestBase {
                // check the state.
                check(() -> {
                        // There should be 4 unassigned splits.
-                       assertEquals(4, 
enumerator.getUnassignedSplits().size());
+                       assertEquals(4, 
getEnumerator().getUnassignedSplits().size());
                        verifyAssignment(
                                        Arrays.asList("0", "3"),
                                        
splitSplitAssignmentTracker.assignmentsByCheckpointId().get(100L).get(0));
@@ -196,7 +213,7 @@ public class SourceCoordinatorTest extends 
SourceCoordinatorTestBase {
                        }
                        
assertFalse(splitSplitAssignmentTracker.uncheckpointedAssignments().containsKey(0));
                        // The split enumerator should now contains the splits 
used to be assigned to reader 0.
-                       assertEquals(7, 
enumerator.getUnassignedSplits().size());
+                       assertEquals(7, 
getEnumerator().getUnassignedSplits().size());
                });
        }
 
@@ -215,11 +232,11 @@ public class SourceCoordinatorTest extends 
SourceCoordinatorTestBase {
                // Complete checkpoint 100.
                sourceCoordinator.notifyCheckpointComplete(100L);
                waitUtil(
-                               () -> 
!enumerator.getSuccessfulCheckpoints().isEmpty(),
+                               () -> 
!getEnumerator().getSuccessfulCheckpoints().isEmpty(),
                                Duration.ofMillis(1000L),
                                "The enumerator failed to process the 
successful checkpoint "
                                                + "before times out.");
-               assertEquals(100L, (long) 
enumerator.getSuccessfulCheckpoints().get(0));
+               assertEquals(100L, (long) 
getEnumerator().getSuccessfulCheckpoints().get(0));
 
                // Fail reader 0.
                sourceCoordinator.subtaskFailed(0, null);
@@ -228,13 +245,59 @@ public class SourceCoordinatorTest extends 
SourceCoordinatorTestBase {
                        // Reader 0 hase been unregistered.
                        assertFalse(context.registeredReaders().containsKey(0));
                        // The assigned splits are not reverted.
-                       assertEquals(4, 
enumerator.getUnassignedSplits().size());
+                       assertEquals(4, 
getEnumerator().getUnassignedSplits().size());
                        
assertFalse(splitSplitAssignmentTracker.uncheckpointedAssignments().containsKey(0));
                        
assertTrue(splitSplitAssignmentTracker.assignmentsByCheckpointId().isEmpty());
                });
        }
 
-       // -------------------------------
+       @Test
+       public void testUserClassLoaderWhenCreatingNewEnumerator() throws 
Exception {
+               final ClassLoader testClassLoader = new URLClassLoader(new 
URL[0]);
+               final OperatorCoordinator.Context context = new 
MockOperatorCoordinatorContext(new OperatorID(), testClassLoader);
+
+               final EnumeratorCreatingSource<?, ClassLoaderTestEnumerator> 
source =
+                               new 
EnumeratorCreatingSource<>(ClassLoaderTestEnumerator::new);
+               final SourceCoordinatorProvider<?> provider = new 
SourceCoordinatorProvider<>(
+                               "testOperator", context.getOperatorId(), 
source, 1);
+
+               final OperatorCoordinator coordinator = 
provider.getCoordinator(context);
+               coordinator.start();
+
+               final ClassLoaderTestEnumerator enumerator = 
source.createEnumeratorFuture.get();
+               assertSame(testClassLoader, enumerator.constructorClassLoader);
+               assertSame(testClassLoader, enumerator.threadClassLoader.get());
+
+               // cleanup
+               coordinator.close();
+       }
+
+       @Test
+       public void testUserClassLoaderWhenRestoringEnumerator() throws 
Exception {
+               final ClassLoader testClassLoader = new URLClassLoader(new 
URL[0]);
+               final OperatorCoordinator.Context context = new 
MockOperatorCoordinatorContext(new OperatorID(), testClassLoader);
+
+               final EnumeratorCreatingSource<?, ClassLoaderTestEnumerator> 
source =
+                               new 
EnumeratorCreatingSource<>(ClassLoaderTestEnumerator::new);
+               final SourceCoordinatorProvider<?> provider = new 
SourceCoordinatorProvider<>(
+                               "testOperator", context.getOperatorId(), 
source, 1);
+
+               final OperatorCoordinator coordinator = 
provider.getCoordinator(context);
+               coordinator.resetToCheckpoint(createEmptyCheckpoint(1L));
+               coordinator.start();
+
+               final ClassLoaderTestEnumerator enumerator = 
source.restoreEnumeratorFuture.get();
+               assertSame(testClassLoader, enumerator.constructorClassLoader);
+               assertSame(testClassLoader, enumerator.threadClassLoader.get());
+
+               // cleanup
+               coordinator.close();
+       }
+
+
+       // 
------------------------------------------------------------------------
+       //  test helpers
+       // 
------------------------------------------------------------------------
 
        private void check(Runnable runnable) {
                try {
@@ -243,4 +306,115 @@ public class SourceCoordinatorTest extends 
SourceCoordinatorTestBase {
                        fail("Test failed due to " + e);
                }
        }
+
+       private static byte[] createEmptyCheckpoint(long checkpointId) throws 
Exception {
+               final OperatorCoordinator.Context opContext = new 
MockOperatorCoordinatorContext(new OperatorID(), 0);
+
+               try (SourceCoordinatorContext<MockSourceSplit> emptyContext = 
new SourceCoordinatorContext<>(
+                               Executors.newDirectExecutorService(),
+                               new 
SourceCoordinatorProvider.CoordinatorExecutorThreadFactory("test", opContext, 
SourceCoordinatorProviderTest.class.getClassLoader()),
+                               1,
+                               opContext,
+                               new MockSourceSplitSerializer())) {
+
+                       return SourceCoordinator.writeCheckpointBytes(
+                                       checkpointId,
+                                       Collections.emptySet(),
+                                       emptyContext,
+                                       new 
MockSplitEnumeratorCheckpointSerializer(),
+                                       new MockSourceSplitSerializer());
+               }
+       }
+
+
+       // 
------------------------------------------------------------------------
+       //  test mocks
+       // 
------------------------------------------------------------------------
+
+       private static final class ClassLoaderTestEnumerator implements 
SplitEnumerator<MockSourceSplit, Set<MockSourceSplit>> {
+
+               final CompletableFuture<ClassLoader> threadClassLoader = new 
CompletableFuture<>();
+               final ClassLoader constructorClassLoader;
+
+               public ClassLoaderTestEnumerator() {
+                       this.constructorClassLoader = 
Thread.currentThread().getContextClassLoader();
+               }
+
+               @Override
+               public void start() {
+                       
threadClassLoader.complete(Thread.currentThread().getContextClassLoader());
+               }
+
+               @Override
+               public void handleSplitRequest(int subtaskId, @Nullable String 
requesterHostname) {
+                       throw new UnsupportedOperationException();
+               }
+
+               @Override
+               public void addSplitsBack(List<MockSourceSplit> splits, int 
subtaskId) {
+                       throw new UnsupportedOperationException();
+               }
+
+               @Override
+               public void addReader(int subtaskId) {
+                       throw new UnsupportedOperationException();
+               }
+
+               @Override
+               public Set<MockSourceSplit> snapshotState() throws Exception {
+                       throw new UnsupportedOperationException();
+               }
+
+               @Override
+               public void close() {}
+       }
+
+       private static final class EnumeratorCreatingSource<T, EnumT extends 
SplitEnumerator<MockSourceSplit, Set<MockSourceSplit>>>
+                       implements Source<T, MockSourceSplit, 
Set<MockSourceSplit>> {
+
+               final CompletableFuture<EnumT> createEnumeratorFuture = new 
CompletableFuture<>();
+               final CompletableFuture<EnumT> restoreEnumeratorFuture = new 
CompletableFuture<>();
+               private final Supplier<EnumT> enumeratorFactory;
+
+               public EnumeratorCreatingSource(Supplier<EnumT> 
enumeratorFactory) {
+                       this.enumeratorFactory = enumeratorFactory;
+               }
+
+               @Override
+               public Boundedness getBoundedness() {
+                       return Boundedness.CONTINUOUS_UNBOUNDED;
+               }
+
+               @Override
+               public SourceReader<T, MockSourceSplit> 
createReader(SourceReaderContext readerContext) {
+                       throw new UnsupportedOperationException();
+               }
+
+               @Override
+               public SplitEnumerator<MockSourceSplit, Set<MockSourceSplit>> 
createEnumerator(
+                               SplitEnumeratorContext<MockSourceSplit> 
enumContext) {
+                       final EnumT enumerator = enumeratorFactory.get();
+                       createEnumeratorFuture.complete(enumerator);
+                       return enumerator;
+               }
+
+               @Override
+               public SplitEnumerator<MockSourceSplit, Set<MockSourceSplit>> 
restoreEnumerator(
+                               SplitEnumeratorContext<MockSourceSplit> 
enumContext,
+                               Set<MockSourceSplit> checkpoint) {
+                       final EnumT enumerator = enumeratorFactory.get();
+                       restoreEnumeratorFuture.complete(enumerator);
+                       return enumerator;
+               }
+
+               @Override
+               public SimpleVersionedSerializer<MockSourceSplit> 
getSplitSerializer() {
+                       return new MockSourceSplitSerializer();
+               }
+
+               @Override
+               public SimpleVersionedSerializer<Set<MockSourceSplit>> 
getEnumeratorCheckpointSerializer() {
+                       return new MockSplitEnumeratorCheckpointSerializer();
+               }
+       }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTestBase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTestBase.java
index 22fb4a3..87ae8c9 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTestBase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTestBase.java
@@ -36,6 +36,8 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 
+import static org.junit.Assert.assertNotNull;
+
 /**
  * The test base for SourceCoordinator related tests.
  */
@@ -49,7 +51,7 @@ public abstract class SourceCoordinatorTestBase {
        protected SplitAssignmentTracker<MockSourceSplit> 
splitSplitAssignmentTracker;
        protected SourceCoordinatorContext<MockSourceSplit> context;
        protected SourceCoordinator<?, ?> sourceCoordinator;
-       protected MockSplitEnumerator enumerator;
+       private MockSplitEnumerator enumerator;
 
        @Before
        public void setup() throws Exception {
@@ -58,7 +60,10 @@ public abstract class SourceCoordinatorTestBase {
                String coordinatorThreadName = TEST_OPERATOR_ID.toHexString();
                SourceCoordinatorProvider.CoordinatorExecutorThreadFactory 
coordinatorThreadFactory =
                                new 
SourceCoordinatorProvider.CoordinatorExecutorThreadFactory(
-                                       coordinatorThreadName, 
operatorCoordinatorContext);
+                                               coordinatorThreadName,
+                                               operatorCoordinatorContext,
+                                               getClass().getClassLoader());
+
                coordinatorExecutor = 
Executors.newSingleThreadExecutor(coordinatorThreadFactory);
                context = new SourceCoordinatorContext<>(
                                coordinatorExecutor,
@@ -68,7 +73,6 @@ public abstract class SourceCoordinatorTestBase {
                                new MockSourceSplitSerializer(),
                                splitSplitAssignmentTracker);
                sourceCoordinator = getNewSourceCoordinator();
-               enumerator = (MockSplitEnumerator) 
sourceCoordinator.getEnumerator();
        }
 
        @After
@@ -79,6 +83,14 @@ public abstract class SourceCoordinatorTestBase {
                }
        }
 
+       protected MockSplitEnumerator getEnumerator() {
+               if (enumerator == null) {
+                       enumerator = (MockSplitEnumerator) 
sourceCoordinator.getEnumerator();
+                       assertNotNull("source was not started", enumerator);
+               }
+               return enumerator;
+       }
+
        // --------------------------
 
        protected SourceCoordinator getNewSourceCoordinator() throws Exception {

Reply via email to