This is an automated email from the ASF dual-hosted git repository.

boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 0fbb21f  [BEAM-11325] Support KafkaIO dynamic read
     new 47d3326  Merge pull request #13750 from [BEAM-11325] Kafka Dynamic Read
0fbb21f is described below

commit 0fbb21fd13b0f1ac1a28e4c839b8ebbe9420e9d3
Author: Boyuan Zhang <[email protected]>
AuthorDate: Fri Jan 8 10:31:34 2021 -0800

    [BEAM-11325] Support KafkaIO dynamic read
---
 .../java/org/apache/beam/sdk/io/kafka/KafkaIO.java | 156 +++++++-
 .../beam/sdk/io/kafka/TopicPartitionCoder.java     |  56 +++
 .../sdk/io/kafka/WatchKafkaTopicPartitionDoFn.java | 156 ++++++++
 .../beam/sdk/io/kafka/TopicPartitionCoderTest.java |  39 ++
 .../io/kafka/WatchKafkaTopicPartitionDoFnTest.java | 422 +++++++++++++++++++++
 5 files changed, 821 insertions(+), 8 deletions(-)

diff --git 
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
index 04b48fa..10aac4a 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
@@ -73,6 +73,7 @@ import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PDone;
 import org.apache.beam.sdk.values.Row;
 import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.sdk.values.TypeDescriptors;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
@@ -163,6 +164,88 @@ import org.slf4j.LoggerFactory;
  * Read#withValueDeserializerAndCoder(Class, Coder)}. Note that Kafka messages 
are interpreted using
  * key and value <i>deserializers</i>.
  *
+ * <h3>Read From Kafka Dynamically</h3>
+ *
+ * For a given kafka bootstrap_server, KafkaIO is also able to detect and read 
from available {@link
+ * TopicPartition} dynamically and stop reading from un. KafkaIO uses {@link
+ * WatchKafkaTopicPartitionDoFn} to emit any new added {@link TopicPartition} 
and uses {@link
+ * ReadFromKafkaDoFn} to read from each {@link KafkaSourceDescriptor}. Dynamic 
read is able to solve
+ * 2 scenarios:
+ *
+ * <ul>
+ *   <li>Certain topic or partition is added/deleted.
+ *   <li>Certain topic or partition is added, then removed but added back again
+ * </ul>
+ *
+ * Within providing {@code checkStopReadingFn}, there are 2 more cases that 
dynamic read can handle:
+ *
+ * <ul>
+ *   <li>Certain topic or partition is stopped
+ *   <li>Certain topic or partition is added, then stopped but added back again
+ * </ul>
+ *
+ * Race conditions may happen under 2 supported cases:
+ *
+ * <ul>
+ *   <li>A TopicPartition is removed, but added backed again
+ *   <li>A TopicPartition is stopped, then want to read it again
+ * </ul>
+ *
+ * When race condition happens, it will result in the stopped/removed 
TopicPartition failing to be
+ * emitted to ReadFromKafkaDoFn again. Or ReadFromKafkaDoFn will output 
replicated records. The
+ * major cause for such race condition is that both {@link 
WatchKafkaTopicPartitionDoFn} and {@link
+ * ReadFromKafkaDoFn} react to the signal from removed/stopped {@link 
TopicPartition} but we cannot
+ * guarantee that both DoFns perform related actions at the same time.
+ *
+ * <p>Here is one example for failing to emit new added {@link TopicPartition}:
+ *
+ * <ul>
+ *   <li>A {@link WatchKafkaTopicPartitionDoFn} is configured with updating 
the current tracking set
+ *       every 1 hour.
+ *   <li>One TopicPartition A is tracked by the {@link 
WatchKafkaTopicPartitionDoFn} at 10:00AM and
+ *       {@link ReadFromKafkaDoFn} starts to read from TopicPartition A 
immediately.
+ *   <li>At 10:30AM, the {@link WatchKafkaTopicPartitionDoFn} notices that the 
{@link
+ *       TopicPartition} has been stopped/removed, so it stops reading from it 
and returns {@code
+ *       ProcessContinuation.stop()}.
+ *   <li>At 10:45 the pipeline author wants to read from TopicPartition A 
again.
+ *   <li>At 11:00AM when {@link WatchKafkaTopicPartitionDoFn} is invoked by 
firing timer, it doesn’t
+ *       know that TopicPartition A has been stopped/removed. All it knows is 
that TopicPartition A
+ *       is still an active TopicPartition and it will not emit TopicPartition 
A again.
+ * </ul>
+ *
+ * Another race condition example for producing duplicate records:
+ *
+ * <ul>
+ *   <li>At 10:00AM, {@link ReadFromKafkaDoFn} is processing TopicPartition A
+ *   <li>At 10:05AM, {@link ReadFromKafkaDoFn} starts to process other 
TopicPartitions(sdf-initiated
+ *       checkpoint or runner-issued checkpoint happens)
+ *   <li>At 10:10AM, {@link WatchKafkaTopicPartitionDoFn} knows that 
TopicPartition A is
+ *       stopped/removed
+ *   <li>At 10:15AM, {@link WatchKafkaTopicPartitionDoFn} knows that 
TopicPartition A is added again
+ *       and emits TopicPartition A again
+ *   <li>At 10:20AM, {@link ReadFromKafkaDoFn} starts to process resumed 
TopicPartition A but at the
+ *       same time {@link ReadFromKafkaDoFn} is also processing the new 
emitted TopicPartitionA.
+ * </ul>
+ *
+ * For more design details, please refer to
+ * 
https://docs.google.com/document/d/1FU3GxVRetHPLVizP3Mdv6mP5tpjZ3fd99qNjUI5DT5k/.
 To enable
+ * dynamic read, you can write a pipeline like:
+ *
+ * <pre>{@code
+ * pipeline
+ *   .apply(KafkaIO.<Long, String>read()
+ *      // Configure the dynamic read with 1 hour, where the pipeline will 
look into available
+ *      // TopicPartitions and emit new added ones every 1 hour.
+ *      .withDynamicRead(Duration.standardHours(1))
+ *      .withCheckStopReadingFn(new SerializedFunction<TopicPartition, 
Boolean>() {})
+ *      .withBootstrapServers("broker_1:9092,broker_2:9092")
+ *      .withKeyDeserializer(LongDeserializer.class)
+ *      .withValueDeserializer(StringDeserializer.class)
+ *   )
+ *   .apply(Values.<String>create()) // PCollection<String>
+ *    ...
+ * }</pre>
+ *
  * <h3>Partition Assignment and Checkpointing</h3>
  *
  * The Kafka partitions are evenly distributed among splits (workers).
@@ -431,6 +514,7 @@ public class KafkaIO {
         .setConsumerConfig(KafkaIOUtils.DEFAULT_CONSUMER_PROPERTIES)
         .setMaxNumRecords(Long.MAX_VALUE)
         .setCommitOffsetsInFinalizeEnabled(false)
+        .setDynamicRead(false)
         .setTimestampPolicyFactory(TimestampPolicyFactory.withProcessingTime())
         .build();
   }
@@ -511,6 +595,10 @@ public class KafkaIO {
 
     abstract boolean isCommitOffsetsInFinalizeEnabled();
 
+    abstract boolean isDynamicRead();
+
+    abstract @Nullable Duration getWatchTopicPartitionDuration();
+
     abstract TimestampPolicyFactory<K, V> getTimestampPolicyFactory();
 
     abstract @Nullable Map<String, Object> getOffsetConsumerConfig();
@@ -550,6 +638,10 @@ public class KafkaIO {
 
       abstract Builder<K, V> setCommitOffsetsInFinalizeEnabled(boolean 
commitOffsetInFinalize);
 
+      abstract Builder<K, V> setDynamicRead(boolean dynamicRead);
+
+      abstract Builder<K, V> setWatchTopicPartitionDuration(Duration duration);
+
       abstract Builder<K, V> setTimestampPolicyFactory(
           TimestampPolicyFactory<K, V> timestampPolicyFactory);
 
@@ -616,6 +708,11 @@ public class KafkaIO {
         if (config.startReadTime != null) {
           setStartReadTime(Instant.ofEpochMilli(config.startReadTime));
         }
+
+        // We can expose dynamic read to external build when ReadFromKafkaDoFn 
is the default
+        // implementation.
+        setDynamicRead(false);
+
         // We do not include Metadata until we can encode KafkaRecords 
cross-language
         return build().withoutMetadata();
       }
@@ -998,6 +1095,16 @@ public class KafkaIO {
     }
 
     /**
+     * Configure the KafkaIO to use {@link WatchKafkaTopicPartitionDoFn} to 
detect and emit any new
+     * available {@link TopicPartition} for {@link ReadFromKafkaDoFn} to 
consume during pipeline
+     * execution time. The KafkaIO will regularly check the availability based 
on the given
+     * duration. If the duration is not specified as {@code null}, the default 
duration is 1 hour.
+     */
+    public Read<K, V> withDynamicRead(Duration duration) {
+      return 
toBuilder().setDynamicRead(true).setWatchTopicPartitionDuration(duration).build();
+    }
+
+    /**
      * Set additional configuration for the backend offset consumer. It may be 
required for a
      * secured Kafka cluster, especially when you see similar WARN log message 
'exception while
      * fetching latest offset for partition {}. will be retried'.
@@ -1052,9 +1159,20 @@ public class KafkaIO {
       checkArgument(
           getConsumerConfig().get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG) != 
null,
           "withBootstrapServers() is required");
-      checkArgument(
-          getTopics().size() > 0 || getTopicPartitions().size() > 0,
-          "Either withTopic(), withTopics() or withTopicPartitions() is 
required");
+      // With dynamic read, we no longer require providing topic/partition 
during pipeline
+      // construction time. But dynamic read requires enabling 
use_sdf_kafka_read.
+      if (!isDynamicRead()) {
+        checkArgument(
+            getTopics().size() > 0 || getTopicPartitions().size() > 0,
+            "Either withTopic(), withTopics() or withTopicPartitions() is 
required");
+      } else {
+        checkArgument(
+            ExperimentalOptions.hasExperiment(
+                    input.getPipeline().getOptions(), "use_sdf_kafka_read")
+                && ExperimentalOptions.hasExperiment(
+                    input.getPipeline().getOptions(), "beam_fn_api"),
+            "Kafka Dynamic Read requires enabling experiment 
use_sdf_kafka_read.");
+      }
       checkArgument(getKeyDeserializerProvider() != null, 
"withKeyDeserializer() is required");
       checkArgument(getValueDeserializerProvider() != null, 
"withValueDeserializer() is required");
 
@@ -1129,11 +1247,33 @@ public class KafkaIO {
       if (isCommitOffsetsInFinalizeEnabled()) {
         readTransform = readTransform.commitOffsets();
       }
-      PCollection<KafkaSourceDescriptor> output =
-          input
-              .getPipeline()
-              .apply(Impulse.create())
-              .apply(ParDo.of(new GenerateKafkaSourceDescriptor(this)));
+      PCollection<KafkaSourceDescriptor> output;
+      if (isDynamicRead()) {
+        output =
+            input
+                .getPipeline()
+                .apply(Impulse.create())
+                .apply(
+                    MapElements.into(
+                            TypeDescriptors.kvs(
+                                new TypeDescriptor<byte[]>() {}, new 
TypeDescriptor<byte[]>() {}))
+                        .via(element -> KV.of(element, element)))
+                .apply(
+                    ParDo.of(
+                        new WatchKafkaTopicPartitionDoFn(
+                            getWatchTopicPartitionDuration(),
+                            getConsumerFactoryFn(),
+                            getCheckStopReadingFn(),
+                            getConsumerConfig(),
+                            getStartReadTime())));
+
+      } else {
+        output =
+            input
+                .getPipeline()
+                .apply(Impulse.create())
+                .apply(ParDo.of(new GenerateKafkaSourceDescriptor(this)));
+      }
       return 
output.apply(readTransform).setCoder(KafkaRecordCoder.of(keyCoder, valueCoder));
     }
 
diff --git 
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/TopicPartitionCoder.java
 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/TopicPartitionCoder.java
new file mode 100644
index 0000000..f11e8ca
--- /dev/null
+++ 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/TopicPartitionCoder.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.List;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.StructuredCoder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.kafka.common.TopicPartition;
+
+/** The {@link Coder} for encoding and decoding {@link TopicPartition} in 
Beam. */
+@SuppressWarnings({"nullness"})
+public class TopicPartitionCoder extends StructuredCoder<TopicPartition> {
+
+  @Override
+  public void encode(TopicPartition value, OutputStream outStream)
+      throws CoderException, IOException {
+    StringUtf8Coder.of().encode(value.topic(), outStream);
+    VarIntCoder.of().encode(value.partition(), outStream);
+  }
+
+  @Override
+  public TopicPartition decode(InputStream inStream) throws CoderException, 
IOException {
+    String topic = StringUtf8Coder.of().decode(inStream);
+    int partition = VarIntCoder.of().decode(inStream);
+    return new TopicPartition(topic, partition);
+  }
+
+  @Override
+  public List<? extends Coder<?>> getCoderArguments() {
+    return null;
+  }
+
+  @Override
+  public void verifyDeterministic() throws NonDeterministicException {}
+}
diff --git 
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/WatchKafkaTopicPartitionDoFn.java
 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/WatchKafkaTopicPartitionDoFn.java
new file mode 100644
index 0000000..d82bfcf
--- /dev/null
+++ 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/WatchKafkaTopicPartitionDoFn.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.state.Timer;
+import org.apache.beam.sdk.state.TimerSpec;
+import org.apache.beam.sdk.state.TimerSpecs;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.values.KV;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+
+/**
+ * A stateful {@linkl DoFn} that emits new available {@link TopicPartition} 
regularly.
+ *
+ * <p>Please refer to
+ * 
https://docs.google.com/document/d/1FU3GxVRetHPLVizP3Mdv6mP5tpjZ3fd99qNjUI5DT5k/edit#
 for more
+ * details.
+ */
+@SuppressWarnings({"nullness"})
+@Experimental
+class WatchKafkaTopicPartitionDoFn extends DoFn<KV<byte[], byte[]>, 
KafkaSourceDescriptor> {
+
+  private static final Duration DEFAULT_CHECK_DURATION = 
Duration.standardHours(1);
+  private static final String TIMER_ID = "watch_timer";
+  private static final String STATE_ID = "topic_partition_set";
+  private final Duration checkDuration;
+
+  private final SerializableFunction<Map<String, Object>, Consumer<byte[], 
byte[]>>
+      kafkaConsumerFactoryFn;
+  private final SerializableFunction<TopicPartition, Boolean> 
checkStopReadingFn;
+  private final Map<String, Object> kafkaConsumerConfig;
+  private final Instant startReadTime;
+
+  WatchKafkaTopicPartitionDoFn(
+      Duration checkDuration,
+      SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> 
kafkaConsumerFactoryFn,
+      SerializableFunction<TopicPartition, Boolean> checkStopReadingFn,
+      Map<String, Object> kafkaConsumerConfig,
+      Instant startReadTime) {
+    this.checkDuration = checkDuration == null ? DEFAULT_CHECK_DURATION : 
checkDuration;
+    this.kafkaConsumerFactoryFn = kafkaConsumerFactoryFn;
+    this.checkStopReadingFn = checkStopReadingFn;
+    this.kafkaConsumerConfig = kafkaConsumerConfig;
+    this.startReadTime = startReadTime;
+  }
+
+  @TimerId(TIMER_ID)
+  private final TimerSpec timerSpec = 
TimerSpecs.timer(TimeDomain.PROCESSING_TIME);
+
+  @StateId(STATE_ID)
+  private final StateSpec<BagState<TopicPartition>> bagStateSpec =
+      StateSpecs.bag(new TopicPartitionCoder());
+
+  @VisibleForTesting
+  Set<TopicPartition> getAllTopicPartitions() {
+    Set<TopicPartition> current = new HashSet<>();
+    try (Consumer<byte[], byte[]> kafkaConsumer =
+        kafkaConsumerFactoryFn.apply(kafkaConsumerConfig)) {
+      for (Map.Entry<String, List<PartitionInfo>> topicInfo :
+          kafkaConsumer.listTopics().entrySet()) {
+        for (PartitionInfo partition : topicInfo.getValue()) {
+          current.add(new TopicPartition(topicInfo.getKey(), 
partition.partition()));
+        }
+      }
+    }
+    return current;
+  }
+
+  @ProcessElement
+  public void processElement(
+      @TimerId(TIMER_ID) Timer timer,
+      @StateId(STATE_ID) BagState<TopicPartition> existingTopicPartitions,
+      OutputReceiver<KafkaSourceDescriptor> outputReceiver) {
+    // For the first time, we emit all available TopicPartition and write them 
into State.
+    Set<TopicPartition> current = getAllTopicPartitions();
+    current.forEach(
+        topicPartition -> {
+          if (checkStopReadingFn == null || 
!checkStopReadingFn.apply(topicPartition)) {
+            existingTopicPartitions.add(topicPartition);
+            outputReceiver.output(
+                KafkaSourceDescriptor.of(topicPartition, null, startReadTime, 
null));
+          }
+        });
+
+    timer.set(Instant.now().plus(checkDuration.getMillis()));
+  }
+
+  @OnTimer(TIMER_ID)
+  public void onTimer(
+      @TimerId(TIMER_ID) Timer timer,
+      @StateId(STATE_ID) BagState<TopicPartition> existingTopicPartitions,
+      OutputReceiver<KafkaSourceDescriptor> outputReceiver) {
+    Set<TopicPartition> readingTopicPartitions = new HashSet<>();
+    existingTopicPartitions
+        .read()
+        .forEach(
+            topicPartition -> {
+              readingTopicPartitions.add(topicPartition);
+            });
+    existingTopicPartitions.clear();
+
+    Set<TopicPartition> currentAll = getAllTopicPartitions();
+
+    // Emit new added TopicPartitions.
+    Set<TopicPartition> newAdded = Sets.difference(currentAll, 
readingTopicPartitions);
+    newAdded.forEach(
+        topicPartition -> {
+          if (checkStopReadingFn == null || 
!checkStopReadingFn.apply(topicPartition)) {
+            outputReceiver.output(
+                KafkaSourceDescriptor.of(topicPartition, null, startReadTime, 
null));
+          }
+        });
+
+    // Update the State.
+    currentAll.forEach(
+        topicPartition -> {
+          if (checkStopReadingFn == null || 
!checkStopReadingFn.apply(topicPartition)) {
+            existingTopicPartitions.add(topicPartition);
+          }
+        });
+
+    // Reset the timer.
+    timer.set(Instant.now().plus(checkDuration.getMillis()));
+  }
+}
diff --git 
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/TopicPartitionCoderTest.java
 
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/TopicPartitionCoderTest.java
new file mode 100644
index 0000000..01c5acd
--- /dev/null
+++ 
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/TopicPartitionCoderTest.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import static org.junit.Assert.assertEquals;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import org.apache.kafka.common.TopicPartition;
+import org.junit.Test;
+
+@SuppressWarnings({"nullness"})
+public class TopicPartitionCoderTest {
+
+  @Test
+  public void testEncodeDecodeRoundTrip() throws Exception {
+    TopicPartitionCoder coder = new TopicPartitionCoder();
+    TopicPartition topicPartition = new TopicPartition("topic", 1);
+    ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+    coder.encode(topicPartition, outputStream);
+    assertEquals(
+        topicPartition, coder.decode(new 
ByteArrayInputStream(outputStream.toByteArray())));
+  }
+}
diff --git 
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/WatchKafkaTopicPartitionDoFnTest.java
 
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/WatchKafkaTopicPartitionDoFnTest.java
new file mode 100644
index 0000000..14460d6
--- /dev/null
+++ 
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/WatchKafkaTopicPartitionDoFnTest.java
@@ -0,0 +1,422 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+import static org.powermock.api.mockito.PowerMockito.mockStatic;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.ReadableState;
+import org.apache.beam.sdk.state.Timer;
+import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+
+@RunWith(PowerMockRunner.class)
+@PrepareForTest({Instant.class})
+@SuppressWarnings({"nullness"})
+public class WatchKafkaTopicPartitionDoFnTest {
+
+  @Mock Consumer<byte[], byte[]> mockConsumer;
+  @Mock Timer timer;
+
+  private final SerializableFunction<Map<String, Object>, Consumer<byte[], 
byte[]>> consumerFn =
+      new SerializableFunction<Map<String, Object>, Consumer<byte[], 
byte[]>>() {
+        @Override
+        public Consumer<byte[], byte[]> apply(Map<String, Object> input) {
+          return mockConsumer;
+        }
+      };
+
+  @Test
+  public void testGetAllTopicPartitions() throws Exception {
+    when(mockConsumer.listTopics())
+        .thenReturn(
+            ImmutableMap.of(
+                "topic1",
+                ImmutableList.of(
+                    new PartitionInfo("topic1", 0, null, null, null),
+                    new PartitionInfo("topic1", 1, null, null, null)),
+                "topic2",
+                ImmutableList.of(
+                    new PartitionInfo("topic2", 0, null, null, null),
+                    new PartitionInfo("topic2", 1, null, null, null))));
+    WatchKafkaTopicPartitionDoFn dofnInstance =
+        new WatchKafkaTopicPartitionDoFn(
+            Duration.millis(1L), consumerFn, null, ImmutableMap.of(), null);
+    assertEquals(
+        ImmutableSet.of(
+            new TopicPartition("topic1", 0),
+            new TopicPartition("topic1", 1),
+            new TopicPartition("topic2", 0),
+            new TopicPartition("topic2", 1)),
+        dofnInstance.getAllTopicPartitions());
+  }
+
+  @Test
+  public void testProcessElementWhenNoAvailableTopicPartition() throws 
Exception {
+    WatchKafkaTopicPartitionDoFn dofnInstance =
+        new WatchKafkaTopicPartitionDoFn(
+            Duration.millis(600L), consumerFn, null, ImmutableMap.of(), null);
+    MockOutputReceiver outputReceiver = new MockOutputReceiver();
+
+    when(mockConsumer.listTopics()).thenReturn(ImmutableMap.of());
+    MockBagState bagState = new MockBagState(ImmutableList.of());
+    Instant now = Instant.EPOCH;
+    mockStatic(Instant.class);
+    when(Instant.now()).thenReturn(now);
+
+    dofnInstance.processElement(timer, bagState, outputReceiver);
+    verify(timer, times(1)).set(now.plus(600L));
+    assertTrue(outputReceiver.getOutputs().isEmpty());
+    assertTrue(bagState.getCurrentStates().isEmpty());
+  }
+
+  @Test
+  public void testProcessElementWithAvailableTopicPartitions() throws 
Exception {
+    Instant startReadTime = Instant.ofEpochMilli(1L);
+    WatchKafkaTopicPartitionDoFn dofnInstance =
+        new WatchKafkaTopicPartitionDoFn(
+            Duration.millis(600L), consumerFn, null, ImmutableMap.of(), 
startReadTime);
+    MockOutputReceiver outputReceiver = new MockOutputReceiver();
+
+    when(mockConsumer.listTopics())
+        .thenReturn(
+            ImmutableMap.of(
+                "topic1",
+                ImmutableList.of(
+                    new PartitionInfo("topic1", 0, null, null, null),
+                    new PartitionInfo("topic1", 1, null, null, null)),
+                "topic2",
+                ImmutableList.of(
+                    new PartitionInfo("topic2", 0, null, null, null),
+                    new PartitionInfo("topic2", 1, null, null, null))));
+    MockBagState bagState = new MockBagState(ImmutableList.of());
+    Instant now = Instant.EPOCH;
+    mockStatic(Instant.class);
+    when(Instant.now()).thenReturn(now);
+
+    dofnInstance.processElement(timer, bagState, outputReceiver);
+
+    verify(timer, times(1)).set(now.plus(600L));
+    Set<TopicPartition> expectedOutputTopicPartitions =
+        ImmutableSet.of(
+            new TopicPartition("topic1", 0),
+            new TopicPartition("topic1", 1),
+            new TopicPartition("topic2", 0),
+            new TopicPartition("topic2", 1));
+    Set<KafkaSourceDescriptor> expectedOutputDescriptor =
+        generateDescriptorsFromTopicPartitions(expectedOutputTopicPartitions, 
startReadTime);
+    assertEquals(expectedOutputDescriptor, new 
HashSet<>(outputReceiver.getOutputs()));
+    assertEquals(expectedOutputTopicPartitions, bagState.getCurrentStates());
+  }
+
+  @Test
+  public void testProcessElementWithStoppingReadingTopicPartition() throws 
Exception {
+    Instant startReadTime = Instant.ofEpochMilli(1L);
+    SerializableFunction<TopicPartition, Boolean> checkStopReadingFn =
+        new SerializableFunction<TopicPartition, Boolean>() {
+          @Override
+          public Boolean apply(TopicPartition input) {
+            if (input.equals(new TopicPartition("topic1", 1))) {
+              return true;
+            }
+            return false;
+          }
+        };
+    WatchKafkaTopicPartitionDoFn dofnInstance =
+        new WatchKafkaTopicPartitionDoFn(
+            Duration.millis(600L),
+            consumerFn,
+            checkStopReadingFn,
+            ImmutableMap.of(),
+            startReadTime);
+    MockOutputReceiver outputReceiver = new MockOutputReceiver();
+
+    when(mockConsumer.listTopics())
+        .thenReturn(
+            ImmutableMap.of(
+                "topic1",
+                ImmutableList.of(
+                    new PartitionInfo("topic1", 0, null, null, null),
+                    new PartitionInfo("topic1", 1, null, null, null)),
+                "topic2",
+                ImmutableList.of(
+                    new PartitionInfo("topic2", 0, null, null, null),
+                    new PartitionInfo("topic2", 1, null, null, null))));
+    MockBagState bagState = new MockBagState(ImmutableList.of());
+    Instant now = Instant.EPOCH;
+    mockStatic(Instant.class);
+    when(Instant.now()).thenReturn(now);
+
+    dofnInstance.processElement(timer, bagState, outputReceiver);
+
+    verify(timer, times(1)).set(now.plus(600L));
+    Set<TopicPartition> expectedOutputTopicPartitions =
+        ImmutableSet.of(
+            new TopicPartition("topic1", 0),
+            new TopicPartition("topic2", 0),
+            new TopicPartition("topic2", 1));
+    Set<KafkaSourceDescriptor> expectedOutputDescriptor =
+        generateDescriptorsFromTopicPartitions(expectedOutputTopicPartitions, 
startReadTime);
+    assertEquals(expectedOutputDescriptor, new 
HashSet<>(outputReceiver.getOutputs()));
+    assertEquals(expectedOutputTopicPartitions, bagState.getCurrentStates());
+  }
+
+  @Test
+  public void testOnTimerWithNoAvailableTopicPartition() throws Exception {
+    WatchKafkaTopicPartitionDoFn dofnInstance =
+        new WatchKafkaTopicPartitionDoFn(
+            Duration.millis(600L), consumerFn, null, ImmutableMap.of(), null);
+    MockOutputReceiver outputReceiver = new MockOutputReceiver();
+
+    when(mockConsumer.listTopics()).thenReturn(ImmutableMap.of());
+    MockBagState bagState = new MockBagState(ImmutableList.of(new 
TopicPartition("topic1", 0)));
+    Instant now = Instant.EPOCH;
+    mockStatic(Instant.class);
+    when(Instant.now()).thenReturn(now);
+
+    dofnInstance.onTimer(timer, bagState, outputReceiver);
+
+    verify(timer, times(1)).set(now.plus(600L));
+    assertTrue(outputReceiver.getOutputs().isEmpty());
+    assertTrue(bagState.getCurrentStates().isEmpty());
+  }
+
+  @Test
+  public void testOnTimerWithAdditionOnly() throws Exception {
+    Instant startReadTime = Instant.ofEpochMilli(1L);
+    WatchKafkaTopicPartitionDoFn dofnInstance =
+        new WatchKafkaTopicPartitionDoFn(
+            Duration.millis(600L), consumerFn, null, ImmutableMap.of(), 
startReadTime);
+    MockOutputReceiver outputReceiver = new MockOutputReceiver();
+
+    when(mockConsumer.listTopics())
+        .thenReturn(
+            ImmutableMap.of(
+                "topic1",
+                ImmutableList.of(
+                    new PartitionInfo("topic1", 0, null, null, null),
+                    new PartitionInfo("topic1", 1, null, null, null)),
+                "topic2",
+                ImmutableList.of(
+                    new PartitionInfo("topic2", 0, null, null, null),
+                    new PartitionInfo("topic2", 1, null, null, null))));
+    MockBagState bagState =
+        new MockBagState(
+            ImmutableList.of(new TopicPartition("topic1", 0), new 
TopicPartition("topic1", 1)));
+    Instant now = Instant.EPOCH;
+    mockStatic(Instant.class);
+    when(Instant.now()).thenReturn(now);
+
+    dofnInstance.onTimer(timer, bagState, outputReceiver);
+
+    verify(timer, times(1)).set(now.plus(600L));
+    Set<TopicPartition> expectedOutputTopicPartitions =
+        ImmutableSet.of(new TopicPartition("topic2", 0), new 
TopicPartition("topic2", 1));
+    Set<TopicPartition> expectedCurrentTopicPartitions =
+        ImmutableSet.of(
+            new TopicPartition("topic1", 0),
+            new TopicPartition("topic1", 1),
+            new TopicPartition("topic2", 0),
+            new TopicPartition("topic2", 1));
+    Set<KafkaSourceDescriptor> expectedOutputDescriptor =
+        generateDescriptorsFromTopicPartitions(expectedOutputTopicPartitions, 
startReadTime);
+    assertEquals(expectedOutputDescriptor, new 
HashSet<>(outputReceiver.getOutputs()));
+    assertEquals(expectedCurrentTopicPartitions, bagState.getCurrentStates());
+  }
+
+  @Test
+  public void testOnTimerWithRemovalOnly() throws Exception {
+    Instant startReadTime = Instant.ofEpochMilli(1L);
+    WatchKafkaTopicPartitionDoFn dofnInstance =
+        new WatchKafkaTopicPartitionDoFn(
+            Duration.millis(600L), consumerFn, null, ImmutableMap.of(), 
startReadTime);
+    MockOutputReceiver outputReceiver = new MockOutputReceiver();
+
+    when(mockConsumer.listTopics())
+        .thenReturn(
+            ImmutableMap.of(
+                "topic1",
+                ImmutableList.of(new PartitionInfo("topic1", 0, null, null, 
null)),
+                "topic2",
+                ImmutableList.of(
+                    new PartitionInfo("topic2", 0, null, null, null),
+                    new PartitionInfo("topic2", 1, null, null, null))));
+    MockBagState bagState =
+        new MockBagState(
+            ImmutableList.of(
+                new TopicPartition("topic1", 0),
+                new TopicPartition("topic1", 1),
+                new TopicPartition("topic2", 0),
+                new TopicPartition("topic2", 1)));
+    Instant now = Instant.EPOCH;
+    mockStatic(Instant.class);
+    when(Instant.now()).thenReturn(now);
+
+    dofnInstance.onTimer(timer, bagState, outputReceiver);
+
+    verify(timer, times(1)).set(now.plus(600L));
+    Set<TopicPartition> expectedCurrentTopicPartitions =
+        ImmutableSet.of(
+            new TopicPartition("topic1", 0),
+            new TopicPartition("topic2", 0),
+            new TopicPartition("topic2", 1));
+    assertTrue(outputReceiver.getOutputs().isEmpty());
+    assertEquals(expectedCurrentTopicPartitions, bagState.getCurrentStates());
+  }
+
+  @Test
+  public void testOnTimerWithStoppedTopicPartitions() throws Exception {
+    Instant startReadTime = Instant.ofEpochMilli(1L);
+    SerializableFunction<TopicPartition, Boolean> checkStopReadingFn =
+        new SerializableFunction<TopicPartition, Boolean>() {
+          @Override
+          public Boolean apply(TopicPartition input) {
+            if (input.equals(new TopicPartition("topic1", 1))) {
+              return true;
+            }
+            return false;
+          }
+        };
+    WatchKafkaTopicPartitionDoFn dofnInstance =
+        new WatchKafkaTopicPartitionDoFn(
+            Duration.millis(600L),
+            consumerFn,
+            checkStopReadingFn,
+            ImmutableMap.of(),
+            startReadTime);
+    MockOutputReceiver outputReceiver = new MockOutputReceiver();
+
+    when(mockConsumer.listTopics())
+        .thenReturn(
+            ImmutableMap.of(
+                "topic1",
+                ImmutableList.of(
+                    new PartitionInfo("topic1", 0, null, null, null),
+                    new PartitionInfo("topic1", 1, null, null, null)),
+                "topic2",
+                ImmutableList.of(
+                    new PartitionInfo("topic2", 0, null, null, null),
+                    new PartitionInfo("topic2", 1, null, null, null))));
+    MockBagState bagState =
+        new MockBagState(
+            ImmutableList.of(
+                new TopicPartition("topic1", 0),
+                new TopicPartition("topic2", 0),
+                new TopicPartition("topic2", 1)));
+    Instant now = Instant.EPOCH;
+    mockStatic(Instant.class);
+    when(Instant.now()).thenReturn(now);
+
+    dofnInstance.onTimer(timer, bagState, outputReceiver);
+
+    Set<TopicPartition> expectedCurrentTopicPartitions =
+        ImmutableSet.of(
+            new TopicPartition("topic1", 0),
+            new TopicPartition("topic2", 0),
+            new TopicPartition("topic2", 1));
+
+    verify(timer, times(1)).set(now.plus(600L));
+    assertTrue(outputReceiver.getOutputs().isEmpty());
+    assertEquals(expectedCurrentTopicPartitions, bagState.getCurrentStates());
+  }
+
+  private static class MockOutputReceiver implements 
OutputReceiver<KafkaSourceDescriptor> {
+
+    private List<KafkaSourceDescriptor> outputs = new ArrayList<>();
+
+    @Override
+    public void output(KafkaSourceDescriptor output) {
+      outputs.add(output);
+    }
+
+    @Override
+    public void outputWithTimestamp(KafkaSourceDescriptor output, Instant 
timestamp) {}
+
+    public List<KafkaSourceDescriptor> getOutputs() {
+      return outputs;
+    }
+  }
+
+  private static class MockBagState implements BagState<TopicPartition> {
+    private Set<TopicPartition> topicPartitions = new HashSet<>();
+
+    MockBagState(List<TopicPartition> readReturn) {
+      topicPartitions.addAll(readReturn);
+    }
+
+    @Override
+    public Iterable<TopicPartition> read() {
+      return topicPartitions;
+    }
+
+    @Override
+    public void add(TopicPartition value) {
+      topicPartitions.add(value);
+    }
+
+    @Override
+    public ReadableState<Boolean> isEmpty() {
+      return null;
+    }
+
+    @Override
+    public BagState<TopicPartition> readLater() {
+      return null;
+    }
+
+    @Override
+    public void clear() {
+      topicPartitions.clear();
+    }
+
+    public Set<TopicPartition> getCurrentStates() {
+      return topicPartitions;
+    }
+  }
+
+  private Set<KafkaSourceDescriptor> generateDescriptorsFromTopicPartitions(
+      Set<TopicPartition> topicPartitions, Instant startReadTime) {
+    return topicPartitions.stream()
+        .map(topicPartition -> KafkaSourceDescriptor.of(topicPartition, null, 
startReadTime, null))
+        .collect(Collectors.toSet());
+  }
+}

Reply via email to