This is an automated email from the ASF dual-hosted git repository.

cameronlee pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/samza.git


The following commit(s) were added to refs/heads/master by this push:
     new dcdd0ec  SAMZA-2695: Unit tests for KafkaCheckpointManager take too 
long to run (#1541)
dcdd0ec is described below

commit dcdd0ec9de223753a9146e8f7bc535575afb738c
Author: Cameron Lee <[email protected]>
AuthorDate: Fri Oct 8 11:12:19 2021 -0700

    SAMZA-2695: Unit tests for KafkaCheckpointManager take too long to run 
(#1541)
---
 .../kafka/TestKafkaCheckpointManager.java          | 561 +++++++++++++++++++++
 .../kafka/TestKafkaCheckpointManagerJava.java      | 285 -----------
 .../kafka/TestKafkaCheckpointManager.scala         | 533 --------------------
 .../samza/test/harness/IntegrationTestHarness.java |   1 +
 .../KafkaCheckpointManagerIntegrationTest.java     | 206 ++++++++
 5 files changed, 768 insertions(+), 818 deletions(-)

diff --git 
a/samza-kafka/src/test/java/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.java
 
b/samza-kafka/src/test/java/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.java
new file mode 100644
index 0000000..fe9bfb1
--- /dev/null
+++ 
b/samza-kafka/src/test/java/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.java
@@ -0,0 +1,561 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.checkpoint.kafka;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import kafka.common.TopicAlreadyMarkedForDeletionException;
+import org.apache.samza.Partition;
+import org.apache.samza.SamzaException;
+import org.apache.samza.checkpoint.Checkpoint;
+import org.apache.samza.checkpoint.CheckpointId;
+import org.apache.samza.checkpoint.CheckpointV1;
+import org.apache.samza.checkpoint.CheckpointV2;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.config.TaskConfig;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.container.grouper.stream.GroupByPartitionFactory;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.serializers.CheckpointV1Serde;
+import org.apache.samza.serializers.CheckpointV2Serde;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.OutgoingMessageEnvelope;
+import org.apache.samza.system.StreamValidationException;
+import org.apache.samza.system.SystemAdmin;
+import org.apache.samza.system.SystemConsumer;
+import org.apache.samza.system.SystemFactory;
+import org.apache.samza.system.SystemProducer;
+import org.apache.samza.system.SystemStreamMetadata;
+import 
org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.system.kafka.KafkaStreamSpec;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.stubbing.OngoingStubbing;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.when;
+
+
+public class TestKafkaCheckpointManager {
+  private static final TaskName TASK0 = new TaskName("Partition 0");
+  private static final TaskName TASK1 = new TaskName("Partition 1");
+  private static final String CHECKPOINT_TOPIC = "checkpointTopic";
+  private static final String CHECKPOINT_SYSTEM = "checkpointSystem";
+  private static final SystemStreamPartition CHECKPOINT_SSP =
+      new SystemStreamPartition(CHECKPOINT_SYSTEM, CHECKPOINT_TOPIC, new 
Partition(0));
+  private static final SystemStreamPartition INPUT_SSP0 =
+      new SystemStreamPartition("inputSystem", "inputTopic", new Partition(0));
+  private static final SystemStreamPartition INPUT_SSP1 =
+      new SystemStreamPartition("inputSystem", "inputTopic", new Partition(1));
+  private static final String GROUPER_FACTORY_CLASS = 
GroupByPartitionFactory.class.getCanonicalName();
+  private static final KafkaStreamSpec CHECKPOINT_SPEC =
+      new KafkaStreamSpec(CHECKPOINT_TOPIC, CHECKPOINT_TOPIC, 
CHECKPOINT_SYSTEM, 1);
+  private static final CheckpointV1Serde CHECKPOINT_V1_SERDE = new 
CheckpointV1Serde();
+  private static final CheckpointV2Serde CHECKPOINT_V2_SERDE = new 
CheckpointV2Serde();
+  private static final KafkaCheckpointLogKeySerde 
KAFKA_CHECKPOINT_LOG_KEY_SERDE = new KafkaCheckpointLogKeySerde();
+
+  @Mock
+  private SystemProducer systemProducer;
+  @Mock
+  private SystemConsumer systemConsumer;
+  @Mock
+  private SystemAdmin systemAdmin;
+  @Mock
+  private SystemAdmin createResourcesSystemAdmin;
+  @Mock
+  private SystemFactory systemFactory;
+  @Mock
+  private MetricsRegistry metricsRegistry;
+
+  @Before
+  public void setup() {
+    MockitoAnnotations.initMocks(this);
+  }
+
+  @Test(expected = TopicAlreadyMarkedForDeletionException.class)
+  public void testCreateResourcesTopicCreationError() {
+    setupSystemFactory(config());
+    // throw an exception during createStream
+    doThrow(new TopicAlreadyMarkedForDeletionException("invalid 
stream")).when(this.createResourcesSystemAdmin)
+        .createStream(CHECKPOINT_SPEC);
+    KafkaCheckpointManager checkpointManager = 
buildKafkaCheckpointManager(true, config());
+    // expect an exception during startup
+    checkpointManager.createResources();
+  }
+
+  @Test(expected = StreamValidationException.class)
+  public void testCreateResourcesTopicValidationError() {
+    setupSystemFactory(config());
+    // throw an exception during validateStream
+    doThrow(new StreamValidationException("invalid 
stream")).when(this.createResourcesSystemAdmin)
+        .validateStream(CHECKPOINT_SPEC);
+    KafkaCheckpointManager checkpointManager = 
buildKafkaCheckpointManager(true, config());
+    // expect an exception during startup
+    checkpointManager.createResources();
+  }
+
+  @Test(expected = SamzaException.class)
+  public void testReadFailsOnSerdeExceptions() throws InterruptedException {
+    setupSystemFactory(config());
+    List<IncomingMessageEnvelope> checkpointEnvelopes =
+        ImmutableList.of(newCheckpointV1Envelope(TASK0, 
buildCheckpointV1(INPUT_SSP0, "0"), "0"));
+    setupConsumer(checkpointEnvelopes);
+    // wire up an exception throwing serde with the checkpointManager
+    CheckpointV1Serde checkpointV1Serde = mock(CheckpointV1Serde.class);
+    doThrow(new RuntimeException("serde 
failed")).when(checkpointV1Serde).fromBytes(any());
+    KafkaCheckpointManager checkpointManager =
+        new KafkaCheckpointManager(CHECKPOINT_SPEC, this.systemFactory, true, 
config(), this.metricsRegistry,
+            checkpointV1Serde, CHECKPOINT_V2_SERDE, 
KAFKA_CHECKPOINT_LOG_KEY_SERDE);
+    checkpointManager.register(TASK0);
+
+    // expect an exception
+    checkpointManager.readLastCheckpoint(TASK0);
+  }
+
+  @Test
+  public void testReadSucceedsOnKeySerdeExceptionsWhenValidationIsDisabled() 
throws InterruptedException {
+    setupSystemFactory(config());
+    List<IncomingMessageEnvelope> checkpointEnvelopes =
+        ImmutableList.of(newCheckpointV1Envelope(TASK0, 
buildCheckpointV1(INPUT_SSP0, "0"), "0"));
+    setupConsumer(checkpointEnvelopes);
+    // wire up an exception throwing serde with the checkpointManager
+    CheckpointV1Serde checkpointV1Serde = mock(CheckpointV1Serde.class);
+    doThrow(new RuntimeException("serde 
failed")).when(checkpointV1Serde).fromBytes(any());
+    KafkaCheckpointManager checkpointManager =
+        new KafkaCheckpointManager(CHECKPOINT_SPEC, this.systemFactory, false, 
config(), this.metricsRegistry,
+            checkpointV1Serde, CHECKPOINT_V2_SERDE, 
KAFKA_CHECKPOINT_LOG_KEY_SERDE);
+    checkpointManager.register(TASK0);
+
+    // expect the read to succeed in spite of the exception from 
ExceptionThrowingSerde
+    assertNull(checkpointManager.readLastCheckpoint(TASK0));
+  }
+
+  @Test
+  public void testStart() {
+    setupSystemFactory(config());
+    String oldestOffset = "1";
+    String newestOffset = "2";
+    SystemStreamMetadata checkpointTopicMetadata = new 
SystemStreamMetadata(CHECKPOINT_TOPIC,
+        ImmutableMap.of(new Partition(0), new 
SystemStreamPartitionMetadata(oldestOffset, newestOffset,
+            Integer.toString(Integer.parseInt(newestOffset) + 1))));
+    
when(this.systemAdmin.getSystemStreamMetadata(Collections.singleton(CHECKPOINT_TOPIC))).thenReturn(
+        ImmutableMap.of(CHECKPOINT_TOPIC, checkpointTopicMetadata));
+
+    KafkaCheckpointManager checkpointManager = 
buildKafkaCheckpointManager(true, config());
+
+    checkpointManager.start();
+
+    verify(this.systemProducer).start();
+    verify(this.systemAdmin).start();
+    verify(this.systemConsumer).register(CHECKPOINT_SSP, oldestOffset);
+    verify(this.systemConsumer).start();
+  }
+
+  @Test
+  public void testRegister() {
+    setupSystemFactory(config());
+    KafkaCheckpointManager kafkaCheckpointManager = 
buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+    verify(this.systemProducer).register(TASK0.getTaskName());
+  }
+
+  @Test
+  public void testStop() {
+    setupSystemFactory(config());
+    KafkaCheckpointManager checkpointManager = 
buildKafkaCheckpointManager(true, config());
+    checkpointManager.stop();
+    verify(this.systemProducer).stop();
+    // default configuration for stopConsumerAfterFirstRead means that 
consumer is not stopped here
+    verify(this.systemConsumer, never()).stop();
+    verify(this.systemAdmin).stop();
+  }
+
+  @Test
+  public void testWriteCheckpointShouldRecreateSystemProducerOnFailure() {
+    setupSystemFactory(config());
+    SystemProducer secondKafkaProducer = mock(SystemProducer.class);
+    // override default mock behavior to return a second producer on the 
second call to create a producer
+    when(this.systemFactory.getProducer(CHECKPOINT_SYSTEM, config(), 
this.metricsRegistry,
+        
KafkaCheckpointManager.class.getSimpleName())).thenReturn(this.systemProducer, 
secondKafkaProducer);
+    // first producer throws an exception on flush
+    doThrow(new RuntimeException("flush 
failed")).when(this.systemProducer).flush(TASK0.getTaskName());
+    KafkaCheckpointManager kafkaCheckpointManager = 
buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+
+    CheckpointV1 checkpointV1 = buildCheckpointV1(INPUT_SSP0, "0");
+    kafkaCheckpointManager.writeCheckpoint(TASK0, checkpointV1);
+
+    // first producer should be stopped
+    verify(this.systemProducer).stop();
+    // register and start the second producer
+    verify(secondKafkaProducer).register(TASK0.getTaskName());
+    verify(secondKafkaProducer).start();
+    // check that the second producer was given the message to send out
+    ArgumentCaptor<OutgoingMessageEnvelope> 
outgoingMessageEnvelopeArgumentCaptor =
+        ArgumentCaptor.forClass(OutgoingMessageEnvelope.class);
+    verify(secondKafkaProducer).send(eq(TASK0.getTaskName()), 
outgoingMessageEnvelopeArgumentCaptor.capture());
+    assertEquals(CHECKPOINT_SSP, 
outgoingMessageEnvelopeArgumentCaptor.getValue().getSystemStream());
+    assertEquals(new 
KafkaCheckpointLogKey(KafkaCheckpointLogKey.CHECKPOINT_V1_KEY_TYPE, TASK0, 
GROUPER_FACTORY_CLASS),
+        KAFKA_CHECKPOINT_LOG_KEY_SERDE.fromBytes((byte[]) 
outgoingMessageEnvelopeArgumentCaptor.getValue().getKey()));
+    assertEquals(checkpointV1,
+        CHECKPOINT_V1_SERDE.fromBytes((byte[]) 
outgoingMessageEnvelopeArgumentCaptor.getValue().getMessage()));
+    verify(secondKafkaProducer).flush(TASK0.getTaskName());
+  }
+
+  @Test
+  public void testCreateResources() {
+    setupSystemFactory(config());
+    KafkaCheckpointManager kafkaCheckpointManager = 
buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.createResources();
+
+    verify(this.createResourcesSystemAdmin).start();
+    verify(this.createResourcesSystemAdmin).createStream(CHECKPOINT_SPEC);
+    verify(this.createResourcesSystemAdmin).validateStream(CHECKPOINT_SPEC);
+    verify(this.createResourcesSystemAdmin).stop();
+  }
+
+  @Test
+  public void testCreateResourcesSkipValidation() {
+    setupSystemFactory(config());
+    KafkaCheckpointManager kafkaCheckpointManager = 
buildKafkaCheckpointManager(false, config());
+    kafkaCheckpointManager.createResources();
+
+    verify(this.createResourcesSystemAdmin).start();
+    verify(this.createResourcesSystemAdmin).createStream(CHECKPOINT_SPEC);
+    verify(this.createResourcesSystemAdmin, 
never()).validateStream(CHECKPOINT_SPEC);
+    verify(this.createResourcesSystemAdmin).stop();
+  }
+
+  @Test
+  public void testReadEmpty() throws InterruptedException {
+    setupSystemFactory(config());
+    setupConsumer(ImmutableList.of());
+    KafkaCheckpointManager kafkaCheckpointManager = 
buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+    assertNull(kafkaCheckpointManager.readLastCheckpoint(TASK0));
+  }
+
+  @Test
+  public void testReadCheckpointV1() throws InterruptedException {
+    setupSystemFactory(config());
+    CheckpointV1 checkpointV1 = buildCheckpointV1(INPUT_SSP0, "0");
+    List<IncomingMessageEnvelope> checkpointEnvelopes =
+        ImmutableList.of(newCheckpointV1Envelope(TASK0, checkpointV1, "0"));
+    setupConsumer(checkpointEnvelopes);
+    KafkaCheckpointManager kafkaCheckpointManager = 
buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+    Checkpoint actualCheckpoint = 
kafkaCheckpointManager.readLastCheckpoint(TASK0);
+    assertEquals(checkpointV1, actualCheckpoint);
+  }
+
+  @Test
+  public void testReadIgnoreCheckpointV2WhenV1Enabled() throws 
InterruptedException {
+    setupSystemFactory(config());
+    CheckpointV1 checkpointV1 = buildCheckpointV1(INPUT_SSP0, "0");
+    List<IncomingMessageEnvelope> checkpointEnvelopes =
+        ImmutableList.of(newCheckpointV1Envelope(TASK0, checkpointV1, "0"),
+            newCheckpointV2Envelope(TASK0, buildCheckpointV2(INPUT_SSP0, "1"), 
"1"));
+    setupConsumer(checkpointEnvelopes);
+    // default is to only read CheckpointV1
+    KafkaCheckpointManager kafkaCheckpointManager = 
buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+    Checkpoint actualCheckpoint = 
kafkaCheckpointManager.readLastCheckpoint(TASK0);
+    assertEquals(checkpointV1, actualCheckpoint);
+  }
+
+  @Test
+  public void testReadCheckpointV2() throws InterruptedException {
+    Config config = 
config(ImmutableMap.of(TaskConfig.CHECKPOINT_READ_VERSIONS, "1,2"));
+    setupSystemFactory(config);
+    CheckpointV2 checkpointV2 = buildCheckpointV2(INPUT_SSP0, "0");
+    List<IncomingMessageEnvelope> checkpointEnvelopes =
+        ImmutableList.of(newCheckpointV2Envelope(TASK0, checkpointV2, "0"));
+    setupConsumer(checkpointEnvelopes);
+    KafkaCheckpointManager kafkaCheckpointManager = 
buildKafkaCheckpointManager(true, config);
+    kafkaCheckpointManager.register(TASK0);
+    Checkpoint actualCheckpoint = 
kafkaCheckpointManager.readLastCheckpoint(TASK0);
+    assertEquals(checkpointV2, actualCheckpoint);
+  }
+
+  @Test
+  public void testReadCheckpointPriority() throws InterruptedException {
+    Config config = 
config(ImmutableMap.of(TaskConfig.CHECKPOINT_READ_VERSIONS, "2,1"));
+    setupSystemFactory(config);
+    CheckpointV2 checkpointV2 = buildCheckpointV2(INPUT_SSP0, "1");
+    List<IncomingMessageEnvelope> checkpointEnvelopes =
+        ImmutableList.of(newCheckpointV1Envelope(TASK0, 
buildCheckpointV1(INPUT_SSP0, "0"), "0"),
+            newCheckpointV2Envelope(TASK0, checkpointV2, "1"));
+    setupConsumer(checkpointEnvelopes);
+    KafkaCheckpointManager kafkaCheckpointManager = 
buildKafkaCheckpointManager(true, config);
+    kafkaCheckpointManager.register(TASK0);
+    Checkpoint actualCheckpoint = 
kafkaCheckpointManager.readLastCheckpoint(TASK0);
+    assertEquals(checkpointV2, actualCheckpoint);
+  }
+
+  @Test
+  public void testReadMultipleCheckpointsMultipleSSP() throws 
InterruptedException {
+    setupSystemFactory(config());
+    KafkaCheckpointManager checkpointManager = 
buildKafkaCheckpointManager(true, config());
+    checkpointManager.register(TASK0);
+    checkpointManager.register(TASK1);
+
+    // mock out a consumer that returns 5 checkpoint IMEs for each SSP
+    int newestOffset = 5;
+    int checkpointOffsetCounter = 0;
+    List<List<IncomingMessageEnvelope>> pollOutputs = new ArrayList<>();
+    for (int offset = 1; offset <= newestOffset; offset++) {
+      pollOutputs.add(ImmutableList.of(
+          // use regular offset value for INPUT_SSP0
+          newCheckpointV1Envelope(TASK0, buildCheckpointV1(INPUT_SSP0, 
Integer.toString(offset)),
+              Integer.toString(checkpointOffsetCounter++)),
+          // use (offset * 2) value for INPUT_SSP1 so offsets are different 
from INPUT_SSP0
+          newCheckpointV1Envelope(TASK1, buildCheckpointV1(INPUT_SSP1, 
Integer.toString(offset * 2)),
+              Integer.toString(checkpointOffsetCounter++))));
+    }
+    setupConsumerMultiplePoll(pollOutputs);
+
+    assertEquals(buildCheckpointV1(INPUT_SSP0, Integer.toString(newestOffset)),
+        checkpointManager.readLastCheckpoint(TASK0));
+    assertEquals(buildCheckpointV1(INPUT_SSP1, Integer.toString(newestOffset * 
2)),
+        checkpointManager.readLastCheckpoint(TASK1));
+    // check expected number of polls (+1 is for the final empty poll), and 
the checkpoint is the newest message
+    verify(this.systemConsumer, times(newestOffset + 
1)).poll(ImmutableSet.of(CHECKPOINT_SSP),
+        SystemConsumer.BLOCK_ON_OUTSTANDING_MESSAGES);
+  }
+
+  @Test
+  public void testReadMultipleCheckpointsUpgradeCheckpointVersion() throws 
InterruptedException {
+    Config config = 
config(ImmutableMap.of(TaskConfig.CHECKPOINT_READ_VERSIONS, "2,1"));
+    setupSystemFactory(config);
+    KafkaCheckpointManager kafkaCheckpointManager = 
buildKafkaCheckpointManager(true, config);
+    kafkaCheckpointManager.register(TASK0);
+    kafkaCheckpointManager.register(TASK1);
+
+    List<IncomingMessageEnvelope> checkpointEnvelopesV1 =
+        ImmutableList.of(newCheckpointV1Envelope(TASK0, 
buildCheckpointV1(INPUT_SSP0, "0"), "0"),
+            newCheckpointV1Envelope(TASK1, buildCheckpointV1(INPUT_SSP1, "0"), 
"1"));
+    CheckpointV2 ssp0CheckpointV2 = buildCheckpointV2(INPUT_SSP0, "10");
+    CheckpointV2 ssp1CheckpointV2 = buildCheckpointV2(INPUT_SSP1, "11");
+    List<IncomingMessageEnvelope> checkpointEnvelopesV2 =
+        ImmutableList.of(newCheckpointV2Envelope(TASK0, ssp0CheckpointV2, "2"),
+            newCheckpointV2Envelope(TASK1, ssp1CheckpointV2, "3"));
+    setupConsumerMultiplePoll(ImmutableList.of(checkpointEnvelopesV1, 
checkpointEnvelopesV2));
+    assertEquals(ssp0CheckpointV2, 
kafkaCheckpointManager.readLastCheckpoint(TASK0));
+    assertEquals(ssp1CheckpointV2, 
kafkaCheckpointManager.readLastCheckpoint(TASK1));
+    // 2 polls for actual checkpoints, 1 final empty poll
+    verify(this.systemConsumer, times(3)).poll(ImmutableSet.of(CHECKPOINT_SSP),
+        SystemConsumer.BLOCK_ON_OUTSTANDING_MESSAGES);
+  }
+
+  @Test
+  public void testWriteCheckpointV1() {
+    setupSystemFactory(config());
+    KafkaCheckpointManager kafkaCheckpointManager = 
buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+    CheckpointV1 checkpointV1 = buildCheckpointV1(INPUT_SSP0, "0");
+    kafkaCheckpointManager.writeCheckpoint(TASK0, checkpointV1);
+    ArgumentCaptor<OutgoingMessageEnvelope> 
outgoingMessageEnvelopeArgumentCaptor =
+        ArgumentCaptor.forClass(OutgoingMessageEnvelope.class);
+    verify(this.systemProducer).send(eq(TASK0.getTaskName()), 
outgoingMessageEnvelopeArgumentCaptor.capture());
+    assertEquals(CHECKPOINT_SSP, 
outgoingMessageEnvelopeArgumentCaptor.getValue().getSystemStream());
+    assertEquals(new 
KafkaCheckpointLogKey(KafkaCheckpointLogKey.CHECKPOINT_V1_KEY_TYPE, TASK0, 
GROUPER_FACTORY_CLASS),
+        KAFKA_CHECKPOINT_LOG_KEY_SERDE.fromBytes((byte[]) 
outgoingMessageEnvelopeArgumentCaptor.getValue().getKey()));
+    assertEquals(checkpointV1,
+        CHECKPOINT_V1_SERDE.fromBytes((byte[]) 
outgoingMessageEnvelopeArgumentCaptor.getValue().getMessage()));
+    verify(this.systemProducer).flush(TASK0.getTaskName());
+  }
+
+  @Test
+  public void testWriteCheckpointV2() {
+    setupSystemFactory(config());
+    KafkaCheckpointManager kafkaCheckpointManager = 
buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+    CheckpointV2 checkpointV2 = buildCheckpointV2(INPUT_SSP0, "0");
+    kafkaCheckpointManager.writeCheckpoint(TASK0, checkpointV2);
+    ArgumentCaptor<OutgoingMessageEnvelope> 
outgoingMessageEnvelopeArgumentCaptor =
+        ArgumentCaptor.forClass(OutgoingMessageEnvelope.class);
+    verify(this.systemProducer).send(eq(TASK0.getTaskName()), 
outgoingMessageEnvelopeArgumentCaptor.capture());
+    assertEquals(CHECKPOINT_SSP, 
outgoingMessageEnvelopeArgumentCaptor.getValue().getSystemStream());
+    assertEquals(new 
KafkaCheckpointLogKey(KafkaCheckpointLogKey.CHECKPOINT_V2_KEY_TYPE, TASK0, 
GROUPER_FACTORY_CLASS),
+        KAFKA_CHECKPOINT_LOG_KEY_SERDE.fromBytes((byte[]) 
outgoingMessageEnvelopeArgumentCaptor.getValue().getKey()));
+    assertEquals(checkpointV2,
+        CHECKPOINT_V2_SERDE.fromBytes((byte[]) 
outgoingMessageEnvelopeArgumentCaptor.getValue().getMessage()));
+    verify(this.systemProducer).flush(TASK0.getTaskName());
+  }
+
+  @Test
+  public void testWriteCheckpointShouldRetryFiniteTimesOnFailure() {
+    setupSystemFactory(config());
+    doThrow(new RuntimeException("send 
failed")).when(this.systemProducer).send(any(), any());
+    KafkaCheckpointManager kafkaCheckpointManager = 
buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+    kafkaCheckpointManager.MaxRetryDurationInMillis_$eq(100); // setter for 
scala var MaxRetryDurationInMillis
+    CheckpointV2 checkpointV2 = buildCheckpointV2(INPUT_SSP0, "0");
+    try {
+      kafkaCheckpointManager.writeCheckpoint(TASK0, checkpointV2);
+      fail("Expected to throw SamzaException");
+    } catch (SamzaException e) {
+      // expected to get here
+    }
+    // one call to send which fails, then writeCheckpoint gives up
+    verify(this.systemProducer).send(any(), any());
+    verify(this.systemProducer, never()).flush(any());
+  }
+
+  @Test
+  public void testConsumerStopsAfterInitialRead() throws Exception {
+    setupSystemFactory(config());
+    CheckpointV1 checkpointV1 = buildCheckpointV1(INPUT_SSP0, "0");
+    setupConsumer(ImmutableList.of(newCheckpointV1Envelope(TASK0, 
checkpointV1, "0")));
+    KafkaCheckpointManager kafkaCheckpointManager = 
buildKafkaCheckpointManager(true, config());
+    kafkaCheckpointManager.register(TASK0);
+    assertEquals(checkpointV1, 
kafkaCheckpointManager.readLastCheckpoint(TASK0));
+    // 1 call to get actual checkpoints, 1 call for empty poll to signal done 
reading
+    verify(this.systemConsumer, 
times(2)).poll(ImmutableSet.of(CHECKPOINT_SSP), 
SystemConsumer.BLOCK_ON_OUTSTANDING_MESSAGES);
+    verify(this.systemConsumer).stop();
+    // reading checkpoint again should not read more messages from the consumer
+    assertEquals(checkpointV1, 
kafkaCheckpointManager.readLastCheckpoint(TASK0));
+    verifyNoMoreInteractions(this.systemConsumer);
+  }
+
+  @Test
+  public void testConsumerStopsAfterInitialReadDisabled() throws Exception {
+    Config config =
+        
config(ImmutableMap.of(TaskConfig.INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ,
 "false"));
+    setupSystemFactory(config);
+    // 1) return checkpointV1 for INPUT_SSP
+    CheckpointV1 ssp0FirstCheckpointV1 = buildCheckpointV1(INPUT_SSP0, "0");
+    List<IncomingMessageEnvelope> checkpointEnvelopes0 =
+        ImmutableList.of(newCheckpointV1Envelope(TASK0, 
buildCheckpointV1(INPUT_SSP0, "0"), "0"));
+    setupConsumer(checkpointEnvelopes0);
+    KafkaCheckpointManager kafkaCheckpointManager = 
buildKafkaCheckpointManager(true, config);
+    kafkaCheckpointManager.register(TASK0);
+    assertEquals(ssp0FirstCheckpointV1, 
kafkaCheckpointManager.readLastCheckpoint(TASK0));
+
+    // 2) return new checkpointV1 for just INPUT_SSP
+    CheckpointV1 ssp0SecondCheckpointV1 = buildCheckpointV1(INPUT_SSP0, "10");
+    List<IncomingMessageEnvelope> checkpointEnvelopes1 =
+        ImmutableList.of(newCheckpointV1Envelope(TASK0, 
ssp0SecondCheckpointV1, "1"));
+    setupConsumer(checkpointEnvelopes1);
+    assertEquals(ssp0SecondCheckpointV1, 
kafkaCheckpointManager.readLastCheckpoint(TASK0));
+
+    verify(this.systemConsumer, never()).stop();
+  }
+
+  private KafkaCheckpointManager buildKafkaCheckpointManager(boolean 
validateCheckpoint, Config config) {
+    return new KafkaCheckpointManager(CHECKPOINT_SPEC, this.systemFactory, 
validateCheckpoint, config,
+        this.metricsRegistry, CHECKPOINT_V1_SERDE, CHECKPOINT_V2_SERDE, 
KAFKA_CHECKPOINT_LOG_KEY_SERDE);
+  }
+
+  private void setupConsumer(List<IncomingMessageEnvelope> pollOutput) throws 
InterruptedException {
+    setupConsumerMultiplePoll(ImmutableList.of(pollOutput));
+  }
+
+  /**
+   * Create a new {@link SystemConsumer} that returns a list of messages 
sequentially at each subsequent poll.
+   *
+   * @param pollOutputs a list of poll outputs to be returned at subsequent 
polls.
+   *                    The i'th call to consumer.poll() will return the list 
at pollOutputs[i]
+   */
+  private void setupConsumerMultiplePoll(List<List<IncomingMessageEnvelope>> 
pollOutputs) throws InterruptedException {
+    OngoingStubbing<Map<SystemStreamPartition, List<IncomingMessageEnvelope>>> 
when =
+        when(this.systemConsumer.poll(ImmutableSet.of(CHECKPOINT_SSP), 
SystemConsumer.BLOCK_ON_OUTSTANDING_MESSAGES));
+    for (List<IncomingMessageEnvelope> pollOutput : pollOutputs) {
+      when = when.thenReturn(ImmutableMap.of(CHECKPOINT_SSP, pollOutput));
+    }
+    when.thenReturn(ImmutableMap.of());
+  }
+
+  private void setupSystemFactory(Config config) {
+    when(this.systemFactory.getProducer(CHECKPOINT_SYSTEM, config, 
this.metricsRegistry,
+        
KafkaCheckpointManager.class.getSimpleName())).thenReturn(this.systemProducer);
+    when(this.systemFactory.getConsumer(CHECKPOINT_SYSTEM, config, 
this.metricsRegistry,
+        
KafkaCheckpointManager.class.getSimpleName())).thenReturn(this.systemConsumer);
+    when(this.systemFactory.getAdmin(CHECKPOINT_SYSTEM, config,
+        
KafkaCheckpointManager.class.getSimpleName())).thenReturn(this.systemAdmin);
+    when(this.systemFactory.getAdmin(CHECKPOINT_SYSTEM, config,
+        KafkaCheckpointManager.class.getSimpleName() + 
"createResource")).thenReturn(this.createResourcesSystemAdmin);
+  }
+
+  private static CheckpointV1 buildCheckpointV1(SystemStreamPartition ssp, 
String offset) {
+    return new CheckpointV1(ImmutableMap.of(ssp, offset));
+  }
+
+  /**
+   * Creates a new checkpoint envelope for the provided task, ssp and offset
+   */
+  private IncomingMessageEnvelope newCheckpointV1Envelope(TaskName taskName, 
CheckpointV1 checkpointV1,
+      String checkpointMessageOffset) {
+    KafkaCheckpointLogKey checkpointKey = new 
KafkaCheckpointLogKey("checkpoint", taskName, GROUPER_FACTORY_CLASS);
+    KafkaCheckpointLogKeySerde checkpointKeySerde = new 
KafkaCheckpointLogKeySerde();
+    CheckpointV1Serde checkpointMsgSerde = new CheckpointV1Serde();
+    return new IncomingMessageEnvelope(CHECKPOINT_SSP, checkpointMessageOffset,
+        checkpointKeySerde.toBytes(checkpointKey), 
checkpointMsgSerde.toBytes(checkpointV1));
+  }
+
+  private static CheckpointV2 buildCheckpointV2(SystemStreamPartition ssp, 
String offset) {
+    return new CheckpointV2(CheckpointId.create(), ImmutableMap.of(ssp, 
offset),
+        ImmutableMap.of("backend", ImmutableMap.of("store", "10")));
+  }
+
+  private IncomingMessageEnvelope newCheckpointV2Envelope(TaskName taskName, 
CheckpointV2 checkpointV2,
+      String checkpointMessageOffset) {
+    KafkaCheckpointLogKey checkpointKey =
+        new 
KafkaCheckpointLogKey(KafkaCheckpointLogKey.CHECKPOINT_V2_KEY_TYPE, taskName, 
GROUPER_FACTORY_CLASS);
+    KafkaCheckpointLogKeySerde checkpointKeySerde = new 
KafkaCheckpointLogKeySerde();
+    CheckpointV2Serde checkpointMsgSerde = new CheckpointV2Serde();
+    return new IncomingMessageEnvelope(CHECKPOINT_SSP, checkpointMessageOffset,
+        checkpointKeySerde.toBytes(checkpointKey), 
checkpointMsgSerde.toBytes(checkpointV2));
+  }
+
+  /**
+   * Build base {@link Config} for tests.
+   */
+  private static Config config() {
+    return new MapConfig(ImmutableMap.of(JobConfig.SSP_GROUPER_FACTORY, 
GROUPER_FACTORY_CLASS));
+  }
+
+  private static Config config(Map<String, String> additional) {
+    Map<String, String> configMap = new HashMap<>(config());
+    configMap.putAll(additional);
+    return new MapConfig(configMap);
+  }
+}
\ No newline at end of file
diff --git 
a/samza-kafka/src/test/java/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManagerJava.java
 
b/samza-kafka/src/test/java/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManagerJava.java
deleted file mode 100644
index d0e927f..0000000
--- 
a/samza-kafka/src/test/java/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManagerJava.java
+++ /dev/null
@@ -1,285 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.checkpoint.kafka;
-
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
-import kafka.common.KafkaException;
-import kafka.common.TopicAlreadyMarkedForDeletionException;
-import org.apache.samza.Partition;
-import org.apache.samza.SamzaException;
-import org.apache.samza.checkpoint.CheckpointV1;
-import org.apache.samza.config.Config;
-import org.apache.samza.config.JobConfig;
-import org.apache.samza.container.TaskName;
-import org.apache.samza.container.grouper.stream.GroupByPartitionFactory;
-import org.apache.samza.metrics.MetricsRegistry;
-import org.apache.samza.serializers.CheckpointV1Serde;
-import org.apache.samza.serializers.CheckpointV2Serde;
-import org.apache.samza.system.IncomingMessageEnvelope;
-import org.apache.samza.system.StreamValidationException;
-import org.apache.samza.system.SystemAdmin;
-import org.apache.samza.system.SystemConsumer;
-import org.apache.samza.system.SystemFactory;
-import org.apache.samza.system.SystemProducer;
-import org.apache.samza.system.SystemStreamMetadata;
-import 
org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata;
-import org.apache.samza.system.SystemStreamPartition;
-import org.apache.samza.system.kafka.KafkaStreamSpec;
-import org.junit.Assert;
-import org.junit.Test;
-import org.mockito.stubbing.OngoingStubbing;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Map;
-
-import static org.mockito.Mockito.*;
-
-public class TestKafkaCheckpointManagerJava {
-  private static final TaskName TASK1 = new TaskName("task1");
-  private static final String CHECKPOINT_TOPIC = "topic-1";
-  private static final String CHECKPOINT_SYSTEM = "system-1";
-  private static final Partition CHECKPOINT_PARTITION = new Partition(0);
-  private static final SystemStreamPartition CHECKPOINT_SSP =
-      new SystemStreamPartition(CHECKPOINT_SYSTEM, CHECKPOINT_TOPIC, 
CHECKPOINT_PARTITION);
-  private static final String GROUPER_FACTORY_CLASS = 
GroupByPartitionFactory.class.getCanonicalName();
-
-  @Test(expected = TopicAlreadyMarkedForDeletionException.class)
-  public void testStartFailsOnTopicCreationErrors() {
-
-    KafkaStreamSpec checkpointSpec = new KafkaStreamSpec(CHECKPOINT_TOPIC, 
CHECKPOINT_TOPIC,
-        CHECKPOINT_SYSTEM, 1);
-    // create an admin that throws an exception during createStream
-    SystemAdmin mockAdmin = newAdmin("0", "10");
-    doThrow(new TopicAlreadyMarkedForDeletionException("invalid 
stream")).when(mockAdmin).createStream(checkpointSpec);
-
-    SystemFactory factory = newFactory(mock(SystemProducer.class), 
mock(SystemConsumer.class), mockAdmin);
-    KafkaCheckpointManager checkpointManager = new 
KafkaCheckpointManager(checkpointSpec, factory,
-        true, mock(Config.class), mock(MetricsRegistry.class), null, null, new 
KafkaCheckpointLogKeySerde());
-
-    // expect an exception during startup
-    checkpointManager.createResources();
-    checkpointManager.start();
-  }
-
-  @Test(expected = StreamValidationException.class)
-  public void testStartFailsOnTopicValidationErrors() {
-
-    KafkaStreamSpec checkpointSpec = new KafkaStreamSpec(CHECKPOINT_TOPIC, 
CHECKPOINT_TOPIC,
-        CHECKPOINT_SYSTEM, 1);
-
-    // create an admin that throws an exception during validateStream
-    SystemAdmin mockAdmin = newAdmin("0", "10");
-    doThrow(new StreamValidationException("invalid 
stream")).when(mockAdmin).validateStream(checkpointSpec);
-
-    SystemFactory factory = newFactory(mock(SystemProducer.class), 
mock(SystemConsumer.class), mockAdmin);
-    KafkaCheckpointManager checkpointManager = new 
KafkaCheckpointManager(checkpointSpec, factory,
-        true, mock(Config.class), mock(MetricsRegistry.class), null, null, new 
KafkaCheckpointLogKeySerde());
-
-    // expect an exception during startup
-    checkpointManager.createResources();
-    checkpointManager.start();
-  }
-
-  @Test(expected = SamzaException.class)
-  public void testReadFailsOnSerdeExceptions() throws Exception {
-    KafkaStreamSpec checkpointSpec = new KafkaStreamSpec(CHECKPOINT_TOPIC, 
CHECKPOINT_TOPIC,
-        CHECKPOINT_SYSTEM, 1);
-    Config mockConfig = mock(Config.class);
-    
when(mockConfig.get(JobConfig.SSP_GROUPER_FACTORY)).thenReturn(GROUPER_FACTORY_CLASS);
-
-    // mock out a consumer that returns a single checkpoint IME
-    SystemStreamPartition ssp = new SystemStreamPartition("system-1", 
"input-topic", new Partition(0));
-    List<List<IncomingMessageEnvelope>> checkpointEnvelopes = ImmutableList.of(
-        ImmutableList.of(newCheckpointEnvelope(TASK1, ssp, "0")));
-    SystemConsumer mockConsumer = newConsumer(checkpointEnvelopes);
-
-    SystemAdmin mockAdmin = newAdmin("0", "1");
-    SystemFactory factory = newFactory(mock(SystemProducer.class), 
mockConsumer, mockAdmin);
-
-    // wire up an exception throwing serde with the checkpointmanager
-    KafkaCheckpointManager checkpointManager = new 
KafkaCheckpointManager(checkpointSpec, factory,
-        true, mockConfig, mock(MetricsRegistry.class), new 
ExceptionThrowingCheckpointV1Serde(), null, new KafkaCheckpointLogKeySerde());
-    checkpointManager.register(TASK1);
-    checkpointManager.start();
-
-    // expect an exception from ExceptionThrowingSerde
-    checkpointManager.readLastCheckpoint(TASK1);
-  }
-
-  @Test
-  public void testReadSucceedsOnKeySerdeExceptionsWhenValidationIsDisabled() 
throws Exception {
-    KafkaStreamSpec checkpointSpec = new KafkaStreamSpec(CHECKPOINT_TOPIC, 
CHECKPOINT_TOPIC,
-        CHECKPOINT_SYSTEM, 1);
-    Config mockConfig = mock(Config.class);
-    
when(mockConfig.get(JobConfig.SSP_GROUPER_FACTORY)).thenReturn(GROUPER_FACTORY_CLASS);
-
-    // mock out a consumer that returns a single checkpoint IME
-    SystemStreamPartition ssp = new SystemStreamPartition("system-1", 
"input-topic", new Partition(0));
-    List<List<IncomingMessageEnvelope>> checkpointEnvelopes = ImmutableList.of(
-        ImmutableList.of(newCheckpointEnvelope(TASK1, ssp, "0")));
-    SystemConsumer mockConsumer = newConsumer(checkpointEnvelopes);
-
-    SystemAdmin mockAdmin = newAdmin("0", "1");
-    SystemFactory factory = newFactory(mock(SystemProducer.class), 
mockConsumer, mockAdmin);
-
-    // wire up an exception throwing serde with the checkpointmanager
-    KafkaCheckpointManager checkpointManager = new 
KafkaCheckpointManager(checkpointSpec, factory,
-        false, mockConfig, mock(MetricsRegistry.class), new 
ExceptionThrowingCheckpointV1Serde(), null,
-        new ExceptionThrowingCheckpointKeySerde());
-    checkpointManager.register(TASK1);
-    checkpointManager.start();
-
-    // expect the read to succeed inspite of the exception from 
ExceptionThrowingSerde
-    checkpointManager.readLastCheckpoint(TASK1);
-  }
-
-  @Test
-  public void testCheckpointsAreReadFromOldestOffset() throws Exception {
-    KafkaStreamSpec checkpointSpec = new KafkaStreamSpec(CHECKPOINT_TOPIC, 
CHECKPOINT_TOPIC,
-        CHECKPOINT_SYSTEM, 1);
-    Config mockConfig = mock(Config.class);
-    
when(mockConfig.get(JobConfig.SSP_GROUPER_FACTORY)).thenReturn(GROUPER_FACTORY_CLASS);
-
-    // mock out a consumer that returns a single checkpoint IME
-    SystemStreamPartition ssp = new SystemStreamPartition("system-1", 
"input-topic", new Partition(0));
-    SystemConsumer mockConsumer = newConsumer(ImmutableList.of(
-        ImmutableList.of(newCheckpointEnvelope(TASK1, ssp, "0"))));
-
-    String oldestOffset = "0";
-    SystemAdmin mockAdmin = newAdmin(oldestOffset, "1");
-    SystemFactory factory = newFactory(mock(SystemProducer.class), 
mockConsumer, mockAdmin);
-    KafkaCheckpointManager checkpointManager = new 
KafkaCheckpointManager(checkpointSpec, factory,
-        true, mockConfig, mock(MetricsRegistry.class), new 
CheckpointV1Serde(), new CheckpointV2Serde(),
-        new KafkaCheckpointLogKeySerde());
-    checkpointManager.register(TASK1);
-
-    // 1. verify that consumer.register is called only during 
checkpointManager.start.
-    // 2. verify that consumer.register is called with the oldest offset.
-    // 3. verify that no other operation on the CheckpointManager re-invokes 
register since start offsets are set during
-    // register
-    verify(mockConsumer, times(0)).register(CHECKPOINT_SSP, oldestOffset);
-    checkpointManager.start();
-    verify(mockConsumer, times(1)).register(CHECKPOINT_SSP, oldestOffset);
-
-    checkpointManager.readLastCheckpoint(TASK1);
-    verify(mockConsumer, times(1)).register(CHECKPOINT_SSP, oldestOffset);
-  }
-
-  @Test
-  public void testAllMessagesInTheLogAreRead() throws Exception {
-    KafkaStreamSpec checkpointSpec = new KafkaStreamSpec(CHECKPOINT_TOPIC, 
CHECKPOINT_TOPIC,
-        CHECKPOINT_SYSTEM, 1);
-    Config mockConfig = mock(Config.class);
-    
when(mockConfig.get(JobConfig.SSP_GROUPER_FACTORY)).thenReturn(GROUPER_FACTORY_CLASS);
-
-    SystemStreamPartition ssp = new SystemStreamPartition("system-1", 
"input-topic", new Partition(0));
-
-    int oldestOffset = 0;
-    int newestOffset = 10;
-
-    // mock out a consumer that returns ten checkpoint IMEs for the same ssp
-    List<List<IncomingMessageEnvelope>> pollOutputs = new ArrayList<>();
-    for (int offset = oldestOffset; offset <= newestOffset; offset++) {
-      pollOutputs.add(ImmutableList.of(newCheckpointEnvelope(TASK1, ssp, 
Integer.toString(offset))));
-    }
-
-    // return one message at a time from each poll simulating a KafkaConsumer 
with max.poll.records = 1
-    SystemConsumer mockConsumer = newConsumer(pollOutputs);
-    SystemAdmin mockAdmin = newAdmin(Integer.toString(oldestOffset), 
Integer.toString(newestOffset));
-    SystemFactory factory = newFactory(mock(SystemProducer.class), 
mockConsumer, mockAdmin);
-
-    KafkaCheckpointManager checkpointManager = new 
KafkaCheckpointManager(checkpointSpec, factory,
-        true, mockConfig, mock(MetricsRegistry.class), new 
CheckpointV1Serde(), new CheckpointV2Serde(),
-        new KafkaCheckpointLogKeySerde());
-    checkpointManager.register(TASK1);
-    checkpointManager.start();
-
-    // check that all ten messages are read, and the checkpoint is the newest 
message
-    CheckpointV1 checkpoint = (CheckpointV1) 
checkpointManager.readLastCheckpoint(TASK1);
-    Assert.assertEquals(checkpoint.getOffsets(), ImmutableMap.of(ssp, 
Integer.toString(newestOffset)));
-  }
-
-  /**
-   * Create a new {@link SystemConsumer} that returns a list of messages 
sequentially at each subsequent poll.
-   *
-   * @param pollOutputs a list of poll outputs to be returned at subsequent 
polls.
-   *                    The i'th call to consumer.poll() will return the list 
at pollOutputs[i]
-   * @return the consumer
-   */
-  private SystemConsumer newConsumer(List<List<IncomingMessageEnvelope>> 
pollOutputs) throws Exception {
-    SystemConsumer mockConsumer = mock(SystemConsumer.class);
-    OngoingStubbing<Map> when = when(mockConsumer.poll(anySet(), anyLong()));
-    for (List<IncomingMessageEnvelope> pollOutput : pollOutputs) {
-      when = when.thenReturn(ImmutableMap.of(CHECKPOINT_SSP, pollOutput));
-    }
-    when.thenReturn(ImmutableMap.of());
-    return mockConsumer;
-  }
-
-  /**
-   * Create a new {@link SystemAdmin} that returns the provided oldest and 
newest offsets for its topics
-   */
-  private SystemAdmin newAdmin(String oldestOffset, String newestOffset) {
-    SystemStreamMetadata checkpointTopicMetadata = new 
SystemStreamMetadata(CHECKPOINT_TOPIC,
-        ImmutableMap.of(new Partition(0), new 
SystemStreamPartitionMetadata(oldestOffset,
-            newestOffset, Integer.toString(Integer.parseInt(newestOffset) + 
1))));
-    SystemAdmin mockAdmin = mock(SystemAdmin.class);
-    
when(mockAdmin.getSystemStreamMetadata(Collections.singleton(CHECKPOINT_TOPIC))).thenReturn(
-        ImmutableMap.of(CHECKPOINT_TOPIC, checkpointTopicMetadata));
-    return mockAdmin;
-  }
-
-  private SystemFactory newFactory(SystemProducer producer, SystemConsumer 
consumer, SystemAdmin admin) {
-    SystemFactory factory = mock(SystemFactory.class);
-    when(factory.getProducer(anyString(), any(Config.class), 
any(MetricsRegistry.class), anyString())).thenReturn(producer);
-    when(factory.getConsumer(anyString(), any(Config.class), 
any(MetricsRegistry.class), anyString())).thenReturn(consumer);
-    when(factory.getAdmin(anyString(), any(Config.class), 
anyString())).thenReturn(admin);
-    return factory;
-  }
-
-  /**
-   * Creates a new checkpoint envelope for the provided task, ssp and offset
-   */
-  private IncomingMessageEnvelope newCheckpointEnvelope(TaskName taskName, 
SystemStreamPartition ssp, String offset) {
-    KafkaCheckpointLogKey checkpointKey =
-        new KafkaCheckpointLogKey("checkpoint", taskName, 
GROUPER_FACTORY_CLASS);
-    KafkaCheckpointLogKeySerde checkpointKeySerde = new 
KafkaCheckpointLogKeySerde();
-
-    CheckpointV1 checkpointMsg = new CheckpointV1(ImmutableMap.of(ssp, 
offset));
-    CheckpointV1Serde checkpointMsgSerde = new CheckpointV1Serde();
-
-    return new IncomingMessageEnvelope(CHECKPOINT_SSP, offset, 
checkpointKeySerde.toBytes(checkpointKey),
-        checkpointMsgSerde.toBytes(checkpointMsg));
-  }
-
-  private static class ExceptionThrowingCheckpointV1Serde extends 
CheckpointV1Serde {
-    public CheckpointV1 fromBytes(byte[] bytes) {
-      throw new KafkaException("exception");
-    }
-  }
-
-  private static class ExceptionThrowingCheckpointKeySerde extends 
KafkaCheckpointLogKeySerde {
-    public KafkaCheckpointLogKey fromBytes(byte[] bytes) {
-      throw new KafkaException("exception");
-    }
-  }
-}
diff --git 
a/samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala
 
b/samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala
deleted file mode 100644
index 835f53e..0000000
--- 
a/samza-kafka/src/test/scala/org/apache/samza/checkpoint/kafka/TestKafkaCheckpointManager.scala
+++ /dev/null
@@ -1,533 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.checkpoint.kafka
-
-import java.util.Properties
-import kafka.integration.KafkaServerTestHarness
-import kafka.utils.{CoreUtils, TestUtils}
-import com.google.common.collect.ImmutableMap
-import org.apache.samza.checkpoint.{Checkpoint, CheckpointId, CheckpointV1, 
CheckpointV2}
-import org.apache.samza.config._
-import org.apache.samza.container.TaskName
-import org.apache.samza.container.grouper.stream.GroupByPartitionFactory
-import org.apache.samza.metrics.MetricsRegistry
-import org.apache.samza.serializers.{CheckpointV1Serde, CheckpointV2Serde}
-import org.apache.samza.system._
-import org.apache.samza.system.kafka.{KafkaStreamSpec, KafkaSystemFactory}
-import org.apache.samza.util.ScalaJavaUtil.JavaOptionals
-import org.apache.samza.util.{NoOpMetricsRegistry, ReflectionUtil}
-import org.apache.samza.{Partition, SamzaException}
-import org.junit.Assert._
-import org.junit._
-import org.mockito.Mockito
-import org.mockito.Matchers
-
-class TestKafkaCheckpointManager extends KafkaServerTestHarness {
-
-  protected def numBrokers: Int = 3
-
-  val checkpointSystemName = "kafka"
-  val sspGrouperFactoryName = classOf[GroupByPartitionFactory].getCanonicalName
-
-  val ssp = new SystemStreamPartition("kafka", "topic", new Partition(0))
-  val checkpoint1 = new CheckpointV1(ImmutableMap.of(ssp, "offset-1"))
-  val checkpoint2 = new CheckpointV1(ImmutableMap.of(ssp, "offset-2"))
-  val taskName = new TaskName("Partition 0")
-  var config: Config = null
-
-  @Before
-  override def setUp {
-    super.setUp
-    TestUtils.waitUntilTrue(() => 
servers.head.metadataCache.getAliveBrokers.size == numBrokers, "Wait for cache 
to update")
-    config = getConfig()
-  }
-
-  override def generateConfigs() = {
-    val props = TestUtils.createBrokerConfigs(numBrokers, zkConnect, 
enableControlledShutdown = true)
-    // do not use relative imports
-    props.map(_root_.kafka.server.KafkaConfig.fromProps)
-  }
-
-  @Test
-  def testWriteCheckpointShouldRecreateSystemProducerOnFailure(): Unit = {
-    val checkpointTopic = "checkpoint-topic-2"
-    val mockKafkaProducer: SystemProducer = 
Mockito.mock(classOf[SystemProducer])
-
-    class MockSystemFactory extends KafkaSystemFactory {
-      override def getProducer(systemName: String, config: Config, registry: 
MetricsRegistry): SystemProducer = {
-        mockKafkaProducer
-      }
-    }
-
-    Mockito.doThrow(new 
RuntimeException()).when(mockKafkaProducer).flush(taskName.getTaskName)
-
-    val props = new 
org.apache.samza.config.KafkaConfig(config).getCheckpointTopicProperties()
-    val spec = new KafkaStreamSpec("id", checkpointTopic, 
checkpointSystemName, 1, 1, props)
-    val checkPointManager = Mockito.spy(new KafkaCheckpointManager(spec, new 
MockSystemFactory, false, config, new NoOpMetricsRegistry))
-    val newKafkaProducer: SystemProducer = 
Mockito.mock(classOf[SystemProducer])
-
-    
Mockito.doReturn(newKafkaProducer).when(checkPointManager).getSystemProducer()
-
-    checkPointManager.register(taskName)
-    checkPointManager.start
-    checkPointManager.writeCheckpoint(taskName, new 
CheckpointV1(ImmutableMap.of()))
-    checkPointManager.stop()
-
-    // Verifications after the test
-
-    Mockito.verify(mockKafkaProducer).stop()
-    Mockito.verify(newKafkaProducer).register(taskName.getTaskName)
-    Mockito.verify(newKafkaProducer).start()
-  }
-
-  @Test
-  def 
testCheckpointShouldBeNullIfCheckpointTopicDoesNotExistShouldBeCreatedOnWriteAndShouldBeReadableAfterWrite():
 Unit = {
-    val checkpointTopic = "checkpoint-topic-1"
-    val kcm1 = createKafkaCheckpointManager(checkpointTopic)
-    kcm1.register(taskName)
-    kcm1.createResources
-    kcm1.start
-    kcm1.stop
-
-    // check that start actually creates the topic with log compaction enabled
-    val topicConfig = 
adminZkClient.getAllTopicConfigs().getOrElse(checkpointTopic, new Properties())
-
-    assertEquals(topicConfig, new 
KafkaConfig(config).getCheckpointTopicProperties())
-    assertEquals("compact", topicConfig.get("cleanup.policy"))
-    assertEquals("26214400", topicConfig.get("segment.bytes"))
-
-    // read before topic exists should result in a null checkpoint
-    val readCp = readCheckpoint(checkpointTopic, taskName)
-    assertNull(readCp)
-
-    writeCheckpoint(checkpointTopic, taskName, checkpoint1)
-    assertEquals(checkpoint1, readCheckpoint(checkpointTopic, taskName))
-
-    // writing a second message and reading it returns a more recent checkpoint
-    writeCheckpoint(checkpointTopic, taskName, checkpoint2)
-    assertEquals(checkpoint2, readCheckpoint(checkpointTopic, taskName))
-  }
-
-  @Test
-  def testCheckpointV1AndV2WriteAndReadV1(): Unit = {
-    val checkpointTopic = "checkpoint-topic-1"
-    val kcm1 = createKafkaCheckpointManager(checkpointTopic)
-    kcm1.register(taskName)
-    kcm1.createResources
-    kcm1.start
-    kcm1.stop
-
-    // check that start actually creates the topic with log compaction enabled
-    val topicConfig = 
adminZkClient.getAllTopicConfigs().getOrElse(checkpointTopic, new Properties())
-
-    assertEquals(topicConfig, new 
KafkaConfig(config).getCheckpointTopicProperties())
-    assertEquals("compact", topicConfig.get("cleanup.policy"))
-    assertEquals("26214400", topicConfig.get("segment.bytes"))
-
-    // read before topic exists should result in a null checkpoint
-    val readCp = readCheckpoint(checkpointTopic, taskName)
-    assertNull(readCp)
-
-    val checkpointV1 = new CheckpointV1(ImmutableMap.of(ssp, "offset-1"))
-    val checkpointV2 = new CheckpointV2(CheckpointId.create(), 
ImmutableMap.of(ssp, "offset-2"),
-      ImmutableMap.of("factory1", ImmutableMap.of("store1", 
"changelogOffset")))
-
-    // skips v2 checkpoints from checkpoint topic
-    writeCheckpoint(checkpointTopic, taskName, checkpointV2)
-    assertNull(readCheckpoint(checkpointTopic, taskName))
-
-    // reads latest v1 checkpoints
-    writeCheckpoint(checkpointTopic, taskName, checkpointV1)
-    assertEquals(checkpointV1, readCheckpoint(checkpointTopic, taskName))
-
-    // writing checkpoint v2 still returns the previous v1 checkpoint
-    writeCheckpoint(checkpointTopic, taskName, checkpointV2)
-    assertEquals(checkpointV1, readCheckpoint(checkpointTopic, taskName))
-  }
-
-  @Test
-  def testCheckpointV1AndV2WriteAndReadV2(): Unit = {
-    val checkpointTopic = "checkpoint-topic-1"
-    val kcm1 = createKafkaCheckpointManager(checkpointTopic)
-    kcm1.register(taskName)
-    kcm1.createResources
-    kcm1.start
-    kcm1.stop
-
-    // check that start actually creates the topic with log compaction enabled
-    val topicConfig = 
adminZkClient.getAllTopicConfigs().getOrElse(checkpointTopic, new Properties())
-
-    assertEquals(topicConfig, new 
KafkaConfig(config).getCheckpointTopicProperties())
-    assertEquals("compact", topicConfig.get("cleanup.policy"))
-    assertEquals("26214400", topicConfig.get("segment.bytes"))
-
-    // read before topic exists should result in a null checkpoint
-    val readCp = readCheckpoint(checkpointTopic, taskName)
-    assertNull(readCp)
-
-    val checkpointV1 = new CheckpointV1(ImmutableMap.of(ssp, "offset-1"))
-    val checkpointV2 = new CheckpointV2(CheckpointId.create(), 
ImmutableMap.of(ssp, "offset-2"),
-      ImmutableMap.of("factory1", ImmutableMap.of("store1", 
"changelogOffset")))
-
-    val overrideConfig = new MapConfig(new ImmutableMap.Builder[String, 
String]()
-      .put(JobConfig.JOB_NAME, "some-job-name")
-      .put(JobConfig.JOB_ID, "i001")
-      .put(s"systems.$checkpointSystemName.samza.factory", 
classOf[KafkaSystemFactory].getCanonicalName)
-      .put(s"systems.$checkpointSystemName.producer.bootstrap.servers", 
brokerList)
-      .put(s"systems.$checkpointSystemName.consumer.zookeeper.connect", 
zkConnect)
-      .put("task.checkpoint.system", checkpointSystemName)
-      .put(TaskConfig.CHECKPOINT_READ_VERSIONS, "2")
-      .build())
-
-    // Skips reading any v1 checkpoints
-    writeCheckpoint(checkpointTopic, taskName, checkpointV1)
-    assertNull(readCheckpoint(checkpointTopic, taskName, overrideConfig))
-
-    // writing a v2 checkpoint would allow reading it back
-    writeCheckpoint(checkpointTopic, taskName, checkpointV2)
-    assertEquals(checkpointV2, readCheckpoint(checkpointTopic, taskName, 
overrideConfig))
-
-    // writing v1 checkpoint is still skipped
-    writeCheckpoint(checkpointTopic, taskName, checkpointV1)
-    assertEquals(checkpointV2, readCheckpoint(checkpointTopic, taskName, 
overrideConfig))
-  }
-
-  @Test
-  def testCheckpointV1AndV2WriteAndReadV1V2PrecedenceList(): Unit = {
-    val checkpointTopic = "checkpoint-topic-1"
-    val kcm1 = createKafkaCheckpointManager(checkpointTopic)
-    kcm1.register(taskName)
-    kcm1.createResources
-    kcm1.start
-    kcm1.stop
-
-    // check that start actually creates the topic with log compaction enabled
-    val topicConfig = 
adminZkClient.getAllTopicConfigs().getOrElse(checkpointTopic, new Properties())
-
-    assertEquals(topicConfig, new 
KafkaConfig(config).getCheckpointTopicProperties())
-    assertEquals("compact", topicConfig.get("cleanup.policy"))
-    assertEquals("26214400", topicConfig.get("segment.bytes"))
-
-    // read before topic exists should result in a null checkpoint
-    val readCp = readCheckpoint(checkpointTopic, taskName)
-    assertNull(readCp)
-
-    val checkpointV1 = new CheckpointV1(ImmutableMap.of(ssp, "offset-1"))
-    val checkpointV2 = new CheckpointV2(CheckpointId.create(), 
ImmutableMap.of(ssp, "offset-2"),
-      ImmutableMap.of("factory1", ImmutableMap.of("store1", 
"changelogOffset")))
-
-    val overrideConfig = new MapConfig(new ImmutableMap.Builder[String, 
String]()
-      .put(JobConfig.JOB_NAME, "some-job-name")
-      .put(JobConfig.JOB_ID, "i001")
-      .put(s"systems.$checkpointSystemName.samza.factory", 
classOf[KafkaSystemFactory].getCanonicalName)
-      .put(s"systems.$checkpointSystemName.producer.bootstrap.servers", 
brokerList)
-      .put(s"systems.$checkpointSystemName.consumer.zookeeper.connect", 
zkConnect)
-      .put("task.checkpoint.system", checkpointSystemName)
-      .put(TaskConfig.CHECKPOINT_READ_VERSIONS, "2,1")
-      .build())
-
-    // Still reads any v1 checkpoints due to precedence list
-    writeCheckpoint(checkpointTopic, taskName, checkpointV1)
-    assertEquals(checkpointV1, readCheckpoint(checkpointTopic, taskName, 
overrideConfig))
-
-    // writing a v2 checkpoint would allow reading it back
-    writeCheckpoint(checkpointTopic, taskName, checkpointV2)
-    assertEquals(checkpointV2, readCheckpoint(checkpointTopic, taskName, 
overrideConfig))
-
-    // writing v1 checkpoint is still skipped
-    writeCheckpoint(checkpointTopic, taskName, checkpointV1)
-    assertEquals(checkpointV2, readCheckpoint(checkpointTopic, taskName, 
overrideConfig))
-
-    val newCheckpointV2 = new CheckpointV2(CheckpointId.create(), 
ImmutableMap.of(ssp, "offset-3"),
-      ImmutableMap.of("factory1", ImmutableMap.of("store1", 
"changelogOffset")))
-    // writing v2 returns a new checkpoint v2
-    writeCheckpoint(checkpointTopic, taskName, newCheckpointV2)
-    assertEquals(newCheckpointV2, readCheckpoint(checkpointTopic, taskName, 
overrideConfig))
-  }
-
-  @Test
-  def testCheckpointValidationSkipped(): Unit = {
-    val checkpointTopic = "checkpoint-topic-1"
-    val kcm1 = createKafkaCheckpointManager(checkpointTopic, serde = new 
MockCheckpointSerde(),
-      failOnTopicValidation = false)
-    kcm1.register(taskName)
-    kcm1.start
-    kcm1.writeCheckpoint(taskName, new CheckpointV1(ImmutableMap.of(ssp, 
"offset-1")))
-    kcm1.readLastCheckpoint(taskName)
-    kcm1.stop
-  }
-
-  @Test
-  def testReadCheckpointShouldIgnoreUnknownCheckpointKeys(): Unit = {
-      val checkpointTopic = "checkpoint-topic-1"
-      val kcm1 = createKafkaCheckpointManager(checkpointTopic)
-      kcm1.register(taskName)
-      kcm1.createResources
-      kcm1.start
-      kcm1.stop
-
-      // check that start actually creates the topic with log compaction 
enabled
-      val topicConfig = 
adminZkClient.getAllTopicConfigs().getOrElse(checkpointTopic, new Properties())
-
-      assertEquals(topicConfig, new 
KafkaConfig(config).getCheckpointTopicProperties())
-      assertEquals("compact", topicConfig.get("cleanup.policy"))
-      assertEquals("26214400", topicConfig.get("segment.bytes"))
-
-      // read before topic exists should result in a null checkpoint
-      val readCp = readCheckpoint(checkpointTopic, taskName)
-      assertNull(readCp)
-    // skips unknown checkpoints from checkpoint topic
-    writeCheckpoint(checkpointTopic, taskName, checkpoint1, "checkpoint-v2", 
useMock = true)
-    assertNull(readCheckpoint(checkpointTopic, taskName, useMock = true))
-
-    // reads latest v1 checkpoints
-    writeCheckpoint(checkpointTopic, taskName, checkpoint1, useMock = true)
-    assertEquals(checkpoint1, readCheckpoint(checkpointTopic, taskName, 
useMock = true))
-
-    // writing checkpoint v2 still returns the previous v1 checkpoint
-    writeCheckpoint(checkpointTopic, taskName, checkpoint2, "checkpoint-v2", 
useMock = true)
-    assertEquals(checkpoint1, readCheckpoint(checkpointTopic, taskName, 
useMock = true))
-
-    // writing checkpoint2 with the correct key returns the checkpoint2
-    writeCheckpoint(checkpointTopic, taskName, checkpoint2, useMock = true)
-    assertEquals(checkpoint2, readCheckpoint(checkpointTopic, taskName, 
useMock = true))
-  }
-
-  @Test
-  def testWriteCheckpointShouldRetryFiniteTimesOnFailure(): Unit = {
-    val checkpointTopic = "checkpoint-topic-2"
-    val mockKafkaProducer: SystemProducer = 
Mockito.mock(classOf[SystemProducer])
-    val mockKafkaSystemConsumer: SystemConsumer = 
Mockito.mock(classOf[SystemConsumer])
-
-    Mockito.doThrow(new 
RuntimeException()).when(mockKafkaProducer).flush(taskName.getTaskName)
-
-    val props = new 
org.apache.samza.config.KafkaConfig(config).getCheckpointTopicProperties()
-    val spec = new KafkaStreamSpec("id", checkpointTopic, 
checkpointSystemName, 1, 1, props)
-    val checkPointManager = new KafkaCheckpointManager(spec, new 
MockSystemFactory(mockKafkaSystemConsumer, mockKafkaProducer), false, config, 
new NoOpMetricsRegistry)
-    checkPointManager.MaxRetryDurationInMillis = 1
-
-    try {
-      checkPointManager.register(taskName)
-      checkPointManager.start
-      checkPointManager.writeCheckpoint(taskName, new 
CheckpointV1(ImmutableMap.of()))
-    } catch {
-      case _: SamzaException => info("Got SamzaException as expected.")
-      case unexpectedException: Throwable => fail("Expected SamzaException but 
got %s" format unexpectedException)
-    } finally {
-      checkPointManager.stop()
-    }
-  }
-
-  @Test
-  def testFailOnTopicValidation(): Unit = {
-    // By default, should fail if there is a topic validation error
-    val checkpointTopic = "eight-partition-topic";
-    val kcm = createKafkaCheckpointManager(checkpointTopic)
-    kcm.register(taskName)
-    // create topic with the wrong number of partitions
-    createTopic(checkpointTopic, 8, new 
KafkaConfig(config).getCheckpointTopicProperties())
-    try {
-      kcm.createResources()
-      kcm.start()
-      fail("Expected an exception for invalid number of partitions in the 
checkpoint topic.")
-    } catch {
-      case e: StreamValidationException => None
-    }
-    kcm.stop()
-  }
-
-  @Test
-  def testNoFailOnTopicValidationDisabled(): Unit = {
-    val checkpointTopic = "eight-partition-topic";
-    // create topic with the wrong number of partitions
-    createTopic(checkpointTopic, 8, new 
KafkaConfig(config).getCheckpointTopicProperties())
-    val failOnTopicValidation = false
-    val kcm = createKafkaCheckpointManager(checkpointTopic, new 
CheckpointV1Serde, failOnTopicValidation)
-    kcm.register(taskName)
-    kcm.createResources()
-    kcm.start()
-    kcm.stop()
-  }
-
-  @Test
-  def testConsumerStopsAfterInitialReadIfConfigSetTrue(): Unit = {
-    val mockKafkaSystemConsumer: SystemConsumer = 
Mockito.mock(classOf[SystemConsumer])
-
-    val checkpointTopic = "checkpoint-topic-test"
-    val props = new 
org.apache.samza.config.KafkaConfig(config).getCheckpointTopicProperties()
-    val spec = new KafkaStreamSpec("id", checkpointTopic, 
checkpointSystemName, 1, 1, props)
-
-    val configMapWithOverride = new java.util.HashMap[String, String](config)
-    
configMapWithOverride.put(TaskConfig.INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ,
 "true")
-    val kafkaCheckpointManager = new KafkaCheckpointManager(spec, new 
MockSystemFactory(mockKafkaSystemConsumer), false, new 
MapConfig(configMapWithOverride), new NoOpMetricsRegistry)
-
-    kafkaCheckpointManager.register(taskName)
-    kafkaCheckpointManager.start()
-    kafkaCheckpointManager.readLastCheckpoint(taskName)
-
-    Mockito.verify(mockKafkaSystemConsumer, 
Mockito.times(1)).register(Matchers.any(), Matchers.any())
-    Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).start()
-    Mockito.verify(mockKafkaSystemConsumer, 
Mockito.times(1)).poll(Matchers.any(), Matchers.any())
-    Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).stop()
-
-    kafkaCheckpointManager.stop()
-
-    Mockito.verifyNoMoreInteractions(mockKafkaSystemConsumer)
-  }
-
-  @Test
-  def testConsumerDoesNotStopAfterInitialReadIfConfigSetFalse(): Unit = {
-    val mockKafkaSystemConsumer: SystemConsumer = 
Mockito.mock(classOf[SystemConsumer])
-
-    val checkpointTopic = "checkpoint-topic-test"
-    val props = new 
org.apache.samza.config.KafkaConfig(config).getCheckpointTopicProperties()
-    val spec = new KafkaStreamSpec("id", checkpointTopic, 
checkpointSystemName, 1, 1, props)
-
-    val configMapWithOverride = new java.util.HashMap[String, String](config)
-    
configMapWithOverride.put(TaskConfig.INTERNAL_CHECKPOINT_MANAGER_CONSUMER_STOP_AFTER_FIRST_READ,
 "false")
-    val kafkaCheckpointManager = new KafkaCheckpointManager(spec, new 
MockSystemFactory(mockKafkaSystemConsumer), false, new 
MapConfig(configMapWithOverride), new NoOpMetricsRegistry)
-
-    kafkaCheckpointManager.register(taskName)
-    kafkaCheckpointManager.start()
-    kafkaCheckpointManager.readLastCheckpoint(taskName)
-
-    Mockito.verify(mockKafkaSystemConsumer, Mockito.times(0)).stop()
-
-    kafkaCheckpointManager.stop()
-
-    Mockito.verify(mockKafkaSystemConsumer, Mockito.times(1)).stop()
-  }
-
-  @After
-  override def tearDown(): Unit = {
-    if (servers != null) {
-      servers.foreach(_.shutdown())
-      servers.foreach(server => CoreUtils.delete(server.config.logDirs))
-    }
-    super.tearDown
-  }
-
-  private def getConfig(): Config = {
-    new MapConfig(new ImmutableMap.Builder[String, String]()
-      .put(JobConfig.JOB_NAME, "some-job-name")
-      .put(JobConfig.JOB_ID, "i001")
-      .put(s"systems.$checkpointSystemName.samza.factory", 
classOf[KafkaSystemFactory].getCanonicalName)
-      .put(s"systems.$checkpointSystemName.producer.bootstrap.servers", 
brokerList)
-      .put(s"systems.$checkpointSystemName.consumer.zookeeper.connect", 
zkConnect)
-      .put("task.checkpoint.system", checkpointSystemName)
-      .build())
-  }
-
-  private def createKafkaCheckpointManager(cpTopic: String, serde: 
CheckpointV1Serde = new CheckpointV1Serde,
-    failOnTopicValidation: Boolean = true, useMock: Boolean = false, 
checkpointKey: String = KafkaCheckpointLogKey.CHECKPOINT_V1_KEY_TYPE,
-    overrideConfig: Config = config) = {
-    val kafkaConfig = new org.apache.samza.config.KafkaConfig(overrideConfig)
-    val props = kafkaConfig.getCheckpointTopicProperties()
-    val systemName = kafkaConfig.getCheckpointSystem.getOrElse(
-      throw new SamzaException("No system defined for Kafka's checkpoint 
manager."))
-
-    val systemConfig = new SystemConfig(overrideConfig)
-    val systemFactoryClassName = 
JavaOptionals.toRichOptional(systemConfig.getSystemFactory(systemName)).toOption
-      .getOrElse(throw new SamzaException("Missing configuration: " + 
SystemConfig.SYSTEM_FACTORY_FORMAT format systemName))
-
-    val systemFactory = ReflectionUtil.getObj(systemFactoryClassName, 
classOf[SystemFactory])
-
-    val spec = new KafkaStreamSpec("id", cpTopic, checkpointSystemName, 1, 1, 
props)
-
-    if (useMock) {
-      new MockKafkaCheckpointManager(spec, systemFactory, 
failOnTopicValidation, serde, checkpointKey)
-    } else {
-      new KafkaCheckpointManager(spec, systemFactory, failOnTopicValidation, 
overrideConfig, new NoOpMetricsRegistry, serde)
-    }
-  }
-
-  private def readCheckpoint(checkpointTopic: String, taskName: TaskName, 
config: Config = config,
-    useMock: Boolean = false) : Checkpoint = {
-    val kcm = createKafkaCheckpointManager(checkpointTopic, overrideConfig = 
config, useMock = useMock)
-    kcm.register(taskName)
-    kcm.start
-    val checkpoint = kcm.readLastCheckpoint(taskName)
-    kcm.stop
-    checkpoint
-  }
-
-  private def writeCheckpoint(checkpointTopic: String, taskName: TaskName, 
checkpoint: Checkpoint,
-    checkpointKey: String = KafkaCheckpointLogKey.CHECKPOINT_V1_KEY_TYPE, 
useMock: Boolean = false): Unit = {
-    val kcm = createKafkaCheckpointManager(checkpointTopic, checkpointKey = 
checkpointKey, useMock = useMock)
-    kcm.register(taskName)
-    kcm.start
-    kcm.writeCheckpoint(taskName, checkpoint)
-    kcm.stop
-  }
-
-  private def createTopic(cpTopic: String, partNum: Int, props: Properties) {
-    adminZkClient.createTopic(cpTopic, partNum, 1, props)
-  }
-
-  class MockSystemFactory(
-    mockKafkaSystemConsumer: SystemConsumer = 
Mockito.mock(classOf[SystemConsumer]),
-    mockKafkaProducer: SystemProducer = Mockito.mock(classOf[SystemProducer])) 
extends KafkaSystemFactory {
-    override def getProducer(systemName: String, config: Config, registry: 
MetricsRegistry): SystemProducer = {
-      mockKafkaProducer
-    }
-
-    override def getConsumer(systemName: String, config: Config, registry: 
MetricsRegistry): SystemConsumer = {
-      mockKafkaSystemConsumer
-    }
-  }
-
-  class MockCheckpointSerde() extends CheckpointV1Serde {
-    override def fromBytes(bytes: Array[Byte]): CheckpointV1 = {
-      throw new SamzaException("Failed to deserialize")
-    }
-  }
-
-
-  class MockKafkaCheckpointManager(spec: KafkaStreamSpec, systemFactory: 
SystemFactory, failOnTopicValidation: Boolean,
-    serde: CheckpointV1Serde = new CheckpointV1Serde, checkpointKey: String)
-    extends KafkaCheckpointManager(spec, systemFactory, failOnTopicValidation, 
config,
-      new NoOpMetricsRegistry, serde) {
-
-    override def buildOutgoingMessageEnvelope[T <: Checkpoint](taskName: 
TaskName, checkpoint: T): OutgoingMessageEnvelope = {
-      val key = new KafkaCheckpointLogKey(checkpointKey, taskName, 
expectedGrouperFactory)
-      val keySerde = new KafkaCheckpointLogKeySerde
-      val checkpointMsgSerde = new CheckpointV1Serde
-      val checkpointV2MsgSerde = new CheckpointV2Serde
-      val keyBytes = try {
-        keySerde.toBytes(key)
-      } catch {
-        case e: Exception => throw new SamzaException(s"Exception when writing 
checkpoint-key for $taskName: $checkpoint", e)
-      }
-      val msgBytes = try {
-        checkpoint match {
-          case v1: CheckpointV1 =>
-            checkpointMsgSerde.toBytes(v1)
-          case v2: CheckpointV2 =>
-            checkpointV2MsgSerde.toBytes(v2)
-          case _ =>
-            throw new IllegalArgumentException("Unknown checkpoint key type 
for test, please use Checkpoint v1 or v2")
-        }
-      } catch {
-        case e: Exception => throw new SamzaException(s"Exception when writing 
checkpoint for $taskName: $checkpoint", e)
-      }
-      new OutgoingMessageEnvelope(checkpointSsp, keyBytes, msgBytes)
-    }
-  }
-}
diff --git 
a/samza-test/src/test/java/org/apache/samza/test/harness/IntegrationTestHarness.java
 
b/samza-test/src/test/java/org/apache/samza/test/harness/IntegrationTestHarness.java
index 57987af..c1db4d1 100644
--- 
a/samza-test/src/test/java/org/apache/samza/test/harness/IntegrationTestHarness.java
+++ 
b/samza-test/src/test/java/org/apache/samza/test/harness/IntegrationTestHarness.java
@@ -109,6 +109,7 @@ public class IntegrationTestHarness extends 
AbstractKafkaServerTestHarness {
     * it shouldn't impact the tests nor have any side effects.
     */
     adminClient.close(ADMIN_OPERATION_WAIT_DURATION_MS, TimeUnit.MILLISECONDS);
+    consumer.unsubscribe();
     consumer.close();
     producer.close();
     super.tearDown();
diff --git 
a/samza-test/src/test/java/org/apache/samza/test/kafka/KafkaCheckpointManagerIntegrationTest.java
 
b/samza-test/src/test/java/org/apache/samza/test/kafka/KafkaCheckpointManagerIntegrationTest.java
new file mode 100644
index 0000000..612647c
--- /dev/null
+++ 
b/samza-test/src/test/java/org/apache/samza/test/kafka/KafkaCheckpointManagerIntegrationTest.java
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.test.kafka;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+import com.google.common.collect.ImmutableMap;
+import org.apache.samza.application.TaskApplication;
+import org.apache.samza.application.descriptors.TaskApplicationDescriptor;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.JobCoordinatorConfig;
+import org.apache.samza.config.KafkaConfig;
+import org.apache.samza.config.TaskConfig;
+import org.apache.samza.serializers.StringSerde;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.kafka.descriptors.KafkaInputDescriptor;
+import org.apache.samza.system.kafka.descriptors.KafkaSystemDescriptor;
+import org.apache.samza.task.MessageCollector;
+import org.apache.samza.task.StreamTask;
+import org.apache.samza.task.StreamTaskFactory;
+import org.apache.samza.task.TaskCoordinator;
+import org.apache.samza.test.framework.StreamApplicationIntegrationTestHarness;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+
+/**
+ * 1) Run app and consume messages
+ * 2) Commit only for first message
+ * 3) Shutdown application
+ * 4) Run app a second time to use the checkpoint
+ * 5) Verify that we had to re-process the message after the first message
+ */
+public class KafkaCheckpointManagerIntegrationTest extends 
StreamApplicationIntegrationTestHarness {
+  private static final String SYSTEM = "kafka";
+  private static final String INPUT_STREAM = "inputStream";
+  private static final Map<String, String> CONFIGS = ImmutableMap.of(
+      JobCoordinatorConfig.JOB_COORDINATOR_FACTORY, 
"org.apache.samza.standalone.PassthroughJobCoordinatorFactory",
+      JobConfig.PROCESSOR_ID, "0",
+      TaskConfig.CHECKPOINT_MANAGER_FACTORY, 
"org.apache.samza.checkpoint.kafka.KafkaCheckpointManagerFactory",
+      KafkaConfig.CHECKPOINT_REPLICATION_FACTOR(), "1",
+      TaskConfig.COMMIT_MS, "-1"); // manual commit only
+  /**
+   * Keep track of which messages have been received by the application.
+   */
+  private static final Map<String, AtomicInteger> PROCESSED = new HashMap<>();
+
+  /**
+   * If message has this prefix, then request a commit after processing it.
+   */
+  private static final String COMMIT_PREFIX = "commit";
+  /**
+   * If message equals this string, then shut down the task if the task is 
configured to handle intermediate shutdown.
+   */
+  private static final String INTERMEDIATE_SHUTDOWN = "intermediateShutdown";
+  /**
+   * If message equals this string, then shut down the task.
+   */
+  private static final String END_OF_STREAM = "endOfStream";
+
+  @Before
+  public void setup() {
+    PROCESSED.clear();
+  }
+
+  @Test
+  public void testCheckpoint() {
+    createTopic(INPUT_STREAM, 2);
+    produceMessages(0);
+    produceMessages(1);
+
+    // run application once and verify processed messages before shutdown
+    runApplication(new CheckpointApplication(true), "CheckpointApplication", 
CONFIGS).getRunner().waitForFinish();
+    verifyProcessedMessagesFirstRun();
+
+    // run application a second time and verify that certain messages had to 
be re-processed
+    runApplication(new CheckpointApplication(false), "CheckpointApplication", 
CONFIGS).getRunner().waitForFinish();
+    verifyProcessedMessagesSecondRun();
+  }
+
+  private void produceMessages(int partitionId) {
+    String key = "key" + partitionId;
+    // commit first message
+    produceMessage(INPUT_STREAM, partitionId, key, commitMessage(partitionId, 
0));
+    // don't commit second message
+    produceMessage(INPUT_STREAM, partitionId, key, 
noCommitMessage(partitionId, 1));
+    // do an initial shutdown so that the test can check that the second 
message gets re-processed
+    produceMessage(INPUT_STREAM, partitionId, key, INTERMEDIATE_SHUTDOWN);
+    // do a commit on the third message
+    produceMessage(INPUT_STREAM, partitionId, key, commitMessage(partitionId, 
2));
+    // this will make the task shut down for the second run
+    produceMessage(INPUT_STREAM, partitionId, key, END_OF_STREAM);
+  }
+
+  /**
+   * Each partition should have seen two messages before shutting down.
+   */
+  private static void verifyProcessedMessagesFirstRun() {
+    assertEquals(4, PROCESSED.size());
+    assertEquals(1, PROCESSED.get(commitMessage(0, 0)).get());
+    assertEquals(1, PROCESSED.get(noCommitMessage(0, 1)).get());
+    assertEquals(1, PROCESSED.get(commitMessage(0, 0)).get());
+    assertEquals(1, PROCESSED.get(noCommitMessage(0, 1)).get());
+  }
+
+  /**
+   * For each partition: re-process the second message (for 2 total of the 
second message), receive the third message.
+   */
+  private static void verifyProcessedMessagesSecondRun() {
+    assertEquals(6, PROCESSED.size());
+    assertEquals(1, PROCESSED.get(commitMessage(0, 0)).get());
+    assertEquals(2, PROCESSED.get(noCommitMessage(0, 1)).get());
+    assertEquals(1, PROCESSED.get(commitMessage(0, 2)).get());
+    assertEquals(1, PROCESSED.get(commitMessage(1, 0)).get());
+    assertEquals(2, PROCESSED.get(noCommitMessage(1, 1)).get());
+    assertEquals(1, PROCESSED.get(commitMessage(1, 2)).get());
+  }
+
+  private static String commitMessage(int partitionId, int messageId) {
+    return String.join("_", COMMIT_PREFIX, "partition", 
Integer.toString(partitionId), Integer.toString(messageId));
+  }
+
+  private static String noCommitMessage(int partitionId, int messageId) {
+    return String.join("_", "partition", Integer.toString(partitionId), 
Integer.toString(messageId));
+  }
+
+  private static class CheckpointApplication implements TaskApplication {
+    private final boolean handleIntermediateShutdown;
+
+    private CheckpointApplication(boolean handleIntermediateShutdown) {
+      this.handleIntermediateShutdown = handleIntermediateShutdown;
+    }
+
+    @Override
+    public void describe(TaskApplicationDescriptor appDescriptor) {
+      KafkaSystemDescriptor sd = new KafkaSystemDescriptor(SYSTEM);
+      KafkaInputDescriptor<String> isd = sd.getInputDescriptor(INPUT_STREAM, 
new StringSerde());
+      appDescriptor.withInputStream(isd)
+          .withTaskFactory((StreamTaskFactory) () -> new 
CheckpointTask(this.handleIntermediateShutdown));
+    }
+  }
+
+  private static class CheckpointTask implements StreamTask {
+    /**
+     * Determine if task should respond to {@link #INTERMEDIATE_SHUTDOWN}.
+     * Helps with testing that any uncommitted messages get reprocessed if the 
job starts again.
+     */
+    private final boolean handleIntermediateShutdown;
+    /**
+     * When requesting shutdown, there is no guarantee of an immediate 
shutdown, since there are multiple tasks in the
+     * container. Use this flag to make sure we don't process more messages 
past the shutdown request in order to have
+     * deterministic counting of the messages for the test.
+     */
+    private boolean stopProcessing = false;
+
+    private CheckpointTask(boolean handleIntermediateShutdown) {
+      this.handleIntermediateShutdown = handleIntermediateShutdown;
+    }
+
+    @Override
+    public void process(IncomingMessageEnvelope envelope, MessageCollector 
collector, TaskCoordinator coordinator) {
+      if (!this.stopProcessing) {
+        String value = (String) envelope.getMessage();
+        if (INTERMEDIATE_SHUTDOWN.equals(value)) {
+          if (this.handleIntermediateShutdown) {
+            setShutdown(coordinator);
+          }
+        } else if (END_OF_STREAM.equals(value)) {
+          setShutdown(coordinator);
+        } else {
+          synchronized (this) {
+            PROCESSED.putIfAbsent(value, new AtomicInteger(0));
+            PROCESSED.get(value).incrementAndGet();
+          }
+          if (value.startsWith(COMMIT_PREFIX)) {
+            coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+          }
+        }
+      }
+    }
+
+    private void setShutdown(TaskCoordinator coordinator) {
+      this.stopProcessing = true;
+      coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+    }
+  }
+}

Reply via email to