This is an automated email from the ASF dual-hosted git repository.

kkarantasis pushed a commit to branch 3.1
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.1 by this push:
     new 87abbb2  KAFKA-12487: Add support for cooperative consumer protocol 
with sink connectors (#10563)
87abbb2 is described below

commit 87abbb2447f5e4e23919102bf78d43ff68272ed2
Author: Chris Egerton <chr...@confluent.io>
AuthorDate: Wed Nov 10 14:14:50 2021 -0500

    KAFKA-12487: Add support for cooperative consumer protocol with sink 
connectors (#10563)
    
    Currently, the `WorkerSinkTask`'s consumer rebalance listener (and related 
logic) is hardcoded to assume eager rebalancing, which means that all 
partitions are revoked any time a rebalance occurs and then the set of 
partitions included in `onPartitionsAssigned` is assumed to be the complete 
assignment for the task. Not only does this cause failures when the cooperative 
consumer protocol is used, it fails to take advantage of the benefits provided 
by that protocol.
    
    These changes alter framework logic to not only not break when the 
cooperative consumer protocol is used for a sink connector, but to reap the 
benefits of it as well, by not revoking partitions unnecessarily from tasks 
just to reopen them immediately after the rebalance has completed.
    
    This change will be necessary in order to support 
[KIP-726](https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=177048248),
 which currently proposes that the default consumer partition assignor be 
changed to the `CooperativeStickyAssignor`.
    
    Two integration tests are added to verify sink task behavior with both 
eager and cooperative consumer protocols, and new and existing unit tests are 
adopted as well.
    
    Reviewers: Nigel Liang <ni...@nigelliang.com>, Konstantine Karantasis 
<k.karanta...@gmail.com>
---
 .../kafka/connect/runtime/WorkerSinkTask.java      | 169 +++++++----
 .../runtime/errors/WorkerErrantRecordReporter.java |  53 +++-
 .../integration/ErrantRecordSinkConnector.java     |   2 +-
 .../integration/ErrorHandlingIntegrationTest.java  |   2 +-
 .../integration/ExampleConnectIntegrationTest.java |   2 +-
 .../integration/MonitorableSinkConnector.java      |  40 ++-
 .../integration/SinkConnectorsIntegrationTest.java | 321 +++++++++++++++++++++
 .../kafka/connect/integration/TaskHandle.java      | 123 +++++++-
 .../integration/TransformationIntegrationTest.java |   2 +-
 .../kafka/connect/runtime/WorkerSinkTaskTest.java  | 311 ++++++++++++++++----
 .../runtime/WorkerSinkTaskThreadedTest.java        |  19 +-
 .../errors/WorkerErrantRecordReporterTest.java     |  13 +-
 .../util/clusters/EmbeddedKafkaCluster.java        |  13 +
 13 files changed, 907 insertions(+), 163 deletions(-)

diff --git 
a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTask.java
 
b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTask.java
index ed7ad73..7992c65 100644
--- 
a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTask.java
+++ 
b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTask.java
@@ -61,6 +61,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.regex.Pattern;
+import java.util.stream.Collectors;
 
 import static java.util.Collections.singleton;
 import static 
org.apache.kafka.connect.runtime.WorkerConfig.TOPIC_TRACKING_ENABLE_CONFIG;
@@ -125,6 +126,7 @@ class WorkerSinkTask extends WorkerTask {
         this.headerConverter = headerConverter;
         this.transformationChain = transformationChain;
         this.messageBatch = new ArrayList<>();
+        this.lastCommittedOffsets = new HashMap<>();
         this.currentOffsets = new HashMap<>();
         this.origOffsets = new HashMap<>();
         this.pausedForRedelivery = false;
@@ -196,7 +198,7 @@ class WorkerSinkTask extends WorkerTask {
         log.info("{} Executing sink task", this);
         // Make sure any uncommitted data has been committed and the task has
         // a chance to clean up its state
-        try (UncheckedCloseable suppressible = this::closePartitions) {
+        try (UncheckedCloseable suppressible = this::closeAllPartitions) {
             while (!isStopping())
                 iteration();
         } catch (WakeupException e) {
@@ -368,13 +370,22 @@ class WorkerSinkTask extends WorkerTask {
     }
 
     private void commitOffsets(long now, boolean closing) {
+        commitOffsets(now, closing, consumer.assignment());
+    }
+
+    private void commitOffsets(long now, boolean closing, 
Collection<TopicPartition> topicPartitions) {
+        log.trace("Committing offsets for partitions {}", topicPartitions);
         if (workerErrantRecordReporter != null) {
-            log.trace("Awaiting all reported errors to be completed");
-            workerErrantRecordReporter.awaitAllFutures();
-            log.trace("Completed all reported errors");
+            log.trace("Awaiting reported errors to be completed");
+            workerErrantRecordReporter.awaitFutures(topicPartitions);
+            log.trace("Completed reported errors");
         }
 
-        if (currentOffsets.isEmpty())
+        Map<TopicPartition, OffsetAndMetadata> offsetsToCommit = 
currentOffsets.entrySet().stream()
+            .filter(e -> topicPartitions.contains(e.getKey()))
+            .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+
+        if (offsetsToCommit.isEmpty())
             return;
 
         committing = true;
@@ -382,28 +393,31 @@ class WorkerSinkTask extends WorkerTask {
         commitStarted = now;
         sinkTaskMetricsGroup.recordOffsetSequenceNumber(commitSeqno);
 
+        Map<TopicPartition, OffsetAndMetadata> 
lastCommittedOffsetsForPartitions = 
this.lastCommittedOffsets.entrySet().stream()
+            .filter(e -> offsetsToCommit.containsKey(e.getKey()))
+            .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+
         final Map<TopicPartition, OffsetAndMetadata> taskProvidedOffsets;
         try {
-            log.trace("{} Calling task.preCommit with current offsets: {}", 
this, currentOffsets);
-            taskProvidedOffsets = task.preCommit(new 
HashMap<>(currentOffsets));
+            log.trace("{} Calling task.preCommit with current offsets: {}", 
this, offsetsToCommit);
+            taskProvidedOffsets = task.preCommit(new 
HashMap<>(offsetsToCommit));
         } catch (Throwable t) {
             if (closing) {
                 log.warn("{} Offset commit failed during close", this);
-                onCommitCompleted(t, commitSeqno, null);
             } else {
                 log.error("{} Offset commit failed, rewinding to last 
committed offsets", this, t);
-                for (Map.Entry<TopicPartition, OffsetAndMetadata> entry : 
lastCommittedOffsets.entrySet()) {
+                for (Map.Entry<TopicPartition, OffsetAndMetadata> entry : 
lastCommittedOffsetsForPartitions.entrySet()) {
                     log.debug("{} Rewinding topic partition {} to offset {}", 
this, entry.getKey(), entry.getValue().offset());
                     consumer.seek(entry.getKey(), entry.getValue().offset());
                 }
-                currentOffsets = new HashMap<>(lastCommittedOffsets);
-                onCommitCompleted(t, commitSeqno, null);
+                currentOffsets.putAll(lastCommittedOffsetsForPartitions);
             }
+            onCommitCompleted(t, commitSeqno, null);
             return;
         } finally {
             if (closing) {
-                log.trace("{} Closing the task before committing the offsets: 
{}", this, currentOffsets);
-                task.close(currentOffsets.keySet());
+                log.trace("{} Closing the task before committing the offsets: 
{}", this, offsetsToCommit);
+                task.close(topicPartitions);
             }
         }
 
@@ -413,32 +427,36 @@ class WorkerSinkTask extends WorkerTask {
             return;
         }
 
-        final Map<TopicPartition, OffsetAndMetadata> commitableOffsets = new 
HashMap<>(lastCommittedOffsets);
+        Collection<TopicPartition> allAssignedTopicPartitions = 
consumer.assignment();
+        final Map<TopicPartition, OffsetAndMetadata> committableOffsets = new 
HashMap<>(lastCommittedOffsetsForPartitions);
         for (Map.Entry<TopicPartition, OffsetAndMetadata> 
taskProvidedOffsetEntry : taskProvidedOffsets.entrySet()) {
             final TopicPartition partition = taskProvidedOffsetEntry.getKey();
             final OffsetAndMetadata taskProvidedOffset = 
taskProvidedOffsetEntry.getValue();
-            if (commitableOffsets.containsKey(partition)) {
+            if (committableOffsets.containsKey(partition)) {
                 long taskOffset = taskProvidedOffset.offset();
-                long currentOffset = currentOffsets.get(partition).offset();
+                long currentOffset = offsetsToCommit.get(partition).offset();
                 if (taskOffset <= currentOffset) {
-                    commitableOffsets.put(partition, taskProvidedOffset);
+                    committableOffsets.put(partition, taskProvidedOffset);
                 } else {
                     log.warn("{} Ignoring invalid task provided offset {}/{} 
-- not yet consumed, taskOffset={} currentOffset={}",
-                            this, partition, taskProvidedOffset, taskOffset, 
currentOffset);
+                        this, partition, taskProvidedOffset, taskOffset, 
currentOffset);
                 }
-            } else {
+            } else if (!allAssignedTopicPartitions.contains(partition)) {
                 log.warn("{} Ignoring invalid task provided offset {}/{} -- 
partition not assigned, assignment={}",
-                        this, partition, taskProvidedOffset, 
consumer.assignment());
+                        this, partition, taskProvidedOffset, 
allAssignedTopicPartitions);
+            } else {
+                log.debug("{} Ignoring task provided offset {}/{} -- partition 
not requested, requested={}",
+                        this, partition, taskProvidedOffset, 
committableOffsets.keySet());
             }
         }
 
-        if (commitableOffsets.equals(lastCommittedOffsets)) {
+        if (committableOffsets.equals(lastCommittedOffsetsForPartitions)) {
             log.debug("{} Skipping offset commit, no change since last 
commit", this);
             onCommitCompleted(null, commitSeqno, null);
             return;
         }
 
-        doCommit(commitableOffsets, closing, commitSeqno);
+        doCommit(committableOffsets, closing, commitSeqno);
     }
 
 
@@ -579,10 +597,12 @@ class WorkerSinkTask extends WorkerTask {
             }
         } catch (RetriableException e) {
             log.error("{} RetriableException from SinkTask:", this, e);
-            // If we're retrying a previous batch, make sure we've paused all 
topic partitions so we don't get new data,
-            // but will still be able to poll in order to handle 
user-requested timeouts, keep group membership, etc.
-            pausedForRedelivery = true;
-            pauseAll();
+            if (!pausedForRedelivery) {
+                // If we're retrying a previous batch, make sure we've paused 
all topic partitions so we don't get new data,
+                // but will still be able to poll in order to handle 
user-requested timeouts, keep group membership, etc.
+                pausedForRedelivery = true;
+                pauseAll();
+            }
             // Let this exit normally, the batch will be reprocessed on the 
next loop.
         } catch (Throwable t) {
             log.error("{} Task threw an uncaught and unrecoverable exception. 
Task is being killed and will not "
@@ -612,13 +632,32 @@ class WorkerSinkTask extends WorkerTask {
     }
 
     private void openPartitions(Collection<TopicPartition> partitions) {
-        sinkTaskMetricsGroup.recordPartitionCount(partitions.size());
+        updatePartitionCount();
         task.open(partitions);
     }
 
-    private void closePartitions() {
-        commitOffsets(time.milliseconds(), true);
-        sinkTaskMetricsGroup.recordPartitionCount(0);
+    private void closeAllPartitions() {
+        closePartitions(currentOffsets.keySet(), false);
+    }
+
+    private void closePartitions(Collection<TopicPartition> topicPartitions, 
boolean lost) {
+        if (!lost) {
+            commitOffsets(time.milliseconds(), true, topicPartitions);
+        } else {
+            log.trace("{} Closing the task as partitions have been lost: {}", 
this, topicPartitions);
+            task.close(topicPartitions);
+            if (workerErrantRecordReporter != null) {
+                log.trace("Cancelling reported errors for {}", 
topicPartitions);
+                workerErrantRecordReporter.cancelFutures(topicPartitions);
+                log.trace("Cancelled all reported errors for {}", 
topicPartitions);
+            }
+            topicPartitions.forEach(currentOffsets::remove);
+        }
+        updatePartitionCount();
+    }
+
+    private void updatePartitionCount() {
+        
sinkTaskMetricsGroup.recordPartitionCount(consumer.assignment().size());
     }
 
     @Override
@@ -651,8 +690,7 @@ class WorkerSinkTask extends WorkerTask {
         @Override
         public void onPartitionsAssigned(Collection<TopicPartition> 
partitions) {
             log.debug("{} Partitions assigned {}", WorkerSinkTask.this, 
partitions);
-            lastCommittedOffsets = new HashMap<>();
-            currentOffsets = new HashMap<>();
+
             for (TopicPartition tp : partitions) {
                 long pos = consumer.position(tp);
                 lastCommittedOffsets.put(tp, new OffsetAndMetadata(pos));
@@ -661,17 +699,29 @@ class WorkerSinkTask extends WorkerTask {
             }
             sinkTaskMetricsGroup.assignedOffsets(currentOffsets);
 
-            // If we paused everything for redelivery (which is no longer 
relevant since we discarded the data), make
-            // sure anything we paused that the task didn't request to be 
paused *and* which we still own is resumed.
-            // Also make sure our tracking of paused partitions is updated to 
remove any partitions we no longer own.
-            pausedForRedelivery = false;
-
-            // Ensure that the paused partitions contains only assigned 
partitions and repause as necessary
-            context.pausedPartitions().retainAll(partitions);
-            if (shouldPause())
+            boolean wasPausedForRedelivery = pausedForRedelivery;
+            pausedForRedelivery = wasPausedForRedelivery && 
!messageBatch.isEmpty();
+            if (pausedForRedelivery) {
+                // Re-pause here in case we picked up new partitions in the 
rebalance
                 pauseAll();
-            else if (!context.pausedPartitions().isEmpty())
-                consumer.pause(context.pausedPartitions());
+            } else {
+                // If we paused everything for redelivery and all partitions 
for the failed deliveries have been revoked, make
+                // sure anything we paused that the task didn't request to be 
paused *and* which we still own is resumed.
+                // Also make sure our tracking of paused partitions is updated 
to remove any partitions we no longer own.
+                if (wasPausedForRedelivery) {
+                    resumeAll();
+                }
+                // Ensure that the paused partitions contains only assigned 
partitions and repause as necessary
+                context.pausedPartitions().retainAll(consumer.assignment());
+                if (shouldPause())
+                    pauseAll();
+                else if (!context.pausedPartitions().isEmpty())
+                    consumer.pause(context.pausedPartitions());
+            }
+
+            if (partitions.isEmpty()) {
+                return;
+            }
 
             // Instead of invoking the assignment callback on initialization, 
we guarantee the consumer is ready upon
             // task start. Since this callback gets invoked during that 
initial setup before we've started the task, we
@@ -691,22 +741,35 @@ class WorkerSinkTask extends WorkerTask {
 
         @Override
         public void onPartitionsRevoked(Collection<TopicPartition> partitions) 
{
+            onPartitionsRemoved(partitions, false);
+        }
+
+        @Override
+        public void onPartitionsLost(Collection<TopicPartition> partitions) {
+            onPartitionsRemoved(partitions, true);
+        }
+
+        private void onPartitionsRemoved(Collection<TopicPartition> 
partitions, boolean lost) {
             if (taskStopped) {
                 log.trace("Skipping partition revocation callback as task has 
already been stopped");
                 return;
             }
-            log.debug("{} Partitions revoked", WorkerSinkTask.this);
+            log.debug("{} Partitions {}: {}", WorkerSinkTask.this, lost ? 
"lost" : "revoked", partitions);
+
+            if (partitions.isEmpty())
+                return;
+
             try {
-                closePartitions();
-                sinkTaskMetricsGroup.clearOffsets();
+                closePartitions(partitions, lost);
+                sinkTaskMetricsGroup.clearOffsets(partitions);
             } catch (RuntimeException e) {
                 // The consumer swallows exceptions raised in the rebalance 
listener, so we need to store
                 // exceptions and rethrow when poll() returns.
                 rebalanceException = e;
             }
 
-            // Make sure we don't have any leftover data since offsets will be 
reset to committed positions
-            messageBatch.clear();
+            // Make sure we don't have any leftover data since offsets for 
these partitions will be reset to committed positions
+            messageBatch.removeIf(record -> partitions.contains(new 
TopicPartition(record.topic(), record.kafkaPartition())));
         }
     }
 
@@ -825,13 +888,15 @@ class WorkerSinkTask extends WorkerTask {
         void assignedOffsets(Map<TopicPartition, OffsetAndMetadata> offsets) {
             consumedOffsets = new HashMap<>(offsets);
             committedOffsets = offsets;
-            sinkRecordActiveCount.record(0.0);
+            computeSinkRecordLag();
         }
 
-        void clearOffsets() {
-            consumedOffsets.clear();
-            committedOffsets.clear();
-            sinkRecordActiveCount.record(0.0);
+        void clearOffsets(Collection<TopicPartition> topicPartitions) {
+            topicPartitions.forEach(tp -> {
+                consumedOffsets.remove(tp);
+                committedOffsets.remove(tp);
+            });
+            computeSinkRecordLag();
         }
 
         void recordOffsetCommitSuccess() {
diff --git 
a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/WorkerErrantRecordReporter.java
 
b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/WorkerErrantRecordReporter.java
index aa1d0be..ed48f79 100644
--- 
a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/WorkerErrantRecordReporter.java
+++ 
b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/WorkerErrantRecordReporter.java
@@ -18,6 +18,7 @@ package org.apache.kafka.connect.runtime.errors;
 
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.producer.RecordMetadata;
+import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.header.internals.RecordHeaders;
 import org.apache.kafka.connect.errors.ConnectException;
 import org.apache.kafka.connect.header.Header;
@@ -31,13 +32,18 @@ import org.apache.kafka.connect.storage.HeaderConverter;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.util.LinkedList;
+import java.util.ArrayList;
+import java.util.Collection;
 import java.util.List;
 import java.util.Optional;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
+import java.util.stream.Collectors;
 
 public class WorkerErrantRecordReporter implements ErrantRecordReporter {
 
@@ -49,7 +55,7 @@ public class WorkerErrantRecordReporter implements 
ErrantRecordReporter {
     private final HeaderConverter headerConverter;
 
     // Visible for testing
-    protected final LinkedList<Future<Void>> futures;
+    protected final ConcurrentMap<TopicPartition, List<Future<Void>>> futures;
 
     public WorkerErrantRecordReporter(
         RetryWithToleranceOperator retryWithToleranceOperator,
@@ -61,7 +67,7 @@ public class WorkerErrantRecordReporter implements 
ErrantRecordReporter {
         this.keyConverter = keyConverter;
         this.valueConverter = valueConverter;
         this.headerConverter = headerConverter;
-        this.futures = new LinkedList<>();
+        this.futures = new ConcurrentHashMap<>();
     }
 
     @Override
@@ -103,26 +109,49 @@ public class WorkerErrantRecordReporter implements 
ErrantRecordReporter {
         Future<Void> future = 
retryWithToleranceOperator.executeFailed(Stage.TASK_PUT, SinkTask.class, 
consumerRecord, error);
 
         if (!future.isDone()) {
-            futures.add(future);
+            TopicPartition partition = new 
TopicPartition(consumerRecord.topic(), consumerRecord.partition());
+            futures.computeIfAbsent(partition, p -> new 
ArrayList<>()).add(future);
         }
         return future;
     }
 
     /**
-     * Gets all futures returned by the sink records sent to Kafka by the 
errant
-     * record reporter. This function is intended to be used to block on all 
the errant record
-     * futures.
+     * Awaits the completion of all error reports for a given set of topic 
partitions
+     * @param topicPartitions the topic partitions to await reporter 
completion for
      */
-    public void awaitAllFutures() {
-        Future<?> future;
-        while ((future = futures.poll()) != null) {
+    public void awaitFutures(Collection<TopicPartition> topicPartitions) {
+        futuresFor(topicPartitions).forEach(future -> {
             try {
                 future.get();
             } catch (InterruptedException | ExecutionException e) {
-                log.error("Encountered an error while awaiting an errant 
record future's completion.");
+                log.error("Encountered an error while awaiting an errant 
record future's completion.", e);
                 throw new ConnectException(e);
             }
-        }
+        });
+    }
+
+    /**
+     * Cancels all active error reports for a given set of topic partitions
+     * @param topicPartitions the topic partitions to cancel reporting for
+     */
+    public void cancelFutures(Collection<TopicPartition> topicPartitions) {
+        futuresFor(topicPartitions).forEach(future -> {
+            try {
+                future.cancel(true);
+            } catch (Exception e) {
+                log.error("Encountered an error while cancelling an errant 
record future", e);
+                // No need to throw the exception here; it's enough to log an 
error message
+            }
+        });
+    }
+
+    // Removes and returns all futures for the given topic partitions from the 
set of currently-active futures
+    private Collection<Future<Void>> futuresFor(Collection<TopicPartition> 
topicPartitions) {
+        return topicPartitions.stream()
+                .map(futures::remove)
+                .filter(Objects::nonNull)
+                .flatMap(List::stream)
+                .collect(Collectors.toList());
     }
 
     /**
diff --git 
a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ErrantRecordSinkConnector.java
 
b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ErrantRecordSinkConnector.java
index 0fe2f88..251c67c 100644
--- 
a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ErrantRecordSinkConnector.java
+++ 
b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ErrantRecordSinkConnector.java
@@ -53,7 +53,7 @@ public class ErrantRecordSinkConnector extends 
MonitorableSinkConnector {
                 TopicPartition tp = cachedTopicPartitions
                     .computeIfAbsent(rec.topic(), v -> new HashMap<>())
                     .computeIfAbsent(rec.kafkaPartition(), v -> new 
TopicPartition(rec.topic(), rec.kafkaPartition()));
-                committedOffsets.put(tp, committedOffsets.getOrDefault(tp, 0L) 
+ 1);
+                committedOffsets.put(tp, committedOffsets.getOrDefault(tp, 0) 
+ 1);
                 reporter.report(rec, new Throwable());
             }
         }
diff --git 
a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ErrorHandlingIntegrationTest.java
 
b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ErrorHandlingIntegrationTest.java
index b6211ed..b3dd9a0 100644
--- 
a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ErrorHandlingIntegrationTest.java
+++ 
b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ErrorHandlingIntegrationTest.java
@@ -261,7 +261,7 @@ public class ErrorHandlingIntegrationTest {
         try {
             ConnectorStateInfo info = connect.connectorStatus(CONNECTOR_NAME);
             return info != null && info.tasks().size() == NUM_TASKS
-                    && 
connectorHandle.taskHandle(TASK_ID).partitionsAssigned() == 1;
+                    && 
connectorHandle.taskHandle(TASK_ID).numPartitionsAssigned() == 1;
         }  catch (Exception e) {
             // Log the exception and return that the partitions were not 
assigned
             log.error("Could not check connector state info.", e);
diff --git 
a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ExampleConnectIntegrationTest.java
 
b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ExampleConnectIntegrationTest.java
index 6f8d8a1..23a87c2 100644
--- 
a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ExampleConnectIntegrationTest.java
+++ 
b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ExampleConnectIntegrationTest.java
@@ -236,7 +236,7 @@ public class ExampleConnectIntegrationTest {
         try {
             ConnectorStateInfo info = connect.connectorStatus(CONNECTOR_NAME);
             return info != null && info.tasks().size() == NUM_TASKS
-                    && connectorHandle.tasks().stream().allMatch(th -> 
th.partitionsAssigned() == 1);
+                    && connectorHandle.tasks().stream().allMatch(th -> 
th.numPartitionsAssigned() == 1);
         } catch (Exception e) {
             // Log the exception and return that the partitions were not 
assigned
             log.error("Could not check connector state info.", e);
diff --git 
a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/MonitorableSinkConnector.java
 
b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/MonitorableSinkConnector.java
index 7b9afa4..5733199 100644
--- 
a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/MonitorableSinkConnector.java
+++ 
b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/MonitorableSinkConnector.java
@@ -29,10 +29,8 @@ import org.slf4j.LoggerFactory;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
 
 /**
  * A sink connector that is used in Apache Kafka integration tests to verify 
the behavior of the
@@ -91,12 +89,10 @@ public class MonitorableSinkConnector extends 
TestSinkConnector {
         private String connectorName;
         private String taskId;
         TaskHandle taskHandle;
-        Set<TopicPartition> assignments;
-        Map<TopicPartition, Long> committedOffsets;
+        Map<TopicPartition, Integer> committedOffsets;
         Map<String, Map<Integer, TopicPartition>> cachedTopicPartitions;
 
         public MonitorableSinkTask() {
-            this.assignments = new HashSet<>();
             this.committedOffsets = new HashMap<>();
             this.cachedTopicPartitions = new HashMap<>();
         }
@@ -117,9 +113,15 @@ public class MonitorableSinkConnector extends 
TestSinkConnector {
 
         @Override
         public void open(Collection<TopicPartition> partitions) {
-            log.debug("Opening {} partitions", partitions.size());
-            assignments.addAll(partitions);
-            taskHandle.partitionsAssigned(partitions.size());
+            log.debug("Opening partitions {}", partitions);
+            taskHandle.partitionsAssigned(partitions);
+        }
+
+        @Override
+        public void close(Collection<TopicPartition> partitions) {
+            log.debug("Closing partitions {}", partitions);
+            taskHandle.partitionsRevoked(partitions);
+            partitions.forEach(committedOffsets::remove);
         }
 
         @Override
@@ -129,26 +131,22 @@ public class MonitorableSinkConnector extends 
TestSinkConnector {
                 TopicPartition tp = cachedTopicPartitions
                         .computeIfAbsent(rec.topic(), v -> new HashMap<>())
                         .computeIfAbsent(rec.kafkaPartition(), v -> new 
TopicPartition(rec.topic(), rec.kafkaPartition()));
-                committedOffsets.put(tp, committedOffsets.getOrDefault(tp, 0L) 
+ 1);
+                committedOffsets.put(tp, committedOffsets.getOrDefault(tp, 0) 
+ 1);
                 log.trace("Task {} obtained record (key='{}' value='{}')", 
taskId, rec.key(), rec.value());
             }
         }
 
         @Override
         public Map<TopicPartition, OffsetAndMetadata> 
preCommit(Map<TopicPartition, OffsetAndMetadata> offsets) {
-            for (TopicPartition tp : assignments) {
-                Long recordsSinceLastCommit = committedOffsets.get(tp);
-                if (recordsSinceLastCommit == null) {
-                    log.warn("preCommit was called with topic-partition {} 
that is not included "
-                            + "in the assignments of this task {}", tp, 
assignments);
-                } else {
-                    taskHandle.commit(recordsSinceLastCommit.intValue());
-                    log.error("Forwarding to framework request to commit 
additional {} for {}",
-                            recordsSinceLastCommit, tp);
-                    taskHandle.commit((int) (long) recordsSinceLastCommit);
-                    committedOffsets.put(tp, 0L);
+            taskHandle.partitionsCommitted(offsets.keySet());
+            offsets.forEach((tp, offset) -> {
+                int recordsSinceLastCommit = committedOffsets.getOrDefault(tp, 
0);
+                if (recordsSinceLastCommit != 0) {
+                    taskHandle.commit(recordsSinceLastCommit);
+                    log.debug("Forwarding to framework request to commit {} 
records for {}", recordsSinceLastCommit, tp);
+                    committedOffsets.put(tp, 0);
                 }
-            }
+            });
             return offsets;
         }
 
diff --git 
a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/SinkConnectorsIntegrationTest.java
 
b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/SinkConnectorsIntegrationTest.java
new file mode 100644
index 0000000..a8bfbb2
--- /dev/null
+++ 
b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/SinkConnectorsIntegrationTest.java
@@ -0,0 +1,321 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.connect.integration;
+
+import org.apache.kafka.clients.consumer.CooperativeStickyAssignor;
+import org.apache.kafka.clients.consumer.RoundRobinAssignor;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.connect.sink.SinkRecord;
+import org.apache.kafka.connect.storage.StringConverter;
+import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster;
+import org.apache.kafka.test.IntegrationTest;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Properties;
+import java.util.Set;
+import java.util.function.Consumer;
+
+import static 
org.apache.kafka.clients.consumer.ConsumerConfig.DEFAULT_API_TIMEOUT_MS_CONFIG;
+import static 
org.apache.kafka.clients.consumer.ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG;
+import static 
org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG;
+import static 
org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX;
+import static 
org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG;
+import static 
org.apache.kafka.connect.runtime.SinkConnectorConfig.TOPICS_CONFIG;
+import static 
org.apache.kafka.connect.runtime.WorkerConfig.CONNECTOR_CLIENT_POLICY_CLASS_CONFIG;
+import static 
org.apache.kafka.connect.runtime.WorkerConfig.KEY_CONVERTER_CLASS_CONFIG;
+import static 
org.apache.kafka.connect.runtime.WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG;
+import static org.apache.kafka.test.TestUtils.waitForCondition;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Integration test for sink connectors
+ */
+@Category(IntegrationTest.class)
+public class SinkConnectorsIntegrationTest {
+
+    private static final int NUM_TASKS = 1;
+    private static final int NUM_WORKERS = 1;
+    private static final String CONNECTOR_NAME = 
"connect-integration-test-sink";
+    private static final long TASK_CONSUME_TIMEOUT_MS = 10_000L;
+
+    private EmbeddedConnectCluster connect;
+
+    @Before
+    public void setup() throws Exception {
+        Map<String, String> workerProps = new HashMap<>();
+        // permit all Kafka client overrides; required for testing different 
consumer partition assignment strategies
+        workerProps.put(CONNECTOR_CLIENT_POLICY_CLASS_CONFIG, "All");
+
+        // setup Kafka broker properties
+        Properties brokerProps = new Properties();
+        brokerProps.put("auto.create.topics.enable", "false");
+        brokerProps.put("delete.topic.enable", "true");
+
+        // build a Connect cluster backed by Kafka and Zk
+        connect = new EmbeddedConnectCluster.Builder()
+                .name("connect-cluster")
+                .numWorkers(NUM_WORKERS)
+                .workerProps(workerProps)
+                .brokerProps(brokerProps)
+                .build();
+        connect.start();
+        connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, 
"Initial group of workers did not start in time.");
+    }
+
+    @After
+    public void close() {
+        // delete connector handle
+        RuntimeHandles.get().deleteConnector(CONNECTOR_NAME);
+
+        // stop all Connect, Kafka and Zk threads.
+        connect.stop();
+    }
+
+    @Test
+    public void testEagerConsumerPartitionAssignment() throws Exception {
+        final String topic1 = "topic1", topic2 = "topic2", topic3 = "topic3";
+        final TopicPartition tp1 = new TopicPartition(topic1, 0), tp2 = new 
TopicPartition(topic2, 0), tp3 = new TopicPartition(topic3, 0);
+        final Collection<String> topics = Arrays.asList(topic1, topic2, 
topic3);
+
+        Map<String, String> connectorProps = 
baseSinkConnectorProps(String.join(",", topics));
+        // Need an eager assignor here; round robin is as good as any
+        connectorProps.put(
+            CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX + 
PARTITION_ASSIGNMENT_STRATEGY_CONFIG,
+            RoundRobinAssignor.class.getName());
+        // After deleting a topic, offset commits will fail for it; reduce the 
timeout here so that the test doesn't take forever to proceed past that point
+        connectorProps.put(
+            CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX + 
DEFAULT_API_TIMEOUT_MS_CONFIG,
+            "5000");
+
+        final Set<String> consumedRecordValues = new HashSet<>();
+        Consumer<SinkRecord> onPut = record -> assertTrue("Task received 
duplicate record from Connect", 
consumedRecordValues.add(Objects.toString(record.value())));
+        ConnectorHandle connector = 
RuntimeHandles.get().connectorHandle(CONNECTOR_NAME);
+        TaskHandle task = connector.taskHandle(CONNECTOR_NAME + "-0", onPut);
+
+        connect.configureConnector(CONNECTOR_NAME, connectorProps);
+        
connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME,
 NUM_TASKS, "Connector tasks did not start in time.");
+
+        // None of the topics has been created yet; the task shouldn't be 
assigned any partitions
+        assertEquals(0, task.numPartitionsAssigned());
+
+        Set<String> expectedRecordValues = new HashSet<>();
+        Set<TopicPartition> expectedAssignment = new HashSet<>();
+
+        connect.kafka().createTopic(topic1, 1);
+        expectedAssignment.add(tp1);
+        connect.kafka().produce(topic1, "t1v1");
+        expectedRecordValues.add("t1v1");
+
+        waitForCondition(
+            () -> expectedRecordValues.equals(consumedRecordValues),
+            TASK_CONSUME_TIMEOUT_MS,
+            "Task did not receive records in time");
+        assertEquals(1, task.timesAssigned(tp1));
+        assertEquals(0, task.timesRevoked(tp1));
+        assertEquals(expectedAssignment, task.assignment());
+
+        connect.kafka().createTopic(topic2, 1);
+        expectedAssignment.add(tp2);
+        connect.kafka().produce(topic2, "t2v1");
+        expectedRecordValues.add("t2v1");
+        connect.kafka().produce(topic2, "t1v2");
+        expectedRecordValues.add("t1v2");
+
+        waitForCondition(
+            () -> expectedRecordValues.equals(consumedRecordValues),
+            TASK_CONSUME_TIMEOUT_MS,
+            "Task did not receive records in time");
+        assertEquals(2, task.timesAssigned(tp1));
+        assertEquals(1, task.timesRevoked(tp1));
+        assertEquals(1, task.timesCommitted(tp1));
+        assertEquals(1, task.timesAssigned(tp2));
+        assertEquals(0, task.timesRevoked(tp2));
+        assertEquals(expectedAssignment, task.assignment());
+
+        connect.kafka().createTopic(topic3, 1);
+        expectedAssignment.add(tp3);
+        connect.kafka().produce(topic3, "t3v1");
+        expectedRecordValues.add("t3v1");
+        connect.kafka().produce(topic2, "t2v2");
+        expectedRecordValues.add("t2v2");
+        connect.kafka().produce(topic2, "t1v3");
+        expectedRecordValues.add("t1v3");
+
+        expectedAssignment.add(tp3);
+        waitForCondition(
+            () -> expectedRecordValues.equals(consumedRecordValues),
+            TASK_CONSUME_TIMEOUT_MS,
+            "Task did not receive records in time");
+        assertEquals(3, task.timesAssigned(tp1));
+        assertEquals(2, task.timesRevoked(tp1));
+        assertEquals(2, task.timesCommitted(tp1));
+        assertEquals(2, task.timesAssigned(tp2));
+        assertEquals(1, task.timesRevoked(tp2));
+        assertEquals(1, task.timesCommitted(tp2));
+        assertEquals(1, task.timesAssigned(tp3));
+        assertEquals(0, task.timesRevoked(tp3));
+        assertEquals(expectedAssignment, task.assignment());
+
+        connect.kafka().deleteTopic(topic1);
+        expectedAssignment.remove(tp1);
+        connect.kafka().produce(topic3, "t3v2");
+        expectedRecordValues.add("t3v2");
+        connect.kafka().produce(topic2, "t2v3");
+        expectedRecordValues.add("t2v3");
+
+        waitForCondition(
+            () -> expectedRecordValues.equals(consumedRecordValues) && 
expectedAssignment.equals(task.assignment()),
+            TASK_CONSUME_TIMEOUT_MS,
+            "Timed out while waiting for task to receive records and updated 
topic partition assignment");
+        assertEquals(3, task.timesAssigned(tp1));
+        assertEquals(3, task.timesRevoked(tp1));
+        assertEquals(3, task.timesCommitted(tp1));
+        assertEquals(3, task.timesAssigned(tp2));
+        assertEquals(2, task.timesRevoked(tp2));
+        assertEquals(2, task.timesCommitted(tp2));
+        assertEquals(2, task.timesAssigned(tp3));
+        assertEquals(1, task.timesRevoked(tp3));
+        assertEquals(1, task.timesCommitted(tp3));
+    }
+
+    @Test
+    public void testCooperativeConsumerPartitionAssignment() throws Exception {
+        final String topic1 = "topic1", topic2 = "topic2", topic3 = "topic3";
+        final TopicPartition tp1 = new TopicPartition(topic1, 0), tp2 = new 
TopicPartition(topic2, 0), tp3 = new TopicPartition(topic3, 0);
+        final Collection<String> topics = Arrays.asList(topic1, topic2, 
topic3);
+
+        Map<String, String> connectorProps = 
baseSinkConnectorProps(String.join(",", topics));
+        // Need an eager assignor here; round robin is as good as any
+        connectorProps.put(
+                CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX + 
PARTITION_ASSIGNMENT_STRATEGY_CONFIG,
+                CooperativeStickyAssignor.class.getName());
+        // After deleting a topic, offset commits will fail for it; reduce the 
timeout here so that the test doesn't take forever to proceed past that point
+        connectorProps.put(
+                CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX + 
DEFAULT_API_TIMEOUT_MS_CONFIG,
+                "5000");
+
+        final Set<String> consumedRecordValues = new HashSet<>();
+        Consumer<SinkRecord> onPut = record -> assertTrue("Task received 
duplicate record from Connect", 
consumedRecordValues.add(Objects.toString(record.value())));
+        ConnectorHandle connector = 
RuntimeHandles.get().connectorHandle(CONNECTOR_NAME);
+        TaskHandle task = connector.taskHandle(CONNECTOR_NAME + "-0", onPut);
+
+        connect.configureConnector(CONNECTOR_NAME, connectorProps);
+        
connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME,
 NUM_TASKS, "Connector tasks did not start in time.");
+
+        // None of the topics has been created yet; the task shouldn't be 
assigned any partitions
+        assertEquals(0, task.numPartitionsAssigned());
+
+        Set<String> expectedRecordValues = new HashSet<>();
+        Set<TopicPartition> expectedAssignment = new HashSet<>();
+
+        connect.kafka().createTopic(topic1, 1);
+        expectedAssignment.add(tp1);
+        connect.kafka().produce(topic1, "t1v1");
+        expectedRecordValues.add("t1v1");
+
+        waitForCondition(
+            () -> expectedRecordValues.equals(consumedRecordValues),
+            TASK_CONSUME_TIMEOUT_MS,
+            "Task did not receive records in time");
+        assertEquals(1, task.timesAssigned(tp1));
+        assertEquals(0, task.timesRevoked(tp1));
+        assertEquals(expectedAssignment, task.assignment());
+
+        connect.kafka().createTopic(topic2, 1);
+        expectedAssignment.add(tp2);
+        connect.kafka().produce(topic2, "t2v1");
+        expectedRecordValues.add("t2v1");
+        connect.kafka().produce(topic2, "t1v2");
+        expectedRecordValues.add("t1v2");
+
+        waitForCondition(
+            () -> expectedRecordValues.equals(consumedRecordValues),
+            TASK_CONSUME_TIMEOUT_MS,
+            "Task did not receive records in time");
+        assertEquals(1, task.timesAssigned(tp1));
+        assertEquals(0, task.timesRevoked(tp1));
+        assertEquals(0, task.timesCommitted(tp1));
+        assertEquals(1, task.timesAssigned(tp2));
+        assertEquals(0, task.timesRevoked(tp2));
+        assertEquals(expectedAssignment, task.assignment());
+
+        connect.kafka().createTopic(topic3, 1);
+        expectedAssignment.add(tp3);
+        connect.kafka().produce(topic3, "t3v1");
+        expectedRecordValues.add("t3v1");
+        connect.kafka().produce(topic2, "t2v2");
+        expectedRecordValues.add("t2v2");
+        connect.kafka().produce(topic2, "t1v3");
+        expectedRecordValues.add("t1v3");
+
+        expectedAssignment.add(tp3);
+        waitForCondition(
+            () -> expectedRecordValues.equals(consumedRecordValues),
+            TASK_CONSUME_TIMEOUT_MS,
+            "Task did not receive records in time");
+        assertEquals(1, task.timesAssigned(tp1));
+        assertEquals(0, task.timesRevoked(tp1));
+        assertEquals(0, task.timesCommitted(tp1));
+        assertEquals(1, task.timesAssigned(tp2));
+        assertEquals(0, task.timesRevoked(tp2));
+        assertEquals(0, task.timesCommitted(tp2));
+        assertEquals(1, task.timesAssigned(tp3));
+        assertEquals(0, task.timesRevoked(tp3));
+        assertEquals(expectedAssignment, task.assignment());
+
+        connect.kafka().deleteTopic(topic1);
+        expectedAssignment.remove(tp1);
+        connect.kafka().produce(topic3, "t3v2");
+        expectedRecordValues.add("t3v2");
+        connect.kafka().produce(topic2, "t2v3");
+        expectedRecordValues.add("t2v3");
+
+        waitForCondition(
+            () -> expectedRecordValues.equals(consumedRecordValues) && 
expectedAssignment.equals(task.assignment()),
+            TASK_CONSUME_TIMEOUT_MS,
+            "Timed out while waiting for task to receive records and updated 
topic partition assignment");
+        assertEquals(1, task.timesAssigned(tp1));
+        assertEquals(1, task.timesRevoked(tp1));
+        assertEquals(1, task.timesCommitted(tp1));
+        assertEquals(1, task.timesAssigned(tp2));
+        assertEquals(0, task.timesRevoked(tp2));
+        assertEquals(0, task.timesCommitted(tp2));
+        assertEquals(1, task.timesAssigned(tp3));
+        assertEquals(0, task.timesRevoked(tp3));
+        assertEquals(0, task.timesCommitted(tp3));
+    }
+
+    private Map<String, String> baseSinkConnectorProps(String topics) {
+        Map<String, String> props = new HashMap<>();
+        props.put(CONNECTOR_CLASS_CONFIG, 
MonitorableSinkConnector.class.getSimpleName());
+        props.put(TASKS_MAX_CONFIG, String.valueOf(NUM_TASKS));
+        props.put(TOPICS_CONFIG, topics);
+        props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName());
+        props.put(VALUE_CONVERTER_CLASS_CONFIG, 
StringConverter.class.getName());
+        return props;
+    }
+}
diff --git 
a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/TaskHandle.java
 
b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/TaskHandle.java
index acb8eb3..ab5b711 100644
--- 
a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/TaskHandle.java
+++ 
b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/TaskHandle.java
@@ -16,15 +16,19 @@
  */
 package org.apache.kafka.connect.integration;
 
+import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.connect.errors.DataException;
 import org.apache.kafka.connect.sink.SinkRecord;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.Collection;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Consumer;
+import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
 /**
@@ -37,7 +41,7 @@ public class TaskHandle {
 
     private final String taskId;
     private final ConnectorHandle connectorHandle;
-    private final AtomicInteger partitionsAssigned = new AtomicInteger(0);
+    private final ConcurrentMap<TopicPartition, PartitionHistory> partitions = 
new ConcurrentHashMap<>();
     private final StartAndStopCounter startAndStopCounter = new 
StartAndStopCounter();
     private final Consumer<SinkRecord> consumer;
 
@@ -128,19 +132,74 @@ public class TaskHandle {
     }
 
     /**
-     * Set the number of partitions assigned to this task.
+     * Adds a set of partitions to the (sink) task's assignment
      *
-     * @param numPartitions number of partitions
+     * @param partitions the newly-assigned partitions
      */
-    public void partitionsAssigned(int numPartitions) {
-        partitionsAssigned.set(numPartitions);
+    public void partitionsAssigned(Collection<TopicPartition> partitions) {
+        partitions.forEach(partition -> 
this.partitions.computeIfAbsent(partition, PartitionHistory::new).assigned());
     }
 
     /**
-     * @return the number of topic partitions assigned to this task.
+     * Removes a set of partitions to the (sink) task's assignment
+     *
+     * @param partitions the newly-revoked partitions
+     */
+    public void partitionsRevoked(Collection<TopicPartition> partitions) {
+        partitions.forEach(partition -> 
this.partitions.computeIfAbsent(partition, PartitionHistory::new).revoked());
+    }
+
+    /**
+     * Records offset commits for a (sink) task's partitions
+     *
+     * @param partitions the committed partitions
+     */
+    public void partitionsCommitted(Collection<TopicPartition> partitions) {
+        partitions.forEach(partition -> 
this.partitions.computeIfAbsent(partition, PartitionHistory::new).committed());
+    }
+
+    /**
+     * @return the complete set of partitions currently assigned to this 
(sink) task
+     */
+    public Collection<TopicPartition> assignment() {
+        return partitions.values().stream()
+                .filter(PartitionHistory::isAssigned)
+                .map(PartitionHistory::topicPartition)
+                .collect(Collectors.toSet());
+    }
+
+    /**
+     * @return the number of topic partitions assigned to this (sink) task.
+     */
+    public int numPartitionsAssigned() {
+        return assignment().size();
+    }
+
+    /**
+     * Returns the number of times the partition has been assigned to this 
(sink) task.
+     * @param partition the partition
+     * @return the number of times it has been assigned; may be 0 if never 
assigned
+     */
+    public int timesAssigned(TopicPartition partition) {
+        return partitions.computeIfAbsent(partition, 
PartitionHistory::new).timesAssigned();
+    }
+
+    /**
+     * Returns the number of times the partition has been revoked from this 
(sink) task.
+     * @param partition the partition
+     * @return the number of times it has been revoked; may be 0 if never 
revoked
+     */
+    public int timesRevoked(TopicPartition partition) {
+        return partitions.computeIfAbsent(partition, 
PartitionHistory::new).timesRevoked();
+    }
+
+    /**
+     * Returns the number of times the framework has committed offsets for 
this partition
+     * @param partition the partition
+     * @return the number of times it has been committed; may be 0 if never 
committed
      */
-    public int partitionsAssigned() {
-        return partitionsAssigned.get();
+    public int timesCommitted(TopicPartition partition) {
+        return partitions.computeIfAbsent(partition, 
PartitionHistory::new).timesCommitted();
     }
 
     /**
@@ -266,4 +325,50 @@ public class TaskHandle {
                 "taskId='" + taskId + '\'' +
                 '}';
     }
+
+    private static class PartitionHistory {
+        private final TopicPartition topicPartition;
+        private boolean assigned = false;
+        private int timesAssigned = 0;
+        private int timesRevoked = 0;
+        private int timesCommitted = 0;
+
+        public PartitionHistory(TopicPartition topicPartition) {
+            this.topicPartition = topicPartition;
+        }
+
+        public void assigned() {
+            timesAssigned++;
+            assigned = true;
+        }
+
+        public void revoked() {
+            timesRevoked++;
+            assigned = false;
+        }
+
+        public void committed() {
+            timesCommitted++;
+        }
+
+        public TopicPartition topicPartition() {
+            return topicPartition;
+        }
+
+        public boolean isAssigned() {
+            return assigned;
+        }
+
+        public int timesAssigned() {
+            return timesAssigned;
+        }
+
+        public int timesRevoked() {
+            return timesRevoked;
+        }
+
+        public int timesCommitted() {
+            return timesCommitted;
+        }
+    }
 }
diff --git 
a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/TransformationIntegrationTest.java
 
b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/TransformationIntegrationTest.java
index 7b71f2f..02d8c7f 100644
--- 
a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/TransformationIntegrationTest.java
+++ 
b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/TransformationIntegrationTest.java
@@ -236,7 +236,7 @@ public class TransformationIntegrationTest {
         props.put(PREDICATES_CONFIG + ".barPredicate.type", 
RecordIsTombstone.class.getSimpleName());
 
         // expect only half the records to be consumed by the connector
-        connectorHandle.expectedCommits(numRecords);
+        connectorHandle.expectedCommits(numRecords / 2);
         connectorHandle.expectedRecords(numRecords / 2);
 
         // start a sink connector
diff --git 
a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskTest.java
 
b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskTest.java
index 7a2a6e4..08a0458 100644
--- 
a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskTest.java
+++ 
b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskTest.java
@@ -83,7 +83,9 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
 import java.util.regex.Pattern;
+import java.util.stream.Collectors;
 
 import static java.util.Arrays.asList;
 import static java.util.Collections.singleton;
@@ -117,6 +119,9 @@ public class WorkerSinkTaskTest {
     private static final TopicPartition TOPIC_PARTITION2 = new 
TopicPartition(TOPIC, PARTITION2);
     private static final TopicPartition TOPIC_PARTITION3 = new 
TopicPartition(TOPIC, PARTITION3);
 
+    private static final Set<TopicPartition> INITIAL_ASSIGNMENT =
+        new HashSet<>(Arrays.asList(TOPIC_PARTITION, TOPIC_PARTITION2));
+
     private static final Map<String, String> TASK_PROPS = new HashMap<>();
     static {
         TASK_PROPS.put(SinkConnector.TOPICS_CONFIG, TOPIC);
@@ -195,9 +200,8 @@ public class WorkerSinkTaskTest {
         expectTaskGetTopic(true);
         expectPollInitialAssignment();
 
-        Set<TopicPartition> partitions = new HashSet<>(asList(TOPIC_PARTITION, 
TOPIC_PARTITION2));
-        EasyMock.expect(consumer.assignment()).andReturn(partitions);
-        consumer.pause(partitions);
+        EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT);
+        consumer.pause(INITIAL_ASSIGNMENT);
         PowerMock.expectLastCall();
 
         PowerMock.replayAll();
@@ -229,14 +233,12 @@ public class WorkerSinkTaskTest {
         sinkTask.put(EasyMock.anyObject());
         EasyMock.expectLastCall();
 
-        Set<TopicPartition> partitions = new HashSet<>(asList(TOPIC_PARTITION, 
TOPIC_PARTITION2));
-
         // Pause
         statusListener.onPause(taskId);
         EasyMock.expectLastCall();
         expectConsumerWakeup();
-        EasyMock.expect(consumer.assignment()).andReturn(partitions);
-        consumer.pause(partitions);
+        EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT);
+        consumer.pause(INITIAL_ASSIGNMENT);
         PowerMock.expectLastCall();
 
         // Offset commit as requested when pausing; No records returned by 
consumer.poll()
@@ -250,11 +252,11 @@ public class WorkerSinkTaskTest {
         statusListener.onResume(taskId);
         EasyMock.expectLastCall();
         expectConsumerWakeup();
-        EasyMock.expect(consumer.assignment()).andReturn(new 
HashSet<>(asList(TOPIC_PARTITION, TOPIC_PARTITION2)));
-        consumer.resume(singleton(TOPIC_PARTITION));
-        PowerMock.expectLastCall();
-        consumer.resume(singleton(TOPIC_PARTITION2));
-        PowerMock.expectLastCall();
+        
EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(2);
+        INITIAL_ASSIGNMENT.forEach(tp -> {
+            consumer.resume(Collections.singleton(tp));
+            PowerMock.expectLastCall();
+        });
 
         expectConsumerPoll(1);
         expectConversionAndTransformation(1);
@@ -334,11 +336,12 @@ public class WorkerSinkTaskTest {
         sinkTask.stop();
         PowerMock.expectLastCall();
 
+        EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT);
         // WorkerSinkTask::close
         consumer.close();
         PowerMock.expectLastCall().andAnswer(() -> {
             rebalanceListener.getValue().onPartitionsRevoked(
-                asList(TOPIC_PARTITION, TOPIC_PARTITION2)
+                INITIAL_ASSIGNMENT
             );
             return null;
         });
@@ -373,9 +376,8 @@ public class WorkerSinkTaskTest {
         sinkTask.put(EasyMock.capture(records));
         EasyMock.expectLastCall().andThrow(new RetriableException("retry"));
         // Pause
-        HashSet<TopicPartition> partitions = new 
HashSet<>(asList(TOPIC_PARTITION, TOPIC_PARTITION2));
-        EasyMock.expect(consumer.assignment()).andReturn(partitions);
-        consumer.pause(partitions);
+        EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT);
+        consumer.pause(INITIAL_ASSIGNMENT);
         PowerMock.expectLastCall();
 
         // Retry delivery should succeed
@@ -383,11 +385,11 @@ public class WorkerSinkTaskTest {
         sinkTask.put(EasyMock.capture(records));
         EasyMock.expectLastCall();
         // And unpause
-        EasyMock.expect(consumer.assignment()).andReturn(partitions);
-        consumer.resume(singleton(TOPIC_PARTITION));
-        PowerMock.expectLastCall();
-        consumer.resume(singleton(TOPIC_PARTITION2));
-        PowerMock.expectLastCall();
+        EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT);
+        INITIAL_ASSIGNMENT.forEach(tp -> {
+            consumer.resume(singleton(tp));
+            PowerMock.expectLastCall();
+        });
 
         PowerMock.replayAll();
 
@@ -434,6 +436,111 @@ public class WorkerSinkTaskTest {
     }
 
     @Test
+    public void testPollRedeliveryWithConsumerRebalance() throws Exception {
+        createTask(initialState);
+
+        expectInitializeTask();
+        expectTaskGetTopic(true);
+        expectPollInitialAssignment();
+
+        // If a retriable exception is thrown, we should redeliver the same 
batch, pausing the consumer in the meantime
+        expectConsumerPoll(1);
+        expectConversionAndTransformation(1);
+        sinkTask.put(EasyMock.anyObject());
+        EasyMock.expectLastCall().andThrow(new RetriableException("retry"));
+        // Pause
+        EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT);
+        consumer.pause(INITIAL_ASSIGNMENT);
+        PowerMock.expectLastCall();
+
+        // Empty consumer poll (all partitions are paused) with rebalance; one 
new partition is assigned
+        
EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer(
+            () -> {
+                
rebalanceListener.getValue().onPartitionsRevoked(Collections.emptySet());
+                
rebalanceListener.getValue().onPartitionsAssigned(Collections.singleton(TOPIC_PARTITION3));
+                return ConsumerRecords.empty();
+            });
+        Set<TopicPartition> newAssignment = new 
HashSet<>(Arrays.asList(TOPIC_PARTITION, TOPIC_PARTITION2, TOPIC_PARTITION3));
+        
EasyMock.expect(consumer.assignment()).andReturn(newAssignment).times(3);
+        
EasyMock.expect(consumer.position(TOPIC_PARTITION3)).andReturn(FIRST_OFFSET);
+        sinkTask.open(Collections.singleton(TOPIC_PARTITION3));
+        EasyMock.expectLastCall();
+        // All partitions are re-paused in order to pause any newly-assigned 
partitions so that redelivery efforts can continue
+        consumer.pause(newAssignment);
+        EasyMock.expectLastCall();
+        sinkTask.put(EasyMock.anyObject());
+        EasyMock.expectLastCall().andThrow(new RetriableException("retry"));
+
+        // Next delivery attempt fails again
+        expectConsumerPoll(0);
+        sinkTask.put(EasyMock.anyObject());
+        EasyMock.expectLastCall().andThrow(new RetriableException("retry"));
+
+        // Non-empty consumer poll; all initially-assigned partitions are 
revoked in rebalance, and new partitions are allowed to resume
+        ConsumerRecord<byte[], byte[]> newRecord = new ConsumerRecord<>(TOPIC, 
PARTITION3, FIRST_OFFSET, RAW_KEY, RAW_VALUE);
+        
EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer(
+            () -> {
+                
rebalanceListener.getValue().onPartitionsRevoked(INITIAL_ASSIGNMENT);
+                
rebalanceListener.getValue().onPartitionsAssigned(Collections.emptyList());
+                return new 
ConsumerRecords<>(Collections.singletonMap(TOPIC_PARTITION3, 
Collections.singletonList(newRecord)));
+            });
+        newAssignment = Collections.singleton(TOPIC_PARTITION3);
+        EasyMock.expect(consumer.assignment()).andReturn(new 
HashSet<>(newAssignment)).times(3);
+        final Map<TopicPartition, OffsetAndMetadata> offsets = 
INITIAL_ASSIGNMENT.stream()
+                .collect(Collectors.toMap(Function.identity(), tp -> new 
OffsetAndMetadata(FIRST_OFFSET)));
+        sinkTask.preCommit(offsets);
+        EasyMock.expectLastCall().andReturn(offsets);
+        sinkTask.close(INITIAL_ASSIGNMENT);
+        EasyMock.expectLastCall();
+        // All partitions are resumed, as all previously paused-for-redelivery 
partitions were revoked
+        newAssignment.forEach(tp -> {
+            consumer.resume(Collections.singleton(tp));
+            EasyMock.expectLastCall();
+        });
+        expectConversionAndTransformation(1);
+        sinkTask.put(EasyMock.anyObject());
+        EasyMock.expectLastCall();
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        workerTask.initializeAndStart();
+        workerTask.iteration();
+        workerTask.iteration();
+        workerTask.iteration();
+        workerTask.iteration();
+        workerTask.iteration();
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testErrorInRebalancePartitionLoss() throws Exception {
+        RuntimeException exception = new RuntimeException("Revocation error");
+
+        createTask(initialState);
+
+        expectInitializeTask();
+        expectTaskGetTopic(true);
+        expectPollInitialAssignment();
+        expectRebalanceLossError(exception);
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        workerTask.initializeAndStart();
+        workerTask.iteration();
+        try {
+            workerTask.iteration();
+            fail("Poll should have raised the rebalance exception");
+        } catch (RuntimeException e) {
+            assertEquals(exception, e);
+        }
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
     public void testErrorInRebalancePartitionRevocation() throws Exception {
         RuntimeException exception = new RuntimeException("Revocation error");
 
@@ -486,6 +593,74 @@ public class WorkerSinkTaskTest {
     }
 
     @Test
+    public void testPartialRevocationAndAssignment() throws Exception {
+        createTask(initialState);
+
+        expectInitializeTask();
+        expectTaskGetTopic(true);
+        expectPollInitialAssignment();
+
+        
EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer(
+            () -> {
+                
rebalanceListener.getValue().onPartitionsRevoked(Collections.singleton(TOPIC_PARTITION));
+                
rebalanceListener.getValue().onPartitionsAssigned(Collections.emptySet());
+                return ConsumerRecords.empty();
+            });
+        
EasyMock.expect(consumer.assignment()).andReturn(Collections.singleton(TOPIC_PARTITION)).times(2);
+        final Map<TopicPartition, OffsetAndMetadata> offsets = new HashMap<>();
+        offsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET));
+        sinkTask.preCommit(offsets);
+        EasyMock.expectLastCall().andReturn(offsets);
+        sinkTask.close(Collections.singleton(TOPIC_PARTITION));
+        EasyMock.expectLastCall();
+        sinkTask.put(Collections.emptyList());
+        EasyMock.expectLastCall();
+
+        
EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer(
+            () -> {
+                
rebalanceListener.getValue().onPartitionsRevoked(Collections.emptySet());
+                
rebalanceListener.getValue().onPartitionsAssigned(Collections.singleton(TOPIC_PARTITION3));
+                return ConsumerRecords.empty();
+            });
+        EasyMock.expect(consumer.assignment()).andReturn(new 
HashSet<>(Arrays.asList(TOPIC_PARTITION2, TOPIC_PARTITION3))).times(2);
+        
EasyMock.expect(consumer.position(TOPIC_PARTITION3)).andReturn(FIRST_OFFSET);
+        sinkTask.open(Collections.singleton(TOPIC_PARTITION3));
+        EasyMock.expectLastCall();
+        sinkTask.put(Collections.emptyList());
+        EasyMock.expectLastCall();
+
+        
EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer(
+            () -> {
+                
rebalanceListener.getValue().onPartitionsLost(Collections.singleton(TOPIC_PARTITION3));
+                
rebalanceListener.getValue().onPartitionsAssigned(Collections.singleton(TOPIC_PARTITION));
+                return ConsumerRecords.empty();
+            });
+        
EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(4);
+        sinkTask.close(Collections.singleton(TOPIC_PARTITION3));
+        EasyMock.expectLastCall();
+        
EasyMock.expect(consumer.position(TOPIC_PARTITION)).andReturn(FIRST_OFFSET);
+        sinkTask.open(Collections.singleton(TOPIC_PARTITION));
+        EasyMock.expectLastCall();
+        sinkTask.put(Collections.emptyList());
+        EasyMock.expectLastCall();
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        workerTask.initializeAndStart();
+        // First iteration--first call to poll, first consumer assignment
+        workerTask.iteration();
+        // Second iteration--second call to poll, partial consumer revocation
+        workerTask.iteration();
+        // Third iteration--third call to poll, partial consumer assignment
+        workerTask.iteration();
+        // Fourth iteration--fourth call to poll, one partition lost; can't 
commit offsets for it, one new partition assigned
+        workerTask.iteration();
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
     public void testWakeupInCommitSyncCausesRetry() throws Exception {
         createTask(initialState);
 
@@ -498,8 +673,6 @@ public class WorkerSinkTaskTest {
         sinkTask.put(EasyMock.anyObject());
         EasyMock.expectLastCall();
 
-        final List<TopicPartition> partitions = asList(TOPIC_PARTITION, 
TOPIC_PARTITION2);
-
         final Map<TopicPartition, OffsetAndMetadata> offsets = new HashMap<>();
         offsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1));
         offsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET));
@@ -514,29 +687,26 @@ public class WorkerSinkTaskTest {
         consumer.commitSync(EasyMock.<Map<TopicPartition, 
OffsetAndMetadata>>anyObject());
         EasyMock.expectLastCall();
 
-        sinkTask.close(new HashSet<>(partitions));
+        sinkTask.close(INITIAL_ASSIGNMENT);
         EasyMock.expectLastCall();
 
-        
EasyMock.expect(consumer.position(TOPIC_PARTITION)).andReturn(FIRST_OFFSET);
-        
EasyMock.expect(consumer.position(TOPIC_PARTITION2)).andReturn(FIRST_OFFSET);
+        INITIAL_ASSIGNMENT.forEach(tp -> 
EasyMock.expect(consumer.position(tp)).andReturn(FIRST_OFFSET));
 
-        sinkTask.open(partitions);
+        sinkTask.open(INITIAL_ASSIGNMENT);
         EasyMock.expectLastCall();
 
+        
EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(5);
         
EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer(
             () -> {
-                rebalanceListener.getValue().onPartitionsRevoked(partitions);
-                rebalanceListener.getValue().onPartitionsAssigned(partitions);
+                
rebalanceListener.getValue().onPartitionsRevoked(INITIAL_ASSIGNMENT);
+                
rebalanceListener.getValue().onPartitionsAssigned(INITIAL_ASSIGNMENT);
                 return ConsumerRecords.empty();
             });
 
-        EasyMock.expect(consumer.assignment()).andReturn(new 
HashSet<>(partitions));
-
-        consumer.resume(Collections.singleton(TOPIC_PARTITION));
-        EasyMock.expectLastCall();
-
-        consumer.resume(Collections.singleton(TOPIC_PARTITION2));
-        EasyMock.expectLastCall();
+        INITIAL_ASSIGNMENT.forEach(tp -> {
+            consumer.resume(Collections.singleton(tp));
+            EasyMock.expectLastCall();
+        });
 
         statusListener.onResume(taskId);
         EasyMock.expectLastCall();
@@ -600,6 +770,8 @@ public class WorkerSinkTaskTest {
         sinkTask.put(EasyMock.eq(Collections.emptyList()));
         EasyMock.expectLastCall();
 
+        
EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(1);
+
         final Map<TopicPartition, OffsetAndMetadata> offsets = new HashMap<>();
         offsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1));
         offsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET));
@@ -647,6 +819,8 @@ public class WorkerSinkTaskTest {
         sinkTask.preCommit(offsets);
         EasyMock.expectLastCall().andReturn(offsets);
 
+        
EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(2);
+
         final Capture<OffsetCommitCallback> callback = EasyMock.newCapture();
         consumer.commitAsync(EasyMock.eq(offsets), EasyMock.capture(callback));
         EasyMock.expectLastCall().andAnswer(() -> {
@@ -767,7 +941,7 @@ public class WorkerSinkTaskTest {
         sinkTask.preCommit(workerCurrentOffsets);
         EasyMock.expectLastCall().andReturn(taskOffsets);
         // Expect extra invalid topic partition to be filtered, which causes 
the consumer assignment to be logged
-        
EasyMock.expect(consumer.assignment()).andReturn(workerCurrentOffsets.keySet());
+        
EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(2);
         final Capture<OffsetCommitCallback> callback = EasyMock.newCapture();
         consumer.commitAsync(EasyMock.eq(committableOffsets), 
EasyMock.capture(callback));
         EasyMock.expectLastCall().andAnswer(() -> {
@@ -820,6 +994,8 @@ public class WorkerSinkTaskTest {
         workerCurrentOffsets.put(TOPIC_PARTITION, new 
OffsetAndMetadata(FIRST_OFFSET + 1));
         workerCurrentOffsets.put(TOPIC_PARTITION2, new 
OffsetAndMetadata(FIRST_OFFSET));
 
+        
EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(2);
+
         // iter 3
         sinkTask.preCommit(workerCurrentOffsets);
         EasyMock.expectLastCall().andReturn(workerStartingOffsets);
@@ -871,6 +1047,8 @@ public class WorkerSinkTaskTest {
         workerCurrentOffsets.put(TOPIC_PARTITION, new 
OffsetAndMetadata(FIRST_OFFSET + 1));
         workerCurrentOffsets.put(TOPIC_PARTITION2, new 
OffsetAndMetadata(FIRST_OFFSET));
 
+        
EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(2);
+
         // iter 3 - note that we return the current offset to indicate they 
should be committed
         sinkTask.preCommit(workerCurrentOffsets);
         EasyMock.expectLastCall().andReturn(workerCurrentOffsets);
@@ -944,6 +1122,8 @@ public class WorkerSinkTaskTest {
         expectInitializeTask();
         expectTaskGetTopic(true);
 
+        expectPollInitialAssignment();
+
         // Put one message through the task to get some offsets to commit
         expectConsumerPoll(1);
         expectConversionAndTransformation(1);
@@ -988,6 +1168,8 @@ public class WorkerSinkTaskTest {
         expectInitializeTask();
         expectTaskGetTopic(true);
 
+        expectPollInitialAssignment();
+
         // Put one message through the task to get some offsets to commit
         expectConsumerPoll(1);
         expectConversionAndTransformation(1);
@@ -1051,7 +1233,7 @@ public class WorkerSinkTaskTest {
         workerCurrentOffsets.put(TOPIC_PARTITION, new 
OffsetAndMetadata(FIRST_OFFSET + 1));
         workerCurrentOffsets.put(TOPIC_PARTITION2, new 
OffsetAndMetadata(FIRST_OFFSET));
 
-        final List<TopicPartition> originalPartitions = 
asList(TOPIC_PARTITION, TOPIC_PARTITION2);
+        final List<TopicPartition> originalPartitions = new 
ArrayList<>(INITIAL_ASSIGNMENT);
         final List<TopicPartition> rebalancedPartitions = 
asList(TOPIC_PARTITION, TOPIC_PARTITION2, TOPIC_PARTITION3);
         final Map<TopicPartition, OffsetAndMetadata> rebalanceOffsets = new 
HashMap<>();
         rebalanceOffsets.put(TOPIC_PARTITION, 
workerCurrentOffsets.get(TOPIC_PARTITION));
@@ -1063,6 +1245,8 @@ public class WorkerSinkTaskTest {
         postRebalanceCurrentOffsets.put(TOPIC_PARTITION2, new 
OffsetAndMetadata(FIRST_OFFSET));
         postRebalanceCurrentOffsets.put(TOPIC_PARTITION3, new 
OffsetAndMetadata(FIRST_OFFSET + 2));
 
+        EasyMock.expect(consumer.assignment()).andReturn(new 
HashSet<>(originalPartitions)).times(2);
+
         // iter 3 - note that we return the current offset to indicate they 
should be committed
         sinkTask.preCommit(workerCurrentOffsets);
         EasyMock.expectLastCall().andReturn(workerCurrentOffsets);
@@ -1125,7 +1309,7 @@ public class WorkerSinkTaskTest {
         EasyMock.expectLastCall().andReturn(workerCurrentOffsets);
         sinkTask.put(EasyMock.anyObject());
         EasyMock.expectLastCall();
-        sinkTask.close(workerCurrentOffsets.keySet());
+        sinkTask.close(new ArrayList<>(workerCurrentOffsets.keySet()));
         EasyMock.expectLastCall();
         consumer.commitSync(workerCurrentOffsets);
         EasyMock.expectLastCall();
@@ -1137,9 +1321,10 @@ public class WorkerSinkTaskTest {
         
EasyMock.expect(consumer.position(TOPIC_PARTITION)).andReturn(offsetTp1);
         
EasyMock.expect(consumer.position(TOPIC_PARTITION2)).andReturn(offsetTp2);
         
EasyMock.expect(consumer.position(TOPIC_PARTITION3)).andReturn(offsetTp3);
+        EasyMock.expect(consumer.assignment()).andReturn(new 
HashSet<>(rebalancedPartitions)).times(6);
 
         // onPartitionsAssigned - step 2
-        sinkTask.open(rebalancedPartitions);
+        sinkTask.open(EasyMock.eq(rebalancedPartitions));
         EasyMock.expectLastCall();
 
         // onPartitionsAssigned - step 3 rewind
@@ -1258,6 +1443,8 @@ public class WorkerSinkTaskTest {
         sinkTask.preCommit(offsets);
         EasyMock.expectLastCall().andReturn(offsets);
 
+        
EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(2);
+
         final Capture<OffsetCommitCallback> callback = EasyMock.newCapture();
         consumer.commitAsync(EasyMock.eq(offsets), EasyMock.capture(callback));
         EasyMock.expectLastCall().andAnswer(() -> {
@@ -1370,9 +1557,8 @@ public class WorkerSinkTaskTest {
 
         expectPollInitialAssignment();
 
-        Set<TopicPartition> partitions = new HashSet<>(asList(TOPIC_PARTITION, 
TOPIC_PARTITION2));
-        EasyMock.expect(consumer.assignment()).andReturn(partitions);
-        consumer.pause(partitions);
+        EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT);
+        consumer.pause(INITIAL_ASSIGNMENT);
         PowerMock.expectLastCall();
 
         PowerMock.replayAll();
@@ -1537,10 +1723,19 @@ public class WorkerSinkTaskTest {
         PowerMock.expectLastCall();
     }
 
-    private void expectRebalanceRevocationError(RuntimeException e) {
-        final List<TopicPartition> partitions = asList(TOPIC_PARTITION, 
TOPIC_PARTITION2);
+    private void expectRebalanceLossError(RuntimeException e) {
+        sinkTask.close(new HashSet<>(INITIAL_ASSIGNMENT));
+        EasyMock.expectLastCall().andThrow(e);
+
+        
EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer(
+            () -> {
+                
rebalanceListener.getValue().onPartitionsLost(INITIAL_ASSIGNMENT);
+                return ConsumerRecords.empty();
+            });
+    }
 
-        sinkTask.close(new HashSet<>(partitions));
+    private void expectRebalanceRevocationError(RuntimeException e) {
+        sinkTask.close(new HashSet<>(INITIAL_ASSIGNMENT));
         EasyMock.expectLastCall().andThrow(e);
 
         sinkTask.preCommit(EasyMock.anyObject());
@@ -1548,15 +1743,13 @@ public class WorkerSinkTaskTest {
 
         
EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer(
             () -> {
-                rebalanceListener.getValue().onPartitionsRevoked(partitions);
+                
rebalanceListener.getValue().onPartitionsRevoked(INITIAL_ASSIGNMENT);
                 return ConsumerRecords.empty();
             });
     }
 
     private void expectRebalanceAssignmentError(RuntimeException e) {
-        final List<TopicPartition> partitions = asList(TOPIC_PARTITION, 
TOPIC_PARTITION2);
-
-        sinkTask.close(new HashSet<>(partitions));
+        sinkTask.close(INITIAL_ASSIGNMENT);
         EasyMock.expectLastCall();
 
         sinkTask.preCommit(EasyMock.anyObject());
@@ -1565,29 +1758,29 @@ public class WorkerSinkTaskTest {
         
EasyMock.expect(consumer.position(TOPIC_PARTITION)).andReturn(FIRST_OFFSET);
         
EasyMock.expect(consumer.position(TOPIC_PARTITION2)).andReturn(FIRST_OFFSET);
 
-        sinkTask.open(partitions);
+        sinkTask.open(INITIAL_ASSIGNMENT);
         EasyMock.expectLastCall().andThrow(e);
 
+        
EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(3);
         
EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer(
             () -> {
-                rebalanceListener.getValue().onPartitionsRevoked(partitions);
-                rebalanceListener.getValue().onPartitionsAssigned(partitions);
+                
rebalanceListener.getValue().onPartitionsRevoked(INITIAL_ASSIGNMENT);
+                
rebalanceListener.getValue().onPartitionsAssigned(INITIAL_ASSIGNMENT);
                 return ConsumerRecords.empty();
             });
     }
 
     private void expectPollInitialAssignment() {
-        final List<TopicPartition> partitions = asList(TOPIC_PARTITION, 
TOPIC_PARTITION2);
-
-        sinkTask.open(partitions);
+        sinkTask.open(INITIAL_ASSIGNMENT);
         EasyMock.expectLastCall();
 
+        
EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(2);
+
         
EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer(()
 -> {
-            rebalanceListener.getValue().onPartitionsAssigned(partitions);
+            
rebalanceListener.getValue().onPartitionsAssigned(INITIAL_ASSIGNMENT);
             return ConsumerRecords.empty();
         });
-        
EasyMock.expect(consumer.position(TOPIC_PARTITION)).andReturn(FIRST_OFFSET);
-        
EasyMock.expect(consumer.position(TOPIC_PARTITION2)).andReturn(FIRST_OFFSET);
+        INITIAL_ASSIGNMENT.forEach(tp -> 
EasyMock.expect(consumer.position(tp)).andReturn(FIRST_OFFSET));
 
         sinkTask.put(Collections.emptyList());
         EasyMock.expectLastCall();
diff --git 
a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskThreadedTest.java
 
b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskThreadedTest.java
index 5918747..a7c6a8a 100644
--- 
a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskThreadedTest.java
+++ 
b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskThreadedTest.java
@@ -65,6 +65,7 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.Set;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.fail;
@@ -92,6 +93,8 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
     private static final TopicPartition TOPIC_PARTITION2 = new 
TopicPartition(TOPIC, PARTITION2);
     private static final TopicPartition TOPIC_PARTITION3 = new 
TopicPartition(TOPIC, PARTITION3);
     private static final TopicPartition UNASSIGNED_TOPIC_PARTITION = new 
TopicPartition(TOPIC, 200);
+    private static final Set<TopicPartition> INITIAL_ASSIGNMENT = new 
HashSet<>(Arrays.asList(
+            TOPIC_PARTITION, TOPIC_PARTITION2, TOPIC_PARTITION3));
 
     private static final Map<String, String> TASK_PROPS = new HashMap<>();
     private static final long TIMESTAMP = 42L;
@@ -198,6 +201,7 @@ public class WorkerSinkTaskThreadedTest extends 
ThreadedTest {
         expectInitializeTask();
         expectTaskGetTopic(true);
         expectPollInitialAssignment();
+        expectConsumerAssignment(INITIAL_ASSIGNMENT).times(2);
 
         // Make each poll() take the offset commit interval
         Capture<Collection<SinkRecord>> capturedRecords
@@ -232,6 +236,7 @@ public class WorkerSinkTaskThreadedTest extends 
ThreadedTest {
         expectInitializeTask();
         expectTaskGetTopic(true);
         expectPollInitialAssignment();
+        expectConsumerAssignment(INITIAL_ASSIGNMENT);
 
         Capture<Collection<SinkRecord>> capturedRecords = 
expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT);
         expectOffsetCommit(1L, new RuntimeException(), null, 0, true);
@@ -272,6 +277,7 @@ public class WorkerSinkTaskThreadedTest extends 
ThreadedTest {
         expectInitializeTask();
         expectTaskGetTopic(true);
         expectPollInitialAssignment();
+        expectConsumerAssignment(INITIAL_ASSIGNMENT).times(3);
         Capture<Collection<SinkRecord>> capturedRecords = 
expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT);
         expectOffsetCommit(1L, null, null, 0, true);
         expectOffsetCommit(2L, new RuntimeException(), null, 0, true);
@@ -311,6 +317,7 @@ public class WorkerSinkTaskThreadedTest extends 
ThreadedTest {
         expectInitializeTask();
         expectTaskGetTopic(true);
         expectPollInitialAssignment();
+        expectConsumerAssignment(INITIAL_ASSIGNMENT).times(2);
 
         Capture<Collection<SinkRecord>> capturedRecords
                 = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT);
@@ -343,6 +350,7 @@ public class WorkerSinkTaskThreadedTest extends 
ThreadedTest {
         expectInitializeTask();
         expectTaskGetTopic(true);
         expectPollInitialAssignment();
+        expectConsumerAssignment(INITIAL_ASSIGNMENT).times(2);
 
         // Cut down amount of time to pass in each poll so we trigger exactly 
1 offset commit
         Capture<Collection<SinkRecord>> capturedRecords
@@ -479,6 +487,7 @@ public class WorkerSinkTaskThreadedTest extends 
ThreadedTest {
         expectInitializeTask();
         expectTaskGetTopic(true);
         expectPollInitialAssignment();
+        expectConsumerAssignment(INITIAL_ASSIGNMENT).times(2);
 
         expectRebalanceDuringPoll().andAnswer(() -> {
             Map<TopicPartition, Long> offsets = 
sinkTaskContext.getValue().offsets();
@@ -511,13 +520,13 @@ public class WorkerSinkTaskThreadedTest extends 
ThreadedTest {
     }
 
     private void expectPollInitialAssignment() throws Exception {
-        final List<TopicPartition> partitions = Arrays.asList(TOPIC_PARTITION, 
TOPIC_PARTITION2, TOPIC_PARTITION3);
+        expectConsumerAssignment(INITIAL_ASSIGNMENT).times(2);
 
-        sinkTask.open(partitions);
+        sinkTask.open(INITIAL_ASSIGNMENT);
         EasyMock.expectLastCall();
 
         
EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer(()
 -> {
-            rebalanceListener.getValue().onPartitionsAssigned(partitions);
+            
rebalanceListener.getValue().onPartitionsAssigned(INITIAL_ASSIGNMENT);
             return ConsumerRecords.empty();
         });
         
EasyMock.expect(consumer.position(TOPIC_PARTITION)).andReturn(FIRST_OFFSET);
@@ -528,6 +537,10 @@ public class WorkerSinkTaskThreadedTest extends 
ThreadedTest {
         EasyMock.expectLastCall();
     }
 
+    private IExpectationSetters<Set<TopicPartition>> 
expectConsumerAssignment(Set<TopicPartition> assignment) {
+        return EasyMock.expect(consumer.assignment()).andReturn(assignment);
+    }
+
     private void expectStopTask() throws Exception {
         sinkTask.stop();
         PowerMock.expectLastCall();
diff --git 
a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/errors/WorkerErrantRecordReporterTest.java
 
b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/errors/WorkerErrantRecordReporterTest.java
index 07a4f9e..2d78297 100644
--- 
a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/errors/WorkerErrantRecordReporterTest.java
+++ 
b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/errors/WorkerErrantRecordReporterTest.java
@@ -17,6 +17,7 @@
 
 package org.apache.kafka.connect.runtime.errors;
 
+import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.connect.sink.SinkRecord;
 import org.apache.kafka.connect.storage.Converter;
 import org.apache.kafka.connect.storage.HeaderConverter;
@@ -27,6 +28,9 @@ import org.junit.runner.RunWith;
 import org.powermock.core.classloader.annotations.PowerMockIgnore;
 import org.powermock.modules.junit4.PowerMockRunner;
 
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
 import java.util.concurrent.CompletableFuture;
 
 import static org.junit.Assert.assertFalse;
@@ -62,13 +66,16 @@ public class WorkerErrantRecordReporterTest {
     }
 
     @Test
-    public void testGetAllFutures() {
+    public void testGetFutures() {
+        Collection<TopicPartition> topicPartitions = new ArrayList<>();
         assertTrue(reporter.futures.isEmpty());
         for (int i = 0; i < 4; i++) {
-            reporter.futures.add(CompletableFuture.completedFuture(null));
+            TopicPartition topicPartition = new TopicPartition("topic", i);
+            topicPartitions.add(topicPartition);
+            reporter.futures.put(topicPartition, 
Collections.singletonList(CompletableFuture.completedFuture(null)));
         }
         assertFalse(reporter.futures.isEmpty());
-        reporter.awaitAllFutures();
+        reporter.awaitFutures(topicPartitions);
         assertTrue(reporter.futures.isEmpty());
     }
 }
diff --git 
a/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/EmbeddedKafkaCluster.java
 
b/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/EmbeddedKafkaCluster.java
index 17fd1ac..cf7fde5 100644
--- 
a/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/EmbeddedKafkaCluster.java
+++ 
b/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/EmbeddedKafkaCluster.java
@@ -383,6 +383,19 @@ public class EmbeddedKafkaCluster {
         }
     }
 
+    /**
+     * Delete a Kafka topic.
+     *
+     * @param topic the topic to delete; may not be null
+     */
+    public void deleteTopic(String topic) {
+        try (final Admin adminClient = createAdminClient()) {
+            adminClient.deleteTopics(Collections.singleton(topic)).all().get();
+        } catch (final InterruptedException | ExecutionException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
     public void produce(String topic, String value) {
         produce(topic, null, null, value);
     }

Reply via email to