http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/JobCheckpointingSettings.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/JobCheckpointingSettings.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/JobCheckpointingSettings.java index 38130d4..3dd037e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/JobCheckpointingSettings.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/JobCheckpointingSettings.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.jobgraph.tasks; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.state.StateBackend; @@ -32,21 +33,21 @@ import static java.util.Objects.requireNonNull; * need to participate. */ public class JobCheckpointingSettings implements java.io.Serializable { - + private static final long serialVersionUID = -2593319571078198180L; - + private final List<JobVertexID> verticesToTrigger; private final List<JobVertexID> verticesToAcknowledge; private final List<JobVertexID> verticesToConfirm; - + private final long checkpointInterval; - + private final long checkpointTimeout; - + private final long minPauseBetweenCheckpoints; - + private final int maxConcurrentCheckpoints; /** Settings for externalized checkpoints. */ @@ -56,6 +57,9 @@ public class JobCheckpointingSettings implements java.io.Serializable { @Nullable private final StateBackend defaultStateBackend; + /** (Factories for) hooks that are executed on the checkpoint coordinator */ + private final MasterTriggerRestoreHook.Factory[] masterHooks; + /** * Flag indicating whether exactly once checkpoint mode has been configured. * If <code>false</code>, at least once mode has been configured. This is @@ -77,12 +81,30 @@ public class JobCheckpointingSettings implements java.io.Serializable { @Nullable StateBackend defaultStateBackend, boolean isExactlyOnce) { + this(verticesToTrigger, verticesToAcknowledge, verticesToConfirm, + checkpointInterval, checkpointTimeout, minPauseBetweenCheckpoints, maxConcurrentCheckpoints, + externalizedCheckpointSettings, defaultStateBackend, null, isExactlyOnce); + } + + public JobCheckpointingSettings( + List<JobVertexID> verticesToTrigger, + List<JobVertexID> verticesToAcknowledge, + List<JobVertexID> verticesToConfirm, + long checkpointInterval, + long checkpointTimeout, + long minPauseBetweenCheckpoints, + int maxConcurrentCheckpoints, + ExternalizedCheckpointSettings externalizedCheckpointSettings, + @Nullable StateBackend defaultStateBackend, + @Nullable MasterTriggerRestoreHook.Factory[] masterHooks, + boolean isExactlyOnce) { + // sanity checks if (checkpointInterval < 1 || checkpointTimeout < 1 || minPauseBetweenCheckpoints < 0 || maxConcurrentCheckpoints < 1) { throw new IllegalArgumentException(); } - + this.verticesToTrigger = requireNonNull(verticesToTrigger); this.verticesToAcknowledge = requireNonNull(verticesToAcknowledge); this.verticesToConfirm = requireNonNull(verticesToConfirm); @@ -93,14 +115,16 @@ public class JobCheckpointingSettings implements java.io.Serializable { this.externalizedCheckpointSettings = requireNonNull(externalizedCheckpointSettings); this.defaultStateBackend = defaultStateBackend; this.isExactlyOnce = isExactlyOnce; + + this.masterHooks = masterHooks != null ? masterHooks : new MasterTriggerRestoreHook.Factory[0]; } - + // -------------------------------------------------------------------------------------------- public List<JobVertexID> getVerticesToTrigger() { return verticesToTrigger; } - + public List<JobVertexID> getVerticesToAcknowledge() { return verticesToAcknowledge; } @@ -134,12 +158,16 @@ public class JobCheckpointingSettings implements java.io.Serializable { return defaultStateBackend; } + public MasterTriggerRestoreHook.Factory[] getMasterHooks() { + return masterHooks; + } + public boolean isExactlyOnce() { return isExactlyOnce; } // -------------------------------------------------------------------------------------------- - + @Override public String toString() { return String.format("SnapshotSettings: interval=%d, timeout=%d, pause-between=%d, " +
http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java new file mode 100644 index 0000000..0ec4606 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.core.io.SimpleVersionedSerializer; +import org.apache.flink.runtime.concurrent.Executors; +import org.apache.flink.runtime.concurrent.impl.FlinkCompletableFuture; +import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; +import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; + +import org.junit.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Executor; + +import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest.mockExecutionVertex; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import static org.mockito.Matchers.eq; +import static org.mockito.Matchers.isNull; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for the user-defined hooks that the checkpoint coordinator can call. + */ +public class CheckpointCoordinatorMasterHooksTest { + + // ------------------------------------------------------------------------ + // hook registration + // ------------------------------------------------------------------------ + + /** + * This method tests that hooks with the same identifier are not registered + * multiple times. + */ + @Test + public void testDeduplicateOnRegister() { + final CheckpointCoordinator cc = instantiateCheckpointCoordinator(new JobID()); + + MasterTriggerRestoreHook<?> hook1 = mock(MasterTriggerRestoreHook.class); + when(hook1.getIdentifier()).thenReturn("test id"); + + MasterTriggerRestoreHook<?> hook2 = mock(MasterTriggerRestoreHook.class); + when(hook2.getIdentifier()).thenReturn("test id"); + + MasterTriggerRestoreHook<?> hook3 = mock(MasterTriggerRestoreHook.class); + when(hook3.getIdentifier()).thenReturn("anotherId"); + + assertTrue(cc.addMasterHook(hook1)); + assertFalse(cc.addMasterHook(hook2)); + assertTrue(cc.addMasterHook(hook3)); + } + + /** + * Test that validates correct exceptions when supplying hooks with invalid IDs. + */ + @Test + public void testNullOrInvalidId() { + final CheckpointCoordinator cc = instantiateCheckpointCoordinator(new JobID()); + + try { + cc.addMasterHook(null); + fail("expected an exception"); + } catch (NullPointerException ignored) {} + + try { + cc.addMasterHook(mock(MasterTriggerRestoreHook.class)); + fail("expected an exception"); + } catch (IllegalArgumentException ignored) {} + + try { + MasterTriggerRestoreHook<?> hook = mock(MasterTriggerRestoreHook.class); + when(hook.getIdentifier()).thenReturn(" "); + + cc.addMasterHook(hook); + fail("expected an exception"); + } catch (IllegalArgumentException ignored) {} + } + + // ------------------------------------------------------------------------ + // trigger / restore behavior + // ------------------------------------------------------------------------ + + @Test + public void testHooksAreCalledOnTrigger() throws Exception { + final String id1 = "id1"; + final String id2 = "id2"; + + final String state1 = "the-test-string-state"; + final byte[] state1serialized = new StringSerializer().serialize(state1); + + final long state2 = 987654321L; + final byte[] state2serialized = new LongSerializer().serialize(state2); + + final MasterTriggerRestoreHook<String> statefulHook1 = mockGeneric(MasterTriggerRestoreHook.class); + when(statefulHook1.getIdentifier()).thenReturn(id1); + when(statefulHook1.createCheckpointDataSerializer()).thenReturn(new StringSerializer()); + when(statefulHook1.triggerCheckpoint(anyLong(), anyLong(), any(Executor.class))) + .thenReturn(FlinkCompletableFuture.completed(state1)); + + final MasterTriggerRestoreHook<Long> statefulHook2 = mockGeneric(MasterTriggerRestoreHook.class); + when(statefulHook2.getIdentifier()).thenReturn(id2); + when(statefulHook2.createCheckpointDataSerializer()).thenReturn(new LongSerializer()); + when(statefulHook2.triggerCheckpoint(anyLong(), anyLong(), any(Executor.class))) + .thenReturn(FlinkCompletableFuture.completed(state2)); + + final MasterTriggerRestoreHook<Void> statelessHook = mockGeneric(MasterTriggerRestoreHook.class); + when(statelessHook.getIdentifier()).thenReturn("some-id"); + + // create the checkpoint coordinator + final JobID jid = new JobID(); + final ExecutionAttemptID execId = new ExecutionAttemptID(); + final ExecutionVertex ackVertex = mockExecutionVertex(execId); + final CheckpointCoordinator cc = instantiateCheckpointCoordinator(jid, ackVertex); + + cc.addMasterHook(statefulHook1); + cc.addMasterHook(statelessHook); + cc.addMasterHook(statefulHook2); + + // trigger a checkpoint + assertTrue(cc.triggerCheckpoint(System.currentTimeMillis(), false)); + assertEquals(1, cc.getNumberOfPendingCheckpoints()); + + verify(statefulHook1, times(1)).triggerCheckpoint(anyLong(), anyLong(), any(Executor.class)); + verify(statefulHook2, times(1)).triggerCheckpoint(anyLong(), anyLong(), any(Executor.class)); + verify(statelessHook, times(1)).triggerCheckpoint(anyLong(), anyLong(), any(Executor.class)); + + final long checkpointId = cc.getPendingCheckpoints().values().iterator().next().getCheckpointId(); + cc.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, execId, checkpointId)); + assertEquals(0, cc.getNumberOfPendingCheckpoints()); + + assertEquals(1, cc.getNumberOfRetainedSuccessfulCheckpoints()); + final CompletedCheckpoint chk = cc.getCheckpointStore().getLatestCheckpoint(); + + final Collection<MasterState> masterStates = chk.getMasterHookStates(); + assertEquals(2, masterStates.size()); + + for (MasterState ms : masterStates) { + if (ms.name().equals(id1)) { + assertArrayEquals(state1serialized, ms.bytes()); + assertEquals(StringSerializer.VERSION, ms.version()); + } + else if (ms.name().equals(id2)) { + assertArrayEquals(state2serialized, ms.bytes()); + assertEquals(LongSerializer.VERSION, ms.version()); + } + else { + fail("unrecognized state name: " + ms.name()); + } + } + } + + @Test + public void testHooksAreCalledOnRestore() throws Exception { + final String id1 = "id1"; + final String id2 = "id2"; + + final String state1 = "the-test-string-state"; + final byte[] state1serialized = new StringSerializer().serialize(state1); + + final long state2 = 987654321L; + final byte[] state2serialized = new LongSerializer().serialize(state2); + + final List<MasterState> masterHookStates = Arrays.asList( + new MasterState(id1, state1serialized, StringSerializer.VERSION), + new MasterState(id2, state2serialized, LongSerializer.VERSION)); + + final MasterTriggerRestoreHook<String> statefulHook1 = mockGeneric(MasterTriggerRestoreHook.class); + when(statefulHook1.getIdentifier()).thenReturn(id1); + when(statefulHook1.createCheckpointDataSerializer()).thenReturn(new StringSerializer()); + when(statefulHook1.triggerCheckpoint(anyLong(), anyLong(), any(Executor.class))) + .thenThrow(new Exception("not expected")); + + final MasterTriggerRestoreHook<Long> statefulHook2 = mockGeneric(MasterTriggerRestoreHook.class); + when(statefulHook2.getIdentifier()).thenReturn(id2); + when(statefulHook2.createCheckpointDataSerializer()).thenReturn(new LongSerializer()); + when(statefulHook2.triggerCheckpoint(anyLong(), anyLong(), any(Executor.class))) + .thenThrow(new Exception("not expected")); + + final MasterTriggerRestoreHook<Void> statelessHook = mockGeneric(MasterTriggerRestoreHook.class); + when(statelessHook.getIdentifier()).thenReturn("some-id"); + + final JobID jid = new JobID(); + final long checkpointId = 13L; + + final CompletedCheckpoint checkpoint = new CompletedCheckpoint( + jid, checkpointId, 123L, 125L, + Collections.<JobVertexID, TaskState>emptyMap(), + masterHookStates, + CheckpointProperties.forStandardCheckpoint(), + null, + null); + + final ExecutionAttemptID execId = new ExecutionAttemptID(); + final ExecutionVertex ackVertex = mockExecutionVertex(execId); + final CheckpointCoordinator cc = instantiateCheckpointCoordinator(jid, ackVertex); + + cc.addMasterHook(statefulHook1); + cc.addMasterHook(statelessHook); + cc.addMasterHook(statefulHook2); + + cc.getCheckpointStore().addCheckpoint(checkpoint); + cc.restoreLatestCheckpointedState( + Collections.<JobVertexID, ExecutionJobVertex>emptyMap(), + true, + false); + + verify(statefulHook1, times(1)).restoreCheckpoint(eq(checkpointId), eq(state1)); + verify(statefulHook2, times(1)).restoreCheckpoint(eq(checkpointId), eq(state2)); + verify(statelessHook, times(1)).restoreCheckpoint(eq(checkpointId), isNull(Void.class)); + } + + @Test + public void checkUnMatchedStateOnRestore() throws Exception { + final String id1 = "id1"; + final String id2 = "id2"; + + final String state1 = "the-test-string-state"; + final byte[] state1serialized = new StringSerializer().serialize(state1); + + final long state2 = 987654321L; + final byte[] state2serialized = new LongSerializer().serialize(state2); + + final List<MasterState> masterHookStates = Arrays.asList( + new MasterState(id1, state1serialized, StringSerializer.VERSION), + new MasterState(id2, state2serialized, LongSerializer.VERSION)); + + final MasterTriggerRestoreHook<String> statefulHook = mockGeneric(MasterTriggerRestoreHook.class); + when(statefulHook.getIdentifier()).thenReturn(id1); + when(statefulHook.createCheckpointDataSerializer()).thenReturn(new StringSerializer()); + when(statefulHook.triggerCheckpoint(anyLong(), anyLong(), any(Executor.class))) + .thenThrow(new Exception("not expected")); + + final MasterTriggerRestoreHook<Void> statelessHook = mockGeneric(MasterTriggerRestoreHook.class); + when(statelessHook.getIdentifier()).thenReturn("some-id"); + + final JobID jid = new JobID(); + final long checkpointId = 44L; + + final CompletedCheckpoint checkpoint = new CompletedCheckpoint( + jid, checkpointId, 123L, 125L, + Collections.<JobVertexID, TaskState>emptyMap(), + masterHookStates, + CheckpointProperties.forStandardCheckpoint(), + null, + null); + + final ExecutionAttemptID execId = new ExecutionAttemptID(); + final ExecutionVertex ackVertex = mockExecutionVertex(execId); + final CheckpointCoordinator cc = instantiateCheckpointCoordinator(jid, ackVertex); + + cc.addMasterHook(statefulHook); + cc.addMasterHook(statelessHook); + + cc.getCheckpointStore().addCheckpoint(checkpoint); + + // since we have unmatched state, this should fail + try { + cc.restoreLatestCheckpointedState( + Collections.<JobVertexID, ExecutionJobVertex>emptyMap(), + true, + false); + fail("exception expected"); + } + catch (IllegalStateException ignored) {} + + // permitting unmatched state should succeed + cc.restoreLatestCheckpointedState( + Collections.<JobVertexID, ExecutionJobVertex>emptyMap(), + true, + true); + + verify(statefulHook, times(1)).restoreCheckpoint(eq(checkpointId), eq(state1)); + verify(statelessHook, times(1)).restoreCheckpoint(eq(checkpointId), isNull(Void.class)); + } + + // ------------------------------------------------------------------------ + // failure scenarios + // ------------------------------------------------------------------------ + + @Test + public void testSerializationFailsOnTrigger() { + } + + @Test + public void testHookCallFailsOnTrigger() { + } + + @Test + public void testDeserializationFailsOnRestore() { + } + + @Test + public void testHookCallFailsOnRestore() { + } + + @Test + public void testTypeIncompatibleWithSerializerOnStore() { + } + + @Test + public void testTypeIncompatibleWithHookOnRestore() { + } + + // ------------------------------------------------------------------------ + // utilities + // ------------------------------------------------------------------------ + + private static CheckpointCoordinator instantiateCheckpointCoordinator(JobID jid, ExecutionVertex... ackVertices) { + return new CheckpointCoordinator( + jid, + 10000000L, + 600000L, + 0L, + 1, + ExternalizedCheckpointSettings.none(), + new ExecutionVertex[0], + ackVertices, + new ExecutionVertex[0], + new StandaloneCheckpointIDCounter(), + new StandaloneCompletedCheckpointStore(10), + null, + Executors.directExecutor()); + } + + private static <T> T mockGeneric(Class<?> clazz) { + @SuppressWarnings("unchecked") + Class<T> typedClass = (Class<T>) clazz; + return mock(typedClass); + } + + // ------------------------------------------------------------------------ + + private static final class StringSerializer implements SimpleVersionedSerializer<String> { + + static final int VERSION = 77; + + @Override + public int getVersion() { + return VERSION; + } + + @Override + public byte[] serialize(String checkpointData) throws IOException { + return checkpointData.getBytes(StandardCharsets.UTF_8); + } + + @Override + public String deserialize(int version, byte[] serialized) throws IOException { + if (version != VERSION) { + throw new IOException("version mismatch"); + } + return new String(serialized, StandardCharsets.UTF_8); + } + } + + // ------------------------------------------------------------------------ + + private static final class LongSerializer implements SimpleVersionedSerializer<Long> { + + static final int VERSION = 5; + + @Override + public int getVersion() { + return VERSION; + } + + @Override + public byte[] serialize(Long checkpointData) throws IOException { + final byte[] bytes = new byte[8]; + ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).putLong(0, checkpointData); + return bytes; + } + + @Override + public Long deserialize(int version, byte[] serialized) throws IOException { + assertEquals(VERSION, version); + assertEquals(8, serialized.length); + + return ByteBuffer.wrap(serialized).order(ByteOrder.LITTLE_ENDIAN).getLong(0); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java index 4a36dd2..fc6e516 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java @@ -273,7 +273,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { Map<JobVertexID, TaskState> taskGroupStates, CheckpointProperties props) { - super(jobId, checkpointId, timestamp, Long.MAX_VALUE, taskGroupStates, props); + super(jobId, checkpointId, timestamp, Long.MAX_VALUE, taskGroupStates, null, props); } @Override http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java index 0b759d4..652cc76 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java @@ -45,7 +45,7 @@ import static org.mockito.Mockito.verify; public class CompletedCheckpointTest { @Rule - public TemporaryFolder tmpFolder = new TemporaryFolder(); + public final TemporaryFolder tmpFolder = new TemporaryFolder(); /** * Tests that persistent checkpoints discard their header file. @@ -61,7 +61,10 @@ public class CompletedCheckpointTest { // Verify discard call is forwarded to state CompletedCheckpoint checkpoint = new CompletedCheckpoint( - new JobID(), 0, 0, 1, taskStates, CheckpointProperties.forStandardCheckpoint(), + new JobID(), 0, 0, 1, + taskStates, + Collections.<MasterState>emptyList(), + CheckpointProperties.forStandardCheckpoint(), new FileStateHandle(new Path(file.toURI()), file.length()), file.getAbsolutePath()); @@ -81,8 +84,12 @@ public class CompletedCheckpointTest { boolean discardSubsumed = true; CheckpointProperties props = new CheckpointProperties(false, false, discardSubsumed, true, true, true, true); + CompletedCheckpoint checkpoint = new CompletedCheckpoint( - new JobID(), 0, 0, 1, taskStates, props); + new JobID(), 0, 0, 1, + taskStates, + Collections.<MasterState>emptyList(), + props); SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); checkpoint.registerSharedStates(sharedStateRegistry); @@ -117,7 +124,10 @@ public class CompletedCheckpointTest { // Keep CheckpointProperties props = new CheckpointProperties(false, true, false, false, false, false, false); CompletedCheckpoint checkpoint = new CompletedCheckpoint( - new JobID(), 0, 0, 1, new HashMap<>(taskStates), props, + new JobID(), 0, 0, 1, + new HashMap<>(taskStates), + Collections.<MasterState>emptyList(), + props, new FileStateHandle(new Path(file.toURI()), file.length()), externalPath); @@ -132,7 +142,10 @@ public class CompletedCheckpointTest { // Discard props = new CheckpointProperties(false, false, true, true, true, true, true); checkpoint = new CompletedCheckpoint( - new JobID(), 0, 0, 1, new HashMap<>(taskStates), props); + new JobID(), 0, 0, 1, + new HashMap<>(taskStates), + Collections.<MasterState>emptyList(), + props); checkpoint.discardOnShutdown(status, sharedStateRegistry); verify(state, times(1)).discardState(); @@ -155,6 +168,7 @@ public class CompletedCheckpointTest { 0, 1, new HashMap<>(taskStates), + Collections.<MasterState>emptyList(), CheckpointProperties.forStandardCheckpoint()); CompletedCheckpointStats.DiscardCallback callback = mock(CompletedCheckpointStats.DiscardCallback.class); http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java index 5fce62e..1f038bd 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java @@ -105,6 +105,7 @@ public class ExecutionGraphCheckpointCoordinatorTest { Collections.<ExecutionJobVertex>emptyList(), Collections.<ExecutionJobVertex>emptyList(), Collections.<ExecutionJobVertex>emptyList(), + Collections.<MasterTriggerRestoreHook<?>>emptyList(), counter, store, null, http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java new file mode 100644 index 0000000..7d9874e --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.savepoint; + +import org.apache.flink.configuration.ConfigConstants; +import org.apache.flink.runtime.checkpoint.MasterState; +import org.apache.flink.runtime.checkpoint.SubtaskState; +import org.apache.flink.runtime.checkpoint.TaskState; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.KeyGroupRangeOffsets; +import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle.StateMetaInfo; +import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare; +import org.apache.flink.util.StringUtils; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** + * A collection of utility methods for testing the (de)serialization of + * checkpoint metadata for persistence. + */ +public class CheckpointTestUtils { + + /** + * Creates a random collection of TaskState objects containing various types of state handles. + */ + public static Collection<TaskState> createTaskStates(int numTaskStates, int numSubtasksPerTask) { + return createTaskStates(new Random(), numTaskStates, numSubtasksPerTask); + } + + /** + * Creates a random collection of TaskState objects containing various types of state handles. + */ + public static Collection<TaskState> createTaskStates( + Random random, + int numTaskStates, + int numSubtasksPerTask) { + + List<TaskState> taskStates = new ArrayList<>(numTaskStates); + + for (int stateIdx = 0; stateIdx < numTaskStates; ++stateIdx) { + + int chainLength = 1 + random.nextInt(8); + + TaskState taskState = new TaskState(new JobVertexID(), numSubtasksPerTask, 128, chainLength); + + int noNonPartitionableStateAtIndex = random.nextInt(chainLength); + int noOperatorStateBackendAtIndex = random.nextInt(chainLength); + int noOperatorStateStreamAtIndex = random.nextInt(chainLength); + + boolean hasKeyedBackend = random.nextInt(4) != 0; + boolean hasKeyedStream = random.nextInt(4) != 0; + + for (int subtaskIdx = 0; subtaskIdx < numSubtasksPerTask; subtaskIdx++) { + + List<StreamStateHandle> nonPartitionableStates = new ArrayList<>(chainLength); + List<OperatorStateHandle> operatorStatesBackend = new ArrayList<>(chainLength); + List<OperatorStateHandle> operatorStatesStream = new ArrayList<>(chainLength); + + for (int chainIdx = 0; chainIdx < chainLength; ++chainIdx) { + + StreamStateHandle nonPartitionableState = + new TestByteStreamStateHandleDeepCompare("a-" + chainIdx, ("Hi-" + chainIdx).getBytes( + ConfigConstants.DEFAULT_CHARSET)); + StreamStateHandle operatorStateBackend = + new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes(ConfigConstants.DEFAULT_CHARSET)); + StreamStateHandle operatorStateStream = + new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes(ConfigConstants.DEFAULT_CHARSET)); + Map<String, StateMetaInfo> offsetsMap = new HashMap<>(); + offsetsMap.put("A", new OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); + offsetsMap.put("B", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); + offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.BROADCAST)); + + if (chainIdx != noNonPartitionableStateAtIndex) { + nonPartitionableStates.add(nonPartitionableState); + } + + if (chainIdx != noOperatorStateBackendAtIndex) { + OperatorStateHandle operatorStateHandleBackend = + new OperatorStateHandle(offsetsMap, operatorStateBackend); + operatorStatesBackend.add(operatorStateHandleBackend); + } + + if (chainIdx != noOperatorStateStreamAtIndex) { + OperatorStateHandle operatorStateHandleStream = + new OperatorStateHandle(offsetsMap, operatorStateStream); + operatorStatesStream.add(operatorStateHandleStream); + } + } + + KeyGroupsStateHandle keyedStateBackend = null; + KeyGroupsStateHandle keyedStateStream = null; + + if (hasKeyedBackend) { + keyedStateBackend = new KeyGroupsStateHandle( + new KeyGroupRangeOffsets(1, 1, new long[]{42}), + new TestByteStreamStateHandleDeepCompare("c", "Hello" + .getBytes(ConfigConstants.DEFAULT_CHARSET))); + } + + if (hasKeyedStream) { + keyedStateStream = new KeyGroupsStateHandle( + new KeyGroupRangeOffsets(1, 1, new long[]{23}), + new TestByteStreamStateHandleDeepCompare("d", "World" + .getBytes(ConfigConstants.DEFAULT_CHARSET))); + } + + taskState.putState(subtaskIdx, new SubtaskState( + new ChainedStateHandle<>(nonPartitionableStates), + new ChainedStateHandle<>(operatorStatesBackend), + new ChainedStateHandle<>(operatorStatesStream), + keyedStateStream, + keyedStateBackend)); + } + + taskStates.add(taskState); + } + + return taskStates; + } + + /** + * Creates a bunch of random master states. + */ + public static Collection<MasterState> createRandomMasterStates(Random random, int num) { + final ArrayList<MasterState> states = new ArrayList<>(num); + + for (int i = 0; i < num; i++) { + int version = random.nextInt(10); + String name = StringUtils.getRandomString(random, 5, 500); + byte[] bytes = new byte[random.nextInt(5000) + 1]; + random.nextBytes(bytes); + + states.add(new MasterState(name, bytes, version)); + } + + return states; + } + + /** + * Asserts that two MasterStates are equal. + * + * <p>The MasterState avoids overriding {@code equals()} on purpose, because equality is not well + * defined in the raw contents. + */ + public static void assertMasterStateEquality(MasterState a, MasterState b) { + assertEquals(a.version(), b.version()); + assertEquals(a.name(), b.name()); + assertArrayEquals(a.bytes(), b.bytes()); + + } + + // ------------------------------------------------------------------------ + + /** utility class, not meant to be instantiated */ + private CheckpointTestUtils() {} +} http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoaderTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoaderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoaderTest.java index c66b29d..20b1e57 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoaderTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoaderTest.java @@ -67,7 +67,7 @@ public class SavepointLoaderTest { JobID jobId = new JobID(); // Store savepoint - SavepointV1 savepoint = new SavepointV1(checkpointId, taskStates.values()); + SavepointV2 savepoint = new SavepointV2(checkpointId, taskStates.values()); String path = SavepointStore.storeSavepoint(tmp.getAbsolutePath(), savepoint); ExecutionJobVertex vertex = mock(ExecutionJobVertex.class); http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStoreTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStoreTest.java index 1eb8055..cf79282 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStoreTest.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.core.fs.FileSystem; import org.apache.flink.core.fs.Path; +import org.apache.flink.runtime.checkpoint.MasterState; import org.apache.flink.runtime.checkpoint.TaskState; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.junit.Rule; @@ -68,7 +69,7 @@ public class SavepointStoreTest { // Store String savepointDirectory = SavepointStore.createSavepointDirectory(root, new JobID()); - SavepointV1 stored = new SavepointV1(1929292, SavepointV1Test.createTaskStates(4, 24)); + SavepointV2 stored = new SavepointV2(1929292, CheckpointTestUtils.createTaskStates(4, 24)); String path = SavepointStore.storeSavepoint(savepointDirectory, stored); list = rootFile.listFiles(); @@ -77,7 +78,10 @@ public class SavepointStoreTest { // Load Savepoint loaded = SavepointStore.loadSavepoint(path, Thread.currentThread().getContextClassLoader()); - assertEquals(stored, loaded); + + assertEquals(stored.getCheckpointId(), loaded.getCheckpointId()); + assertEquals(stored.getTaskStates(), loaded.getTaskStates()); + assertEquals(stored.getMasterStates(), loaded.getMasterStates()); loaded.dispose(); @@ -126,8 +130,8 @@ public class SavepointStoreTest { File rootFile = new File(root); // New savepoint type for test - int version = ThreadLocalRandom.current().nextInt(); - long checkpointId = ThreadLocalRandom.current().nextLong(); + int version = ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE); // make this a positive number + long checkpointId = ThreadLocalRandom.current().nextLong(Long.MAX_VALUE); // make this a positive number // Add serializer serializers.put(version, NewSavepointSerializer.INSTANCE); @@ -143,7 +147,7 @@ public class SavepointStoreTest { // Savepoint v0 String savepointDirectory2 = SavepointStore.createSavepointDirectory(root, new JobID()); - Savepoint savepoint = new SavepointV1(checkpointId, SavepointV1Test.createTaskStates(4, 32)); + SavepointV2 savepoint = new SavepointV2(checkpointId, CheckpointTestUtils.createTaskStates(4, 32)); String pathSavepoint = SavepointStore.storeSavepoint(savepointDirectory2, savepoint); list = rootFile.listFiles(); @@ -156,7 +160,9 @@ public class SavepointStoreTest { assertEquals(newSavepoint, loaded); loaded = SavepointStore.loadSavepoint(pathSavepoint, Thread.currentThread().getContextClassLoader()); - assertEquals(savepoint, loaded); + assertEquals(savepoint.getCheckpointId(), loaded.getCheckpointId()); + assertEquals(savepoint.getTaskStates(), loaded.getTaskStates()); + assertEquals(savepoint.getMasterStates(), loaded.getMasterStates()); } /** @@ -199,7 +205,7 @@ public class SavepointStoreTest { FileSystem fs = FileSystem.get(new Path(root).toUri()); // Store - SavepointV1 savepoint = new SavepointV1(1929292, SavepointV1Test.createTaskStates(4, 24)); + SavepointV2 savepoint = new SavepointV2(1929292, CheckpointTestUtils.createTaskStates(4, 24)); FileStateHandle store1 = SavepointStore.storeExternalizedCheckpointToHandle(root, savepoint); fs.exists(store1.getFilePath()); @@ -251,7 +257,12 @@ public class SavepointStoreTest { @Override public Collection<TaskState> getTaskStates() { - return Collections.EMPTY_LIST; + return Collections.emptyList(); + } + + @Override + public Collection<MasterState> getMasterStates() { + return Collections.emptyList(); } @Override http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java index 58cf1aa..0eff7bc 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1SerializerTest.java @@ -27,6 +27,7 @@ import java.io.ByteArrayInputStream; import java.util.Random; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; public class SavepointV1SerializerTest { @@ -35,25 +36,29 @@ public class SavepointV1SerializerTest { */ @Test public void testSerializeDeserializeV1() throws Exception { - Random r = new Random(42); - for (int i = 0; i < 100; ++i) { + final Random r = new Random(42); + + for (int i = 0; i < 50; ++i) { SavepointV1 expected = - new SavepointV1(i+ 123123, SavepointV1Test.createTaskStates(1 + r.nextInt(64), 1 + r.nextInt(64))); + new SavepointV1(i+ 123123, CheckpointTestUtils.createTaskStates(r, 1 + r.nextInt(64), 1 + r.nextInt(64))); SavepointV1Serializer serializer = SavepointV1Serializer.INSTANCE; // Serialize ByteArrayOutputStreamWithPos baos = new ByteArrayOutputStreamWithPos(); - serializer.serialize(expected, new DataOutputViewStreamWrapper(baos)); + serializer.serializeOld(expected, new DataOutputViewStreamWrapper(baos)); byte[] bytes = baos.toByteArray(); // Deserialize ByteArrayInputStream bais = new ByteArrayInputStream(bytes); - Savepoint actual = serializer.deserialize( + SavepointV2 actual = serializer.deserialize( new DataInputViewStreamWrapper(bais), Thread.currentThread().getContextClassLoader()); - assertEquals(expected, actual); + + assertEquals(expected.getCheckpointId(), actual.getCheckpointId()); + assertEquals(expected.getTaskStates(), actual.getTaskStates()); + assertTrue(actual.getMasterStates().isEmpty()); } } } http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java deleted file mode 100644 index 08ec35e..0000000 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.checkpoint.savepoint; - -import org.apache.flink.configuration.ConfigConstants; -import org.apache.flink.runtime.checkpoint.SubtaskState; -import org.apache.flink.runtime.checkpoint.TaskState; -import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.KeyGroupRangeOffsets; -import org.apache.flink.runtime.state.KeyGroupsStateHandle; -import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare; -import org.junit.Test; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Random; -import java.util.concurrent.ThreadLocalRandom; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -public class SavepointV1Test { - - /** - * Simple test of savepoint methods. - */ - @Test - public void testSavepointV1() throws Exception { - long checkpointId = ThreadLocalRandom.current().nextLong(Integer.MAX_VALUE); - int numTaskStates = 4; - int numSubtaskStates = 16; - - Collection<TaskState> expected = createTaskStates(numTaskStates, numSubtaskStates); - - SavepointV1 savepoint = new SavepointV1(checkpointId, expected); - - assertEquals(SavepointV1.VERSION, savepoint.getVersion()); - assertEquals(checkpointId, savepoint.getCheckpointId()); - assertEquals(expected, savepoint.getTaskStates()); - - assertFalse(savepoint.getTaskStates().isEmpty()); - savepoint.dispose(); - assertTrue(savepoint.getTaskStates().isEmpty()); - } - - static Collection<TaskState> createTaskStates(int numTaskStates, int numSubtasksPerTask) throws IOException { - - Random random = new Random(numTaskStates * 31 + numSubtasksPerTask); - - List<TaskState> taskStates = new ArrayList<>(numTaskStates); - - for (int stateIdx = 0; stateIdx < numTaskStates; ++stateIdx) { - - int chainLength = 1 + random.nextInt(8); - - TaskState taskState = new TaskState(new JobVertexID(), numSubtasksPerTask, 128, chainLength); - - int noNonPartitionableStateAtIndex = random.nextInt(chainLength); - int noOperatorStateBackendAtIndex = random.nextInt(chainLength); - int noOperatorStateStreamAtIndex = random.nextInt(chainLength); - - boolean hasKeyedBackend = random.nextInt(4) != 0; - boolean hasKeyedStream = random.nextInt(4) != 0; - - for (int subtaskIdx = 0; subtaskIdx < numSubtasksPerTask; subtaskIdx++) { - - List<StreamStateHandle> nonPartitionableStates = new ArrayList<>(chainLength); - List<OperatorStateHandle> operatorStatesBackend = new ArrayList<>(chainLength); - List<OperatorStateHandle> operatorStatesStream = new ArrayList<>(chainLength); - - for (int chainIdx = 0; chainIdx < chainLength; ++chainIdx) { - - StreamStateHandle nonPartitionableState = - new TestByteStreamStateHandleDeepCompare("a-" + chainIdx, ("Hi-" + chainIdx).getBytes( - ConfigConstants.DEFAULT_CHARSET)); - StreamStateHandle operatorStateBackend = - new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes(ConfigConstants.DEFAULT_CHARSET)); - StreamStateHandle operatorStateStream = - new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes(ConfigConstants.DEFAULT_CHARSET)); - Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(); - offsetsMap.put("A", new OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); - offsetsMap.put("B", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); - offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.BROADCAST)); - - if (chainIdx != noNonPartitionableStateAtIndex) { - nonPartitionableStates.add(nonPartitionableState); - } - - if (chainIdx != noOperatorStateBackendAtIndex) { - OperatorStateHandle operatorStateHandleBackend = - new OperatorStateHandle(offsetsMap, operatorStateBackend); - operatorStatesBackend.add(operatorStateHandleBackend); - } - - if (chainIdx != noOperatorStateStreamAtIndex) { - OperatorStateHandle operatorStateHandleStream = - new OperatorStateHandle(offsetsMap, operatorStateStream); - operatorStatesStream.add(operatorStateHandleStream); - } - } - - KeyGroupsStateHandle keyedStateBackend = null; - KeyGroupsStateHandle keyedStateStream = null; - - if (hasKeyedBackend) { - keyedStateBackend = new KeyGroupsStateHandle( - new KeyGroupRangeOffsets(1, 1, new long[]{42}), - new TestByteStreamStateHandleDeepCompare("c", "Hello" - .getBytes(ConfigConstants.DEFAULT_CHARSET))); - } - - if (hasKeyedStream) { - keyedStateStream = new KeyGroupsStateHandle( - new KeyGroupRangeOffsets(1, 1, new long[]{23}), - new TestByteStreamStateHandleDeepCompare("d", "World" - .getBytes(ConfigConstants.DEFAULT_CHARSET))); - } - - taskState.putState(subtaskIdx, new SubtaskState( - new ChainedStateHandle<>(nonPartitionableStates), - new ChainedStateHandle<>(operatorStatesBackend), - new ChainedStateHandle<>(operatorStatesStream), - keyedStateStream, - keyedStateBackend)); - } - - taskStates.add(taskState); - } - - return taskStates; - } - -} http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2SerializerTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2SerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2SerializerTest.java new file mode 100644 index 0000000..deb14dd --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2SerializerTest.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.savepoint; + +import org.apache.flink.core.memory.ByteArrayInputStreamWithPos; +import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.runtime.checkpoint.MasterState; +import org.apache.flink.runtime.checkpoint.TaskState; + +import org.junit.Test; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.Random; + +import static org.junit.Assert.assertEquals; + +/** + * Various tests for the version 2 format serializer of a checkpoint. + */ +public class SavepointV2SerializerTest { + + @Test + public void testCheckpointWithNoState() throws Exception { + final Random rnd = new Random(); + + for (int i = 0; i < 100; ++i) { + final long checkpointId = rnd.nextLong() & 0x7fffffffffffffffL; + final Collection<TaskState> taskStates = Collections.emptyList(); + final Collection<MasterState> masterStates = Collections.emptyList(); + + testCheckpointSerialization(checkpointId, taskStates, masterStates); + } + } + + @Test + public void testCheckpointWithOnlyMasterState() throws Exception { + final Random rnd = new Random(); + final int maxNumMasterStates = 5; + + for (int i = 0; i < 100; ++i) { + final long checkpointId = rnd.nextLong() & 0x7fffffffffffffffL; + + final Collection<TaskState> taskStates = Collections.emptyList(); + + final int numMasterStates = rnd.nextInt(maxNumMasterStates) + 1; + final Collection<MasterState> masterStates = + CheckpointTestUtils.createRandomMasterStates(rnd, numMasterStates); + + testCheckpointSerialization(checkpointId, taskStates, masterStates); + } + } + + @Test + public void testCheckpointWithOnlyTaskState() throws Exception { + final Random rnd = new Random(); + final int maxTaskStates = 20; + final int maxNumSubtasks = 20; + + for (int i = 0; i < 100; ++i) { + final long checkpointId = rnd.nextLong() & 0x7fffffffffffffffL; + + final int numTasks = rnd.nextInt(maxTaskStates) + 1; + final int numSubtasks = rnd.nextInt(maxNumSubtasks) + 1; + final Collection<TaskState> taskStates = + CheckpointTestUtils.createTaskStates(rnd, numTasks, numSubtasks); + + final Collection<MasterState> masterStates = Collections.emptyList(); + + testCheckpointSerialization(checkpointId, taskStates, masterStates); + } + } + + @Test + public void testCheckpointWithMasterAndTaskState() throws Exception { + final Random rnd = new Random(); + + final int maxNumMasterStates = 5; + final int maxTaskStates = 20; + final int maxNumSubtasks = 20; + + for (int i = 0; i < 100; ++i) { + final long checkpointId = rnd.nextLong() & 0x7fffffffffffffffL; + + final int numTasks = rnd.nextInt(maxTaskStates) + 1; + final int numSubtasks = rnd.nextInt(maxNumSubtasks) + 1; + final Collection<TaskState> taskStates = + CheckpointTestUtils.createTaskStates(rnd, numTasks, numSubtasks); + + final int numMasterStates = rnd.nextInt(maxNumMasterStates) + 1; + final Collection<MasterState> masterStates = + CheckpointTestUtils.createRandomMasterStates(rnd, numMasterStates); + + testCheckpointSerialization(checkpointId, taskStates, masterStates); + } + } + + private void testCheckpointSerialization( + long checkpointId, + Collection<TaskState> taskStates, + Collection<MasterState> masterStates) throws IOException { + + SavepointV2Serializer serializer = SavepointV2Serializer.INSTANCE; + + ByteArrayOutputStreamWithPos baos = new ByteArrayOutputStreamWithPos(); + DataOutputStream out = new DataOutputViewStreamWrapper(baos); + + serializer.serialize(new SavepointV2(checkpointId, taskStates, masterStates), out); + out.close(); + + byte[] bytes = baos.toByteArray(); + + DataInputStream in = new DataInputViewStreamWrapper(new ByteArrayInputStreamWithPos(bytes)); + SavepointV2 deserialized = serializer.deserialize(in, getClass().getClassLoader()); + + assertEquals(checkpointId, deserialized.getCheckpointId()); + assertEquals(taskStates, deserialized.getTaskStates()); + + assertEquals(masterStates.size(), deserialized.getMasterStates().size()); + for (Iterator<MasterState> a = masterStates.iterator(), b = deserialized.getMasterStates().iterator(); + a.hasNext();) + { + CheckpointTestUtils.assertMasterStateEquality(a.next(), b.next()); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Test.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Test.java new file mode 100644 index 0000000..428a62a --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Test.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.savepoint; + +import org.apache.flink.runtime.checkpoint.MasterState; +import org.apache.flink.runtime.checkpoint.TaskState; + +import org.junit.Test; + +import java.util.Collection; +import java.util.Random; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class SavepointV2Test { + + /** + * Simple test of savepoint methods. + */ + @Test + public void testSavepointV1() throws Exception { + final Random rnd = new Random(); + + final long checkpointId = rnd.nextInt(Integer.MAX_VALUE) + 1; + final int numTaskStates = 4; + final int numSubtaskStates = 16; + final int numMasterStates = 7; + + Collection<TaskState> taskStates = + CheckpointTestUtils.createTaskStates(rnd, numTaskStates, numSubtaskStates); + + Collection<MasterState> masterStates = + CheckpointTestUtils.createRandomMasterStates(rnd, numMasterStates); + + SavepointV2 checkpoint = new SavepointV2(checkpointId, taskStates, masterStates); + + assertEquals(2, checkpoint.getVersion()); + assertEquals(checkpointId, checkpoint.getCheckpointId()); + assertEquals(taskStates, checkpoint.getTaskStates()); + assertEquals(masterStates, checkpoint.getMasterStates()); + + assertFalse(checkpoint.getTaskStates().isEmpty()); + assertFalse(checkpoint.getMasterStates().isEmpty()); + + checkpoint.dispose(); + + assertTrue(checkpoint.getTaskStates().isEmpty()); + assertTrue(checkpoint.getMasterStates().isEmpty()); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ArchivedExecutionGraphTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ArchivedExecutionGraphTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ArchivedExecutionGraphTest.java index f96b624..4e1d0f7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ArchivedExecutionGraphTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ArchivedExecutionGraphTest.java @@ -31,6 +31,7 @@ import org.apache.flink.runtime.accumulators.StringifiedAccumulatorResult; import org.apache.flink.runtime.akka.AkkaUtils; import org.apache.flink.runtime.checkpoint.CheckpointStatsSnapshot; import org.apache.flink.runtime.checkpoint.CheckpointStatsTracker; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; import org.apache.flink.runtime.checkpoint.StandaloneCheckpointIDCounter; import org.apache.flink.runtime.checkpoint.StandaloneCompletedCheckpointStore; import org.apache.flink.runtime.execution.ExecutionState; @@ -128,6 +129,7 @@ public class ArchivedExecutionGraphTest { Collections.<ExecutionJobVertex>emptyList(), Collections.<ExecutionJobVertex>emptyList(), Collections.<ExecutionJobVertex>emptyList(), + Collections.<MasterTriggerRestoreHook<?>>emptyList(), new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), null, http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ExternallyInducedSource.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ExternallyInducedSource.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ExternallyInducedSource.java new file mode 100644 index 0000000..b26cf4f --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/ExternallyInducedSource.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.checkpoint; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.util.FlinkException; + +/** + * Sources that implement this interface do not trigger checkpoints when receiving a + * trigger message from the checkpoint coordinator, but when their input data/events + * indicate that a checkpoint should be triggered. + * + * <p>Since sources cannot simply create a new checkpoint on their own, this mechanism + * always goes together with a {@link WithMasterCheckpointHook hook on the master side}. + * In a typical setup, the hook on the master tells the source system (for example + * the message queue) to prepare a checkpoint. The exact point when the checkpoint is + * taken is then controlled by the event stream received from the source, and triggered + * by the source function (implementing this interface) in Flink when seeing the relevant + * events. + * + * @param <T> Type of the elements produced by the source function + * @param <CD> The type of the data stored in the checkpoint by the master that triggers + */ +@PublicEvolving +public interface ExternallyInducedSource<T, CD> extends SourceFunction<T>, WithMasterCheckpointHook<CD> { + + /** + * Sets the checkpoint trigger through which the source can trigger the checkpoint. + * + * @param checkpointTrigger The checkpoint trigger to set + */ + void setCheckpointTrigger(CheckpointTrigger checkpointTrigger); + + // ------------------------------------------------------------------------ + + /** + * Through the {@code CheckpointTrigger}, the source function notifies the Flink + * source operator when to trigger the checkpoint. + */ + interface CheckpointTrigger { + + /** + * Triggers a checkpoint. This method should be called by the source + * when it sees the event that indicates that a checkpoint should be triggered. + * + * <p>When this method is called, the parallel operator instance in which the + * calling source function runs will perform its checkpoint and insert the + * checkpoint barrier into the data stream. + * + * @param checkpointId The ID that identifies the checkpoint. + * + * @throws FlinkException Thrown when the checkpoint could not be triggered, for example + * because of an invalid state or errors when storing the + * checkpoint state. + */ + void triggerCheckpoint(long checkpointId) throws FlinkException; + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/WithMasterCheckpointHook.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/WithMasterCheckpointHook.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/WithMasterCheckpointHook.java new file mode 100644 index 0000000..ef872de --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/checkpoint/WithMasterCheckpointHook.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.checkpoint; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; + +/** + * This interface can be implemented by streaming functions that need to trigger a + * "global action" on the master (in the checkpoint coordinator) as part of every + * checkpoint and restore operation. + * + * @param <E> The type of the data stored by the hook in the checkpoint, or {@code Void}, if none. + */ +@PublicEvolving +public interface WithMasterCheckpointHook<E> extends java.io.Serializable { + + /** + * Creates the hook that should be called by the checkpoint coordinator. + */ + MasterTriggerRestoreHook<E> createMasterTriggerRestoreHook(); +} http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/FunctionMasterCheckpointHookFactory.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/FunctionMasterCheckpointHookFactory.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/FunctionMasterCheckpointHookFactory.java new file mode 100644 index 0000000..c256698 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/FunctionMasterCheckpointHookFactory.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.graph; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; +import org.apache.flink.streaming.api.checkpoint.WithMasterCheckpointHook; + +/** + * Utility class that turns a {@link WithMasterCheckpointHook} into a + * {@link org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook.Factory}. + */ +class FunctionMasterCheckpointHookFactory implements MasterTriggerRestoreHook.Factory { + + private static final long serialVersionUID = 2L; + + private final WithMasterCheckpointHook<?> creator; + + FunctionMasterCheckpointHookFactory(WithMasterCheckpointHook<?> creator) { + this.creator = checkNotNull(creator); + } + + @SuppressWarnings("unchecked") + @Override + public <V> MasterTriggerRestoreHook<V> create() { + return (MasterTriggerRestoreHook<V>) creator.createMasterTriggerRestoreHook(); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java index a1d33d8..7f24cd3 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java @@ -29,6 +29,7 @@ import java.util.Map.Entry; import org.apache.commons.lang3.StringUtils; import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.functions.Function; import org.apache.flink.api.common.operators.ResourceSpec; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; import org.apache.flink.api.common.restartstrategy.RestartStrategies; @@ -36,6 +37,7 @@ import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.IllegalConfigurationException; import org.apache.flink.migration.streaming.api.graph.StreamGraphHasherV1; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; import org.apache.flink.runtime.jobgraph.InputFormatVertex; @@ -51,7 +53,9 @@ import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup; import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup; import org.apache.flink.runtime.operators.util.TaskConfig; import org.apache.flink.streaming.api.CheckpointingMode; +import org.apache.flink.streaming.api.checkpoint.WithMasterCheckpointHook; import org.apache.flink.streaming.api.environment.CheckpointConfig; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; @@ -542,6 +546,8 @@ public class StreamingJobGraphGenerator { interval = Long.MAX_VALUE; } + // --- configure the participating vertices --- + // collect the vertices that receive "trigger checkpoint" messages. // currently, these are all the sources List<JobVertexID> triggerVertices = new ArrayList<>(); @@ -552,7 +558,7 @@ public class StreamingJobGraphGenerator { // collect the vertices that receive "commit checkpoint" messages // currently, these are all vertices - List<JobVertexID> commitVertices = new ArrayList<>(); + List<JobVertexID> commitVertices = new ArrayList<>(jobVertices.size()); for (JobVertex vertex : jobVertices.values()) { if (vertex.isInputVertex()) { @@ -562,6 +568,8 @@ public class StreamingJobGraphGenerator { ackVertices.add(vertex.getID()); } + // --- configure options --- + ExternalizedCheckpointSettings externalizedCheckpointSettings; if (cfg.isExternalizedCheckpointsEnabled()) { CheckpointConfig.ExternalizedCheckpointCleanup cleanup = cfg.getExternalizedCheckpointCleanup(); @@ -587,12 +595,30 @@ public class StreamingJobGraphGenerator { "exactly-once or at-least-once."); } + // --- configure the master-side checkpoint hooks --- + + final ArrayList<MasterTriggerRestoreHook.Factory> hooks = new ArrayList<>(); + + for (StreamNode node : streamGraph.getStreamNodes()) { + StreamOperator<?> op = node.getOperator(); + if (op instanceof AbstractUdfStreamOperator) { + Function f = ((AbstractUdfStreamOperator<?, ?>) op).getUserFunction(); + + if (f instanceof WithMasterCheckpointHook) { + hooks.add(new FunctionMasterCheckpointHookFactory((WithMasterCheckpointHook<?>) f)); + } + } + } + + // --- done, put it all together --- + JobCheckpointingSettings settings = new JobCheckpointingSettings( triggerVertices, ackVertices, commitVertices, interval, cfg.getCheckpointTimeout(), cfg.getMinPauseBetweenCheckpoints(), cfg.getMaxConcurrentCheckpoints(), externalizedCheckpointSettings, streamGraph.getStateBackend(), + hooks.toArray(new MasterTriggerRestoreHook.Factory[hooks.size()]), isExactlyOnce); jobGraph.setSnapshotSettings(settings); http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java index 66e92df..31cd7c1 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java @@ -19,8 +19,12 @@ package org.apache.flink.streaming.runtime.tasks; import org.apache.flink.annotation.Internal; +import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.streaming.api.checkpoint.ExternallyInducedSource; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.operators.StreamSource; +import org.apache.flink.util.FlinkException; /** * {@link StreamTask} for executing a {@link StreamSource}. @@ -40,9 +44,44 @@ import org.apache.flink.streaming.api.operators.StreamSource; public class SourceStreamTask<OUT, SRC extends SourceFunction<OUT>, OP extends StreamSource<OUT, SRC>> extends StreamTask<OUT, OP> { + private volatile boolean externallyInducedCheckpoints; + @Override protected void init() { // does not hold any resources, so no initialization needed + + // we check if the source is actually inducing the checkpoints, rather + // than the trigger ch + SourceFunction<?> source = headOperator.getUserFunction(); + if (source instanceof ExternallyInducedSource) { + externallyInducedCheckpoints = true; + + ExternallyInducedSource.CheckpointTrigger triggerHook = new ExternallyInducedSource.CheckpointTrigger() { + + @Override + public void triggerCheckpoint(long checkpointId) throws FlinkException { + // TODO - we need to see how to derive those. We should probably not encode this in the + // TODO - source's trigger message, but do a handshake in this task between the trigger + // TODO - message from the master, and the source's trigger notification + final CheckpointOptions checkpointOptions = CheckpointOptions.forFullCheckpoint(); + final long timestamp = System.currentTimeMillis(); + + final CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, timestamp); + + try { + SourceStreamTask.super.triggerCheckpoint(checkpointMetaData, checkpointOptions); + } + catch (RuntimeException | FlinkException e) { + throw e; + } + catch (Exception e) { + throw new FlinkException(e.getMessage(), e); + } + } + }; + + ((ExternallyInducedSource<?, ?>) source).setCheckpointTrigger(triggerHook); + } } @Override @@ -62,4 +101,21 @@ public class SourceStreamTask<OUT, SRC extends SourceFunction<OUT>, OP extends S headOperator.cancel(); } } + + // ------------------------------------------------------------------------ + // Checkpointing + // ------------------------------------------------------------------------ + + @Override + public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) throws Exception { + if (!externallyInducedCheckpoints) { + return super.triggerCheckpoint(checkpointMetaData, checkpointOptions); + } + else { + // we do not trigger checkpoints here, we simply state whether we can trigger them + synchronized (getCheckpointLock()) { + return isRunning(); + } + } + } } http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/WithMasterCheckpointHookConfigTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/WithMasterCheckpointHookConfigTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/WithMasterCheckpointHookConfigTest.java new file mode 100644 index 0000000..b5a95eb --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/WithMasterCheckpointHookConfigTest.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.graph; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.core.io.SimpleVersionedSerializer; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook; +import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook.Factory; +import org.apache.flink.runtime.concurrent.Future; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.streaming.api.checkpoint.WithMasterCheckpointHook; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.DiscardingSink; +import org.apache.flink.streaming.api.functions.source.SourceFunction; + +import org.junit.Test; + +import javax.annotation.Nullable; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.Executor; + +import static java.util.Arrays.asList; +import static org.junit.Assert.*; + +/** + * Tests that when sources implement {@link WithMasterCheckpointHook} the hooks are properly + * configured in the job's checkpoint settings. + */ +@SuppressWarnings("serial") +public class WithMasterCheckpointHookConfigTest { + + /** + * This test creates a program with 4 sources (2 with master hooks, 2 without). + * The resulting job graph must have 2 configured master hooks. + */ + @Test + public void testHookConfiguration() throws Exception { + // create some sources some of which configure master hooks + final TestSource source1 = new TestSource(); + final TestSourceWithHook source2 = new TestSourceWithHook("foo"); + final TestSource source3 = new TestSource(); + final TestSourceWithHook source4 = new TestSourceWithHook("bar"); + + final MapFunction<String, String> identity = new Identity<>(); + final IdentityWithHook<String> identityWithHook1 = new IdentityWithHook<>("apple"); + final IdentityWithHook<String> identityWithHook2 = new IdentityWithHook<>("orange"); + + final Set<MasterTriggerRestoreHook<?>> hooks = new HashSet<MasterTriggerRestoreHook<?>>(asList( + source2.createMasterTriggerRestoreHook(), + source4.createMasterTriggerRestoreHook(), + identityWithHook1.createMasterTriggerRestoreHook(), + identityWithHook2.createMasterTriggerRestoreHook())); + + // we can instantiate a local environment here, because we never actually execute something + final StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironment(); + env.enableCheckpointing(500); + + env + .addSource(source1).map(identity) + .union(env.addSource(source2).map(identity)) + .union(env.addSource(source3).map(identityWithHook1)) + .union(env.addSource(source4).map(identityWithHook2)) + .addSink(new DiscardingSink<String>()); + + final JobGraph jg = env.getStreamGraph().getJobGraph(); + assertEquals(hooks.size(), jg.getCheckpointingSettings().getMasterHooks().length); + + // check that all hooks are contained and exist exactly once + for (Factory f : jg.getCheckpointingSettings().getMasterHooks()) { + MasterTriggerRestoreHook<?> hook = f.create(); + assertTrue(hooks.remove(hook)); + } + assertTrue(hooks.isEmpty()); + } + + // ----------------------------------------------------------------------- + + private static class TestHook implements MasterTriggerRestoreHook<String> { + + private final String id; + + TestHook(String id) { + this.id = id; + } + + @Override + public String getIdentifier() { + return id; + } + + @Override + public Future<String> triggerCheckpoint(long checkpointId, long timestamp, Executor executor) { + throw new UnsupportedOperationException(); + } + + @Override + public void restoreCheckpoint(long checkpointId, @Nullable String checkpointData) throws Exception { + throw new UnsupportedOperationException(); + } + + @Nullable + @Override + public SimpleVersionedSerializer<String> createCheckpointDataSerializer() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean equals(Object obj) { + return obj == this || (obj != null && obj.getClass() == getClass() && ((TestHook) obj).id.equals(id)); + } + + @Override + public int hashCode() { + return id.hashCode(); + } + } + + // ----------------------------------------------------------------------- + + private static class TestSource implements SourceFunction<String> { + + @Override + public void run(SourceContext<String> ctx) { + throw new UnsupportedOperationException(); + } + + @Override + public void cancel() {} + } + + // ----------------------------------------------------------------------- + + private static class TestSourceWithHook extends TestSource implements WithMasterCheckpointHook<String> { + + private final String id; + + TestSourceWithHook(String id) { + this.id = id; + } + + @Override + public TestHook createMasterTriggerRestoreHook() { + return new TestHook(id); + } + } + + // ----------------------------------------------------------------------- + + private static class Identity<T> implements MapFunction<T, T> { + + @Override + public T map(T value) { + return value; + } + } + + // ----------------------------------------------------------------------- + + private static class IdentityWithHook<T> extends Identity<T> implements WithMasterCheckpointHook<String> { + + private final String id; + + IdentityWithHook(String id) { + this.id = id; + } + + @Override + public TestHook createMasterTriggerRestoreHook() { + return new TestHook(id); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/90ca4381/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamRecordWriterTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamRecordWriterTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamRecordWriterTest.java index 38741ba..54cd186 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamRecordWriterTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamRecordWriterTest.java @@ -29,12 +29,9 @@ import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; import org.apache.flink.types.LongValue; import org.junit.Test; -import org.junit.runner.RunWith; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; import java.io.IOException; @@ -45,8 +42,6 @@ import static org.mockito.Mockito.*; * This test uses the PowerMockRunner runner to work around the fact that the * {@link ResultPartitionWriter} class is final. */ -@RunWith(PowerMockRunner.class) -@PrepareForTest(ResultPartitionWriter.class) public class StreamRecordWriterTest { /**
