sijie closed pull request #2555: [ecosystem] Flink pulsar source connector
URL: https://github.com/apache/incubator-pulsar/pull/2555
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSource.java
 
b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSource.java
new file mode 100644
index 0000000000..f1b2595596
--- /dev/null
+++ 
b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSource.java
@@ -0,0 +1,204 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.flink.streaming.connectors.pulsar;
+
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.serialization.DeserializationSchema;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import 
org.apache.flink.streaming.api.functions.source.MessageAcknowledgingSourceBase;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.util.IOUtils;
+
+import org.apache.pulsar.client.api.Consumer;
+import org.apache.pulsar.client.api.Message;
+import org.apache.pulsar.client.api.MessageId;
+import org.apache.pulsar.client.api.PulsarClient;
+import org.apache.pulsar.client.api.PulsarClientException;
+import org.apache.pulsar.client.api.SubscriptionType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Pulsar source (consumer) which receives messages from a topic and 
acknowledges messages.
+ * When checkpointing is enabled, it guarantees at least once processing 
semantics.
+ *
+ * <p>When checkpointing is disabled, it auto acknowledges messages based on 
the number of messages it has
+ * received. In this mode messages may be dropped.
+ */
+class PulsarConsumerSource<T> extends MessageAcknowledgingSourceBase<T, 
MessageId> implements PulsarSourceBase<T> {
+
+    private static final Logger LOG = 
LoggerFactory.getLogger(PulsarConsumerSource.class);
+
+    private final int messageReceiveTimeoutMs = 100;
+    private final String serviceUrl;
+    private final String topic;
+    private final String subscriptionName;
+    private final DeserializationSchema<T> deserializer;
+
+    private PulsarClient client;
+    private Consumer<byte[]> consumer;
+
+    private boolean isCheckpointingEnabled;
+
+    private final long acknowledgementBatchSize;
+    private long batchCount;
+    private long totalMessageCount;
+
+    private transient volatile boolean isRunning;
+
+    PulsarConsumerSource(PulsarSourceBuilder<T> builder) {
+        super(MessageId.class);
+        this.serviceUrl = builder.serviceUrl;
+        this.topic = builder.topic;
+        this.deserializer = builder.deserializationSchema;
+        this.subscriptionName = builder.subscriptionName;
+        this.acknowledgementBatchSize = builder.acknowledgementBatchSize;
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+
+        final RuntimeContext context = getRuntimeContext();
+        if (context instanceof StreamingRuntimeContext) {
+            isCheckpointingEnabled = ((StreamingRuntimeContext) 
context).isCheckpointingEnabled();
+        }
+
+        client = createClient();
+        consumer = createConsumer(client);
+
+        isRunning = true;
+    }
+
+    @Override
+    protected void acknowledgeIDs(long checkpointId, Set<MessageId> 
messageIds) {
+        if (consumer == null) {
+            LOG.error("null consumer unable to acknowledge messages");
+            throw new RuntimeException("null pulsar consumer unable to 
acknowledge messages");
+        }
+
+        if (messageIds.isEmpty()) {
+            LOG.info("no message ids to acknowledge");
+            return;
+        }
+
+        Map<String, CompletableFuture<Void>> futures = new 
HashMap<>(messageIds.size());
+        for (MessageId id : messageIds) {
+            futures.put(id.toString(), consumer.acknowledgeAsync(id));
+        }
+
+        futures.forEach((k, f) -> {
+            try {
+                f.get();
+            } catch (Exception e) {
+                LOG.error("failed to acknowledge messageId " + k, e);
+                throw new RuntimeException("Messages could not be acknowledged 
during checkpoint creation.", e);
+            }
+        });
+    }
+
+    @Override
+    public void run(SourceContext<T> context) throws Exception {
+        Message message;
+        while (isRunning) {
+            message = consumer.receive(messageReceiveTimeoutMs, 
TimeUnit.MILLISECONDS);
+            if (message == null) {
+                LOG.info("unexpected null message");
+                continue;
+            }
+
+            if (isCheckpointingEnabled) {
+                emitCheckpointing(context, message);
+            } else {
+                emitAutoAcking(context, message);
+            }
+        }
+    }
+
+    private void emitCheckpointing(SourceContext<T> context, Message message) 
throws IOException {
+        synchronized (context.getCheckpointLock()) {
+            if (!addId(message.getMessageId())) {
+                if (LOG.isDebugEnabled()) {
+                    LOG.debug("messageId=" + message.getMessageId().toString() 
+ " already processed.");
+                }
+                return;
+            }
+            context.collect(deserialize(message));
+            totalMessageCount++;
+        }
+    }
+
+    private void emitAutoAcking(SourceContext<T> context, Message message) 
throws IOException {
+        context.collect(deserialize(message));
+        batchCount++;
+        totalMessageCount++;
+        if (batchCount >= acknowledgementBatchSize) {
+            LOG.info("processed {} messages acknowledging messageId {}", 
batchCount, message.getMessageId());
+            consumer.acknowledgeCumulative(message.getMessageId());
+            batchCount = 0;
+        }
+    }
+
+    private T deserialize(Message message) throws IOException {
+        return deserializer.deserialize(message.getData());
+    }
+
+    @Override
+    public void cancel() {
+        isRunning = false;
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+        IOUtils.cleanup(LOG, consumer);
+        IOUtils.cleanup(LOG, client);
+    }
+
+    @Override
+    public TypeInformation<T> getProducedType() {
+        return deserializer.getProducedType();
+    }
+
+    boolean isCheckpointingEnabled() {
+        return isCheckpointingEnabled;
+    }
+
+    PulsarClient createClient() throws PulsarClientException {
+        return PulsarClient.builder()
+            .serviceUrl(serviceUrl)
+            .build();
+    }
+
+    Consumer<byte[]> createConsumer(PulsarClient client) throws 
PulsarClientException {
+        return client.newConsumer()
+            .topic(topic)
+            .subscriptionName(subscriptionName)
+            .subscriptionType(SubscriptionType.Failover)
+            .subscribe();
+    }
+}
diff --git 
a/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBase.java
 
b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBase.java
new file mode 100644
index 0000000000..9d442152da
--- /dev/null
+++ 
b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBase.java
@@ -0,0 +1,31 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.flink.streaming.connectors.pulsar;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
+import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
+
+/**
+ * Base class for pulsar sources.
+ * @param <T>
+ */
+@PublicEvolving
+interface PulsarSourceBase<T> extends ParallelSourceFunction<T>, 
ResultTypeQueryable<T> {
+}
diff --git 
a/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBuilder.java
 
b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBuilder.java
new file mode 100644
index 0000000000..7f1ee9c377
--- /dev/null
+++ 
b/pulsar-flink/src/main/java/org/apache/flink/streaming/connectors/pulsar/PulsarSourceBuilder.java
@@ -0,0 +1,118 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.flink.streaming.connectors.pulsar;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.serialization.DeserializationSchema;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * A class for building a pulsar source.
+ */
+@PublicEvolving
+public class PulsarSourceBuilder<T> {
+
+    static final String SERVICE_URL = "pulsar://localhost:6650";
+    static final long ACKNOWLEDGEMENT_BATCH_SIZE = 100;
+    static final long MAX_ACKNOWLEDGEMENT_BATCH_SIZE = 1000;
+
+    final DeserializationSchema<T> deserializationSchema;
+    String serviceUrl = SERVICE_URL;
+    String topic;
+    String subscriptionName = "flink-sub";
+    long acknowledgementBatchSize = ACKNOWLEDGEMENT_BATCH_SIZE;
+
+    private PulsarSourceBuilder(DeserializationSchema<T> 
deserializationSchema) {
+        this.deserializationSchema = deserializationSchema;
+    }
+
+    /**
+     * Sets the pulsar service url to connect to. Defaults to 
pulsar://localhost:6650.
+     *
+     * @param serviceUrl service url to connect to
+     * @return this builder
+     */
+    public PulsarSourceBuilder<T> serviceUrl(String serviceUrl) {
+        Preconditions.checkNotNull(serviceUrl);
+        this.serviceUrl = serviceUrl;
+        return this;
+    }
+
+    /**
+     * Sets the topic to consumer from. This is required.
+     *
+     * <p>Topic names 
(https://pulsar.incubator.apache.org/docs/latest/getting-started/ConceptsAndArchitecture/#Topics)
+     * are in the following format:
+     * {persistent|non-persistent}://tenant/namespace/topic
+     *
+     * @param topic the topic to consumer from
+     * @return this builder
+     */
+    public PulsarSourceBuilder<T> topic(String topic) {
+        Preconditions.checkNotNull(topic);
+        this.topic = topic;
+        return this;
+    }
+
+    /**
+     * Sets the subscription name for the topic consumer. Defaults to 
flink-sub.
+     *
+     * @param subscriptionName the subscription name for the topic consumer
+     * @return this builder
+     */
+    public PulsarSourceBuilder<T> subscriptionName(String subscriptionName) {
+        Preconditions.checkNotNull(subscriptionName);
+        this.subscriptionName = subscriptionName;
+        return this;
+    }
+
+    /**
+     * Sets the number of messages to receive before acknowledging. This 
defaults to 100. This
+     * value is only used when checkpointing is disabled.
+     *
+     * @param size number of messages to receive before acknowledging
+     * @return this builder
+     */
+    public PulsarSourceBuilder<T> acknowledgementBatchSize(long size) {
+        if (size > 0 && size <= MAX_ACKNOWLEDGEMENT_BATCH_SIZE) {
+            acknowledgementBatchSize = size;
+        }
+        return this;
+    }
+
+    public SourceFunction<T> build() {
+        Preconditions.checkNotNull(serviceUrl, "a service url is required");
+        Preconditions.checkNotNull(topic, "a topic is required");
+        Preconditions.checkNotNull(subscriptionName, "a subscription name is 
required");
+
+        return new PulsarConsumerSource<>(this);
+    }
+
+    /**
+     * Creates a PulsarSourceBuilder.
+     *
+     * @param deserializationSchema the deserializer used to convert between 
Pulsar's byte messages and Flink's objects.
+     * @return a builder
+     */
+    public static <T> PulsarSourceBuilder<T> builder(DeserializationSchema<T> 
deserializationSchema) {
+        Preconditions.checkNotNull(deserializationSchema);
+        return new PulsarSourceBuilder<>(deserializationSchema);
+    }
+}
diff --git 
a/pulsar-flink/src/test/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSourceTests.java
 
b/pulsar-flink/src/test/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSourceTests.java
new file mode 100644
index 0000000000..97811da0de
--- /dev/null
+++ 
b/pulsar-flink/src/test/java/org/apache/flink/streaming/connectors/pulsar/PulsarConsumerSourceTests.java
@@ -0,0 +1,524 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.flink.streaming.connectors.pulsar;
+
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.serialization.SimpleStringSchema;
+import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.operators.StreamSource;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
+
+import org.apache.pulsar.client.api.Consumer;
+import org.apache.pulsar.client.api.ConsumerStats;
+import org.apache.pulsar.client.api.Message;
+import org.apache.pulsar.client.api.MessageId;
+import org.apache.pulsar.client.api.PulsarClient;
+import org.apache.pulsar.client.api.PulsarClientException;
+import org.apache.pulsar.client.api.Schema;
+import org.apache.pulsar.client.impl.MessageImpl;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.mockito.Matchers.any;
+
+/**
+ * Tests for the PulsarConsumerSource. The source supports two operation modes.
+ * 1) At-least-once (when checkpointed) with Pulsar message acknowledgements 
and the deduplication mechanism in
+ *    {@link 
org.apache.flink.streaming.api.functions.source.MessageAcknowledgingSourceBase}..
+ * 3) No strong delivery guarantees (without checkpointing) with Pulsar 
acknowledging messages after
+ *       after it receives x number of messages.
+ *
+ * <p>This tests assumes that the MessageIds are increasing monotonously. That 
doesn't have to be the
+ * case. The MessageId is used to uniquely identify messages.
+ */
+public class PulsarConsumerSourceTests {
+
+    private PulsarConsumerSource<String> source;
+
+    private TestConsumer consumer;
+
+    private TestSourceContext context;
+
+    private Thread sourceThread;
+
+    private Exception exception;
+
+    @Before
+    public void before() {
+        context = new TestSourceContext();
+
+        sourceThread = new Thread(() -> {
+            try {
+                source.run(context);
+            } catch (Exception e) {
+                exception = e;
+            }
+        });
+    }
+
+    @After
+    public void after() throws Exception {
+        if (source != null) {
+            source.cancel();
+        }
+        if (sourceThread != null) {
+            sourceThread.join();
+        }
+    }
+
+    @Test
+    public void testCheckpointing() throws Exception {
+        final int numMessages = 5;
+        consumer = new TestConsumer(numMessages);
+
+        source = createSource(consumer, 1, true);
+        source.open(new Configuration());
+
+        final StreamSource<String, PulsarConsumerSource<String>> src = new 
StreamSource<>(source);
+        final AbstractStreamOperatorTestHarness<String> testHarness =
+            new AbstractStreamOperatorTestHarness<>(src, 1, 1, 0);
+
+        testHarness.open();
+
+        sourceThread.start();
+
+        final Random random = new Random(System.currentTimeMillis());
+        for (int i = 0; i < 3; ++i) {
+
+            // wait and receive messages from the test consumer
+            receiveMessages();
+
+            final long snapshotId = random.nextLong();
+            OperatorSubtaskState data;
+            synchronized (context.getCheckpointLock()) {
+                data = testHarness.snapshot(snapshotId, 
System.currentTimeMillis());
+            }
+
+            final TestPulsarConsumerSource sourceCopy =
+                createSource(Mockito.mock(Consumer.class), 1, true);
+            final StreamSource<String, TestPulsarConsumerSource> srcCopy = new 
StreamSource<>(sourceCopy);
+            final AbstractStreamOperatorTestHarness<String> testHarnessCopy =
+                new AbstractStreamOperatorTestHarness<>(srcCopy, 1, 1, 0);
+
+            testHarnessCopy.setup();
+            testHarnessCopy.initializeState(data);
+            testHarnessCopy.open();
+
+            final ArrayDeque<Tuple2<Long, Set<MessageId>>> deque = 
sourceCopy.getRestoredState();
+            final Set<MessageId> messageIds = deque.getLast().f1;
+
+            final int start = consumer.currentMessage.get() - numMessages;
+            for (int mi = start; mi < (start + numMessages); ++mi) {
+                
Assert.assertTrue(messageIds.contains(consumer.messages.get(mi).getMessageId()));
+            }
+
+            // check if the messages are being acknowledged
+            synchronized (context.getCheckpointLock()) {
+                source.notifyCheckpointComplete(snapshotId);
+
+                Assert.assertEquals(consumer.acknowledgedIds.keySet(), 
messageIds);
+                // clear acknowledgements for the next snapshot comparison
+                consumer.acknowledgedIds.clear();
+            }
+
+            final int lastMessageIndex = consumer.currentMessage.get();
+            consumer.addMessages(createMessages(lastMessageIndex, 5));
+        }
+    }
+
+    @Test
+    public void testCheckpointingDuplicatedIds() throws Exception {
+        consumer = new TestConsumer(5);
+
+        source = createSource(consumer, 1, true);
+        source.open(new Configuration());
+
+        sourceThread.start();
+
+        receiveMessages();
+
+        Assert.assertEquals(5, context.elements.size());
+
+        // try to reprocess the messages we should not collect any more 
elements
+        consumer.reset();
+
+        receiveMessages();
+
+        Assert.assertEquals(5, context.elements.size());
+    }
+
+    @Test
+    public void testCheckpointingDisabledMessagesEqualBatchSize() throws 
Exception {
+
+        consumer = new TestConsumer(5);
+
+        source = createSource(consumer, 5, false);
+        source.open(new Configuration());
+
+        sourceThread.start();
+
+        receiveMessages();
+
+        Assert.assertEquals(1, consumer.acknowledgedIds.size());
+    }
+
+    @Test
+    public void testCheckpointingDisabledMoreMessagesThanBatchSize() throws 
Exception {
+
+        consumer = new TestConsumer(6);
+
+        source = createSource(consumer, 5, false);
+        source.open(new Configuration());
+
+        sourceThread.start();
+
+        receiveMessages();
+
+        Assert.assertEquals(1, consumer.acknowledgedIds.size());
+    }
+
+    @Test
+    public void testCheckpointingDisabledLessMessagesThanBatchSize() throws 
Exception {
+
+        consumer = new TestConsumer(4);
+
+        source = createSource(consumer, 5, false);
+        source.open(new Configuration());
+
+        sourceThread.start();
+
+        receiveMessages();
+
+        Assert.assertEquals(0, consumer.acknowledgedIds.size());
+    }
+
+    @Test
+    public void testCheckpointingDisabledMessages2XBatchSize() throws 
Exception {
+
+        consumer = new TestConsumer(10);
+
+        source = createSource(consumer, 5, false);
+        source.open(new Configuration());
+
+        sourceThread.start();
+
+        receiveMessages();
+
+        Assert.assertEquals(2, consumer.acknowledgedIds.size());
+    }
+
+    private void receiveMessages() throws InterruptedException {
+        while (consumer.currentMessage.get() < consumer.messages.size()) {
+            Thread.sleep(5);
+        }
+    }
+
+    private TestPulsarConsumerSource createSource(Consumer<byte[]> 
testConsumer,
+                                                  long batchSize, boolean 
isCheckpointingEnabled) throws Exception {
+        PulsarSourceBuilder<String> builder =
+            PulsarSourceBuilder.builder(new SimpleStringSchema())
+                .acknowledgementBatchSize(batchSize);
+        TestPulsarConsumerSource source = new 
TestPulsarConsumerSource(builder, testConsumer, isCheckpointingEnabled);
+
+        OperatorStateStore mockStore = Mockito.mock(OperatorStateStore.class);
+        FunctionInitializationContext mockContext = 
Mockito.mock(FunctionInitializationContext.class);
+        
Mockito.when(mockContext.getOperatorStateStore()).thenReturn(mockStore);
+        
Mockito.when(mockStore.getSerializableListState(any(String.class))).thenReturn(null);
+
+        source.initializeState(mockContext);
+
+        return source;
+    }
+
+    private static class TestPulsarConsumerSource extends 
PulsarConsumerSource<String> {
+
+        private ArrayDeque<Tuple2<Long, Set<MessageId>>> restoredState;
+
+        private Consumer<byte[]> testConsumer;
+        private boolean isCheckpointingEnabled;
+
+        TestPulsarConsumerSource(PulsarSourceBuilder<String> builder,
+                                 Consumer<byte[]> testConsumer, boolean 
isCheckpointingEnabled) {
+            super(builder);
+            this.testConsumer = testConsumer;
+            this.isCheckpointingEnabled = isCheckpointingEnabled;
+        }
+
+        @Override
+        protected boolean addId(MessageId messageId) {
+            Assert.assertEquals(true, isCheckpointingEnabled());
+            return super.addId(messageId);
+        }
+
+        @Override
+        public RuntimeContext getRuntimeContext() {
+            StreamingRuntimeContext context = 
Mockito.mock(StreamingRuntimeContext.class);
+            
Mockito.when(context.isCheckpointingEnabled()).thenReturn(isCheckpointingEnabled);
+            return context;
+        }
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) 
throws Exception {
+            super.initializeState(context);
+            this.restoredState = this.pendingCheckpoints;
+        }
+
+        public ArrayDeque<Tuple2<Long, Set<MessageId>>> getRestoredState() {
+            return this.restoredState;
+        }
+
+        @Override
+        PulsarClient createClient() {
+            return Mockito.mock(PulsarClient.class);
+        }
+
+        @Override
+        Consumer<byte[]> createConsumer(PulsarClient client) {
+            return testConsumer;
+        }
+    }
+
+    private static class TestSourceContext implements 
SourceFunction.SourceContext<String> {
+
+        private static final Object lock = new Object();
+
+        private final List<String> elements = new ArrayList<>();
+
+        @Override
+        public void collect(String element) {
+            elements.add(element);
+        }
+
+        @Override
+        public void collectWithTimestamp(String element, long timestamp) {
+
+        }
+
+        @Override
+        public void emitWatermark(Watermark mark) {
+
+        }
+
+        @Override
+        public void markAsTemporarilyIdle() {
+
+        }
+
+        @Override
+        public Object getCheckpointLock() {
+            return lock;
+        }
+
+        @Override
+        public void close() {
+
+        }
+    }
+
+    private static class TestConsumer implements Consumer<byte[]> {
+
+        private final List<Message> messages = new ArrayList<>();
+
+        private AtomicInteger currentMessage = new AtomicInteger();
+
+        private final Map<MessageId, MessageId> acknowledgedIds = new 
ConcurrentHashMap<>();
+
+        private TestConsumer(int numMessages) {
+            messages.addAll(createMessages(0, numMessages));
+        }
+
+        private void reset() {
+            currentMessage.set(0);
+        }
+
+        @Override
+        public String getTopic() {
+            return null;
+        }
+
+        @Override
+        public String getSubscription() {
+            return null;
+        }
+
+        @Override
+        public void unsubscribe() throws PulsarClientException {
+
+        }
+
+        @Override
+        public CompletableFuture<Void> unsubscribeAsync() {
+            return null;
+        }
+
+        @Override
+        public Message<byte[]> receive() throws PulsarClientException {
+            return null;
+        }
+
+        public synchronized void addMessages(List<Message> messages) {
+            this.messages.addAll(messages);
+        }
+
+        @Override
+        public CompletableFuture<Message<byte[]>> receiveAsync() {
+            return null;
+        }
+
+        @Override
+        public Message<byte[]> receive(int i, TimeUnit timeUnit) throws 
PulsarClientException {
+            synchronized (this) {
+                if (currentMessage.get() == messages.size()) {
+                    try {
+                        Thread.sleep(10);
+                    } catch (InterruptedException e) {
+                        System.out.println("no more messages sleeping index: " 
+ currentMessage.get());
+                    }
+                    return null;
+                }
+                return messages.get(currentMessage.getAndIncrement());
+            }
+        }
+
+        @Override
+        public void acknowledge(Message<?> message) throws 
PulsarClientException {
+
+        }
+
+        @Override
+        public void acknowledge(MessageId messageId) throws 
PulsarClientException {
+
+        }
+
+        @Override
+        public void acknowledgeCumulative(Message<?> message) throws 
PulsarClientException {
+
+        }
+
+        @Override
+        public void acknowledgeCumulative(MessageId messageId) throws 
PulsarClientException {
+            acknowledgedIds.put(messageId, messageId);
+        }
+
+        @Override
+        public CompletableFuture<Void> acknowledgeAsync(Message<?> message) {
+            return null;
+        }
+
+        @Override
+        public CompletableFuture<Void> acknowledgeAsync(MessageId messageId) {
+            acknowledgedIds.put(messageId, messageId);
+            return CompletableFuture.completedFuture(null);
+        }
+
+        @Override
+        public CompletableFuture<Void> acknowledgeCumulativeAsync(Message<?> 
message) {
+            return null;
+        }
+
+        @Override
+        public CompletableFuture<Void> acknowledgeCumulativeAsync(MessageId 
messageId) {
+            return null;
+        }
+
+        @Override
+        public ConsumerStats getStats() {
+            return null;
+        }
+
+        @Override
+        public void close() throws PulsarClientException {
+
+        }
+
+        @Override
+        public CompletableFuture<Void> closeAsync() {
+            return null;
+        }
+
+        @Override
+        public boolean hasReachedEndOfTopic() {
+            return false;
+        }
+
+        @Override
+        public void redeliverUnacknowledgedMessages() {
+
+        }
+
+        @Override
+        public void seek(MessageId messageId) throws PulsarClientException {
+
+        }
+
+        @Override
+        public CompletableFuture<Void> seekAsync(MessageId messageId) {
+            return null;
+        }
+
+        @Override
+        public boolean isConnected() {
+            return true;
+        }
+
+        @Override
+        public String getConsumerName() {
+            return "test-consumer-0";
+        }
+    }
+
+    private static List<Message> createMessages(int startIndex, int 
numMessages) {
+        final List<Message> messages = new ArrayList<>();
+        for (int i = startIndex; i < (startIndex + numMessages); ++i) {
+            String content = "message-" + i;
+            messages.add(createMessage(content, createMessageId(1, i + 1, 1)));
+        }
+        return messages;
+    }
+
+    private static Message<byte[]> createMessage(String content, String 
messageId) {
+        return new MessageImpl<>(messageId, Collections.emptyMap(), 
content.getBytes(), Schema.BYTES);
+    }
+
+    private static String createMessageId(long ledgerId, long entryId, long 
partitionIndex) {
+        return String.format("%d:%d:%d", ledgerId, entryId, partitionIndex);
+    }
+}


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to