This is an automated email from the ASF dual-hosted git repository. jqin pushed a commit to branch release-1.11 in repository https://gitbox.apache.org/repos/asf/flink.git
commit cb057b45f67cebf1ecd35abd0e62524b8f188d6a Author: Stephan Ewen <[email protected]> AuthorDate: Sun Nov 22 14:08:46 2020 +0100 [FLINK-20223][runtime] (part 2) Set user code classloader as context class loader for SplitEnumerator creation and thread Co-authored-by: Jiangjie (Becket) Qin <[email protected]> - This moves the instantiation of the SplitEnumerator out of the constructor to prevent double-instantiation - Add context class loaders to creation and coordinator thread - Class-loading and SplitEnumerator instantiation is purely handled in the SourceCoordinator and does not leak into RecreateOnResetOperatorCoordinator. --- .../source/coordinator/SourceCoordinator.java | 70 +++++-- .../coordinator/SourceCoordinatorContext.java | 4 + .../coordinator/SourceCoordinatorProvider.java | 11 +- .../MockOperatorCoordinatorContext.java | 21 ++- .../coordinator/SourceCoordinatorContextTest.java | 7 +- .../coordinator/SourceCoordinatorProviderTest.java | 3 +- .../source/coordinator/SourceCoordinatorTest.java | 208 +++++++++++++++++++-- .../coordinator/SourceCoordinatorTestBase.java | 18 +- 8 files changed, 296 insertions(+), 46 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinator.java index 10baf02..2685ca6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinator.java @@ -33,6 +33,7 @@ import org.apache.flink.runtime.source.event.ReaderRegistrationEvent; import org.apache.flink.runtime.source.event.RequestSplitEvent; import org.apache.flink.runtime.source.event.SourceEventWrapper; import org.apache.flink.util.FlinkException; +import org.apache.flink.util.TemporaryClassLoaderContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,6 +53,7 @@ import java.util.concurrent.TimeUnit; import static org.apache.flink.runtime.source.coordinator.SourceCoordinatorSerdeUtils.readAndVerifyCoordinatorSerdeVersion; import static org.apache.flink.runtime.source.coordinator.SourceCoordinatorSerdeUtils.readBytes; import static org.apache.flink.runtime.source.coordinator.SourceCoordinatorSerdeUtils.writeCoordinatorSerdeVersion; +import static org.apache.flink.util.Preconditions.checkState; /** * The default implementation of the {@link OperatorCoordinator} for the {@link Source}. @@ -68,7 +70,9 @@ import static org.apache.flink.runtime.source.coordinator.SourceCoordinatorSerde */ @Internal public class SourceCoordinator<SplitT extends SourceSplit, EnumChkT> implements OperatorCoordinator { + private static final Logger LOG = LoggerFactory.getLogger(SourceCoordinator.class); + /** The name of the operator this SourceCoordinator is associated with. */ private final String operatorName; /** A single-thread executor to handle all the changes to the coordinator. */ @@ -81,7 +85,8 @@ public class SourceCoordinator<SplitT extends SourceSplit, EnumChkT> implements private final SimpleVersionedSerializer<SplitT> splitSerializer; /** The context containing the states of the coordinator. */ private final SourceCoordinatorContext<SplitT> context; - /** The split enumerator created from the associated Source. */ + /** The split enumerator created from the associated Source. This one is created either during resetting + * the coordinator to a checkpoint, or when the coordinator is started. */ private SplitEnumerator<SplitT, EnumChkT> enumerator; /** A flag marking whether the coordinator has started. */ private boolean started; @@ -90,20 +95,29 @@ public class SourceCoordinator<SplitT extends SourceSplit, EnumChkT> implements String operatorName, ExecutorService coordinatorExecutor, Source<?, SplitT, EnumChkT> source, - SourceCoordinatorContext<SplitT> context) throws Exception { + SourceCoordinatorContext<SplitT> context) { this.operatorName = operatorName; this.coordinatorExecutor = coordinatorExecutor; this.source = source; this.enumCheckpointSerializer = source.getEnumeratorCheckpointSerializer(); this.splitSerializer = source.getSplitSerializer(); this.context = context; - this.enumerator = source.createEnumerator(context); - this.started = false; } @Override public void start() throws Exception { LOG.info("Starting split enumerator for source {}.", operatorName); + + // there are two ways the coordinator can get created: + // (1) Source.restoreEnumerator(), in which case the 'resetToCheckpoint()' method creates it + // (2) Source.createEnumerator, in which case it has not been created, yet, and we create it here + if (enumerator == null) { + final ClassLoader userCodeClassLoader = context.getCoordinatorContext().getUserCodeClassloader(); + try (TemporaryClassLoaderContext ignored = TemporaryClassLoaderContext.of(userCodeClassLoader)) { + enumerator = source.createEnumerator(context); + } + } + // The start sequence is the first task in the coordinator executor. // We rely on the single-threaded coordinator executor to guarantee // the other methods are invoked after the enumerator has started. @@ -117,7 +131,9 @@ public class SourceCoordinator<SplitT extends SourceSplit, EnumChkT> implements try { if (started) { context.close(); - enumerator.close(); + if (enumerator != null) { + enumerator.close(); + } } } finally { coordinatorExecutor.shutdownNow(); @@ -217,13 +233,16 @@ public class SourceCoordinator<SplitT extends SourceSplit, EnumChkT> implements @Override public void resetToCheckpoint(byte[] checkpointData) throws Exception { - if (started) { - throw new IllegalStateException(String.format( - "The coordinator for source %s has started. The source coordinator state can " + - "only be reset to a checkpoint before it starts.", operatorName)); + checkState(!started, "The coordinator can only be reset if it was not yet started"); + assert enumerator == null; + + LOG.info("Restoring SplitEnumerator of source {} from checkpoint.", operatorName); + + final ClassLoader userCodeClassLoader = context.getCoordinatorContext().getUserCodeClassloader(); + try (TemporaryClassLoaderContext ignored = TemporaryClassLoaderContext.of(userCodeClassLoader)) { + final EnumChkT enumeratorCheckpoint = deserializeCheckpointAndRestoreContext(checkpointData); + enumerator = source.restoreEnumerator(context, enumeratorCheckpoint); } - LOG.info("Resetting coordinator of source {} from checkpoint.", operatorName); - fromBytes(checkpointData); } // --------------------------------------------------- @@ -249,16 +268,30 @@ public class SourceCoordinator<SplitT extends SourceSplit, EnumChkT> implements * @throws Exception When something goes wrong in serialization. */ private byte[] toBytes(long checkpointId) throws Exception { - EnumChkT enumCkpt = enumerator.snapshotState(); + return writeCheckpointBytes( + checkpointId, + enumerator.snapshotState(), + context, + enumCheckpointSerializer, + splitSerializer); + } + + static <SplitT extends SourceSplit, EnumChkT> byte[] writeCheckpointBytes( + final long checkpointId, + final EnumChkT enumeratorCheckpoint, + final SourceCoordinatorContext<SplitT> coordinatorContext, + final SimpleVersionedSerializer<EnumChkT> checkpointSerializer, + final SimpleVersionedSerializer<SplitT> splitSerializer) throws Exception { try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); DataOutputStream out = new DataOutputViewStreamWrapper(baos)) { + writeCoordinatorSerdeVersion(out); - out.writeInt(enumCheckpointSerializer.getVersion()); - byte[] serialziedEnumChkpt = enumCheckpointSerializer.serialize(enumCkpt); + out.writeInt(checkpointSerializer.getVersion()); + byte[] serialziedEnumChkpt = checkpointSerializer.serialize(enumeratorCheckpoint); out.writeInt(serialziedEnumChkpt.length); out.write(serialziedEnumChkpt); - context.snapshotState(checkpointId, splitSerializer, out); + coordinatorContext.snapshotState(checkpointId, splitSerializer, out); out.flush(); return baos.toByteArray(); } @@ -270,16 +303,15 @@ public class SourceCoordinator<SplitT extends SourceSplit, EnumChkT> implements * @param bytes The checkpoint bytes that was returned from {@link #toBytes(long)} * @throws Exception When the deserialization failed. */ - private void fromBytes(byte[] bytes) throws Exception { + private EnumChkT deserializeCheckpointAndRestoreContext(byte[] bytes) throws Exception { try (ByteArrayInputStream bais = new ByteArrayInputStream(bytes); DataInputStream in = new DataInputViewStreamWrapper(bais)) { readAndVerifyCoordinatorSerdeVersion(in); int enumSerializerVersion = in.readInt(); int serializedEnumChkptSize = in.readInt(); byte[] serializedEnumChkpt = readBytes(in, serializedEnumChkptSize); - EnumChkT enumChkpt = enumCheckpointSerializer.deserialize(enumSerializerVersion, serializedEnumChkpt); context.restoreState(splitSerializer, in); - enumerator = source.restoreEnumerator(context, enumChkpt); + return enumCheckpointSerializer.deserialize(enumSerializerVersion, serializedEnumChkpt); } } @@ -294,5 +326,7 @@ public class SourceCoordinator<SplitT extends SourceSplit, EnumChkT> implements if (!started) { throw new IllegalStateException("The coordinator has not started yet."); } + + assert enumerator != null; } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContext.java index 6141688..3ed0ab5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContext.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContext.java @@ -303,6 +303,10 @@ public class SourceCoordinatorContext<SplitT extends SourceSplit> assignmentTracker.onCheckpointComplete(checkpointId); } + OperatorCoordinator.Context getCoordinatorContext() { + return operatorCoordinatorContext; + } + // ---------------- private helper methods ----------------- /** diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorProvider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorProvider.java index 4047bad..f13aa4e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorProvider.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorProvider.java @@ -66,8 +66,9 @@ public class SourceCoordinatorProvider<SplitT extends SourceSplit> extends Recre public OperatorCoordinator getCoordinator(OperatorCoordinator.Context context) throws Exception { final String coordinatorThreadName = "SourceCoordinator-" + operatorName; CoordinatorExecutorThreadFactory coordinatorThreadFactory = - new CoordinatorExecutorThreadFactory(coordinatorThreadName, context); + new CoordinatorExecutorThreadFactory(coordinatorThreadName, context, context.getUserCodeClassloader()); ExecutorService coordinatorExecutor = Executors.newSingleThreadExecutor(coordinatorThreadFactory); + SimpleVersionedSerializer<SplitT> splitSerializer = source.getSplitSerializer(); SourceCoordinatorContext<SplitT> sourceCoordinatorContext = new SourceCoordinatorContext<>(coordinatorExecutor, coordinatorThreadFactory, numWorkerThreads, @@ -81,14 +82,17 @@ public class SourceCoordinatorProvider<SplitT extends SourceSplit> extends Recre public static class CoordinatorExecutorThreadFactory implements ThreadFactory { private final String coordinatorThreadName; private final OperatorCoordinator.Context context; + private final ClassLoader cl; private Thread t; CoordinatorExecutorThreadFactory( - String coordinatorThreadName, - OperatorCoordinator.Context context) { + final String coordinatorThreadName, + final OperatorCoordinator.Context context, + final ClassLoader contextClassLoader) { this.coordinatorThreadName = coordinatorThreadName; this.context = context; this.t = null; + this.cl = contextClassLoader; } @Override @@ -98,6 +102,7 @@ public class SourceCoordinatorProvider<SplitT extends SourceSplit> extends Recre "SingleThreadExecutor."); } t = new Thread(r, coordinatorThreadName); + t.setContextClassLoader(cl); t.setUncaughtExceptionHandler((thread, throwable) -> context.failJob(throwable)); return t; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/coordination/MockOperatorCoordinatorContext.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/coordination/MockOperatorCoordinatorContext.java index ab8a56b..e7184c9 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/coordination/MockOperatorCoordinatorContext.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/coordination/MockOperatorCoordinatorContext.java @@ -30,6 +30,7 @@ import java.util.concurrent.CompletableFuture; public class MockOperatorCoordinatorContext implements OperatorCoordinator.Context { private final OperatorID operatorID; + private final ClassLoader userCodeClassLoader; private final int numSubtasks; private final boolean failEventSending; @@ -40,12 +41,28 @@ public class MockOperatorCoordinatorContext implements OperatorCoordinator.Conte this(operatorID, numSubtasks, true); } - public MockOperatorCoordinatorContext(OperatorID operatorID, int numSubtasks, boolean failEventSending) { + public MockOperatorCoordinatorContext( + OperatorID operatorID, + int numSubtasks, + boolean failEventSending) { + this(operatorID, numSubtasks, failEventSending, MockOperatorCoordinatorContext.class.getClassLoader()); + } + + public MockOperatorCoordinatorContext(OperatorID operatorID, ClassLoader userCodeClassLoader) { + this(operatorID, 1, true, userCodeClassLoader); + } + + public MockOperatorCoordinatorContext( + OperatorID operatorID, + int numSubtasks, + boolean failEventSending, + ClassLoader userCodeClassLoader) { this.operatorID = operatorID; this.numSubtasks = numSubtasks; this.eventsToOperator = new HashMap<>(); this.jobFailed = false; this.failEventSending = failEventSending; + this.userCodeClassLoader = userCodeClassLoader; } @Override @@ -79,7 +96,7 @@ public class MockOperatorCoordinatorContext implements OperatorCoordinator.Conte @Override public ClassLoader getUserCodeClassloader() { - return getClass().getClassLoader(); + return userCodeClassLoader; } // ------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContextTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContextTest.java index 280f6a4..70550ae 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContextTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorContextTest.java @@ -127,7 +127,7 @@ public class SourceCoordinatorContextTest extends SourceCoordinatorTestBase { context.assignSplits(splitsAssignment); } }, - "assignSpoits() should fail to assign the splits to a reader that is not registered.", + "assignSplits() should fail to assign the splits to a reader that is not registered.", "Cannot assign splits"); } @@ -145,7 +145,10 @@ public class SourceCoordinatorContextTest extends SourceCoordinatorTestBase { SplitAssignmentTracker<MockSourceSplit> restoredTracker = new SplitAssignmentTracker<>(); SourceCoordinatorProvider.CoordinatorExecutorThreadFactory coordinatorThreadFactory = new SourceCoordinatorProvider.CoordinatorExecutorThreadFactory( - TEST_OPERATOR_ID.toHexString(), operatorCoordinatorContext); + TEST_OPERATOR_ID.toHexString(), + operatorCoordinatorContext, + getClass().getClassLoader()); + try (ByteArrayInputStream bais = new ByteArrayInputStream(bytes); DataInputStream in = new DataInputStream(bais)) { restoredContext = new SourceCoordinatorContext<>( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorProviderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorProviderTest.java index c0fc3a8..cb8bfd8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorProviderTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorProviderTest.java @@ -42,9 +42,11 @@ import static org.junit.Assert.assertTrue; /** * Unit tests for {@link SourceCoordinatorProvider}. */ +@SuppressWarnings("serial") public class SourceCoordinatorProviderTest { private static final OperatorID OPERATOR_ID = new OperatorID(1234L, 5678L); private static final int NUM_SPLITS = 10; + private SourceCoordinatorProvider<MockSourceSplit> provider; @Before @@ -116,5 +118,4 @@ public class SourceCoordinatorProviderTest { Duration.ofSeconds(10L), "The job did not fail before timeout."); } - } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTest.java index 0ddf971..d8182d4 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTest.java @@ -18,10 +18,22 @@ package org.apache.flink.runtime.source.coordinator; +import org.apache.flink.api.connector.source.Boundedness; +import org.apache.flink.api.connector.source.Source; import org.apache.flink.api.connector.source.SourceEvent; +import org.apache.flink.api.connector.source.SourceReader; +import org.apache.flink.api.connector.source.SourceReaderContext; +import org.apache.flink.api.connector.source.SplitEnumerator; +import org.apache.flink.api.connector.source.SplitEnumeratorContext; import org.apache.flink.api.connector.source.mocks.MockSourceSplit; import org.apache.flink.api.connector.source.mocks.MockSourceSplitSerializer; import org.apache.flink.api.connector.source.mocks.MockSplitEnumerator; +import org.apache.flink.api.connector.source.mocks.MockSplitEnumeratorCheckpointSerializer; +import org.apache.flink.core.io.SimpleVersionedSerializer; +import org.apache.flink.runtime.concurrent.Executors; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext; +import org.apache.flink.runtime.operators.coordination.OperatorCoordinator; import org.apache.flink.runtime.operators.coordination.OperatorEvent; import org.apache.flink.runtime.source.event.AddSplitEvent; import org.apache.flink.runtime.source.event.ReaderRegistrationEvent; @@ -29,25 +41,33 @@ import org.apache.flink.runtime.source.event.SourceEventWrapper; import org.junit.Test; +import javax.annotation.Nullable; + import java.io.IOException; +import java.net.URL; +import java.net.URLClassLoader; import java.time.Duration; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; import static org.apache.flink.core.testutils.CommonTestUtils.waitUtil; import static org.apache.flink.runtime.source.coordinator.CoordinatorTestUtils.verifyAssignment; import static org.apache.flink.runtime.source.coordinator.CoordinatorTestUtils.verifyException; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** * Unit tests for {@link SourceCoordinator}. */ +@SuppressWarnings("serial") public class SourceCoordinatorTest extends SourceCoordinatorTestBase { @Test @@ -70,25 +90,22 @@ public class SourceCoordinatorTest extends SourceCoordinatorTestBase { sourceCoordinator.start(); verifyException(() -> sourceCoordinator.resetToCheckpoint(null), "Reset to checkpoint should fail after the coordinator has started", - String.format("The coordinator for source %s has started. The source coordinator state can " + - "only be reset to a checkpoint before it starts.", OPERATOR_NAME)); + "The coordinator can only be reset if it was not yet started"); } @Test(timeout = 10000L) public void testStart() throws Exception { - assertFalse(enumerator.started()); sourceCoordinator.start(); - while (!enumerator.started()) { + while (!getEnumerator().started()) { Thread.sleep(1); } } @Test public void testClosed() throws Exception { - assertFalse(enumerator.closed()); sourceCoordinator.start(); sourceCoordinator.close(); - assertTrue(enumerator.closed()); + assertTrue(getEnumerator().closed()); } @Test @@ -98,9 +115,9 @@ public class SourceCoordinatorTest extends SourceCoordinatorTestBase { 0, new ReaderRegistrationEvent(0, "location_0")); check(() -> { assertEquals("2 splits should have been assigned to reader 0", - 4, enumerator.getUnassignedSplits().size()); + 4, getEnumerator().getUnassignedSplits().size()); assertTrue(context.registeredReaders().containsKey(0)); - assertTrue(enumerator.getHandledSourceEvent().isEmpty()); + assertTrue(getEnumerator().getHandledSourceEvent().isEmpty()); verifyAssignment(Arrays.asList("0", "3"), splitSplitAssignmentTracker.uncheckpointedAssignments().get(0)); }); } @@ -111,8 +128,8 @@ public class SourceCoordinatorTest extends SourceCoordinatorTestBase { SourceEvent sourceEvent = new SourceEvent() {}; sourceCoordinator.handleEventFromOperator(0, new SourceEventWrapper(sourceEvent)); check(() -> { - assertEquals(1, enumerator.getHandledSourceEvent().size()); - assertEquals(sourceEvent, enumerator.getHandledSourceEvent().get(0)); + assertEquals(1, getEnumerator().getHandledSourceEvent().size()); + assertEquals(sourceEvent, getEnumerator().getHandledSourceEvent().get(0)); }); } @@ -152,7 +169,7 @@ public class SourceCoordinatorTest extends SourceCoordinatorTestBase { checkpointFuture1.get(); // Add split 6, assign it to reader 0 and take another snapshot 101. - enumerator.addNewSplits(Collections.singletonList(new MockSourceSplit(6))); + getEnumerator().addNewSplits(Collections.singletonList(new MockSourceSplit(6))); final CompletableFuture<byte[]> checkpointFuture2 = new CompletableFuture<>(); sourceCoordinator.checkpointCoordinator(101L, checkpointFuture2); @@ -161,7 +178,7 @@ public class SourceCoordinatorTest extends SourceCoordinatorTestBase { // check the state. check(() -> { // There should be 4 unassigned splits. - assertEquals(4, enumerator.getUnassignedSplits().size()); + assertEquals(4, getEnumerator().getUnassignedSplits().size()); verifyAssignment( Arrays.asList("0", "3"), splitSplitAssignmentTracker.assignmentsByCheckpointId().get(100L).get(0)); @@ -196,7 +213,7 @@ public class SourceCoordinatorTest extends SourceCoordinatorTestBase { } assertFalse(splitSplitAssignmentTracker.uncheckpointedAssignments().containsKey(0)); // The split enumerator should now contains the splits used to be assigned to reader 0. - assertEquals(7, enumerator.getUnassignedSplits().size()); + assertEquals(7, getEnumerator().getUnassignedSplits().size()); }); } @@ -215,11 +232,11 @@ public class SourceCoordinatorTest extends SourceCoordinatorTestBase { // Complete checkpoint 100. sourceCoordinator.notifyCheckpointComplete(100L); waitUtil( - () -> !enumerator.getSuccessfulCheckpoints().isEmpty(), + () -> !getEnumerator().getSuccessfulCheckpoints().isEmpty(), Duration.ofMillis(1000L), "The enumerator failed to process the successful checkpoint " + "before times out."); - assertEquals(100L, (long) enumerator.getSuccessfulCheckpoints().get(0)); + assertEquals(100L, (long) getEnumerator().getSuccessfulCheckpoints().get(0)); // Fail reader 0. sourceCoordinator.subtaskFailed(0, null); @@ -228,13 +245,59 @@ public class SourceCoordinatorTest extends SourceCoordinatorTestBase { // Reader 0 hase been unregistered. assertFalse(context.registeredReaders().containsKey(0)); // The assigned splits are not reverted. - assertEquals(4, enumerator.getUnassignedSplits().size()); + assertEquals(4, getEnumerator().getUnassignedSplits().size()); assertFalse(splitSplitAssignmentTracker.uncheckpointedAssignments().containsKey(0)); assertTrue(splitSplitAssignmentTracker.assignmentsByCheckpointId().isEmpty()); }); } - // ------------------------------- + @Test + public void testUserClassLoaderWhenCreatingNewEnumerator() throws Exception { + final ClassLoader testClassLoader = new URLClassLoader(new URL[0]); + final OperatorCoordinator.Context context = new MockOperatorCoordinatorContext(new OperatorID(), testClassLoader); + + final EnumeratorCreatingSource<?, ClassLoaderTestEnumerator> source = + new EnumeratorCreatingSource<>(ClassLoaderTestEnumerator::new); + final SourceCoordinatorProvider<?> provider = new SourceCoordinatorProvider<>( + "testOperator", context.getOperatorId(), source, 1); + + final OperatorCoordinator coordinator = provider.getCoordinator(context); + coordinator.start(); + + final ClassLoaderTestEnumerator enumerator = source.createEnumeratorFuture.get(); + assertSame(testClassLoader, enumerator.constructorClassLoader); + assertSame(testClassLoader, enumerator.threadClassLoader.get()); + + // cleanup + coordinator.close(); + } + + @Test + public void testUserClassLoaderWhenRestoringEnumerator() throws Exception { + final ClassLoader testClassLoader = new URLClassLoader(new URL[0]); + final OperatorCoordinator.Context context = new MockOperatorCoordinatorContext(new OperatorID(), testClassLoader); + + final EnumeratorCreatingSource<?, ClassLoaderTestEnumerator> source = + new EnumeratorCreatingSource<>(ClassLoaderTestEnumerator::new); + final SourceCoordinatorProvider<?> provider = new SourceCoordinatorProvider<>( + "testOperator", context.getOperatorId(), source, 1); + + final OperatorCoordinator coordinator = provider.getCoordinator(context); + coordinator.resetToCheckpoint(createEmptyCheckpoint(1L)); + coordinator.start(); + + final ClassLoaderTestEnumerator enumerator = source.restoreEnumeratorFuture.get(); + assertSame(testClassLoader, enumerator.constructorClassLoader); + assertSame(testClassLoader, enumerator.threadClassLoader.get()); + + // cleanup + coordinator.close(); + } + + + // ------------------------------------------------------------------------ + // test helpers + // ------------------------------------------------------------------------ private void check(Runnable runnable) { try { @@ -243,4 +306,115 @@ public class SourceCoordinatorTest extends SourceCoordinatorTestBase { fail("Test failed due to " + e); } } + + private static byte[] createEmptyCheckpoint(long checkpointId) throws Exception { + final OperatorCoordinator.Context opContext = new MockOperatorCoordinatorContext(new OperatorID(), 0); + + try (SourceCoordinatorContext<MockSourceSplit> emptyContext = new SourceCoordinatorContext<>( + Executors.newDirectExecutorService(), + new SourceCoordinatorProvider.CoordinatorExecutorThreadFactory("test", opContext, SourceCoordinatorProviderTest.class.getClassLoader()), + 1, + opContext, + new MockSourceSplitSerializer())) { + + return SourceCoordinator.writeCheckpointBytes( + checkpointId, + Collections.emptySet(), + emptyContext, + new MockSplitEnumeratorCheckpointSerializer(), + new MockSourceSplitSerializer()); + } + } + + + // ------------------------------------------------------------------------ + // test mocks + // ------------------------------------------------------------------------ + + private static final class ClassLoaderTestEnumerator implements SplitEnumerator<MockSourceSplit, Set<MockSourceSplit>> { + + final CompletableFuture<ClassLoader> threadClassLoader = new CompletableFuture<>(); + final ClassLoader constructorClassLoader; + + public ClassLoaderTestEnumerator() { + this.constructorClassLoader = Thread.currentThread().getContextClassLoader(); + } + + @Override + public void start() { + threadClassLoader.complete(Thread.currentThread().getContextClassLoader()); + } + + @Override + public void handleSplitRequest(int subtaskId, @Nullable String requesterHostname) { + throw new UnsupportedOperationException(); + } + + @Override + public void addSplitsBack(List<MockSourceSplit> splits, int subtaskId) { + throw new UnsupportedOperationException(); + } + + @Override + public void addReader(int subtaskId) { + throw new UnsupportedOperationException(); + } + + @Override + public Set<MockSourceSplit> snapshotState() throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + public void close() {} + } + + private static final class EnumeratorCreatingSource<T, EnumT extends SplitEnumerator<MockSourceSplit, Set<MockSourceSplit>>> + implements Source<T, MockSourceSplit, Set<MockSourceSplit>> { + + final CompletableFuture<EnumT> createEnumeratorFuture = new CompletableFuture<>(); + final CompletableFuture<EnumT> restoreEnumeratorFuture = new CompletableFuture<>(); + private final Supplier<EnumT> enumeratorFactory; + + public EnumeratorCreatingSource(Supplier<EnumT> enumeratorFactory) { + this.enumeratorFactory = enumeratorFactory; + } + + @Override + public Boundedness getBoundedness() { + return Boundedness.CONTINUOUS_UNBOUNDED; + } + + @Override + public SourceReader<T, MockSourceSplit> createReader(SourceReaderContext readerContext) { + throw new UnsupportedOperationException(); + } + + @Override + public SplitEnumerator<MockSourceSplit, Set<MockSourceSplit>> createEnumerator( + SplitEnumeratorContext<MockSourceSplit> enumContext) { + final EnumT enumerator = enumeratorFactory.get(); + createEnumeratorFuture.complete(enumerator); + return enumerator; + } + + @Override + public SplitEnumerator<MockSourceSplit, Set<MockSourceSplit>> restoreEnumerator( + SplitEnumeratorContext<MockSourceSplit> enumContext, + Set<MockSourceSplit> checkpoint) { + final EnumT enumerator = enumeratorFactory.get(); + restoreEnumeratorFuture.complete(enumerator); + return enumerator; + } + + @Override + public SimpleVersionedSerializer<MockSourceSplit> getSplitSerializer() { + return new MockSourceSplitSerializer(); + } + + @Override + public SimpleVersionedSerializer<Set<MockSourceSplit>> getEnumeratorCheckpointSerializer() { + return new MockSplitEnumeratorCheckpointSerializer(); + } + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTestBase.java index 22fb4a3..87ae8c9 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/source/coordinator/SourceCoordinatorTestBase.java @@ -36,6 +36,8 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import static org.junit.Assert.assertNotNull; + /** * The test base for SourceCoordinator related tests. */ @@ -49,7 +51,7 @@ public abstract class SourceCoordinatorTestBase { protected SplitAssignmentTracker<MockSourceSplit> splitSplitAssignmentTracker; protected SourceCoordinatorContext<MockSourceSplit> context; protected SourceCoordinator<?, ?> sourceCoordinator; - protected MockSplitEnumerator enumerator; + private MockSplitEnumerator enumerator; @Before public void setup() throws Exception { @@ -58,7 +60,10 @@ public abstract class SourceCoordinatorTestBase { String coordinatorThreadName = TEST_OPERATOR_ID.toHexString(); SourceCoordinatorProvider.CoordinatorExecutorThreadFactory coordinatorThreadFactory = new SourceCoordinatorProvider.CoordinatorExecutorThreadFactory( - coordinatorThreadName, operatorCoordinatorContext); + coordinatorThreadName, + operatorCoordinatorContext, + getClass().getClassLoader()); + coordinatorExecutor = Executors.newSingleThreadExecutor(coordinatorThreadFactory); context = new SourceCoordinatorContext<>( coordinatorExecutor, @@ -68,7 +73,6 @@ public abstract class SourceCoordinatorTestBase { new MockSourceSplitSerializer(), splitSplitAssignmentTracker); sourceCoordinator = getNewSourceCoordinator(); - enumerator = (MockSplitEnumerator) sourceCoordinator.getEnumerator(); } @After @@ -79,6 +83,14 @@ public abstract class SourceCoordinatorTestBase { } } + protected MockSplitEnumerator getEnumerator() { + if (enumerator == null) { + enumerator = (MockSplitEnumerator) sourceCoordinator.getEnumerator(); + assertNotNull("source was not started", enumerator); + } + return enumerator; + } + // -------------------------- protected SourceCoordinator getNewSourceCoordinator() throws Exception {
