Repository: flink
Updated Branches:
  refs/heads/master d6aed38b3 -> b0f0f3722


[FLINK-5701] [kafka] FlinkKafkaProducer should check asyncException on 
checkpoints

This closes #3278.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/646490c4
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/646490c4
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/646490c4

Branch: refs/heads/master
Commit: 646490c4e93eca315e4bf41704f149390f8639cc
Parents: d6aed38
Author: Tzu-Li (Gordon) Tai <[email protected]>
Authored: Tue Feb 7 00:37:13 2017 +0800
Committer: Tzu-Li (Gordon) Tai <[email protected]>
Committed: Thu Feb 23 01:16:57 2017 +0800

----------------------------------------------------------------------
 .../kafka/FlinkKafkaProducerBase.java           |  15 +-
 .../kafka/FlinkKafkaProducerBaseTest.java       | 391 ++++++++++++-------
 2 files changed, 272 insertions(+), 134 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/646490c4/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
----------------------------------------------------------------------
diff --git 
a/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
 
b/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
index 679b731..6a7b17f 100644
--- 
a/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
+++ 
b/flink-connectors/flink-connector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBase.java
@@ -17,6 +17,7 @@
 
 package org.apache.flink.streaming.connectors.kafka;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.api.common.state.OperatorStateStore;
 import org.apache.flink.api.java.ClosureCleaner;
@@ -348,6 +349,9 @@ public abstract class FlinkKafkaProducerBase<IN> extends 
RichSinkFunction<IN> im
 
        @Override
        public void snapshotState(FunctionSnapshotContext ctx) throws Exception 
{
+               // check for asynchronous errors and fail the checkpoint if 
necessary
+               checkErroneous();
+
                if (flushOnCheckpoint) {
                        // flushing is activated: We need to wait until 
pendingRecords is 0
                        flush();
@@ -355,7 +359,9 @@ public abstract class FlinkKafkaProducerBase<IN> extends 
RichSinkFunction<IN> im
                                if (pendingRecords != 0) {
                                        throw new 
IllegalStateException("Pending record count must be zero at this point: " + 
pendingRecords);
                                }
-                               // pending records count is 0. We can now 
confirm the checkpoint
+
+                               // if the flushed requests has errors, we 
should propagate it also and fail the checkpoint
+                               checkErroneous();
                        }
                }
        }
@@ -383,4 +389,11 @@ public abstract class FlinkKafkaProducerBase<IN> extends 
RichSinkFunction<IN> im
                props.setProperty(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, 
brokerList);
                return props;
        }
+
+       @VisibleForTesting
+       protected long numPendingRecords() {
+               synchronized (pendingRecordsLock) {
+                       return pendingRecords;
+               }
+       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/646490c4/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBaseTest.java
----------------------------------------------------------------------
diff --git 
a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBaseTest.java
 
b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBaseTest.java
index 2e06160..1f16d8e 100644
--- 
a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBaseTest.java
+++ 
b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducerBaseTest.java
@@ -18,38 +18,36 @@
 package org.apache.flink.streaming.connectors.kafka;
 
 import org.apache.flink.api.common.functions.RuntimeContext;
-import org.apache.flink.api.java.tuple.Tuple1;
+import org.apache.flink.core.testutils.CheckedThread;
+import org.apache.flink.core.testutils.MultiShotLatch;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.streaming.api.operators.StreamSink;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
 import org.apache.flink.configuration.Configuration;
-import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import 
org.apache.flink.streaming.connectors.kafka.partitioner.KafkaPartitioner;
 import 
org.apache.flink.streaming.connectors.kafka.testutils.FakeStandardProducerConfig;
 import org.apache.flink.streaming.util.serialization.KeyedSerializationSchema;
 import org.apache.kafka.clients.producer.Callback;
 import org.apache.kafka.clients.producer.KafkaProducer;
 import org.apache.kafka.clients.producer.ProducerConfig;
-import org.apache.kafka.clients.producer.RecordMetadata;
 import org.apache.kafka.clients.producer.ProducerRecord;
-import org.apache.kafka.common.Metric;
-import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.serialization.ByteArraySerializer;
 import org.junit.Assert;
 import org.junit.Test;
-import scala.concurrent.duration.Deadline;
-import scala.concurrent.duration.FiniteDuration;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
 
 import java.util.ArrayList;
 import java.util.List;
-import java.util.Map;
 import java.util.Properties;
-import java.util.concurrent.Future;
-import java.util.concurrent.atomic.AtomicBoolean;
 
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyString;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -88,201 +86,328 @@ public class FlinkKafkaProducerBaseTest {
        @Test
        public void testPartitionerOpenedWithDeterminatePartitionList() throws 
Exception {
                KafkaPartitioner mockPartitioner = mock(KafkaPartitioner.class);
+
                RuntimeContext mockRuntimeContext = mock(RuntimeContext.class);
                when(mockRuntimeContext.getIndexOfThisSubtask()).thenReturn(0);
                
when(mockRuntimeContext.getNumberOfParallelSubtasks()).thenReturn(1);
-
-               DummyFlinkKafkaProducer producer = new DummyFlinkKafkaProducer(
+               
+               // out-of-order list of 4 partitions
+               List<PartitionInfo> mockPartitionsList = new ArrayList<>(4);
+               mockPartitionsList.add(new 
PartitionInfo(DummyFlinkKafkaProducer.DUMMY_TOPIC, 3, null, null, null));
+               mockPartitionsList.add(new 
PartitionInfo(DummyFlinkKafkaProducer.DUMMY_TOPIC, 1, null, null, null));
+               mockPartitionsList.add(new 
PartitionInfo(DummyFlinkKafkaProducer.DUMMY_TOPIC, 0, null, null, null));
+               mockPartitionsList.add(new 
PartitionInfo(DummyFlinkKafkaProducer.DUMMY_TOPIC, 2, null, null, null));
+
+               final DummyFlinkKafkaProducer producer = new 
DummyFlinkKafkaProducer(
                        FakeStandardProducerConfig.get(), mockPartitioner);
                producer.setRuntimeContext(mockRuntimeContext);
 
+               final KafkaProducer mockProducer = 
producer.getMockKafkaProducer();
+               
when(mockProducer.partitionsFor(anyString())).thenReturn(mockPartitionsList);
+               when(mockProducer.metrics()).thenReturn(null);
+
                producer.open(new Configuration());
 
-               // the internal mock KafkaProducer will return an out-of-order 
list of 4 partitions,
-               // which should be sorted before provided to the custom 
partitioner's open() method
+               // the out-of-order partitions list should be sorted before 
provided to the custom partitioner's open() method
                int[] correctPartitionList = {0, 1, 2, 3};
                verify(mockPartitioner).open(0, 1, correctPartitionList);
        }
 
        /**
-        * Test ensuring that the producer is not dropping buffered records.;
-        * we set a timeout because the test will not finish if the logic is 
broken
+        * Test ensuring that if an invoke call happens right after an async 
exception is caught, it should be rethrown
         */
-       @Test(timeout=5000)
-       public void testAtLeastOnceProducer() throws Throwable {
-               runAtLeastOnceTest(true);
+       @Test
+       public void testAsyncErrorRethrownOnInvoke() throws Throwable {
+               final DummyFlinkKafkaProducer<String> producer = new 
DummyFlinkKafkaProducer<>(
+                       FakeStandardProducerConfig.get(), null);
+
+               OneInputStreamOperatorTestHarness<String, Object> testHarness =
+                       new OneInputStreamOperatorTestHarness<>(new 
StreamSink<>(producer));
+
+               testHarness.open();
+
+               testHarness.processElement(new StreamRecord<>("msg-1"));
+
+               // let the message request return an async exception
+               producer.getPendingCallbacks().get(0).onCompletion(null, new 
Exception("artificial async exception"));
+
+               try {
+                       testHarness.processElement(new StreamRecord<>("msg-2"));
+               } catch (Exception e) {
+                       // the next invoke should rethrow the async exception
+                       
Assert.assertTrue(e.getCause().getMessage().contains("artificial async 
exception"));
+
+                       // test succeeded
+                       return;
+               }
+
+               Assert.fail();
        }
 
        /**
-        * Ensures that the at least once producing test fails if the flushing 
is disabled
+        * Test ensuring that if a snapshot call happens right after an async 
exception is caught, it should be rethrown
         */
-       @Test(expected = AssertionError.class, timeout=5000)
-       public void testAtLeastOnceProducerFailsIfFlushingDisabled() throws 
Throwable {
-               runAtLeastOnceTest(false);
-       }
-
-       private void runAtLeastOnceTest(boolean flushOnCheckpoint) throws 
Throwable {
-               final AtomicBoolean snapshottingFinished = new 
AtomicBoolean(false);
+       @Test
+       public void testAsyncErrorRethrownOnCheckpoint() throws Throwable {
                final DummyFlinkKafkaProducer<String> producer = new 
DummyFlinkKafkaProducer<>(
-                       FakeStandardProducerConfig.get(), null, 
snapshottingFinished);
-               producer.setFlushOnCheckpoint(flushOnCheckpoint);
+                       FakeStandardProducerConfig.get(), null);
 
                OneInputStreamOperatorTestHarness<String, Object> testHarness =
-                               new OneInputStreamOperatorTestHarness<>(new 
StreamSink(producer));
+                       new OneInputStreamOperatorTestHarness<>(new 
StreamSink<>(producer));
 
                testHarness.open();
 
-               for (int i = 0; i < 100; i++) {
-                       testHarness.processElement(new StreamRecord<>("msg-" + 
i));
+               testHarness.processElement(new StreamRecord<>("msg-1"));
+
+               // let the message request return an async exception
+               producer.getPendingCallbacks().get(0).onCompletion(null, new 
Exception("artificial async exception"));
+
+               try {
+                       testHarness.snapshot(123L, 123L);
+               } catch (Exception e) {
+                       // the next invoke should rethrow the async exception
+                       
Assert.assertTrue(e.getCause().getMessage().contains("artificial async 
exception"));
+
+                       // test succeeded
+                       return;
                }
 
-               // start a thread confirming all pending records
-               final Tuple1<Throwable> runnableError = new Tuple1<>(null);
-               final Thread threadA = Thread.currentThread();
+               Assert.fail();
+       }
+
+       /**
+        * Test ensuring that if an async exception is caught for one of the 
flushed requests on checkpoint,
+        * it should be rethrown; we set a timeout because the test will not 
finish if the logic is broken.
+        *
+        * Note that this test does not test the snapshot method is blocked 
correctly when there are pending recorrds.
+        * The test for that is covered in testAtLeastOnceProducer.
+        */
+       @SuppressWarnings("unchecked")
+       @Test(timeout=5000)
+       public void testAsyncErrorRethrownOnCheckpointAfterFlush() throws 
Throwable {
+               final DummyFlinkKafkaProducer<String> producer = new 
DummyFlinkKafkaProducer<>(
+                       FakeStandardProducerConfig.get(), null);
+               producer.setFlushOnCheckpoint(true);
+
+               final KafkaProducer<?, ?> mockProducer = 
producer.getMockKafkaProducer();
+
+               final OneInputStreamOperatorTestHarness<String, Object> 
testHarness =
+                       new OneInputStreamOperatorTestHarness<>(new 
StreamSink<>(producer));
+
+               testHarness.open();
+
+               testHarness.processElement(new StreamRecord<>("msg-1"));
+               testHarness.processElement(new StreamRecord<>("msg-2"));
+               testHarness.processElement(new StreamRecord<>("msg-3"));
+
+               verify(mockProducer, times(3)).send(any(ProducerRecord.class), 
any(Callback.class));
+
+               // only let the first callback succeed for now
+               producer.getPendingCallbacks().get(0).onCompletion(null, null);
 
-               Runnable confirmer = new Runnable() {
+               CheckedThread snapshotThread = new CheckedThread() {
                        @Override
-                       public void run() {
-                               try {
-                                       MockProducer mp = 
producer.getProducerInstance();
-                                       List<Callback> pending = 
mp.getPending();
-
-                                       // we need to find out if the 
snapshot() method blocks forever
-                                       // this is not possible. If snapshot() 
is running, it will
-                                       // start removing elements from the 
pending list.
-                                       synchronized (threadA) {
-                                               threadA.wait(500L);
-                                       }
-                                       // we now check that no records have 
been confirmed yet
-                                       Assert.assertEquals(100, 
pending.size());
-                                       Assert.assertFalse("Snapshot method 
returned before all records were confirmed",
-                                               snapshottingFinished.get());
-
-                                       // now confirm all checkpoints
-                                       for (Callback c: pending) {
-                                               c.onCompletion(null, null);
-                                       }
-                                       pending.clear();
-                               } catch(Throwable t) {
-                                       runnableError.f0 = t;
-                               }
+                       public void go() throws Exception {
+                               // this should block at first, since there are 
still two pending records that needs to be flushed
+                               testHarness.snapshot(123L, 123L);
                        }
                };
-               Thread threadB = new Thread(confirmer);
-               threadB.start();
+               snapshotThread.start();
 
-               // this should block:
-               testHarness.snapshot(0, 0);
+               // let the 2nd message fail with an async exception
+               producer.getPendingCallbacks().get(1).onCompletion(null, new 
Exception("artificial async failure for 2nd message"));
+               producer.getPendingCallbacks().get(2).onCompletion(null, null);
 
-               synchronized (threadA) {
-                       threadA.notifyAll(); // just in case, to let the test 
fail faster
-               }
-               Assert.assertEquals(0, 
producer.getProducerInstance().getPending().size());
-               Deadline deadline = FiniteDuration.apply(5, "s").fromNow();
-               while (deadline.hasTimeLeft() && threadB.isAlive()) {
-                       threadB.join(500);
-               }
-               Assert.assertFalse("Thread A is expected to be finished at this 
point. If not, the test is prone to fail", threadB.isAlive());
-               if (runnableError.f0 != null) {
-                       throw runnableError.f0;
+               try {
+                       snapshotThread.sync();
+               } catch (Exception e) {
+                       // the snapshot should have failed with the async 
exception
+                       
Assert.assertTrue(e.getCause().getMessage().contains("artificial async failure 
for 2nd message"));
+
+                       // test succeeded
+                       return;
                }
 
+               Assert.fail();
+       }
+
+       /**
+        * Test ensuring that the producer is not dropping buffered records;
+        * we set a timeout because the test will not finish if the logic is 
broken
+        */
+       @SuppressWarnings("unchecked")
+       @Test(timeout=10000)
+       public void testAtLeastOnceProducer() throws Throwable {
+               final DummyFlinkKafkaProducer<String> producer = new 
DummyFlinkKafkaProducer<>(
+                       FakeStandardProducerConfig.get(), null);
+               producer.setFlushOnCheckpoint(true);
+
+               final KafkaProducer<?, ?> mockProducer = 
producer.getMockKafkaProducer();
+
+               final OneInputStreamOperatorTestHarness<String, Object> 
testHarness =
+                       new OneInputStreamOperatorTestHarness<>(new 
StreamSink<>(producer));
+
+               testHarness.open();
+
+               testHarness.processElement(new StreamRecord<>("msg-1"));
+               testHarness.processElement(new StreamRecord<>("msg-2"));
+               testHarness.processElement(new StreamRecord<>("msg-3"));
+
+               verify(mockProducer, times(3)).send(any(ProducerRecord.class), 
any(Callback.class));
+               Assert.assertEquals(3, producer.getPendingSize());
+
+               // start a thread to perform checkpointing
+               CheckedThread snapshotThread = new CheckedThread() {
+                       @Override
+                       public void go() throws Exception {
+                               // this should block until all records are 
flushed;
+                               // if the snapshot implementation returns 
before pending records are flushed,
+                               testHarness.snapshot(123L, 123L);
+                       }
+               };
+               snapshotThread.start();
+
+               // before proceeding, make sure that flushing has started and 
that the snapshot is still blocked;
+               // this would block forever if the snapshot didn't perform a 
flush
+               producer.waitUntilFlushStarted();
+               Assert.assertTrue("Snapshot returned before all records were 
flushed", snapshotThread.isAlive());
+
+               // now, complete the callbacks
+               producer.getPendingCallbacks().get(0).onCompletion(null, null);
+               Assert.assertTrue("Snapshot returned before all records were 
flushed", snapshotThread.isAlive());
+               Assert.assertEquals(2, producer.getPendingSize());
+
+               producer.getPendingCallbacks().get(1).onCompletion(null, null);
+               Assert.assertTrue("Snapshot returned before all records were 
flushed", snapshotThread.isAlive());
+               Assert.assertEquals(1, producer.getPendingSize());
+
+               producer.getPendingCallbacks().get(2).onCompletion(null, null);
+               Assert.assertEquals(0, producer.getPendingSize());
+
+               // this would fail with an exception if flushing wasn't 
completed before the snapshot method returned
+               snapshotThread.sync();
+
                testHarness.close();
        }
 
+       /**
+        * This test is meant to assure that testAtLeastOnceProducer is valid 
by testing that if flushing is disabled,
+        * the snapshot method does indeed finishes without waiting for pending 
records;
+        * we set a timeout because the test will not finish if the logic is 
broken
+        */
+       @SuppressWarnings("unchecked")
+       @Test(timeout=5000)
+       public void testDoesNotWaitForPendingRecordsIfFlushingDisabled() throws 
Throwable {
+               final DummyFlinkKafkaProducer<String> producer = new 
DummyFlinkKafkaProducer<>(
+                       FakeStandardProducerConfig.get(), null);
+               producer.setFlushOnCheckpoint(false);
+
+               final KafkaProducer<?, ?> mockProducer = 
producer.getMockKafkaProducer();
+
+               final OneInputStreamOperatorTestHarness<String, Object> 
testHarness =
+                       new OneInputStreamOperatorTestHarness<>(new 
StreamSink<>(producer));
+
+               testHarness.open();
+
+               testHarness.processElement(new StreamRecord<>("msg"));
+
+               // make sure that all callbacks have not been completed
+               verify(mockProducer, times(1)).send(any(ProducerRecord.class), 
any(Callback.class));
+
+               // should return even if there are pending records
+               testHarness.snapshot(123L, 123L);
+
+               testHarness.close();
+       }
 
        // 
------------------------------------------------------------------------
 
        private static class DummyFlinkKafkaProducer<T> extends 
FlinkKafkaProducerBase<T> {
                private static final long serialVersionUID = 1L;
+               
+               private final static String DUMMY_TOPIC = "dummy-topic";
 
-               private transient MockProducer prod;
-               private AtomicBoolean snapshottingFinished;
+               private transient KafkaProducer<?, ?> mockProducer;
+               private transient List<Callback> pendingCallbacks;
+               private transient MultiShotLatch flushLatch;
+               private boolean isFlushed;
 
                @SuppressWarnings("unchecked")
-               public DummyFlinkKafkaProducer(Properties producerConfig, 
KafkaPartitioner partitioner, AtomicBoolean snapshottingFinished) {
-                       super("dummy-topic", (KeyedSerializationSchema< T >) 
mock(KeyedSerializationSchema.class), producerConfig, partitioner);
-                       this.snapshottingFinished = snapshottingFinished;
-               }
+               DummyFlinkKafkaProducer(Properties producerConfig, 
KafkaPartitioner partitioner) {
 
-               // constructor variant for test irrelated to snapshotting
-               @SuppressWarnings("unchecked")
-               public DummyFlinkKafkaProducer(Properties producerConfig, 
KafkaPartitioner partitioner) {
-                       super("dummy-topic", (KeyedSerializationSchema< T >) 
mock(KeyedSerializationSchema.class), producerConfig, partitioner);
-                       this.snapshottingFinished = new AtomicBoolean(true);
-               }
+                       super(DUMMY_TOPIC, (KeyedSerializationSchema<T>) 
mock(KeyedSerializationSchema.class), producerConfig, partitioner);
 
-               @Override
-               protected <K, V> KafkaProducer<K, V> 
getKafkaProducer(Properties props) {
-                       this.prod = new MockProducer();
-                       return this.prod;
-               }
+                       this.mockProducer = mock(KafkaProducer.class);
+                       when(mockProducer.send(any(ProducerRecord.class), 
any(Callback.class))).thenAnswer(new Answer<Object>() {
+                               @Override
+                               public Object answer(InvocationOnMock 
invocationOnMock) throws Throwable {
+                                       
pendingCallbacks.add(invocationOnMock.getArgumentAt(1, Callback.class));
+                                       return null;
+                               }
+                       });
 
-               @Override
-               public void snapshotState(FunctionSnapshotContext ctx) throws 
Exception {
-                       // call the actual snapshot state
-                       super.snapshotState(ctx);
-                       // notify test that snapshotting has been done
-                       snapshottingFinished.set(true);
+                       this.pendingCallbacks = new ArrayList<>();
+                       this.flushLatch = new MultiShotLatch();
                }
 
-               @Override
-               protected void flush() {
-                       this.prod.flush();
+               long getPendingSize() {
+                       if (flushOnCheckpoint) {
+                               return numPendingRecords();
+                       } else {
+                               // when flushing is disabled, the 
implementation does not
+                               // maintain the current number of pending 
records to reduce
+                               // the extra locking overhead required to do so
+                               throw new 
UnsupportedOperationException("getPendingSize not supported when flushing is 
disabled");
+                       }
                }
 
-               public MockProducer getProducerInstance() {
-                       return this.prod;
+               List<Callback> getPendingCallbacks() {
+                       return pendingCallbacks;
                }
-       }
-
-       private static class MockProducer<K, V> extends KafkaProducer<K, V> {
-               List<Callback> pendingCallbacks = new ArrayList<>();
 
-               public MockProducer() {
-                       super(FakeStandardProducerConfig.get());
+               KafkaProducer<?, ?> getMockKafkaProducer() {
+                       return mockProducer;
                }
 
                @Override
-               public Future<RecordMetadata> send(ProducerRecord<K, V> record) 
{
-                       throw new UnsupportedOperationException("Unexpected");
-               }
+               public void snapshotState(FunctionSnapshotContext ctx) throws 
Exception {
+                       isFlushed = false;
 
-               @Override
-               public Future<RecordMetadata> send(ProducerRecord<K, V> record, 
Callback callback) {
-                       pendingCallbacks.add(callback);
-                       return null;
+                       super.snapshotState(ctx);
+
+                       // if the snapshot implementation doesn't wait until 
all pending records are flushed, we should fail the test
+                       if (flushOnCheckpoint && !isFlushed) {
+                               throw new RuntimeException("Flushing is 
enabled; snapshots should be blocked until all pending records are flushed");
+                       }
                }
 
-               @Override
-               public List<PartitionInfo> partitionsFor(String topic) {
-                       List<PartitionInfo> list = new ArrayList<>();
-                       // deliberately return an out-of-order partition list
-                       list.add(new PartitionInfo(topic, 3, null, null, null));
-                       list.add(new PartitionInfo(topic, 1, null, null, null));
-                       list.add(new PartitionInfo(topic, 0, null, null, null));
-                       list.add(new PartitionInfo(topic, 2, null, null, null));
-                       return list;
+               public void waitUntilFlushStarted() throws Exception {
+                       flushLatch.await();
                }
 
+               @SuppressWarnings("unchecked")
                @Override
-               public Map<MetricName, ? extends Metric> metrics() {
-                       return null;
+               protected <K, V> KafkaProducer<K, V> 
getKafkaProducer(Properties props) {
+                       return (KafkaProducer<K, V>) mockProducer;
                }
 
+               @Override
+               protected void flush() {
+                       flushLatch.trigger();
 
-               public List<Callback> getPending() {
-                       return this.pendingCallbacks;
-               }
-
-               public void flush() {
-                       while (pendingCallbacks.size() > 0) {
+                       // simply wait until the producer's pending records 
become zero.
+                       // This relies on the fact that the producer's Callback 
implementation
+                       // and pending records tracking logic is implemented 
correctly, otherwise
+                       // we will loop forever.
+                       while (numPendingRecords() > 0) {
                                try {
                                        Thread.sleep(10);
                                } catch (InterruptedException e) {
                                        throw new RuntimeException("Unable to 
flush producer, task was interrupted");
                                }
                        }
+
+                       isFlushed = true;
                }
        }
 }

Reply via email to