Repository: samza
Updated Branches:
  refs/heads/master 403042394 -> 8f7f56744


SAMZA-1384: Race condition with async commit affects checkpoint correctness

Author: Jacob Maes <jm...@linkedin.com>

Reviewers: Prateek Maheshwari <pmahe...@linkedin.com>

Closes #263 from jmakes/samza-1384


Project: http://git-wip-us.apache.org/repos/asf/samza/repo
Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/8f7f5674
Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/8f7f5674
Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/8f7f5674

Branch: refs/heads/master
Commit: 8f7f56744c2824426962e1ca4f382b1877e4a9b0
Parents: 4030423
Author: Jacob Maes <jm...@linkedin.com>
Authored: Thu Aug 10 15:44:42 2017 -0700
Committer: Jacob Maes <jm...@linkedin.com>
Committed: Thu Aug 10 15:44:42 2017 -0700

----------------------------------------------------------------------
 .../apache/samza/checkpoint/OffsetManager.scala | 50 +++++++++-----
 .../apache/samza/container/TaskInstance.scala   |  4 +-
 .../org/apache/samza/task/TestAsyncRunLoop.java | 24 ++++---
 .../samza/checkpoint/TestOffsetManager.scala    | 58 ++++++++++++++--
 .../samza/container/TestTaskInstance.scala      | 69 +++++++++++++++-----
 5 files changed, 154 insertions(+), 51 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/samza/blob/8f7f5674/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala 
b/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
index 783340a..8c739d4 100644
--- a/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
+++ b/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
@@ -20,20 +20,20 @@
 package org.apache.samza.checkpoint
 
 
+
+import java.util.HashMap
 import java.util.concurrent.ConcurrentHashMap
 
-import org.apache.samza.system.SystemStream
-import org.apache.samza.system.SystemStreamPartition
-import org.apache.samza.system.SystemStreamMetadata
-import org.apache.samza.system.SystemStreamMetadata.OffsetType
 import org.apache.samza.SamzaException
-import scala.collection.JavaConverters._
-import org.apache.samza.util.Logging
 import org.apache.samza.config.Config
 import org.apache.samza.config.StreamConfig.Config2Stream
 import org.apache.samza.config.SystemConfig.Config2System
-import org.apache.samza.system.SystemAdmin
 import org.apache.samza.container.TaskName
+import org.apache.samza.system.SystemStreamMetadata.OffsetType
+import org.apache.samza.system.{SystemAdmin, SystemStream, 
SystemStreamMetadata, SystemStreamPartition}
+import org.apache.samza.util.Logging
+
+import scala.collection.JavaConverters._
 import scala.collection._
 
 /**
@@ -217,25 +217,42 @@ class OffsetManager(
   }
 
   /**
-   * Checkpoint all offsets for a given TaskName using the CheckpointManager.
-   */
-  def checkpoint(taskName: TaskName) {
+    * Gets a snapshot of all the current offsets for the specified task. This 
is useful to
+    * ensure there are no concurrent updates to the offsets between when this 
method is
+    * invoked and the corresponding call to [[OffsetManager.writeCheckpoint()]]
+    */
+  def buildCheckpoint(taskName: TaskName): Checkpoint = {
     if (checkpointManager != null || checkpointListeners.nonEmpty) {
-      debug("Checkpointing offsets for taskName %s." format taskName)
+      debug("Getting checkpoint offsets for taskName %s." format taskName)
 
-      val sspsForTaskName = systemStreamPartitions.getOrElse(taskName, throw 
new SamzaException("No such SystemStreamPartition set " + taskName + " 
registered for this checkpointmanager")).toSet
+      val sspsForTaskName = systemStreamPartitions.getOrElse(taskName, throw 
new SamzaException("No SSPs registered for task: " + taskName)).toSet
       val sspToOffsets = lastProcessedOffsets.get(taskName)
-      val partitionOffsets = if(sspToOffsets != null) {
+
+      val partitionOffsets = if (sspToOffsets != null) {
+        // Filter the offsets in case the task model changed since the last 
checkpoint was written.
         sspToOffsets.asScala.filterKeys(sspsForTaskName.contains)
       } else {
         warn(taskName + " is not found... ")
         Map[SystemStreamPartition, String]()
       }
 
-      val checkpoint = new Checkpoint(partitionOffsets.asJava)
+      new Checkpoint(new HashMap(partitionOffsets.asJava)) // Copy into new 
Map to prevent mutation
+    } else {
+      debug("Returning null checkpoint for taskName %s because no checkpoint 
manager/callback is defined." format taskName)
+      null
+    }
+  }
+
+  /**
+    * Write the specified checkpoint for the given task.
+    */
+  def writeCheckpoint(taskName: TaskName, checkpoint: Checkpoint) {
+    if (checkpoint != null && (checkpointManager != null || 
checkpointListeners.nonEmpty)) {
+      debug("Writing checkpoint for taskName %s with offsets %s." format 
(taskName, checkpoint))
 
       if(checkpointManager != null) {
         checkpointManager.writeCheckpoint(taskName, checkpoint)
+        val sspToOffsets = checkpoint.getOffsets
         if(sspToOffsets != null) {
           sspToOffsets.asScala.foreach {
             case (ssp, cp) => 
offsetManagerMetrics.checkpointedOffsets.get(ssp).set(cp)
@@ -244,15 +261,12 @@ class OffsetManager(
       }
 
       // invoke checkpoint listeners
-      //partitionOffsets.groupBy(_._1.getSystem).foreach {
-      partitionOffsets.groupBy { case (ssp, _) => ssp.getSystem }.foreach {
+      checkpoint.getOffsets.asScala.groupBy { case (ssp, _) => ssp.getSystem 
}.foreach {
         case (systemName:String, offsets: Map[SystemStreamPartition, String]) 
=> {
           // Option is empty if there is no checkpointListener for this 
systemName
           
checkpointListeners.get(systemName).foreach(_.onCheckpoint(offsets.asJava))
         }
       }
-    } else {
-      debug("Skipping checkpointing for taskName %s because no checkpoint 
manager/callback is defined." format taskName)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/samza/blob/8f7f5674/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala 
b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
index 84e993b..65fefda 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
@@ -206,6 +206,8 @@ class TaskInstance(
   def commit {
     metrics.commits.inc
 
+    val checkpoint = offsetManager.buildCheckpoint(taskName)
+
     trace("Flushing producers for taskName: %s" format taskName)
 
     collector.flush
@@ -218,7 +220,7 @@ class TaskInstance(
 
     trace("Checkpointing offsets for taskName: %s" format taskName)
 
-    offsetManager.checkpoint(taskName)
+    offsetManager.writeCheckpoint(taskName, checkpoint)
   }
 
   def shutdownTask {

http://git-wip-us.apache.org/repos/asf/samza/blob/8f7f5674/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java 
b/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
index 1afc26a..6694f26 100644
--- a/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
+++ b/samza-core/src/test/java/org/apache/samza/task/TestAsyncRunLoop.java
@@ -50,8 +50,7 @@ import scala.Option;
 import scala.collection.JavaConverters;
 
 import static org.junit.Assert.assertEquals;
-import static org.mockito.Matchers.anyLong;
-import static org.mockito.Matchers.anyObject;
+import static org.mockito.Matchers.*;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -368,8 +367,10 @@ public class TestAsyncRunLoop {
     task0ProcessedMessagesLatch.await();
     task1ProcessedMessagesLatch.await();
 
-    verify(offsetManager).checkpoint(taskName0);
-    verify(offsetManager, never()).checkpoint(taskName1);
+    verify(offsetManager).buildCheckpoint(taskName0);
+    verify(offsetManager).writeCheckpoint(taskName0, any());
+    verify(offsetManager, never()).buildCheckpoint(taskName1);
+    verify(offsetManager, never()).writeCheckpoint(taskName1, any());
   }
 
   //@Test
@@ -398,8 +399,10 @@ public class TestAsyncRunLoop {
     task0ProcessedMessagesLatch.await();
     task1ProcessedMessagesLatch.await();
 
-    verify(offsetManager).checkpoint(taskName0);
-    verify(offsetManager).checkpoint(taskName1);
+    verify(offsetManager).buildCheckpoint(taskName0);
+    verify(offsetManager).writeCheckpoint(taskName0, any());
+    verify(offsetManager).buildCheckpoint(taskName1);
+    verify(offsetManager).writeCheckpoint(taskName1, any());
   }
 
   //@Test
@@ -552,8 +555,10 @@ public class TestAsyncRunLoop {
     task0ProcessedMessagesLatch.await();
     task1ProcessedMessagesLatch.await();
 
-    verify(offsetManager).checkpoint(taskName0);
-    verify(offsetManager).checkpoint(taskName1);
+    verify(offsetManager).buildCheckpoint(taskName0);
+    verify(offsetManager).writeCheckpoint(taskName0, any());
+    verify(offsetManager).buildCheckpoint(taskName1);
+    verify(offsetManager).writeCheckpoint(taskName1, any());
   }
 
   // TODO: Add assertions.
@@ -641,7 +646,8 @@ public class TestAsyncRunLoop {
           secondMsgCompletionLatch.countDown();
           // OffsetManager.update with firstMsg offset, task.commit has 
happened when second message callback has not completed.
           verify(offsetManager).update(taskName0, 
firstMsg.getSystemStreamPartition(), firstMsg.getOffset());
-          verify(offsetManager, atLeastOnce()).checkpoint(taskName0);
+          verify(offsetManager, atLeastOnce()).buildCheckpoint(taskName0);
+          verify(offsetManager, atLeastOnce()).writeCheckpoint(taskName0, 
any());
         }
       } catch (Exception e) {
         e.printStackTrace();

http://git-wip-us.apache.org/repos/asf/samza/blob/8f7f5674/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala 
b/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
index abfc63f..54a08f6 100644
--- 
a/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
+++ 
b/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
@@ -82,7 +82,7 @@ class TestOffsetManager {
     assertEquals("46", offsetManager.getStartingOffset(taskName, 
systemStreamPartition).get)
     // Should not update null offset
     offsetManager.update(taskName, systemStreamPartition, null)
-    offsetManager.checkpoint(taskName)
+    checkpoint(offsetManager, taskName)
     val expectedCheckpoint = new Checkpoint(Map(systemStreamPartition -> 
"47").asJava)
     assertEquals(expectedCheckpoint, 
checkpointManager.readLastCheckpoint(taskName))
   }
@@ -102,14 +102,14 @@ class TestOffsetManager {
     offsetManager.register(taskName, Set(systemStreamPartition))
     offsetManager.start
     // Should get offset 45 back from the checkpoint manager, which is last 
processed, and system admin should return 46 as starting offset.
-    offsetManager.checkpoint(taskName)
+    checkpoint(offsetManager, taskName)
     assertEquals("45", 
offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
     offsetManager.update(taskName, systemStreamPartition, "46")
     offsetManager.update(taskName, systemStreamPartition, "47")
-    offsetManager.checkpoint(taskName)
+    checkpoint(offsetManager, taskName)
     assertEquals("47", 
offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
     offsetManager.update(taskName, systemStreamPartition, "48")
-    offsetManager.checkpoint(taskName)
+    checkpoint(offsetManager, taskName)
     assertEquals("48", 
offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
   }
 
@@ -269,7 +269,7 @@ class TestOffsetManager {
 
     offsetManager.start
     // Should get offset 45 back from the checkpoint manager, which is last 
processed, and system admin should return 46 as starting offset.
-    offsetManager.checkpoint(taskName)
+    checkpoint(offsetManager, taskName)
     assertEquals("45", 
offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
     assertEquals("100", 
offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition2).getValue)
     assertEquals("45", consumer.recentCheckpoint.get(systemStreamPartition))
@@ -278,7 +278,7 @@ class TestOffsetManager {
 
     offsetManager.update(taskName, systemStreamPartition, "46")
     offsetManager.update(taskName, systemStreamPartition, "47")
-    offsetManager.checkpoint(taskName)
+    checkpoint(offsetManager, taskName)
     assertEquals("47", 
offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
     assertEquals("100", 
offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition2).getValue)
     assertEquals("47", consumer.recentCheckpoint.get(systemStreamPartition))
@@ -286,7 +286,7 @@ class TestOffsetManager {
 
     offsetManager.update(taskName, systemStreamPartition, "48")
     offsetManager.update(taskName, systemStreamPartition2, "101")
-    offsetManager.checkpoint(taskName)
+    checkpoint(offsetManager, taskName)
     assertEquals("48", 
offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
     assertEquals("101", 
offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition2).getValue)
     assertEquals("48", consumer.recentCheckpoint.get(systemStreamPartition))
@@ -294,6 +294,50 @@ class TestOffsetManager {
     offsetManager.stop
   }
 
+  /**
+    * If task.max.concurrency > 1 and task.async.commit == true, a task could 
update its offsets at the same time as
+    * TaskInstance.commit(). This makes it possible to checkpoint offsets 
which did not successfully flush.
+    *
+    * This is prevented by using separate methods to get a checkpoint and 
write that checkpoint. See SAMZA-1384
+    */
+  @Test
+  def testConcurrentCheckpointAndUpdate{
+    val taskName = new TaskName("c")
+    val systemStream = new SystemStream("test-system", "test-stream")
+    val partition = new Partition(0)
+    val systemStreamPartition = new SystemStreamPartition(systemStream, 
partition)
+    val testStreamMetadata = new SystemStreamMetadata(systemStream.getStream, 
Map(partition -> new SystemStreamPartitionMetadata("0", "1", "2")).asJava)
+    val systemStreamMetadata = Map(systemStream -> testStreamMetadata)
+    val checkpointManager = getCheckpointManager(systemStreamPartition, 
taskName)
+    val systemAdmins = Map("test-system" -> getSystemAdmin)
+    val offsetManager = OffsetManager(systemStreamMetadata, new MapConfig, 
checkpointManager, systemAdmins, Map(), new OffsetManagerMetrics)
+    offsetManager.register(taskName, Set(systemStreamPartition))
+    offsetManager.start
+
+    // Should get offset 45 back from the checkpoint manager, which is last 
processed, and system admin should return 46 as starting offset.
+    checkpoint(offsetManager, taskName)
+    assertEquals("45", 
offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
+
+    offsetManager.update(taskName, systemStreamPartition, "46")
+    // Get checkpoint snapshot like we do at the beginning of 
TaskInstance.commit()
+    val checkpoint46 = offsetManager.buildCheckpoint(taskName)
+    offsetManager.update(taskName, systemStreamPartition, "47") // Offset 
updated before checkpoint
+    offsetManager.writeCheckpoint(taskName, checkpoint46)
+    assertEquals(Some("47"), offsetManager.getLastProcessedOffset(taskName, 
systemStreamPartition))
+    assertEquals("46", 
offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
+
+    // Now write the checkpoint for the latest offset
+    val checkpoint47 = offsetManager.buildCheckpoint(taskName)
+    offsetManager.writeCheckpoint(taskName, checkpoint47)
+    assertEquals(Some("47"), offsetManager.getLastProcessedOffset(taskName, 
systemStreamPartition))
+    assertEquals("47", 
offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
+  }
+
+  // Utility method to create and write checkpoint in one statement
+  def checkpoint(offsetManager: OffsetManager, taskName: TaskName): Unit = {
+    offsetManager.writeCheckpoint(taskName, 
offsetManager.buildCheckpoint(taskName))
+  }
+
   class SystemConsumerWithCheckpointCallback extends SystemConsumer with 
CheckpointListener{
     var recentCheckpoint: java.util.Map[SystemStreamPartition, String] = 
java.util.Collections.emptyMap[SystemStreamPartition, String]
     override def start() {}

http://git-wip-us.apache.org/repos/asf/samza/blob/8f7f5674/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala 
b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
index 40974a6..4958a57 100644
--- 
a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
+++ 
b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
@@ -20,31 +20,25 @@
 package org.apache.samza.container
 
 import java.util.concurrent.ConcurrentHashMap
-import org.apache.samza.SamzaException
+
 import org.apache.samza.Partition
-import org.apache.samza.checkpoint.OffsetManager
-import org.apache.samza.config.Config
-import org.apache.samza.config.MapConfig
-import org.apache.samza.metrics.Counter
-import org.apache.samza.metrics.Metric
-import org.apache.samza.metrics.MetricsRegistryMap
+import org.apache.samza.checkpoint.{Checkpoint, OffsetManager}
+import org.apache.samza.config.{Config, MapConfig}
+import org.apache.samza.metrics.{Counter, Metric, MetricsRegistryMap}
 import org.apache.samza.serializers.SerdeManager
-import org.apache.samza.system.IncomingMessageEnvelope
-import org.apache.samza.system.SystemConsumer
-import org.apache.samza.system.SystemConsumers
-import org.apache.samza.system.SystemProducer
-import org.apache.samza.system.SystemProducers
-import org.apache.samza.system.SystemStream
-import org.apache.samza.system.SystemStreamMetadata
+import org.apache.samza.storage.TaskStorageManager
 import 
org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata
-import org.apache.samza.system.SystemStreamPartition
+import org.apache.samza.system._
 import org.apache.samza.system.chooser.RoundRobinChooser
 import org.apache.samza.task._
 import org.junit.Assert._
 import org.junit.Test
+import org.mockito.Matchers._
+import org.mockito.Mockito
+import org.mockito.Mockito._
 import org.scalatest.Assertions.intercept
+
 import scala.collection.JavaConverters._
-import org.apache.samza.system.SystemAdmin
 import scala.collection.mutable.ListBuffer
 
 class TestTaskInstance {
@@ -354,6 +348,49 @@ class TestTaskInstance {
     val expected = List(envelope1, envelope2, envelope4)
     assertEquals(expected, result.toList)
   }
+
+  @Test
+  def testCommitOrder {
+    // Simple objects
+    val partition = new Partition(0)
+    val taskName = new TaskName("taskName")
+    val systemStream = new SystemStream("test-system", "test-stream")
+    val systemStreamPartition = new SystemStreamPartition(systemStream, 
partition)
+    val checkpoint = new Checkpoint(Map(systemStreamPartition -> "4").asJava)
+
+    // Mocks
+    val collector = Mockito.mock(classOf[TaskInstanceCollector])
+    val storageManager = Mockito.mock(classOf[TaskStorageManager])
+    val offsetManager = Mockito.mock(classOf[OffsetManager])
+    when(offsetManager.buildCheckpoint(any())).thenReturn(checkpoint)
+    val mockOrder = inOrder(offsetManager, collector, storageManager)
+
+    val taskInstance: TaskInstance = new TaskInstance(
+      Mockito.mock(classOf[StreamTask]).asInstanceOf[StreamTask],
+      taskName,
+      new MapConfig,
+      new TaskInstanceMetrics,
+      null,
+      Mockito.mock(classOf[SystemConsumers]),
+      collector,
+      Mockito.mock(classOf[SamzaContainerContext]),
+      offsetManager,
+      storageManager,
+      systemStreamPartitions = Set(systemStreamPartition))
+
+    taskInstance.commit
+
+    // We must first get a snapshot of the checkpoint so it doesn't change 
while we flush. SAMZA-1384
+    mockOrder.verify(offsetManager).buildCheckpoint(taskName)
+    // Producers must be flushed next and ideally the output would be flushed 
before the changelog
+    // s.t. the changelog and checkpoints (state and inputs) are captured last
+    mockOrder.verify(collector).flush
+    // Local state is next, to ensure that the state (particularly the offset 
file) never points to a newer changelog
+    // offset than what is reflected in the on disk state.
+    mockOrder.verify(storageManager).flush()
+    // Finally, checkpoint the inputs with the snapshotted checkpoint captured 
at the beginning of commit
+    mockOrder.verify(offsetManager).writeCheckpoint(taskName, checkpoint)
+  }
 }
 
 class MockSystemAdmin extends SystemAdmin {

Reply via email to