This is an automated email from the ASF dual-hosted git repository.

ajothomas pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/samza.git


The following commit(s) were added to refs/heads/master by this push:
     new 63c86b5e6 Add new configuration allowing to keep processing when there 
are fatal exceptions or timeout (#1708)
63c86b5e6 is described below

commit 63c86b5e661b0a0a9a0b33a0851379b8786ee36e
Author: Haolan Ye <[email protected]>
AuthorDate: Mon Nov 25 10:56:07 2024 -0800

    Add new configuration allowing to keep processing when there are fatal 
exceptions or timeout (#1708)
---
 .../java/org/apache/samza/config/TaskConfig.java   |  27 +++
 .../org/apache/samza/container/TaskInstance.scala  |  57 ++++-
 .../samza/container/TaskInstanceMetrics.scala      |   6 +-
 .../apache/samza/container/TestTaskInstance.scala  | 249 +++++++++++++++++++--
 4 files changed, 313 insertions(+), 26 deletions(-)

diff --git a/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java 
b/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
index 0f168be18..276f3812f 100644
--- a/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
+++ b/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
@@ -65,6 +65,21 @@ public class TaskConfig extends MapConfig {
   public static final String COMMIT_TIMEOUT_MS = "task.commit.timeout.ms";
   static final long DEFAULT_COMMIT_TIMEOUT_MS = 
Duration.ofMinutes(30).toMillis();
 
+  // Flag to indicate whether to skip commit during failures (exceptions or 
timeouts)
+  // The number of allowed successive commit exceptions and timeouts are 
controlled by the following two configs.
+  public static final String SKIP_COMMIT_DURING_FAILURES_ENABLED = 
"task.commit.skip.commit.during.failures.enabled";
+  private static final boolean DEFAULT_SKIP_COMMIT_DURING_FAILURES_ENABLED = 
false;
+
+  // Maximum number of allowed successive commit exceptions.
+  // If the number of successive commit exceptions exceeds this limit, the 
task will be shut down.
+  public static final String SKIP_COMMIT_EXCEPTION_MAX_LIMIT = 
"task.commit.skip.commit.exception.max.limit";
+  private static final int DEFAULT_SKIP_COMMIT_EXCEPTION_MAX_LIMIT = 5;
+
+  // Maximum number of allowed successive commit timeouts.
+  // If the number of successive commit timeout exceeds this limit, the task 
will be shut down.
+  public static final String SKIP_COMMIT_TIMEOUT_MAX_LIMIT = 
"task.commit.skip.commit.timeout.max.limit";
+  private static final int DEFAULT_SKIP_COMMIT_TIMEOUT_MAX_LIMIT = 2;
+
   // how long to wait for a clean shutdown
   public static final String TASK_SHUTDOWN_MS = "task.shutdown.ms";
   static final long DEFAULT_TASK_SHUTDOWN_MS = 30000L;
@@ -418,4 +433,16 @@ public class TaskConfig extends MapConfig {
   public double getWatermarkQuorumSizePercentage() {
     return getDouble(WATERMARK_QUORUM_SIZE_PERCENTAGE, 
DEFAULT_WATERMARK_QUORUM_SIZE_PERCENTAGE);
   }
+
+  public boolean getSkipCommitDuringFailuresEnabled() {
+    return getBoolean(SKIP_COMMIT_DURING_FAILURES_ENABLED, 
DEFAULT_SKIP_COMMIT_DURING_FAILURES_ENABLED);
+  }
+
+  public int getSkipCommitExceptionMaxLimit() {
+    return getInt(SKIP_COMMIT_EXCEPTION_MAX_LIMIT, 
DEFAULT_SKIP_COMMIT_EXCEPTION_MAX_LIMIT);
+  }
+
+  public int getSkipCommitTimeoutMaxLimit() {
+    return getInt(SKIP_COMMIT_TIMEOUT_MAX_LIMIT, 
DEFAULT_SKIP_COMMIT_TIMEOUT_MAX_LIMIT);
+  }
 }
diff --git 
a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala 
b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
index 70d9ca380..f5d13106f 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
@@ -38,7 +38,7 @@ import 
org.apache.samza.util.ScalaJavaUtil.JavaOptionals.toRichOptional
 import org.apache.samza.util.{Logging, ReflectionUtil, ScalaJavaUtil}
 
 import java.util
-import java.util.concurrent.atomic.AtomicReference
+import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
 import java.util.function.BiConsumer
 import java.util.function.Function
 import scala.collection.JavaConversions._
@@ -133,8 +133,13 @@ class TaskInstance(
   val checkpointWriteVersions = new 
TaskConfig(config).getCheckpointWriteVersions
 
   @volatile var lastCommitStartTimeMs = System.currentTimeMillis()
+  val commitExceptionCounter = new AtomicInteger(0)
+  val commitTimeoutCounter = new AtomicInteger(0)
   val commitMaxDelayMs = taskConfig.getCommitMaxDelayMs
   val commitTimeoutMs = taskConfig.getCommitTimeoutMs
+  val skipCommitDuringFailureEnabled = 
taskConfig.getSkipCommitDuringFailuresEnabled
+  val skipCommitExceptionMaxLimit = taskConfig.getSkipCommitExceptionMaxLimit
+  val skipCommitTimeoutMaxLimit = taskConfig.getSkipCommitTimeoutMaxLimit
   val commitInProgress = new Semaphore(1)
   val commitException = new AtomicReference[Exception]()
 
@@ -312,10 +317,22 @@ class TaskInstance(
 
     val commitStartNs = System.nanoTime()
     // first check if there were any unrecoverable errors during the async 
stage of the pending commit
-    // and if so, shut down the container.
+    // If there is unrecoverable error, increment the metric and the counter.
+    // Shutdown the container in the following scenarios:
+    // 1. skipCommitDuringFailureEnabled is not enabled
+    // 2. skipCommitDuringFailureEnabled is enabled but the number of 
exceptions exceeded the max count
+    // Otherwise, ignore the exception.
     if (commitException.get() != null) {
-      throw new SamzaException("Unrecoverable error during pending commit for 
taskName: %s." format taskName,
-        commitException.get())
+      metrics.commitExceptions.inc()
+      commitExceptionCounter.incrementAndGet()
+      if (!skipCommitDuringFailureEnabled || commitExceptionCounter.get() > 
skipCommitExceptionMaxLimit) {
+        throw new SamzaException("Unrecoverable error during pending commit 
for taskName: %s. Exception Counter: %s"
+          format (taskName, commitExceptionCounter.get()), 
commitException.get())
+      } else {
+        warn("Ignored the commit failure for taskName %s. Exception Counter: 
%s."
+          format (taskName, commitExceptionCounter.get()), 
commitException.get())
+        commitException.set(null)
+      }
     }
 
     // if no commit is in progress for this task, continue with this commit.
@@ -328,7 +345,7 @@ class TaskInstance(
       if (timeSinceLastCommit < commitMaxDelayMs) {
         info("Skipping commit for taskName: %s since another commit is in 
progress. " +
           "%s ms have elapsed since the pending commit started." format 
(taskName, timeSinceLastCommit))
-        metrics.commitsSkipped.set(metrics.commitsSkipped.getValue + 1)
+        metrics.commitsSkipped.inc()
         return
       } else {
         warn("Blocking processing for taskName: %s until in-flight commit is 
complete. " +
@@ -336,13 +353,28 @@ class TaskInstance(
           "which is greater than the max allowed commit delay: %s."
           format (taskName, timeSinceLastCommit, commitMaxDelayMs))
 
+        // Wait for the previous commit to complete within the timeout.
+        // If it doesn't complete within the timeout, increment metric and the 
counter.
+        // Shutdown the container in the following scenarios:
+        // 1. skipCommitDuringFailureEnabled is not enabled
+        // 2. skipCommitDuringFailureEnabled is enabled but the number of 
timeouts exceeded the max count
+        // Otherwise, ignore the timeout.
         if (!commitInProgress.tryAcquire(commitTimeoutMs, 
TimeUnit.MILLISECONDS)) {
           val timeSinceLastCommit = System.currentTimeMillis() - 
lastCommitStartTimeMs
-          metrics.commitsTimedOut.set(metrics.commitsTimedOut.getValue + 1)
-          throw new SamzaException("Timeout waiting for pending commit for 
taskName: %s to finish. " +
-            "%s ms have elapsed since the pending commit started. Max allowed 
commit delay is %s ms " +
-            "and commit timeout beyond that is %s ms" format (taskName, 
timeSinceLastCommit,
-            commitMaxDelayMs, commitTimeoutMs))
+          metrics.commitsTimedOut.inc()
+          commitTimeoutCounter.incrementAndGet()
+          if (!skipCommitDuringFailureEnabled || commitTimeoutCounter.get() > 
skipCommitTimeoutMaxLimit) {
+            throw new SamzaException("Timeout waiting for pending commit for 
taskName: %s to finish. " +
+              "%s ms have elapsed since the pending commit started. Max 
allowed commit delay is %s ms " +
+              "and commit timeout beyond that is %s ms. Timeout Counter: %s" 
format (taskName, timeSinceLastCommit,
+              commitMaxDelayMs, commitTimeoutMs, commitTimeoutCounter.get()))
+          } else {
+            warn("Ignoring commit timeout for taskName: %s. %s ms have elapsed 
since another commit started. " +
+              "Max allowed commit delay is %s ms and commit timeout beyond 
that is %s ms. Timeout Counter: %s."
+              format (taskName, timeSinceLastCommit, commitMaxDelayMs, 
commitTimeoutMs, commitTimeoutCounter.get()))
+            commitInProgress.release()
+            return
+          }
         }
       }
     }
@@ -426,7 +458,7 @@ class TaskInstance(
       }
     })
 
-    metrics.lastCommitNs.set(System.nanoTime() - commitStartNs)
+    metrics.lastCommitNs.set(System.nanoTime())
     metrics.commitSyncNs.update(System.nanoTime() - commitStartNs)
     debug("Finishing sync stage of commit for taskName: %s checkpointId: %s" 
format (taskName, checkpointId))
   }
@@ -531,8 +563,11 @@ class TaskInstance(
                 "Saved exception under Caused By.", commitException.get())
             }
           } else {
+            commitExceptionCounter.set(0)
+            commitTimeoutCounter.set(0)
             metrics.commitAsyncNs.update(System.nanoTime() - asyncStageStartNs)
             metrics.commitNs.update(System.nanoTime() - commitStartNs)
+            metrics.lastCommitAsyncTimestamp.set(System.nanoTime())
           }
         } finally {
           // release the permit indicating that previous commit is complete.
diff --git 
a/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
 
b/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
index 54d366525..02674fb7e 100644
--- 
a/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
+++ 
b/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
@@ -38,10 +38,12 @@ class TaskInstanceMetrics(
   val pendingMessages = newGauge("pending-messages", 0)
   val messagesInFlight = newGauge("messages-in-flight", 0)
   val asyncCallbackCompleted = newCounter("async-callback-complete-calls")
-  val commitsTimedOut = newGauge("commits-timed-out", 0)
-  val commitsSkipped = newGauge("commits-skipped", 0)
+  val commitsTimedOut = newCounter("commits-timed-out")
+  val commitsSkipped = newCounter("commits-skipped")
+  val commitExceptions = newCounter("commit-exceptions")
   val commitNs = newTimer("commit-ns")
   val lastCommitNs = newGauge("last-commit-ns", 0L)
+  val lastCommitAsyncTimestamp = newGauge("last-async-commit-timestamp", 0L)
   val commitSyncNs = newTimer("commit-sync-ns")
   val commitAsyncNs = newTimer("commit-async-ns")
   val snapshotNs = newTimer("snapshot-ns")
diff --git 
a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala 
b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
index 6afec52e7..ff52b4006 100644
--- 
a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
+++ 
b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
@@ -277,7 +277,7 @@ class TestTaskInstance extends AssertionsForJUnit with 
MockitoSugar {
     when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
     val cleanUpTimer = mock[Timer]
     when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
-    val skippedCounter = mock[Gauge[Int]]
+    val skippedCounter = mock[Counter]
     when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
     val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
     inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
@@ -370,7 +370,7 @@ class TestTaskInstance extends AssertionsForJUnit with 
MockitoSugar {
     when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
     val uploadTimer = mock[Timer]
     when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
-    val skippedCounter = mock[Gauge[Int]]
+    val skippedCounter = mock[Counter]
     when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
 
     val inputOffsets = Map(SYSTEM_STREAM_PARTITION -> "4").asJava
@@ -431,7 +431,7 @@ class TestTaskInstance extends AssertionsForJUnit with 
MockitoSugar {
     when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
     val uploadTimer = mock[Timer]
     when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
-    val skippedCounter = mock[Gauge[Int]]
+    val skippedCounter = mock[Counter]
     when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
     val lastCommitGauge = mock[Gauge[Long]]
     when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
@@ -504,10 +504,12 @@ class TestTaskInstance extends AssertionsForJUnit with 
MockitoSugar {
     when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
     val cleanUpTimer = mock[Timer]
     when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
-    val skippedCounter = mock[Gauge[Int]]
+    val skippedCounter = mock[Counter]
     when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
     val lastCommitGauge = mock[Gauge[Long]]
     when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
+    val commitExceptionsGauge = mock[Counter]
+    when(this.metrics.commitExceptions).thenReturn(commitExceptionsGauge)
 
     val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
     inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
@@ -556,10 +558,12 @@ class TestTaskInstance extends AssertionsForJUnit with 
MockitoSugar {
     when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
     val cleanUpTimer = mock[Timer]
     when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
-    val skippedCounter = mock[Gauge[Int]]
+    val skippedCounter = mock[Counter]
     when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
     val lastCommitGauge = mock[Gauge[Long]]
     when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
+    val commitExceptionsGauge = mock[Counter]
+    when(this.metrics.commitExceptions).thenReturn(commitExceptionsGauge)
 
     val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
     inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
@@ -608,10 +612,12 @@ class TestTaskInstance extends AssertionsForJUnit with 
MockitoSugar {
     when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
     val cleanUpTimer = mock[Timer]
     when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
-    val skippedCounter = mock[Gauge[Int]]
+    val skippedCounter = mock[Counter]
     when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
     val lastCommitGauge = mock[Gauge[Long]]
     when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
+    val commitExceptionsGauge = mock[Counter]
+    when(this.metrics.commitExceptions).thenReturn(commitExceptionsGauge)
 
     val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
     inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
@@ -661,10 +667,12 @@ class TestTaskInstance extends AssertionsForJUnit with 
MockitoSugar {
     when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
     val cleanUpTimer = mock[Timer]
     when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
-    val skippedCounter = mock[Gauge[Int]]
+    val skippedCounter = mock[Counter]
     when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
     val lastCommitGauge = mock[Gauge[Long]]
     when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
+    val commitExceptionsGauge = mock[Counter]
+    when(this.metrics.commitExceptions).thenReturn(commitExceptionsGauge)
 
     val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
     inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
@@ -714,10 +722,12 @@ class TestTaskInstance extends AssertionsForJUnit with 
MockitoSugar {
     when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
     val cleanUpTimer = mock[Timer]
     when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
-    val skippedCounter = mock[Gauge[Int]]
+    val skippedCounter = mock[Counter]
     when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
     val lastCommitGauge = mock[Gauge[Long]]
     when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
+    val commitExceptionsGauge = mock[Counter]
+    when(this.metrics.commitExceptions).thenReturn(commitExceptionsGauge)
 
     val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
     inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
@@ -768,7 +778,7 @@ class TestTaskInstance extends AssertionsForJUnit with 
MockitoSugar {
     when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
     val cleanUpTimer = mock[Timer]
     when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
-    val skippedCounter = mock[Gauge[Int]]
+    val skippedCounter = mock[Counter]
     when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
     val lastCommitGauge = mock[Gauge[Long]]
     when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
@@ -828,7 +838,7 @@ class TestTaskInstance extends AssertionsForJUnit with 
MockitoSugar {
     when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
     val cleanUpTimer = mock[Timer]
     when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
-    val skippedCounter = mock[Gauge[Int]]
+    val skippedCounter = mock[Counter]
     when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
     val lastCommitGauge = mock[Gauge[Long]]
     when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
@@ -859,7 +869,7 @@ class TestTaskInstance extends AssertionsForJUnit with 
MockitoSugar {
 
     taskInstance.commit
 
-    verify(skippedCounter).set(1)
+    verify(skippedCounter, times(1)).inc()
 
     verify(commitsCounter, times(1)).inc() // should only have been 
incremented once on the initial commit
     verify(snapshotTimer).update(anyLong())
@@ -884,7 +894,7 @@ class TestTaskInstance extends AssertionsForJUnit with 
MockitoSugar {
     when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
     val cleanUpTimer = mock[Timer]
     when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
-    val skippedCounter = mock[Gauge[Int]]
+    val skippedCounter = mock[Counter]
     when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
     val lastCommitGauge = mock[Gauge[Long]]
     when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
@@ -947,7 +957,7 @@ class TestTaskInstance extends AssertionsForJUnit with 
MockitoSugar {
     when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
     val cleanUpTimer = mock[Timer]
     when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
-    val skippedCounter = mock[Gauge[Int]]
+    val skippedCounter = mock[Counter]
     when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
     val lastCommitGauge = mock[Gauge[Long]]
     when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
@@ -1004,6 +1014,208 @@ class TestTaskInstance extends AssertionsForJUnit with 
MockitoSugar {
     verify(snapshotTimer, times(2)).update(anyLong())
   }
 
+  @Test
+  def testSkipExceptionFromFirstCommitAndContinueSecondCommit(): Unit = {
+    val commitsCounter = mock[Counter]
+    when(this.metrics.commits).thenReturn(commitsCounter)
+    val snapshotTimer = mock[Timer]
+    when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+    val uploadTimer = mock[Timer]
+    when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+    val commitTimer = mock[Timer]
+    when(this.metrics.commitNs).thenReturn(commitTimer)
+    val commitSyncTimer = mock[Timer]
+    when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+    val commitAsyncTimer = mock[Timer]
+    when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+    val cleanUpTimer = mock[Timer]
+    when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
+    val skippedCounter = mock[Counter]
+    when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
+    val lastCommitGauge = mock[Gauge[Long]]
+    when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
+    val commitExceptionCounter = mock[Counter]
+    when(this.metrics.commitExceptions).thenReturn(commitExceptionCounter)
+
+    val taskConfigsMap = new util.HashMap[String, String]()
+    taskConfigsMap.put("task.commit.ms", "-1")
+    taskConfigsMap.put("task.commit.max.delay.ms", "-1")
+    taskConfigsMap.put("task.commit.timeout.ms", "2000000")
+    // skip commit if exception occurs during the commit
+    taskConfigsMap.put("task.commit.skip.commit.during.failures.enabled", 
"true")
+    // should throw exception if second commit exception occurs
+    taskConfigsMap.put("task.commit.skip.commit.exception.max.limit", "1")
+    when(this.jobContext.getConfig).thenReturn(new MapConfig(taskConfigsMap))
+    setupTaskInstance(None, ForkJoinPool.commonPool())
+
+    val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+    inputOffsets.put(SYSTEM_STREAM_PARTITION, "4")
+    val stateCheckpointMarkers: util.Map[String, String] = new 
util.HashMap[String, String]()
+    
when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+    // Ensure the second commit proceeds without exceptions
+    when(this.taskCommitManager.upload(any(), any()))
+      .thenReturn(CompletableFuture.completedFuture(
+        
Collections.singletonMap(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME,
 stateCheckpointMarkers)))
+    //  exception during the first commit
+    when(this.taskCommitManager.upload(any(), any()))
+      .thenReturn(FutureUtil.failedFuture[util.Map[String, util.Map[String, 
String]]](new RuntimeException))
+
+    // First commit fails but should not throw exception
+    taskInstance.commit
+    verify(commitsCounter).inc()
+    verify(snapshotTimer).update(anyLong())
+    verifyZeroInteractions(uploadTimer)
+    verifyZeroInteractions(commitTimer)
+    verifyZeroInteractions(skippedCounter)
+    waitForCommitExceptionIsSet(100, 5)
+    // Second commit should succeed
+    taskInstance.commit
+    verify(commitsCounter, times(2)).inc() // should only have been 
incremented twice - once for each commit
+    verify(commitExceptionCounter).inc()
+  }
+
+  @Test
+  def testCommitThrowsIfAllowSkipCommitButExceptionCountReachMaxLimit(): Unit 
= {
+    val commitsCounter = mock[Counter]
+    when(this.metrics.commits).thenReturn(commitsCounter)
+    val snapshotTimer = mock[Timer]
+    when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+    val uploadTimer = mock[Timer]
+    when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+    val commitTimer = mock[Timer]
+    when(this.metrics.commitNs).thenReturn(commitTimer)
+    val commitSyncTimer = mock[Timer]
+    when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+    val commitAsyncTimer = mock[Timer]
+    when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+    val cleanUpTimer = mock[Timer]
+    when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
+    val skippedCounter = mock[Counter]
+    when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
+    val lastCommitGauge = mock[Gauge[Long]]
+    when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
+    val commitExceptionCounter = mock[Counter]
+    when(this.metrics.commitExceptions).thenReturn(commitExceptionCounter)
+
+    val taskConfigsMap = new util.HashMap[String, String]()
+    taskConfigsMap.put("task.commit.ms", "-1")
+    taskConfigsMap.put("task.commit.max.delay.ms", "-1")
+    taskConfigsMap.put("task.commit.timeout.ms", "2000000")
+    // skip commit if exception occurs during the commit
+    taskConfigsMap.put("task.commit.skip.commit.during.failures.enabled", 
"true")
+    // should throw exception if second commit exception occurs
+    taskConfigsMap.put("task.commit.skip.commit.exception.max.limit", "1")
+    when(this.jobContext.getConfig).thenReturn(new MapConfig(taskConfigsMap))
+    setupTaskInstance(None, ForkJoinPool.commonPool())
+
+    val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+    inputOffsets.put(SYSTEM_STREAM_PARTITION, "4")
+    
when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+    // exception for commits
+    when(this.taskCommitManager.upload(any(), any()))
+      .thenReturn(FutureUtil.failedFuture[util.Map[String, util.Map[String, 
String]]](new RuntimeException))
+
+    // First commit fails but should not throw exception
+    taskInstance.commit
+    waitForCommitExceptionIsSet(100, 5)
+    // Second commit fails but should not throw exception
+    taskInstance.commit
+    verify(commitExceptionCounter).inc()
+    verify(commitsCounter, times(2)).inc()
+    verify(snapshotTimer, times(2)).update(anyLong())
+    verifyZeroInteractions(uploadTimer)
+    verifyZeroInteractions(commitTimer)
+    verifyZeroInteractions(skippedCounter)
+    waitForCommitExceptionIsSet(100, 5)
+    // third commit should fail as the the commit exception counter is greater 
than the max limit
+    try {
+      taskInstance.commit
+      fail("Should have thrown an exception if exception count reached the max 
limit.")
+    } catch {
+      case e: Exception =>
+        // expected
+    }
+    verify(commitExceptionCounter, times(2)).inc()
+    verify(commitsCounter, times(2)).inc()
+  }
+
+  @Test
+  def testCommitThrowsIfAllowSkipTimeoutButTimeoutCountReachMaxLimit(): Unit = 
{
+    val commitsCounter = mock[Counter]
+    when(this.metrics.commits).thenReturn(commitsCounter)
+    val snapshotTimer = mock[Timer]
+    when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+    val commitTimer = mock[Timer]
+    when(this.metrics.commitNs).thenReturn(commitTimer)
+    val commitSyncTimer = mock[Timer]
+    when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+    val commitAsyncTimer = mock[Timer]
+    when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+    val uploadTimer = mock[Timer]
+    when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+    val cleanUpTimer = mock[Timer]
+    when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
+    val skippedCounter = mock[Counter]
+    when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
+    val commitsTimedOutCounter = mock[Counter]
+    when(this.metrics.commitsTimedOut).thenReturn(commitsTimedOutCounter)
+    val lastCommitGauge = mock[Gauge[Long]]
+    when(this.metrics.lastCommitNs).thenReturn(lastCommitGauge)
+    val commitExceptionCounter = mock[Counter]
+    when(this.metrics.commitExceptions).thenReturn(commitExceptionCounter)
+
+    val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+    inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
+    val changelogSSP = new SystemStreamPartition(new SystemStream(SYSTEM_NAME, 
"test-changelog-stream"), new Partition(0))
+
+    val stateCheckpointMarkers: util.Map[String, String] = new 
util.HashMap[String, String]()
+    val stateCheckpointMarker = KafkaStateCheckpointMarker.serialize(new 
KafkaStateCheckpointMarker(changelogSSP, "5"))
+    stateCheckpointMarkers.put("storeName", stateCheckpointMarker)
+    
when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+
+    val snapshotSCMs = 
ImmutableMap.of(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, 
stateCheckpointMarkers)
+    when(this.taskCommitManager.snapshot(any())).thenReturn(snapshotSCMs)
+    val snapshotSCMFuture: CompletableFuture[util.Map[String, util.Map[String, 
String]]] =
+      CompletableFuture.completedFuture(snapshotSCMs)
+
+    when(this.taskCommitManager.upload(any(), 
Matchers.eq(snapshotSCMs))).thenReturn(snapshotSCMFuture) // kafka is no-op
+
+    val cleanUpFuture = new CompletableFuture[Void]()
+    when(this.taskCommitManager.cleanUp(any(), 
any())).thenReturn(cleanUpFuture)
+
+    // use a separate executor to perform async operations on to test caller 
thread blocking behavior
+    val taskConfigsMap = new util.HashMap[String, String]()
+    taskConfigsMap.put("task.commit.ms", "-1")
+    // "block" immediately if previous commit async stage not complete
+    taskConfigsMap.put("task.commit.max.delay.ms", "-1")
+    taskConfigsMap.put("task.commit.timeout.ms", "0") // throw exception 
immediately if blocked
+    taskConfigsMap.put("task.commit.skip.commit.during.failures.enabled", 
"true")
+    // should throw exception if second commit timeout occurs
+    taskConfigsMap.put("task.commit.skip.commit.timeout.max.limit", "1")
+    when(this.jobContext.getConfig).thenReturn(new MapConfig(taskConfigsMap)) 
// override default behavior
+
+    setupTaskInstance(None, ForkJoinPool.commonPool())
+
+    taskInstance.commit // async stage will not complete until cleanUpFuture 
is completed
+    taskInstance.commit // second commit found commit timeout and release the 
semaphore
+
+    verifyZeroInteractions(commitExceptionCounter)
+    verifyZeroInteractions(skippedCounter)
+    verify(commitsTimedOutCounter).inc()
+    verify(commitsCounter, times(1)).inc() // should only have been 
incremented once now - second commit was skipped
+    taskInstance.commit // third commit should proceed without any issues and 
acquire the semaphore
+    try {
+      taskInstance.commit // fourth commit should throw exception as the 
timeout count reached the max limit
+      fail("Should have thrown an exception due to exceeding timeout limit.")
+    } catch {
+      case e: Exception =>
+        // expected
+    }
+    verify(commitsTimedOutCounter, times(2)).inc() // incremented twice 
(second and fourth commit)
+    verify(commitsCounter, times(2)).inc() // incremented twice (first and 
third commit)
+    cleanUpFuture.complete(null) // just to unblock shared executor
+  }
+
 
   /**
     * Given that no application task context factory is provided, then no 
lifecycle calls should be made.
@@ -1091,6 +1303,17 @@ class TestTaskInstance extends AssertionsForJUnit with 
MockitoSugar {
       externalContextOption = Some(this.externalContext), elasticityFactor = 
elasticityFactor)
   }
 
+  private def waitForCommitExceptionIsSet(sleepTimeInMs: Int, maxRetry: Int): 
Unit = {
+    var retries = 0
+    while (taskInstance.commitException.get() == null && retries < maxRetry) {
+      retries += 1
+      Thread.sleep(sleepTimeInMs)
+    }
+    if (taskInstance.commitException.get() == null) {
+      fail("Should have set the commit exception.")
+    }
+  }
+
   /**
     * Task type which has all task traits, which can be mocked.
     */

Reply via email to