[FLINK-6531] [checkpoints] Ensure proper classloading for user-defined checkpoint hooks
This closes #3868 Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/aa8a90a5 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/aa8a90a5 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/aa8a90a5 Branch: refs/heads/master Commit: aa8a90a588b4d72fc585731bea233495f0690364 Parents: 72aa262 Author: Stephan Ewen <[email protected]> Authored: Wed May 10 22:03:49 2017 +0200 Committer: Stephan Ewen <[email protected]> Committed: Thu May 11 14:11:07 2017 +0200 ---------------------------------------------------------------------- .../executiongraph/ExecutionGraphBuilder.java | 13 +- .../tasks/JobCheckpointingSettings.java | 12 +- .../CheckpointSettingsSerializableTest.java | 122 +++++++++++++++++++ .../api/graph/StreamingJobGraphGenerator.java | 20 ++- .../WithMasterCheckpointHookConfigTest.java | 10 +- 5 files changed, 167 insertions(+), 10 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/aa8a90a5/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraphBuilder.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraphBuilder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraphBuilder.java index aa28fbc..0e76cfb 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraphBuilder.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraphBuilder.java @@ -51,6 +51,7 @@ import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings; import org.apache.flink.runtime.state.AbstractStateBackend; import org.apache.flink.runtime.state.StateBackend; import org.apache.flink.util.DynamicCodeLoadingException; +import org.apache.flink.util.SerializedValue; import org.slf4j.Logger; import javax.annotation.Nullable; @@ -240,13 +241,21 @@ public class ExecutionGraphBuilder { // instantiate the user-defined checkpoint hooks - final MasterTriggerRestoreHook.Factory[] hookFactories = snapshotSettings.getMasterHooks(); + final SerializedValue<MasterTriggerRestoreHook.Factory[]> serializedHooks = snapshotSettings.getMasterHooks(); final List<MasterTriggerRestoreHook<?>> hooks; - if (hookFactories == null || hookFactories.length == 0) { + if (serializedHooks == null) { hooks = Collections.emptyList(); } else { + final MasterTriggerRestoreHook.Factory[] hookFactories; + try { + hookFactories = serializedHooks.deserializeValue(classLoader); + } + catch (IOException | ClassNotFoundException e) { + throw new JobExecutionException(jobId, "Could not instantiate user-defined checkpoint hooks", e); + } + hooks = new ArrayList<>(hookFactories.length); for (MasterTriggerRestoreHook.Factory factory : hookFactories) { hooks.add(factory.create()); http://git-wip-us.apache.org/repos/asf/flink/blob/aa8a90a5/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/JobCheckpointingSettings.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/JobCheckpointingSettings.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/JobCheckpointingSettings.java index 3dd037e..a30a2ba 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/JobCheckpointingSettings.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/JobCheckpointingSettings.java @@ -21,6 +21,7 @@ package org.apache.flink.runtime.jobgraph.tasks; import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.util.SerializedValue; import javax.annotation.Nullable; import java.util.List; @@ -58,7 +59,8 @@ public class JobCheckpointingSettings implements java.io.Serializable { private final StateBackend defaultStateBackend; /** (Factories for) hooks that are executed on the checkpoint coordinator */ - private final MasterTriggerRestoreHook.Factory[] masterHooks; + @Nullable + private final SerializedValue<MasterTriggerRestoreHook.Factory[]> masterHooks; /** * Flag indicating whether exactly once checkpoint mode has been configured. @@ -96,7 +98,7 @@ public class JobCheckpointingSettings implements java.io.Serializable { int maxConcurrentCheckpoints, ExternalizedCheckpointSettings externalizedCheckpointSettings, @Nullable StateBackend defaultStateBackend, - @Nullable MasterTriggerRestoreHook.Factory[] masterHooks, + @Nullable SerializedValue<MasterTriggerRestoreHook.Factory[]> masterHooks, boolean isExactlyOnce) { // sanity checks @@ -115,8 +117,7 @@ public class JobCheckpointingSettings implements java.io.Serializable { this.externalizedCheckpointSettings = requireNonNull(externalizedCheckpointSettings); this.defaultStateBackend = defaultStateBackend; this.isExactlyOnce = isExactlyOnce; - - this.masterHooks = masterHooks != null ? masterHooks : new MasterTriggerRestoreHook.Factory[0]; + this.masterHooks = masterHooks; } // -------------------------------------------------------------------------------------------- @@ -158,7 +159,8 @@ public class JobCheckpointingSettings implements java.io.Serializable { return defaultStateBackend; } - public MasterTriggerRestoreHook.Factory[] getMasterHooks() { + @Nullable + public SerializedValue<MasterTriggerRestoreHook.Factory[]> getMasterHooks() { return masterHooks; } http://git-wip-us.apache.org/repos/asf/flink/blob/aa8a90a5/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointSettingsSerializableTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointSettingsSerializableTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointSettingsSerializableTest.java new file mode 100644 index 0000000..0246180 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointSettingsSerializableTest.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.time.Time; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.testutils.CommonTestUtils; +import org.apache.flink.metrics.groups.UnregisteredMetricsGroup; +import org.apache.flink.runtime.executiongraph.ExecutionGraph; +import org.apache.flink.runtime.executiongraph.ExecutionGraphBuilder; +import org.apache.flink.runtime.executiongraph.restart.NoRestartStrategy; +import org.apache.flink.runtime.instance.SlotProvider; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; +import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings; +import org.apache.flink.runtime.testingUtils.TestingUtils; +import org.apache.flink.util.SerializedValue; +import org.apache.flink.util.TestLogger; + +import org.junit.Test; + +import java.io.Serializable; +import java.net.URL; +import java.net.URLClassLoader; +import java.util.Collections; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * This test validates that the checkpoint settings serialize correctly + * in the presence of user-defined objects. + */ +public class CheckpointSettingsSerializableTest extends TestLogger { + + @Test + public void testClassLoaderForCheckpointHooks() throws Exception { + final ClassLoader classLoader = new URLClassLoader(new URL[0], getClass().getClassLoader()); + final Serializable outOfClassPath = CommonTestUtils.createObjectForClassNotInClassPath(classLoader); + + final MasterTriggerRestoreHook.Factory[] hooks = { + new TestFactory(outOfClassPath) }; + final SerializedValue<MasterTriggerRestoreHook.Factory[]> serHooks = new SerializedValue<>(hooks); + + final JobCheckpointingSettings checkpointingSettings = new JobCheckpointingSettings( + Collections.<JobVertexID>emptyList(), + Collections.<JobVertexID>emptyList(), + Collections.<JobVertexID>emptyList(), + 1000L, + 10000L, + 0L, + 1, + ExternalizedCheckpointSettings.none(), + null, + serHooks, + true); + + final JobGraph jobGraph = new JobGraph(new JobID(), "test job"); + jobGraph.setSnapshotSettings(checkpointingSettings); + + // to serialize/deserialize the job graph to see if the behavior is correct under + // distributed execution + final JobGraph copy = CommonTestUtils.createCopySerializable(jobGraph); + + final ExecutionGraph eg = ExecutionGraphBuilder.buildGraph( + null, + copy, + new Configuration(), + TestingUtils.defaultExecutor(), + TestingUtils.defaultExecutor(), + mock(SlotProvider.class), + classLoader, + new StandaloneCheckpointRecoveryFactory(), + Time.seconds(10), + new NoRestartStrategy(), + new UnregisteredMetricsGroup(), + 10, + log); + + assertEquals(1, eg.getCheckpointCoordinator().getNumberOfRegisteredMasterHooks()); + } + + // ------------------------------------------------------------------------ + + private static final class TestFactory implements MasterTriggerRestoreHook.Factory { + + private static final long serialVersionUID = -612969579110202607L; + + private final Serializable payload; + + TestFactory(Serializable payload) { + this.payload = payload; + } + + @SuppressWarnings("unchecked") + @Override + public <V> MasterTriggerRestoreHook<V> create() { + MasterTriggerRestoreHook<V> hook = mock(MasterTriggerRestoreHook.class); + when(hook.getIdentifier()).thenReturn("id"); + return hook; + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/aa8a90a5/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java index b3a6cf8..6d1af72 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java @@ -66,6 +66,8 @@ import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; import org.apache.flink.streaming.runtime.tasks.StreamIterationHead; import org.apache.flink.streaming.runtime.tasks.StreamIterationTail; +import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.SerializedValue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -642,6 +644,22 @@ public class StreamingJobGraphGenerator { } } + // because the hooks can have user-defined code, they need to be stored as + // eagerly serialized values + final SerializedValue<MasterTriggerRestoreHook.Factory[]> serializedHooks; + if (hooks.isEmpty()) { + serializedHooks = null; + } else { + try { + MasterTriggerRestoreHook.Factory[] asArray = + hooks.toArray(new MasterTriggerRestoreHook.Factory[hooks.size()]); + serializedHooks = new SerializedValue<>(asArray); + } + catch (IOException e) { + throw new FlinkRuntimeException("Trigger/restore hook is not serializable", e); + } + } + // --- done, put it all together --- JobCheckpointingSettings settings = new JobCheckpointingSettings( @@ -650,7 +668,7 @@ public class StreamingJobGraphGenerator { cfg.getMaxConcurrentCheckpoints(), externalizedCheckpointSettings, streamGraph.getStateBackend(), - hooks.toArray(new MasterTriggerRestoreHook.Factory[hooks.size()]), + serializedHooks, isExactlyOnce); jobGraph.setSnapshotSettings(settings); http://git-wip-us.apache.org/repos/asf/flink/blob/aa8a90a5/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/WithMasterCheckpointHookConfigTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/WithMasterCheckpointHookConfigTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/WithMasterCheckpointHookConfigTest.java index b5a95eb..8065cf1 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/WithMasterCheckpointHookConfigTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/WithMasterCheckpointHookConfigTest.java @@ -29,6 +29,7 @@ import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.sink.DiscardingSink; import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.util.SerializedValue; import org.junit.Test; import javax.annotation.Nullable; @@ -80,10 +81,15 @@ public class WithMasterCheckpointHookConfigTest { .addSink(new DiscardingSink<String>()); final JobGraph jg = env.getStreamGraph().getJobGraph(); - assertEquals(hooks.size(), jg.getCheckpointingSettings().getMasterHooks().length); + + SerializedValue<Factory[]> serializedConfiguredHooks = jg.getCheckpointingSettings().getMasterHooks(); + assertNotNull(serializedConfiguredHooks); + + Factory[] configuredHooks = serializedConfiguredHooks.deserializeValue(getClass().getClassLoader()); + assertEquals(hooks.size(), configuredHooks.length); // check that all hooks are contained and exist exactly once - for (Factory f : jg.getCheckpointingSettings().getMasterHooks()) { + for (Factory f : configuredHooks) { MasterTriggerRestoreHook<?> hook = f.create(); assertTrue(hooks.remove(hook)); }
