This is an automated email from the ASF dual-hosted git repository.

arvid pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 900af3e  [FLINK-23183][connectors/rabbitmq] Fix ACKs for redelivered 
messages in RMQSource and add integration tests
900af3e is described below

commit 900af3e43eb34f8899a4509a8f6c8e43d72f1cc8
Author: Michal Ciesielczyk <[email protected]>
AuthorDate: Mon Jul 5 20:01:39 2021 +0200

    [FLINK-23183][connectors/rabbitmq] Fix ACKs for redelivered messages in 
RMQSource and add integration tests
    
    Changes:
    - channel.basicReject in RMQSource is called in case of already processed 
(and checkpointed) but redelivered messages (e.g. after the job failover)
    - add integration test that verifies that the source actually consumes the 
messages
    - add integration test reproducing the message redelivery issue in case of 
ack failure
---
 .../streaming/connectors/rabbitmq/RMQSource.java   |  10 ++-
 .../connectors/rabbitmq/RMQSourceITCase.java       | 100 ++++++++++++++++++++-
 .../connectors/rabbitmq/RMQSourceTest.java         |  52 ++++++++++-
 3 files changed, 158 insertions(+), 4 deletions(-)

diff --git 
a/flink-connectors/flink-connector-rabbitmq/src/main/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSource.java
 
b/flink-connectors/flink-connector-rabbitmq/src/main/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSource.java
index b3444f1..017bc7c 100644
--- 
a/flink-connectors/flink-connector-rabbitmq/src/main/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSource.java
+++ 
b/flink-connectors/flink-connector-rabbitmq/src/main/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSource.java
@@ -410,10 +410,16 @@ public class RMQSource<OUT> extends 
MultipleIdsMessageAcknowledgingSourceBase<OU
                 if (usesCorrelationId) {
                     Preconditions.checkNotNull(
                             correlationId,
-                            "RabbitMQ source was instantiated "
-                                    + "with usesCorrelationId set to true yet 
we couldn't extract the correlation id from it !");
+                            "RabbitMQ source was instantiated with 
usesCorrelationId set to "
+                                    + "true yet we couldn't extract the 
correlation id from it!");
                     if (!addId(correlationId)) {
                         // we have already processed this message
+                        try {
+                            channel.basicReject(deliveryTag, false);
+                        } catch (IOException e) {
+                            throw new RuntimeException(
+                                    "Message could not be acknowledged with 
basicReject.", e);
+                        }
                         return false;
                     }
                 }
diff --git 
a/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceITCase.java
 
b/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceITCase.java
index 0e687f9..7f462f5 100644
--- 
a/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceITCase.java
+++ 
b/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceITCase.java
@@ -21,6 +21,7 @@ package org.apache.flink.streaming.connectors.rabbitmq;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.JobStatus;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
 import org.apache.flink.api.common.serialization.SimpleStringSchema;
 import org.apache.flink.api.common.time.Deadline;
 import org.apache.flink.client.program.rest.RestClusterClient;
@@ -31,10 +32,12 @@ import 
org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
 import org.apache.flink.streaming.api.datastream.DataStreamSource;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
 import 
org.apache.flink.streaming.connectors.rabbitmq.common.RMQConnectionConfig;
 import org.apache.flink.test.util.MiniClusterWithClientResource;
 import org.apache.flink.util.DockerImageVersions;
 
+import com.rabbitmq.client.AMQP;
 import com.rabbitmq.client.Channel;
 import com.rabbitmq.client.Connection;
 import com.rabbitmq.client.ConnectionFactory;
@@ -52,7 +55,11 @@ import org.testcontainers.utility.DockerImageName;
 
 import java.io.IOException;
 import java.time.Duration;
+import java.util.List;
 import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 /** A class containing RabbitMQ source tests against a real RabbiMQ cluster. */
 public class RMQSourceITCase {
@@ -64,6 +71,7 @@ public class RMQSourceITCase {
     private static final int RABBITMQ_PORT = 5672;
     private static final String QUEUE_NAME = "test-queue";
     private static final JobID JOB_ID = new JobID();
+    private static final SimpleStringSchema SCHEMA = new SimpleStringSchema();
 
     private RestClusterClient<?> clusterClient;
     private RMQConnectionConfig config;
@@ -90,24 +98,27 @@ public class RMQSourceITCase {
         final Connection connection = getRMQConnection();
         final Channel channel = connection.createChannel();
         channel.queueDeclare(QUEUE_NAME, true, false, false, null);
+        channel.queuePurge(QUEUE_NAME);
         channel.txSelect();
         clusterClient = flinkCluster.getRestClusterClient();
         config =
                 new RMQConnectionConfig.Builder()
                         .setHost(RMQ_CONTAINER.getHost())
                         .setDeliveryTimeout(500)
+                        .setPrefetchCount(5)
                         .setVirtualHost("/")
                         .setUserName(RMQ_CONTAINER.getAdminUsername())
                         .setPassword(RMQ_CONTAINER.getAdminPassword())
                         .setPort(RMQ_CONTAINER.getMappedPort(RABBITMQ_PORT))
                         .build();
+        CountingSink.reset();
     }
 
     @Test
     public void testStopWithSavepoint() throws Exception {
         final StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
         final DataStreamSource<String> source =
-                env.addSource(new RMQSource<>(config, QUEUE_NAME, new 
SimpleStringSchema()));
+                env.addSource(new RMQSource<>(config, QUEUE_NAME, SCHEMA));
         source.addSink(new DiscardingSink<>());
         env.enableCheckpointing(500);
         final JobGraph jobGraph = env.getStreamGraph().getJobGraph();
@@ -128,6 +139,59 @@ public class RMQSourceITCase {
         clusterClient.stopWithSavepoint(JOB_ID, false, 
tmp.newFolder().getAbsolutePath()).get();
     }
 
+    @Test
+    public void testMessageDelivery() throws Exception {
+        final StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        List<String> msgs =
+                IntStream.range(0, 
10).mapToObj(String::valueOf).collect(Collectors.toList());
+        publishToRMQ(msgs);
+
+        final DataStreamSource<String> source =
+                env.addSource(new RMQSource<>(config, QUEUE_NAME, SCHEMA));
+        source.addSink(CountingSink.getInstance());
+        final JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+        JobID jobId = clusterClient.submitJob(jobGraph).get();
+        CommonTestUtils.waitUntilCondition(
+                () -> CountingSink.getCount() == msgs.size(),
+                Deadline.fromNow(Duration.ofSeconds(30)),
+                5L);
+        clusterClient.cancel(jobId);
+    }
+
+    @Test
+    public void testAckFailure() throws Exception {
+        final StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        
env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 
500));
+        env.enableCheckpointing(500);
+        List<String> msgs =
+                IntStream.range(0, 
10).mapToObj(String::valueOf).collect(Collectors.toList());
+        publishToRMQ(msgs);
+
+        RMQSource<String> rmqSource =
+                new RMQSource<String>(config, QUEUE_NAME, true, SCHEMA) {
+                    @Override
+                    protected void acknowledgeSessionIDs(List<Long> 
sessionIds) {
+                        try {
+                            if (!sessionIds.isEmpty()) {
+                                throw new RuntimeException("Test acknowledge 
failure");
+                            }
+                            channel.txCommit();
+                        } catch (IOException e) {
+                            throw new RuntimeException("Error while committing 
transaction", e);
+                        }
+                    }
+                };
+        final DataStreamSource<String> source = env.addSource(rmqSource);
+        source.addSink(CountingSink.getInstance());
+        final JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+        JobID jobId = clusterClient.submitJob(jobGraph).get();
+        CommonTestUtils.waitUntilCondition(
+                () -> CountingSink.getCount() == msgs.size(),
+                Deadline.fromNow(Duration.ofSeconds(60)),
+                5L);
+        clusterClient.cancel(jobId);
+    }
+
     private static Connection getRMQConnection() throws IOException, 
TimeoutException {
         ConnectionFactory factory = new ConnectionFactory();
         factory.setUsername(RMQ_CONTAINER.getAdminUsername());
@@ -138,4 +202,38 @@ public class RMQSourceITCase {
         factory.setPort(RMQ_CONTAINER.getAmqpPort());
         return factory.newConnection();
     }
+
+    private static void publishToRMQ(Iterable<String> messages)
+            throws IOException, TimeoutException {
+        AMQP.BasicProperties.Builder propertiesBuilder = new 
AMQP.BasicProperties.Builder();
+        try (Connection rmqConnection = getRMQConnection();
+                Channel channel = rmqConnection.createChannel()) {
+            for (String msg : messages) {
+                AMQP.BasicProperties properties = 
propertiesBuilder.correlationId(msg).build();
+                channel.basicPublish("", QUEUE_NAME, properties, 
SCHEMA.serialize(msg));
+            }
+        }
+    }
+
+    private static class CountingSink implements SinkFunction<String> {
+
+        private static final AtomicInteger count = new AtomicInteger();
+
+        public static CountingSink getInstance() {
+            return new CountingSink();
+        }
+
+        public static void reset() {
+            count.set(0);
+        }
+
+        public static int getCount() {
+            return count.get();
+        }
+
+        @Override
+        public void invoke(String value, SinkFunction.Context context) {
+            count.incrementAndGet();
+        }
+    }
 }
diff --git 
a/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java
 
b/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java
index c30bc93..dab807c 100644
--- 
a/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java
+++ 
b/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java
@@ -336,7 +336,8 @@ public class RMQSourceTest {
             Thread.sleep(5);
         }
 
-        // see addId in RMQTestSource.addId for the assert
+        // verify if RMQTestSource#addId was never called
+        assertEquals(0, ((RMQTestSource) source).addIdCalls);
     }
 
     /** Tests error reporting in case of invalid correlation ids. */
@@ -352,6 +353,53 @@ public class RMQSourceTest {
         assertTrue(exception instanceof NullPointerException);
     }
 
+    /** Tests whether redelivered messages are acknowledged properly. */
+    @Test
+    public void testRedeliveredSessionIDsAck() throws Exception {
+        source.autoAck = false;
+
+        StreamSource<String, RMQSource<String>> src = new 
StreamSource<>(source);
+        AbstractStreamOperatorTestHarness<String> testHarness =
+                new AbstractStreamOperatorTestHarness<>(src, 1, 1, 0);
+        testHarness.open();
+        sourceThread.start();
+
+        while (DummySourceContext.numElementsCollected < 10) {
+            // wait until messages have been processed
+            Thread.sleep(5);
+        }
+
+        // mock message redelivery by resetting the message ID
+        long numMsgRedelivered;
+        synchronized (DummySourceContext.lock) {
+            numMsgRedelivered = DummySourceContext.numElementsCollected;
+            messageId = 0;
+        }
+        while (DummySourceContext.numElementsCollected < numMsgRedelivered + 
10) {
+            // wait until some messages will be redelivered
+            Thread.sleep(5);
+        }
+
+        // ack the messages by snapshotting the state
+        final Random random = new Random(System.currentTimeMillis());
+        long lastMessageId;
+        long snapshotId = random.nextLong();
+        synchronized (DummySourceContext.lock) {
+            testHarness.snapshot(snapshotId, System.currentTimeMillis());
+            source.notifyCheckpointComplete(snapshotId);
+            lastMessageId = messageId;
+        }
+
+        // check if all the messages are being acknowledged
+        long totalNumberOfAcks = numMsgRedelivered + lastMessageId;
+        assertEquals(lastMessageId, DummySourceContext.numElementsCollected);
+        assertEquals(totalNumberOfAcks, ((RMQTestSource) source).addIdCalls);
+        Mockito.verify(source.channel, Mockito.times((int) lastMessageId))
+                .basicAck(Mockito.anyLong(), Mockito.eq(false));
+        Mockito.verify(source.channel, Mockito.times((int) numMsgRedelivered))
+                .basicReject(Mockito.anyLong(), Mockito.eq(false));
+    }
+
     /** Tests whether constructor params are passed correctly. */
     @Test
     public void testConstructorParams() throws Exception {
@@ -659,6 +707,7 @@ public class RMQSourceTest {
         private Delivery mockedDelivery;
         public Envelope mockedAMQPEnvelope;
         public AMQP.BasicProperties mockedAMQPProperties;
+        public int addIdCalls = 0;
 
         public RMQTestSource() {
             super();
@@ -758,6 +807,7 @@ public class RMQSourceTest {
 
         @Override
         protected boolean addId(String uid) {
+            addIdCalls++;
             assertEquals(false, autoAck);
             return super.addId(uid);
         }

Reply via email to