tombentley commented on code in PR #11780:
URL: https://github.com/apache/kafka/pull/11780#discussion_r887801819


##########
connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ExactlyOnceWorkerSourceTask.java:
##########
@@ -0,0 +1,525 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.connect.runtime;
+
+import org.apache.kafka.clients.producer.Producer;
+import org.apache.kafka.clients.producer.ProducerRecord;
+import org.apache.kafka.clients.producer.RecordMetadata;
+import org.apache.kafka.common.errors.InvalidProducerEpochException;
+import org.apache.kafka.common.metrics.Sensor;
+import org.apache.kafka.common.metrics.stats.Avg;
+import org.apache.kafka.common.metrics.stats.Max;
+import org.apache.kafka.common.metrics.stats.Min;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.connect.errors.ConnectException;
+import org.apache.kafka.connect.runtime.distributed.ClusterConfigState;
+import org.apache.kafka.connect.runtime.errors.RetryWithToleranceOperator;
+import org.apache.kafka.connect.source.SourceRecord;
+import org.apache.kafka.connect.source.SourceTask;
+import org.apache.kafka.connect.source.SourceTask.TransactionBoundary;
+import org.apache.kafka.connect.storage.CloseableOffsetStorageReader;
+import org.apache.kafka.connect.storage.ConnectorOffsetBackingStore;
+import org.apache.kafka.connect.storage.Converter;
+import org.apache.kafka.connect.storage.HeaderConverter;
+import org.apache.kafka.connect.storage.OffsetStorageWriter;
+import org.apache.kafka.connect.storage.StatusBackingStore;
+import org.apache.kafka.connect.util.ConnectorTaskId;
+import org.apache.kafka.connect.util.LoggingContext;
+import org.apache.kafka.connect.util.TopicAdmin;
+import org.apache.kafka.connect.util.TopicCreationGroup;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicReference;
+
+
+/**
+ * WorkerTask that uses a SourceTask to ingest data into Kafka, with support 
for exactly-once delivery guarantees.
+ */
+class ExactlyOnceWorkerSourceTask extends AbstractWorkerSourceTask {
+    private static final Logger log = 
LoggerFactory.getLogger(ExactlyOnceWorkerSourceTask.class);
+
+    private boolean transactionOpen;
+    private final LinkedHashMap<SourceRecord, RecordMetadata> 
commitableRecords;
+
+    private final TransactionManager transactionManager;
+    private final TransactionMetricsGroup transactionMetrics;
+
+    private final ConnectorOffsetBackingStore offsetBackingStore;
+    private final Runnable preProducerCheck;
+    private final Runnable postProducerCheck;
+
+    public ExactlyOnceWorkerSourceTask(ConnectorTaskId id,
+                                       SourceTask task,
+                                       TaskStatus.Listener statusListener,
+                                       TargetState initialState,
+                                       Converter keyConverter,
+                                       Converter valueConverter,
+                                       HeaderConverter headerConverter,
+                                       TransformationChain<SourceRecord> 
transformationChain,
+                                       Producer<byte[], byte[]> producer,
+                                       TopicAdmin admin,
+                                       Map<String, TopicCreationGroup> 
topicGroups,
+                                       CloseableOffsetStorageReader 
offsetReader,
+                                       OffsetStorageWriter offsetWriter,
+                                       ConnectorOffsetBackingStore 
offsetBackingStore,
+                                       WorkerConfig workerConfig,
+                                       ClusterConfigState configState,
+                                       ConnectMetrics connectMetrics,
+                                       ClassLoader loader,
+                                       Time time,
+                                       RetryWithToleranceOperator 
retryWithToleranceOperator,
+                                       StatusBackingStore statusBackingStore,
+                                       SourceConnectorConfig sourceConfig,
+                                       Executor closeExecutor,
+                                       Runnable preProducerCheck,
+                                       Runnable postProducerCheck) {
+        super(id, task, statusListener, initialState, keyConverter, 
valueConverter, headerConverter, transformationChain,
+                new WorkerSourceTaskContext(offsetReader, id, configState, 
buildTransactionContext(sourceConfig)),
+                producer, admin, topicGroups, offsetReader, offsetWriter, 
offsetBackingStore, workerConfig, connectMetrics,
+                loader, time, retryWithToleranceOperator, statusBackingStore, 
closeExecutor);
+
+        this.transactionOpen = false;
+        this.commitableRecords = new LinkedHashMap<>();
+        this.offsetBackingStore = offsetBackingStore;
+
+        this.preProducerCheck = preProducerCheck;
+        this.postProducerCheck = postProducerCheck;
+
+        this.transactionManager = buildTransactionManager(workerConfig, 
sourceConfig, sourceTaskContext.transactionContext());
+        this.transactionMetrics = new TransactionMetricsGroup(id, 
connectMetrics);
+    }
+
+    private static WorkerTransactionContext 
buildTransactionContext(SourceConnectorConfig sourceConfig) {
+        return 
TransactionBoundary.CONNECTOR.equals(sourceConfig.transactionBoundary())
+                ? new WorkerTransactionContext()
+                : null;
+    }
+
+    @Override
+    protected void prepareToInitializeTask() {
+        preProducerCheck.run();
+
+        // Try not to start up the offset store (which has its own producer 
and consumer) if we've already been shut down at this point
+        if (isStopping())
+            return;
+        offsetBackingStore.start();
+
+        // Try not to initialize the transactional producer (which may 
accidentally fence out other, later task generations) if we've already
+        // been shut down at this point
+        if (isStopping())
+            return;
+        producer.initTransactions();
+
+        postProducerCheck.run();
+    }
+
+    @Override
+    protected void prepareToEnterSendLoop() {
+        transactionManager.initialize();
+    }
+
+    @Override
+    protected void beginSendIteration() {
+        // No-op
+    }
+
+    @Override
+    protected void prepareToPollTask() {
+        // No-op
+    }
+
+    @Override
+    protected void recordDropped(SourceRecord record) {
+        synchronized (this) {
+            commitableRecords.put(record, null);
+        }
+        transactionManager.maybeCommitTransactionForRecord(record);
+    }
+
+    @Override
+    protected Optional<SubmittedRecords.SubmittedRecord> prepareToSendRecord(
+            SourceRecord sourceRecord,
+            ProducerRecord<byte[], byte[]> producerRecord
+    ) {
+        if 
(offsetBackingStore.primaryOffsetsTopic().equals(producerRecord.topic())) {
+            // This is to prevent deadlock that occurs when:
+            //     1. A task provides a record whose topic is the task's 
offsets topic
+            //     2. That record is dispatched to the task's producer in a 
transaction that remains open
+            //        at least until the worker polls the task again
+            //     3. In the subsequent call to SourceTask::poll, the task 
requests offsets from the worker
+            //        (which requires a read to the end of the offsets topic, 
and will block until any open
+            //        transactions on the topic are either committed or 
aborted)
+            throw new ConnectException("Source tasks may not produce to their 
own offsets topics when exactly-once support is enabled");
+        }
+        maybeBeginTransaction();
+        return Optional.empty();
+    }
+
+    @Override
+    protected void recordDispatched(SourceRecord record) {
+        // Offsets are converted & serialized in the OffsetWriter
+        // Important: we only save offsets for the record after it has been 
accepted by the producer; this way,
+        // we commit those offsets if and only if the record is sent 
successfully.
+        offsetWriter.offset(record.sourcePartition(), record.sourceOffset());
+        transactionMetrics.addRecord();
+        transactionManager.maybeCommitTransactionForRecord(record);
+    }
+
+    @Override
+    protected void batchDispatched() {
+        transactionManager.maybeCommitTransactionForBatch();
+    }
+
+    @Override
+    protected void recordSent(
+            SourceRecord sourceRecord,
+            ProducerRecord<byte[], byte[]> producerRecord,
+            RecordMetadata recordMetadata
+    ) {
+        synchronized (this) {
+            commitableRecords.put(sourceRecord, recordMetadata);
+        }
+    }
+
+    @Override
+    protected void producerSendFailed(
+            boolean synchronous,
+            ProducerRecord<byte[], byte[]> producerRecord,
+            SourceRecord preTransformRecord,
+            Exception e
+    ) {
+        if (synchronous) {
+            throw maybeWrapProducerSendException(
+                    "Unrecoverable exception trying to send",
+                    e
+            );
+        } else {
+            // No-op; all asynchronously-reported producer exceptions should 
be bubbled up again by Producer::commitTransaction
+        }
+    }
+
+    @Override
+    protected void finalOffsetCommit(boolean failed) {
+        if (failed) {
+            log.debug("Skipping final offset commit as task has failed");
+            return;
+        }
+
+        // It should be safe to commit here even if we were in the middle of 
retrying on RetriableExceptions in the
+        // send loop since we only track source offsets for records that have 
been successfully dispatched to the
+        // producer.
+        // Any records that we were retrying on (and any records after them in 
the batch) won't be included in the
+        // transaction and their offsets won't be committed, but (unless the 
user has requested connector-defined
+        // transaction boundaries), it's better to commit some data than none.
+        transactionManager.maybeCommitFinalTransaction();
+    }
+
+    @Override
+    protected void onPause() {
+        super.onPause();
+        // Commit the transaction now so that we don't end up with a hanging 
transaction, or worse, get fenced out
+        // and fail the task once unpaused
+        transactionManager.maybeCommitFinalTransaction();
+    }
+
+    private void maybeBeginTransaction() {
+        if (!transactionOpen) {
+            producer.beginTransaction();
+            transactionOpen = true;
+        }
+    }
+
+    private void commitTransaction() {
+        log.debug("{} Committing offsets", this);
+
+        long started = time.milliseconds();
+
+        // We might have just aborted a transaction, in which case we'll have 
to begin a new one
+        // in order to commit offsets
+        maybeBeginTransaction();
+
+        AtomicReference<Throwable> flushError = new AtomicReference<>();
+        Future<Void> offsetFlush = null;
+        if (offsetWriter.beginFlush()) {
+            // Now we can actually write the offsets to the internal topic.
+            offsetFlush = offsetWriter.doFlush((error, result) -> {
+                if (error != null) {
+                    log.error("{} Failed to flush offsets to storage: ", 
ExactlyOnceWorkerSourceTask.this, error);
+                    flushError.compareAndSet(null, error);
+                } else {
+                    log.trace("{} Finished flushing offsets to storage", 
ExactlyOnceWorkerSourceTask.this);
+                }
+            });
+        }
+
+        // Commit the transaction
+        // Blocks until all outstanding records have been sent and ack'd
+        try {
+            producer.commitTransaction();
+            if (offsetFlush != null) {
+                // Although it's guaranteed by the above call to 
Producer::commitTransaction that all outstanding
+                // records for the task's producer (including those sent to 
the offsets topic) have been delivered and
+                // ack'd, there is no guarantee that the producer callbacks 
for those records have been completed. So,
+                // we add this call to Future::get to ensure that these 
callbacks are invoked successfully before
+                // proceeding.
+                offsetFlush.get();
+            }
+        } catch (Throwable t) {
+            flushError.compareAndSet(null, t);
+        }
+
+        transactionOpen = false;
+
+        Throwable error = flushError.get();
+        if (error != null) {
+            recordCommitFailure(time.milliseconds() - started, null);
+            offsetWriter.cancelFlush();
+            throw maybeWrapProducerSendException(
+                    "Failed to flush offsets and/or records for task " + id,
+                    error
+            );
+        }
+
+        transactionMetrics.commitTransaction();
+
+        long durationMillis = time.milliseconds() - started;
+        recordCommitSuccess(durationMillis);
+        log.debug("{} Finished commitOffsets successfully in {} ms", this, 
durationMillis);
+
+        // No need for any synchronization here; all other accesses to this 
field take place in producer callbacks,
+        // which should all be completed by this point

Review Comment:
   Why should they be completed by this point? I.e. what in the Java memory 
model guarantees it?



##########
connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Worker.java:
##########
@@ -784,6 +871,10 @@ private static Map<String, Object> 
connectorClientConfigOverrides(ConnectorTaskI
         return clientOverrides;
     }
 
+    public static String transactionalId(String groupId, String connector, int 
taskId) {

Review Comment:
   Perhaps we should call this `taskTransactionalId` to distinguish it from the 
transactional id used for the config topic?



##########
connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Worker.java:
##########
@@ -576,88 +672,42 @@ public boolean startTask(
 
             executor.submit(workerTask);
             if (workerTask instanceof WorkerSourceTask) {
-                sourceTaskOffsetCommitter.schedule(id, (WorkerSourceTask) 
workerTask);
+                sourceTaskOffsetCommitter.ifPresent(committer -> 
committer.schedule(id, (WorkerSourceTask) workerTask));
             }
             return true;
         }
     }
 
-    private WorkerTask buildWorkerTask(ClusterConfigState configState,
-                                       ConnectorConfig connConfig,
-                                       ConnectorTaskId id,
-                                       Task task,
-                                       TaskStatus.Listener statusListener,
-                                       TargetState initialState,
-                                       Converter keyConverter,
-                                       Converter valueConverter,
-                                       HeaderConverter headerConverter,
-                                       ClassLoader loader) {
-        ErrorHandlingMetrics errorHandlingMetrics = errorHandlingMetrics(id);
-        final Class<? extends Connector> connectorClass = 
plugins.connectorClass(
-            connConfig.getString(ConnectorConfig.CONNECTOR_CLASS_CONFIG));
-        RetryWithToleranceOperator retryWithToleranceOperator = new 
RetryWithToleranceOperator(connConfig.errorRetryTimeout(),
-                connConfig.errorMaxDelayInMillis(), 
connConfig.errorToleranceType(), Time.SYSTEM);
-        retryWithToleranceOperator.metrics(errorHandlingMetrics);
-
-        // Decide which type of worker task we need based on the type of task.
-        if (task instanceof SourceTask) {
-            SourceConnectorConfig sourceConfig = new 
SourceConnectorConfig(plugins,
-                    connConfig.originalsStrings(), 
config.topicCreationEnable());
-            retryWithToleranceOperator.reporters(sourceTaskReporters(id, 
sourceConfig, errorHandlingMetrics));
-            TransformationChain<SourceRecord> transformationChain = new 
TransformationChain<>(sourceConfig.<SourceRecord>transformations(), 
retryWithToleranceOperator);
-            log.info("Initializing: {}", transformationChain);
-            CloseableOffsetStorageReader offsetReader = new 
OffsetStorageReaderImpl(offsetBackingStore, id.connector(),
-                    internalKeyConverter, internalValueConverter);
-            OffsetStorageWriter offsetWriter = new 
OffsetStorageWriter(offsetBackingStore, id.connector(),
-                    internalKeyConverter, internalValueConverter);
-            Map<String, Object> producerProps = producerConfigs(id, 
"connector-producer-" + id, config, sourceConfig, connectorClass,
-                                                                
connectorClientConfigOverridePolicy, kafkaClusterId);
-            KafkaProducer<byte[], byte[]> producer = new 
KafkaProducer<>(producerProps);
-            TopicAdmin admin;
-            Map<String, TopicCreationGroup> topicCreationGroups;
-            if (config.topicCreationEnable() && 
sourceConfig.usesTopicCreation()) {
-                Map<String, Object> adminProps = adminConfigs(id, 
"connector-adminclient-" + id, config,
-                        sourceConfig, connectorClass, 
connectorClientConfigOverridePolicy, kafkaClusterId);
-                admin = new TopicAdmin(adminProps);
-                topicCreationGroups = 
TopicCreationGroup.configuredGroups(sourceConfig);
-            } else {
-                admin = null;
-                topicCreationGroups = null;
-            }
-
-            // Note we pass the configState as it performs dynamic 
transformations under the covers
-            return new WorkerSourceTask(id, (SourceTask) task, statusListener, 
initialState, keyConverter, valueConverter,
-                    headerConverter, transformationChain, producer, admin, 
topicCreationGroups,
-                    offsetReader, offsetWriter, config, configState, metrics, 
loader, time, retryWithToleranceOperator, herder.statusBackingStore(), 
executor);
-        } else if (task instanceof SinkTask) {
-            TransformationChain<SinkRecord> transformationChain = new 
TransformationChain<>(connConfig.<SinkRecord>transformations(), 
retryWithToleranceOperator);
-            log.info("Initializing: {}", transformationChain);
-            SinkConnectorConfig sinkConfig = new SinkConnectorConfig(plugins, 
connConfig.originalsStrings());
-            retryWithToleranceOperator.reporters(sinkTaskReporters(id, 
sinkConfig, errorHandlingMetrics, connectorClass));
-            WorkerErrantRecordReporter workerErrantRecordReporter = 
createWorkerErrantRecordReporter(sinkConfig, retryWithToleranceOperator,
-                    keyConverter, valueConverter, headerConverter);
-
-            Map<String, Object> consumerProps = consumerConfigs(id, config, 
connConfig, connectorClass, connectorClientConfigOverridePolicy, 
kafkaClusterId);
-            KafkaConsumer<byte[], byte[]> consumer = new 
KafkaConsumer<>(consumerProps);
-
-            return new WorkerSinkTask(id, (SinkTask) task, statusListener, 
initialState, config, configState, metrics, keyConverter,
-                                      valueConverter, headerConverter, 
transformationChain, consumer, loader, time,
-                                      retryWithToleranceOperator, 
workerErrantRecordReporter, herder.statusBackingStore());
-        } else {
-            log.error("Tasks must be a subclass of either SourceTask or 
SinkTask and current is {}", task);
-            throw new ConnectException("Tasks must be a subclass of either 
SourceTask or SinkTask");
-        }
+    static Map<String, Object> 
exactlyOnceSourceTaskProducerConfigs(ConnectorTaskId id,
+                                                              WorkerConfig 
config,
+                                                              ConnectorConfig 
connConfig,
+                                                              Class<? extends 
Connector>  connectorClass,
+                                                              
ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy,
+                                                              String 
clusterId) {
+        Map<String, Object> result = baseProducerConfigs(id.connector(), 
"connector-producer-" + id, config, connConfig, connectorClass, 
connectorClientConfigOverridePolicy, clusterId);
+        ConnectUtils.ensureProperty(
+                result, ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "true",
+                "for connectors when exactly-once source support is enabled",
+                false
+        );
+        String transactionalId = transactionalId(config.groupId(), 
id.connector(), id.task());

Review Comment:
   Is this worth a comment?



##########
connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerTransactionContext.java:
##########
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.connect.runtime;
+
+import org.apache.kafka.connect.source.SourceRecord;
+import org.apache.kafka.connect.source.TransactionContext;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.HashSet;
+import java.util.Objects;
+import java.util.Set;
+
+public class WorkerTransactionContext implements TransactionContext {
+
+    private static final Logger log = 
LoggerFactory.getLogger(WorkerTransactionContext.class);
+
+    private final Set<SourceRecord> commitableRecords = new HashSet<>();

Review Comment:
   A comment about thread safety here seems warranted. It's not clear which 
threads are accessing an instance of quite what the mechanisms for ensuring 
thread safety are. E.g. why are `abortTransaction`/`commitTransaction` 
synchronized, but only volatile writes (surely one or other thread safety 
mechanism is needed, but not both)? Why is `abortTransaction` not 
`synchronized`?



##########
connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ExactlyOnceWorkerSourceTaskTest.java:
##########
@@ -0,0 +1,1330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.connect.runtime;
+
+import org.apache.kafka.clients.admin.NewTopic;
+import org.apache.kafka.clients.producer.KafkaProducer;
+import org.apache.kafka.clients.producer.ProducerRecord;
+import org.apache.kafka.clients.producer.RecordMetadata;
+import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.InvalidTopicException;
+import org.apache.kafka.common.errors.RecordTooLargeException;
+import org.apache.kafka.common.errors.TopicAuthorizationException;
+import org.apache.kafka.common.header.Header;
+import org.apache.kafka.common.header.Headers;
+import org.apache.kafka.common.header.internals.RecordHeaders;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.connect.data.Schema;
+import org.apache.kafka.connect.errors.ConnectException;
+import org.apache.kafka.connect.integration.MonitorableSourceConnector;
+import org.apache.kafka.connect.runtime.ConnectMetrics.MetricGroup;
+import org.apache.kafka.connect.runtime.distributed.ClusterConfigState;
+import org.apache.kafka.connect.runtime.errors.RetryWithToleranceOperatorTest;
+import org.apache.kafka.connect.runtime.isolation.Plugins;
+import org.apache.kafka.connect.runtime.standalone.StandaloneConfig;
+import org.apache.kafka.connect.source.SourceRecord;
+import org.apache.kafka.connect.source.SourceTask;
+import org.apache.kafka.connect.source.SourceTaskContext;
+import org.apache.kafka.connect.source.TransactionContext;
+import org.apache.kafka.connect.storage.CloseableOffsetStorageReader;
+import org.apache.kafka.connect.storage.ConnectorOffsetBackingStore;
+import org.apache.kafka.connect.storage.Converter;
+import org.apache.kafka.connect.storage.HeaderConverter;
+import org.apache.kafka.connect.storage.OffsetStorageWriter;
+import org.apache.kafka.connect.storage.StatusBackingStore;
+import org.apache.kafka.connect.storage.StringConverter;
+import org.apache.kafka.connect.util.Callback;
+import org.apache.kafka.connect.util.ConnectorTaskId;
+import org.apache.kafka.connect.util.ParameterizedTest;
+import org.apache.kafka.connect.util.ThreadedTest;
+import org.apache.kafka.connect.util.TopicAdmin;
+import org.apache.kafka.connect.util.TopicCreationGroup;
+import org.easymock.Capture;
+import org.easymock.EasyMock;
+import org.easymock.IAnswer;
+import org.easymock.IExpectationSetters;
+import org.junit.After;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.api.easymock.PowerMock;
+import org.powermock.api.easymock.annotation.Mock;
+import org.powermock.api.easymock.annotation.MockStrict;
+import org.powermock.core.classloader.annotations.PowerMockIgnore;
+import org.powermock.modules.junit4.PowerMockRunner;
+import org.powermock.modules.junit4.PowerMockRunnerDelegate;
+
+import java.time.Duration;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import static 
org.apache.kafka.connect.integration.MonitorableSourceConnector.TOPIC_CONFIG;
+import static 
org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG;
+import static 
org.apache.kafka.connect.runtime.ConnectorConfig.KEY_CONVERTER_CLASS_CONFIG;
+import static 
org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG;
+import static 
org.apache.kafka.connect.runtime.ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG;
+import static 
org.apache.kafka.connect.runtime.SourceConnectorConfig.TOPIC_CREATION_GROUPS_CONFIG;
+import static 
org.apache.kafka.connect.runtime.SourceConnectorConfig.TRANSACTION_BOUNDARY_CONFIG;
+import static 
org.apache.kafka.connect.runtime.SourceConnectorConfig.TRANSACTION_BOUNDARY_INTERVAL_CONFIG;
+import static 
org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_PREFIX;
+import static 
org.apache.kafka.connect.runtime.TopicCreationConfig.EXCLUDE_REGEX_CONFIG;
+import static 
org.apache.kafka.connect.runtime.TopicCreationConfig.INCLUDE_REGEX_CONFIG;
+import static 
org.apache.kafka.connect.runtime.TopicCreationConfig.PARTITIONS_CONFIG;
+import static 
org.apache.kafka.connect.runtime.TopicCreationConfig.REPLICATION_FACTOR_CONFIG;
+import static 
org.apache.kafka.connect.runtime.WorkerConfig.TOPIC_CREATION_ENABLE_CONFIG;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+@PowerMockIgnore({"javax.management.*",
+        "org.apache.log4j.*"})
+@RunWith(PowerMockRunner.class)
+@PowerMockRunnerDelegate(ParameterizedTest.class)
+public class ExactlyOnceWorkerSourceTaskTest extends ThreadedTest {
+    private static final String TOPIC = "topic";
+    private static final Map<String, byte[]> PARTITION = 
Collections.singletonMap("key", "partition".getBytes());
+    private static final Map<String, Integer> OFFSET = 
Collections.singletonMap("key", 12);
+
+    // Connect-format data
+    private static final Schema KEY_SCHEMA = Schema.INT32_SCHEMA;
+    private static final Integer KEY = -1;
+    private static final Schema RECORD_SCHEMA = Schema.INT64_SCHEMA;
+    private static final Long RECORD = 12L;
+    // Serialized data. The actual format of this data doesn't matter -- we 
just want to see that the right version
+    // is used in the right place.
+    private static final byte[] SERIALIZED_KEY = "converted-key".getBytes();
+    private static final byte[] SERIALIZED_RECORD = 
"converted-record".getBytes();
+
+    private final ExecutorService executor = 
Executors.newSingleThreadExecutor();
+    private final ConnectorTaskId taskId = new ConnectorTaskId("job", 0);
+    private WorkerConfig config;
+    private SourceConnectorConfig sourceConfig;
+    private Plugins plugins;
+    private MockConnectMetrics metrics;
+    private Time time;
+    private CountDownLatch pollLatch;
+    @Mock private SourceTask sourceTask;
+    @Mock private Converter keyConverter;
+    @Mock private Converter valueConverter;
+    @Mock private HeaderConverter headerConverter;
+    @Mock private TransformationChain<SourceRecord> transformationChain;
+    @Mock private KafkaProducer<byte[], byte[]> producer;
+    @Mock private TopicAdmin admin;
+    @Mock private CloseableOffsetStorageReader offsetReader;
+    @Mock private OffsetStorageWriter offsetWriter;
+    @Mock private ClusterConfigState clusterConfigState;
+    private ExactlyOnceWorkerSourceTask workerTask;
+    @Mock private Future<RecordMetadata> sendFuture;
+    @MockStrict private TaskStatus.Listener statusListener;
+    @Mock private StatusBackingStore statusBackingStore;
+    @Mock private ConnectorOffsetBackingStore offsetStore;
+    @Mock private Runnable preProducerCheck;
+    @Mock private Runnable postProducerCheck;
+
+    private Capture<org.apache.kafka.clients.producer.Callback> 
producerCallbacks;
+
+    private static final Map<String, String> TASK_PROPS = new HashMap<>();
+    static {
+        TASK_PROPS.put(TaskConfig.TASK_CLASS_CONFIG, 
TestSourceTask.class.getName());
+    }
+    private static final TaskConfig TASK_CONFIG = new TaskConfig(TASK_PROPS);
+
+    private static final SourceRecord SOURCE_RECORD =
+            new SourceRecord(PARTITION, OFFSET, "topic", null, KEY_SCHEMA, 
KEY, RECORD_SCHEMA, RECORD);
+
+    private static final List<SourceRecord> RECORDS = 
Collections.singletonList(SOURCE_RECORD);
+
+    private final boolean enableTopicCreation;
+
+    @ParameterizedTest.Parameters
+    public static Collection<Boolean> parameters() {
+        return Arrays.asList(false, true);
+    }
+
+    public ExactlyOnceWorkerSourceTaskTest(boolean enableTopicCreation) {
+        this.enableTopicCreation = enableTopicCreation;
+    }
+
+    @Override
+    public void setup() {
+        super.setup();
+        Map<String, String> workerProps = workerProps();
+        plugins = new Plugins(workerProps);
+        config = new StandaloneConfig(workerProps);
+        sourceConfig = new SourceConnectorConfig(plugins, 
sourceConnectorProps(), true);
+        producerCallbacks = EasyMock.newCapture();
+        metrics = new MockConnectMetrics();
+        time = Time.SYSTEM;
+        
EasyMock.expect(offsetStore.primaryOffsetsTopic()).andStubReturn("offsets-topic");
+        pollLatch = new CountDownLatch(1);
+    }
+
+    private Map<String, String> workerProps() {
+        Map<String, String> props = new HashMap<>();
+        props.put("key.converter", 
"org.apache.kafka.connect.json.JsonConverter");
+        props.put("value.converter", 
"org.apache.kafka.connect.json.JsonConverter");
+        props.put("internal.key.converter", 
"org.apache.kafka.connect.json.JsonConverter");
+        props.put("internal.value.converter", 
"org.apache.kafka.connect.json.JsonConverter");
+        props.put("internal.key.converter.schemas.enable", "false");
+        props.put("internal.value.converter.schemas.enable", "false");
+        props.put("offset.storage.file.filename", "/tmp/connect.offsets");
+        props.put(TOPIC_CREATION_ENABLE_CONFIG, 
String.valueOf(enableTopicCreation));
+        return props;
+    }
+
+    private Map<String, String> sourceConnectorProps() {
+        return sourceConnectorProps(SourceTask.TransactionBoundary.DEFAULT);
+    }
+
+    private Map<String, String> 
sourceConnectorProps(SourceTask.TransactionBoundary transactionBoundary) {
+        // setup up props for the source connector
+        Map<String, String> props = new HashMap<>();
+        props.put("name", "foo-connector");
+        props.put(CONNECTOR_CLASS_CONFIG, 
MonitorableSourceConnector.class.getSimpleName());
+        props.put(TASKS_MAX_CONFIG, String.valueOf(1));
+        props.put(TOPIC_CONFIG, TOPIC);
+        props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName());
+        props.put(VALUE_CONVERTER_CLASS_CONFIG, 
StringConverter.class.getName());
+        props.put(TOPIC_CREATION_GROUPS_CONFIG, String.join(",", "foo", 
"bar"));
+        props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, 
String.valueOf(1));
+        props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, 
String.valueOf(1));
+        props.put(TRANSACTION_BOUNDARY_CONFIG, transactionBoundary.toString());
+        props.put(SourceConnectorConfig.TOPIC_CREATION_PREFIX + "foo" + "." + 
INCLUDE_REGEX_CONFIG, TOPIC);
+        props.put(SourceConnectorConfig.TOPIC_CREATION_PREFIX + "bar" + "." + 
INCLUDE_REGEX_CONFIG, ".*");
+        props.put(SourceConnectorConfig.TOPIC_CREATION_PREFIX + "bar" + "." + 
EXCLUDE_REGEX_CONFIG, TOPIC);
+        return props;
+    }
+
+    @After
+    public void tearDown() {
+        if (metrics != null) metrics.stop();
+    }
+
+    private void createWorkerTask() {
+        createWorkerTask(TargetState.STARTED);
+    }
+
+    private void createWorkerTask(TargetState initialState) {
+        createWorkerTask(initialState, keyConverter, valueConverter, 
headerConverter);
+    }
+
+    private void createWorkerTask(TargetState initialState, Converter 
keyConverter, Converter valueConverter, HeaderConverter headerConverter) {
+        workerTask = new ExactlyOnceWorkerSourceTask(taskId, sourceTask, 
statusListener, initialState, keyConverter, valueConverter, headerConverter,
+                transformationChain, producer, admin, 
TopicCreationGroup.configuredGroups(sourceConfig), offsetReader, offsetWriter, 
offsetStore,
+                config, clusterConfigState, metrics, 
plugins.delegatingLoader(), time, RetryWithToleranceOperatorTest.NOOP_OPERATOR, 
statusBackingStore,
+                sourceConfig, Runnable::run, preProducerCheck, 
postProducerCheck);
+    }
+
+    @Test
+    public void testStartPaused() throws Exception {
+        final CountDownLatch pauseLatch = new CountDownLatch(1);
+
+        createWorkerTask(TargetState.PAUSED);
+
+        expectCall(() -> statusListener.onPause(taskId)).andAnswer(() -> {
+            pauseLatch.countDown();
+            return null;
+        });
+
+        // The task checks to see if there are offsets to commit before pausing
+        EasyMock.expect(offsetWriter.willFlush()).andReturn(false);
+
+        expectClose();
+
+        expectCall(() -> statusListener.onShutdown(taskId));
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        Future<?> taskFuture = executor.submit(workerTask);
+
+        assertTrue(pauseLatch.await(5, TimeUnit.SECONDS));
+        workerTask.stop();
+        assertTrue(workerTask.awaitStop(1000));
+
+        taskFuture.get();
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testPause() throws Exception {
+        createWorkerTask();
+
+        expectPreflight();
+        expectStartup();
+
+        AtomicInteger polls = new AtomicInteger(0);
+        AtomicInteger flushes = new AtomicInteger(0);
+        pollLatch = new CountDownLatch(10);
+        expectPolls(polls);
+        expectAnyFlushes(flushes);
+
+        expectTopicCreation(TOPIC);
+
+        expectCall(() -> statusListener.onPause(taskId));
+
+        expectCall(sourceTask::stop);
+        expectCall(() -> statusListener.onShutdown(taskId));
+
+        expectClose();
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        Future<?> taskFuture = executor.submit(workerTask);
+        assertTrue(awaitLatch(pollLatch));
+
+        workerTask.transitionTo(TargetState.PAUSED);
+
+        int priorCount = polls.get();
+        Thread.sleep(100);
+
+        // since the transition is observed asynchronously, the count could be 
off by one loop iteration
+        assertTrue(polls.get() - priorCount <= 1);
+
+        workerTask.stop();
+        assertTrue(workerTask.awaitStop(1000));
+
+        taskFuture.get();
+
+        assertEquals("Task should have flushed offsets for every record poll, 
once on pause, and once for end-of-life offset commit",
+                flushes.get(), polls.get() + 2);
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testFailureInPreProducerCheck() {
+        createWorkerTask();
+
+        Exception exception = new ConnectException("Failed to perform zombie 
fencing");
+        expectCall(preProducerCheck::run).andThrow(exception);
+
+        expectCall(() -> statusListener.onFailure(taskId, exception));
+
+        // Don't expect task to be stopped since it was never started to begin 
with
+
+        expectClose();
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        // No need to execute on a separate thread; preflight checks should 
all take place before the poll-send loop starts
+        workerTask.run();
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testFailureInOffsetStoreStart() {
+        createWorkerTask();
+
+        expectCall(preProducerCheck::run);
+        Exception exception = new ConnectException("No soup for you!");
+        expectCall(offsetStore::start).andThrow(exception);
+
+        expectCall(() -> statusListener.onFailure(taskId, exception));
+
+        // Don't expect task to be stopped since it was never started to begin 
with
+
+        expectClose();
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        // No need to execute on a separate thread; preflight checks should 
all take place before the poll-send loop starts
+        workerTask.run();
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testFailureInProducerInitialization() {
+        createWorkerTask();
+
+        expectCall(preProducerCheck::run);
+        expectCall(offsetStore::start);
+        expectCall(producer::initTransactions);
+        Exception exception = new ConnectException("You can't do that!");
+        expectCall(postProducerCheck::run).andThrow(exception);
+
+        expectCall(() -> statusListener.onFailure(taskId, exception));
+
+        // Don't expect task to be stopped since it was never started to begin 
with
+
+        expectClose();
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        // No need to execute on a separate thread; preflight checks should 
all take place before the poll-send loop starts
+        workerTask.run();
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testFailureInPostProducerCheck() {
+        createWorkerTask();
+
+        expectCall(preProducerCheck::run);
+        expectCall(offsetStore::start);
+        Exception exception = new ConnectException("New task configs for the 
connector have already been generated");
+        expectCall(producer::initTransactions).andThrow(exception);
+
+        expectCall(() -> statusListener.onFailure(taskId, exception));
+
+        // Don't expect task to be stopped since it was never started to begin 
with
+
+        expectClose();
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        // No need to execute on a separate thread; preflight checks should 
all take place before the poll-send loop starts
+        workerTask.run();
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testPollsInBackground() throws Exception {
+        createWorkerTask();
+
+        expectPreflight();
+        expectStartup();
+
+        AtomicInteger polls = new AtomicInteger(0);
+        AtomicInteger flushes = new AtomicInteger(0);
+        pollLatch = new CountDownLatch(10);
+        expectPolls(polls);
+        expectAnyFlushes(flushes);
+
+        expectTopicCreation(TOPIC);
+
+        expectCall(sourceTask::stop);
+        expectCall(() -> statusListener.onShutdown(taskId));
+
+        expectClose();
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        Future<?> taskFuture = executor.submit(workerTask);
+
+        assertTrue(awaitLatch(pollLatch));
+        workerTask.stop();
+        assertTrue(workerTask.awaitStop(1000));
+
+        taskFuture.get();
+        assertPollMetrics(10);
+        assertTransactionMetrics(1);
+
+        assertEquals("Task should have flushed offsets for every record poll 
and for end-of-life offset commit",
+                flushes.get(), polls.get() + 1);
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testFailureInPoll() throws Exception {
+        createWorkerTask();
+
+        expectPreflight();
+        expectStartup();
+
+        final CountDownLatch pollLatch = new CountDownLatch(1);
+        final RuntimeException exception = new RuntimeException();
+        EasyMock.expect(sourceTask.poll()).andAnswer(() -> {
+            pollLatch.countDown();
+            throw exception;
+        });
+
+        expectCall(() -> statusListener.onFailure(taskId, exception));
+        expectCall(sourceTask::stop);
+
+        expectClose();
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        Future<?> taskFuture = executor.submit(workerTask);
+
+        assertTrue(awaitLatch(pollLatch));
+        //Failure in poll should trigger automatic stop of the worker
+        assertTrue(workerTask.awaitStop(1000));
+
+        taskFuture.get();
+        assertPollMetrics(0);
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testFailureInPollAfterCancel() throws Exception {
+        createWorkerTask();
+
+        expectPreflight();
+        expectStartup();
+
+        final CountDownLatch pollLatch = new CountDownLatch(1);
+        final CountDownLatch workerCancelLatch = new CountDownLatch(1);
+        final RuntimeException exception = new RuntimeException();
+        EasyMock.expect(sourceTask.poll()).andAnswer(() -> {
+            pollLatch.countDown();
+            assertTrue(awaitLatch(workerCancelLatch));
+            throw exception;
+        });
+
+        expectCall(offsetReader::close);
+        expectCall(() -> producer.close(Duration.ZERO));
+        expectCall(sourceTask::stop);
+
+        expectClose();
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        Future<?> taskFuture = executor.submit(workerTask);
+
+        assertTrue(awaitLatch(pollLatch));
+        workerTask.cancel();
+        workerCancelLatch.countDown();
+        assertTrue(workerTask.awaitStop(1000));
+
+        taskFuture.get();
+        assertPollMetrics(0);
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testFailureInPollAfterStop() throws Exception {
+        createWorkerTask();
+
+        expectPreflight();
+        expectStartup();
+
+        final CountDownLatch pollLatch = new CountDownLatch(1);
+        final CountDownLatch workerStopLatch = new CountDownLatch(1);
+        final RuntimeException exception = new RuntimeException();
+        EasyMock.expect(sourceTask.poll()).andAnswer(() -> {
+            pollLatch.countDown();
+            assertTrue(awaitLatch(workerStopLatch));
+            throw exception;
+        });
+
+        expectCall(() -> statusListener.onShutdown(taskId));
+        expectCall(sourceTask::stop);
+
+        expectClose();
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        Future<?> taskFuture = executor.submit(workerTask);
+
+        assertTrue(awaitLatch(pollLatch));
+        workerTask.stop();
+        workerStopLatch.countDown();
+        assertTrue(workerTask.awaitStop(1000));
+
+        taskFuture.get();
+        assertPollMetrics(0);
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testPollReturnsNoRecords() throws Exception {
+        // Test that the task handles an empty list of records
+        createWorkerTask();
+
+        expectPreflight();
+        expectStartup();
+
+        final CountDownLatch pollLatch = expectEmptyPolls(1, new 
AtomicInteger());
+        EasyMock.expect(offsetWriter.willFlush()).andReturn(false).anyTimes();
+
+        expectCall(sourceTask::stop);
+        expectCall(() -> statusListener.onShutdown(taskId));
+
+        expectClose();
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        Future<?> taskFuture = executor.submit(workerTask);
+
+        assertTrue(awaitLatch(pollLatch));
+        workerTask.stop();
+        assertTrue(workerTask.awaitStop(1000));
+
+        taskFuture.get();
+        assertPollMetrics(0);
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testPollBasedCommit() throws Exception {
+        Map<String, String> connectorProps = 
sourceConnectorProps(SourceTask.TransactionBoundary.POLL);
+        sourceConfig = new SourceConnectorConfig(plugins, connectorProps, 
enableTopicCreation);
+
+        createWorkerTask();
+
+        expectPreflight();
+        expectStartup();
+
+        AtomicInteger polls = new AtomicInteger();
+        AtomicInteger flushes = new AtomicInteger();
+        expectPolls(polls);
+        expectAnyFlushes(flushes);
+
+        expectTopicCreation(TOPIC);
+
+        expectCall(sourceTask::stop);
+        expectCall(() -> statusListener.onShutdown(taskId));
+
+        expectClose();
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        Future<?> taskFuture = executor.submit(workerTask);
+
+        assertTrue(awaitLatch(pollLatch));
+        workerTask.stop();
+        assertTrue(workerTask.awaitStop(1000));
+
+        taskFuture.get();
+
+        assertEquals("Task should have flushed offsets for every record poll, 
and for end-of-life offset commit",
+                flushes.get(), polls.get() + 1);
+
+        assertPollMetrics(1);
+        assertTransactionMetrics(1);
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testIntervalBasedCommit() throws Exception {
+        long commitInterval = 618;
+        Map<String, String> connectorProps = 
sourceConnectorProps(SourceTask.TransactionBoundary.INTERVAL);
+        connectorProps.put(TRANSACTION_BOUNDARY_INTERVAL_CONFIG, 
Long.toString(commitInterval));
+        sourceConfig = new SourceConnectorConfig(plugins, connectorProps, 
enableTopicCreation);
+
+        time = new MockTime();
+
+        createWorkerTask();
+
+        expectPreflight();
+        expectStartup();
+
+        expectPolls();
+        final CountDownLatch firstPollLatch = new CountDownLatch(2);
+        final CountDownLatch secondPollLatch = new CountDownLatch(2);
+        final CountDownLatch thirdPollLatch = new CountDownLatch(2);
+
+        AtomicInteger flushes = new AtomicInteger();
+        expectFlush(FlushOutcome.SUCCEED, flushes);
+        expectFlush(FlushOutcome.SUCCEED, flushes);
+        expectFlush(FlushOutcome.SUCCEED, flushes);
+
+        expectTopicCreation(TOPIC);
+
+        expectCall(sourceTask::stop);
+        expectCall(() -> statusListener.onShutdown(taskId));
+
+        expectClose();
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        Future<?> taskFuture = executor.submit(workerTask);
+
+        pollLatch = firstPollLatch;
+        assertTrue(awaitLatch(pollLatch));
+        assertEquals("No flushes should have taken place before offset commit 
interval has elapsed", 0, flushes.get());
+        time.sleep(commitInterval);
+
+        pollLatch = secondPollLatch;
+        assertTrue(awaitLatch(pollLatch));
+        assertEquals("One flush should have taken place after offset commit 
interval has elapsed", 1, flushes.get());
+        time.sleep(commitInterval * 2);
+
+        pollLatch = thirdPollLatch;
+        assertTrue(awaitLatch(pollLatch));
+        assertEquals("Two flushes should have taken place after offset commit 
interval has elapsed again", 2, flushes.get());
+
+        workerTask.stop();
+        assertTrue(workerTask.awaitStop(1000));
+
+        taskFuture.get();
+
+        assertEquals("Task should have flushed offsets twice based on offset 
commit interval, and performed final end-of-life offset commit",
+                3, flushes.get());
+
+        assertPollMetrics(2);
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testConnectorBasedCommit() throws Exception {
+        Map<String, String> connectorProps = 
sourceConnectorProps(SourceTask.TransactionBoundary.CONNECTOR);
+        sourceConfig = new SourceConnectorConfig(plugins, connectorProps, 
enableTopicCreation);
+        createWorkerTask();
+
+        expectPreflight();
+        expectStartup();
+
+        expectPolls();
+        List<CountDownLatch> pollLatches = IntStream.range(0, 7).mapToObj(i -> 
new CountDownLatch(3)).collect(Collectors.toList());
+
+        AtomicInteger flushes = new AtomicInteger();
+        // First flush: triggered by TransactionContext::commitTransaction 
(batch)
+        expectFlush(FlushOutcome.SUCCEED, flushes);
+
+        // Second flush: triggered by TransactionContext::commitTransaction 
(record)
+        expectFlush(FlushOutcome.SUCCEED, flushes);
+
+        // Third flush: triggered by TransactionContext::abortTransaction 
(batch)
+        expectCall(producer::abortTransaction);
+        EasyMock.expect(offsetWriter.willFlush()).andReturn(true);
+        expectFlush(FlushOutcome.SUCCEED, flushes);
+
+        // Third flush: triggered by TransactionContext::abortTransaction 
(record)
+        EasyMock.expect(offsetWriter.willFlush()).andReturn(true);
+        expectCall(producer::abortTransaction);
+        expectFlush(FlushOutcome.SUCCEED, flushes);
+
+        expectTopicCreation(TOPIC);
+
+        expectCall(sourceTask::stop);
+        expectCall(() -> statusListener.onShutdown(taskId));
+
+        expectClose();
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        Future<?> taskFuture = executor.submit(workerTask);
+
+        TransactionContext transactionContext = 
workerTask.sourceTaskContext.transactionContext();
+
+        int poll = -1;
+        pollLatch = pollLatches.get(++poll);
+        assertTrue(awaitLatch(pollLatch));
+        assertEquals("No flushes should have taken place without connector 
requesting transaction commit", 0, flushes.get());
+
+        transactionContext.commitTransaction();
+        pollLatch = pollLatches.get(++poll);
+        assertTrue(awaitLatch(pollLatch));
+        assertEquals("One flush should have taken place after connector 
requested batch commit", 1, flushes.get());
+
+        transactionContext.commitTransaction(SOURCE_RECORD);
+        pollLatch = pollLatches.get(++poll);
+        assertTrue(awaitLatch(pollLatch));
+        assertEquals("Two flushes should have taken place after connector 
requested individual record commit", 2, flushes.get());
+
+        pollLatch = pollLatches.get(++poll);
+        assertTrue(awaitLatch(pollLatch));
+        assertEquals("Only two flushes should still have taken place without 
connector re-requesting commit, even on identical records", 2, flushes.get());
+
+        transactionContext.abortTransaction();
+        pollLatch = pollLatches.get(++poll);
+        assertTrue(awaitLatch(pollLatch));
+        assertEquals("Three flushes should have taken place after connector 
requested batch abort", 3, flushes.get());
+
+        transactionContext.abortTransaction(SOURCE_RECORD);
+        pollLatch = pollLatches.get(++poll);
+        assertTrue(awaitLatch(pollLatch));
+        assertEquals("Four flushes should have taken place after connector 
requested individual record abort", 4, flushes.get());
+
+        pollLatch = pollLatches.get(++poll);
+        assertTrue(awaitLatch(pollLatch));
+        assertEquals("Only four flushes should still have taken place without 
connector re-requesting abort, even on identical records", 4, flushes.get());
+
+        workerTask.stop();
+        assertTrue(workerTask.awaitStop(1000));
+
+        taskFuture.get();
+
+        assertEquals("Task should have flushed offsets four times based on 
connector-defined boundaries, and skipped final end-of-life offset commit",
+                4, flushes.get());
+
+        assertPollMetrics(1);
+        assertTransactionMetrics(2);
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testCommitFlushCallbackFailure() throws Exception {
+        testCommitFailure(FlushOutcome.FAIL_FLUSH_CALLBACK);
+    }
+
+    @Test
+    public void testCommitTransactionFailure() throws Exception {
+        testCommitFailure(FlushOutcome.FAIL_TRANSACTION_COMMIT);
+    }
+
+    private void testCommitFailure(FlushOutcome causeOfFailure) throws 
Exception {
+        createWorkerTask();
+
+        expectPreflight();
+        expectStartup();
+
+        expectPolls();
+        expectFlush(causeOfFailure);
+
+        expectTopicCreation(TOPIC);
+
+        expectCall(sourceTask::stop);
+        // Unlike the standard WorkerSourceTask class, this one fails 
permanently when offset commits don't succeed
+        final CountDownLatch taskFailure = new CountDownLatch(1);
+        expectCall(() -> statusListener.onFailure(EasyMock.eq(taskId), 
EasyMock.anyObject()))
+                .andAnswer(() -> {
+                    taskFailure.countDown();
+                    return null;
+                });
+
+        expectClose();
+
+        PowerMock.replayAll();
+
+        workerTask.initialize(TASK_CONFIG);
+        Future<?> taskFuture = executor.submit(workerTask);
+
+        assertTrue(awaitLatch(taskFailure));
+        workerTask.stop();
+        assertTrue(workerTask.awaitStop(1000));
+
+        taskFuture.get();
+        assertPollMetrics(1);
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testSendRecordsRetries() throws Exception {
+        createWorkerTask();
+
+        // Differentiate only by Kafka partition so we can reuse conversion 
expectations
+        SourceRecord record1 = new SourceRecord(PARTITION, OFFSET, "topic", 1, 
KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD);
+        SourceRecord record2 = new SourceRecord(PARTITION, OFFSET, "topic", 2, 
KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD);
+        SourceRecord record3 = new SourceRecord(PARTITION, OFFSET, "topic", 3, 
KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD);
+
+        expectTopicCreation(TOPIC);
+
+        // First round
+        expectSendRecordOnce(false);
+        expectCall(producer::beginTransaction);
+        // Any Producer retriable exception should work here
+        expectSendRecordSyncFailure(new 
org.apache.kafka.common.errors.TimeoutException("retriable sync failure"));
+
+        // Second round
+        expectSendRecordOnce(true);
+        expectSendRecordOnce(false);
+
+        PowerMock.replayAll();
+
+        // Try to send 3, make first pass, second fail. Should save last two
+        workerTask.toSend = Arrays.asList(record1, record2, record3);
+        workerTask.sendRecords();
+        assertEquals(Arrays.asList(record2, record3), workerTask.toSend);
+
+        // Next they all succeed
+        workerTask.sendRecords();
+        assertNull(workerTask.toSend);
+
+        PowerMock.verifyAll();
+    }
+
+    @Test
+    public void testSendRecordsProducerSendFailsImmediately() {
+        if (!enableTopicCreation)
+            // should only test with topic creation enabled
+            return;
+
+        createWorkerTask();
+
+        SourceRecord record1 = new SourceRecord(PARTITION, OFFSET, TOPIC, 1, 
KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD);
+        SourceRecord record2 = new SourceRecord(PARTITION, OFFSET, TOPIC, 2, 
KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD);
+
+        expectPreliminaryCalls();
+        expectCall(producer::beginTransaction);
+        expectTopicCreation(TOPIC);
+
+        EasyMock.expect(producer.send(EasyMock.anyObject(), 
EasyMock.anyObject()))
+                .andThrow(new KafkaException("Producer closed while send in 
progress", new InvalidTopicException(TOPIC)));
+
+        PowerMock.replayAll();
+
+        workerTask.toSend = Arrays.asList(record1, record2);
+        assertThrows(ConnectException.class, workerTask::sendRecords);
+    }

Review Comment:
   No verify?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to