This is an automated email from the ASF dual-hosted git repository.
boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 0fbb21f [BEAM-11325] Support KafkaIO dynamic read
new 47d3326 Merge pull request #13750 from [BEAM-11325] Kafka Dynamic Read
0fbb21f is described below
commit 0fbb21fd13b0f1ac1a28e4c839b8ebbe9420e9d3
Author: Boyuan Zhang <[email protected]>
AuthorDate: Fri Jan 8 10:31:34 2021 -0800
[BEAM-11325] Support KafkaIO dynamic read
---
.../java/org/apache/beam/sdk/io/kafka/KafkaIO.java | 156 +++++++-
.../beam/sdk/io/kafka/TopicPartitionCoder.java | 56 +++
.../sdk/io/kafka/WatchKafkaTopicPartitionDoFn.java | 156 ++++++++
.../beam/sdk/io/kafka/TopicPartitionCoderTest.java | 39 ++
.../io/kafka/WatchKafkaTopicPartitionDoFnTest.java | 422 +++++++++++++++++++++
5 files changed, 821 insertions(+), 8 deletions(-)
diff --git
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
index 04b48fa..10aac4a 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
@@ -73,6 +73,7 @@ import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.sdk.values.TypeDescriptors;
import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner;
import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
@@ -163,6 +164,88 @@ import org.slf4j.LoggerFactory;
* Read#withValueDeserializerAndCoder(Class, Coder)}. Note that Kafka messages
are interpreted using
* key and value <i>deserializers</i>.
*
+ * <h3>Read From Kafka Dynamically</h3>
+ *
+ * For a given kafka bootstrap_server, KafkaIO is also able to detect and read
from available {@link
+ * TopicPartition} dynamically and stop reading from un. KafkaIO uses {@link
+ * WatchKafkaTopicPartitionDoFn} to emit any new added {@link TopicPartition}
and uses {@link
+ * ReadFromKafkaDoFn} to read from each {@link KafkaSourceDescriptor}. Dynamic
read is able to solve
+ * 2 scenarios:
+ *
+ * <ul>
+ * <li>Certain topic or partition is added/deleted.
+ * <li>Certain topic or partition is added, then removed but added back again
+ * </ul>
+ *
+ * Within providing {@code checkStopReadingFn}, there are 2 more cases that
dynamic read can handle:
+ *
+ * <ul>
+ * <li>Certain topic or partition is stopped
+ * <li>Certain topic or partition is added, then stopped but added back again
+ * </ul>
+ *
+ * Race conditions may happen under 2 supported cases:
+ *
+ * <ul>
+ * <li>A TopicPartition is removed, but added backed again
+ * <li>A TopicPartition is stopped, then want to read it again
+ * </ul>
+ *
+ * When race condition happens, it will result in the stopped/removed
TopicPartition failing to be
+ * emitted to ReadFromKafkaDoFn again. Or ReadFromKafkaDoFn will output
replicated records. The
+ * major cause for such race condition is that both {@link
WatchKafkaTopicPartitionDoFn} and {@link
+ * ReadFromKafkaDoFn} react to the signal from removed/stopped {@link
TopicPartition} but we cannot
+ * guarantee that both DoFns perform related actions at the same time.
+ *
+ * <p>Here is one example for failing to emit new added {@link TopicPartition}:
+ *
+ * <ul>
+ * <li>A {@link WatchKafkaTopicPartitionDoFn} is configured with updating
the current tracking set
+ * every 1 hour.
+ * <li>One TopicPartition A is tracked by the {@link
WatchKafkaTopicPartitionDoFn} at 10:00AM and
+ * {@link ReadFromKafkaDoFn} starts to read from TopicPartition A
immediately.
+ * <li>At 10:30AM, the {@link WatchKafkaTopicPartitionDoFn} notices that the
{@link
+ * TopicPartition} has been stopped/removed, so it stops reading from it
and returns {@code
+ * ProcessContinuation.stop()}.
+ * <li>At 10:45 the pipeline author wants to read from TopicPartition A
again.
+ * <li>At 11:00AM when {@link WatchKafkaTopicPartitionDoFn} is invoked by
firing timer, it doesn’t
+ * know that TopicPartition A has been stopped/removed. All it knows is
that TopicPartition A
+ * is still an active TopicPartition and it will not emit TopicPartition
A again.
+ * </ul>
+ *
+ * Another race condition example for producing duplicate records:
+ *
+ * <ul>
+ * <li>At 10:00AM, {@link ReadFromKafkaDoFn} is processing TopicPartition A
+ * <li>At 10:05AM, {@link ReadFromKafkaDoFn} starts to process other
TopicPartitions(sdf-initiated
+ * checkpoint or runner-issued checkpoint happens)
+ * <li>At 10:10AM, {@link WatchKafkaTopicPartitionDoFn} knows that
TopicPartition A is
+ * stopped/removed
+ * <li>At 10:15AM, {@link WatchKafkaTopicPartitionDoFn} knows that
TopicPartition A is added again
+ * and emits TopicPartition A again
+ * <li>At 10:20AM, {@link ReadFromKafkaDoFn} starts to process resumed
TopicPartition A but at the
+ * same time {@link ReadFromKafkaDoFn} is also processing the new
emitted TopicPartitionA.
+ * </ul>
+ *
+ * For more design details, please refer to
+ *
https://docs.google.com/document/d/1FU3GxVRetHPLVizP3Mdv6mP5tpjZ3fd99qNjUI5DT5k/.
To enable
+ * dynamic read, you can write a pipeline like:
+ *
+ * <pre>{@code
+ * pipeline
+ * .apply(KafkaIO.<Long, String>read()
+ * // Configure the dynamic read with 1 hour, where the pipeline will
look into available
+ * // TopicPartitions and emit new added ones every 1 hour.
+ * .withDynamicRead(Duration.standardHours(1))
+ * .withCheckStopReadingFn(new SerializedFunction<TopicPartition,
Boolean>() {})
+ * .withBootstrapServers("broker_1:9092,broker_2:9092")
+ * .withKeyDeserializer(LongDeserializer.class)
+ * .withValueDeserializer(StringDeserializer.class)
+ * )
+ * .apply(Values.<String>create()) // PCollection<String>
+ * ...
+ * }</pre>
+ *
* <h3>Partition Assignment and Checkpointing</h3>
*
* The Kafka partitions are evenly distributed among splits (workers).
@@ -431,6 +514,7 @@ public class KafkaIO {
.setConsumerConfig(KafkaIOUtils.DEFAULT_CONSUMER_PROPERTIES)
.setMaxNumRecords(Long.MAX_VALUE)
.setCommitOffsetsInFinalizeEnabled(false)
+ .setDynamicRead(false)
.setTimestampPolicyFactory(TimestampPolicyFactory.withProcessingTime())
.build();
}
@@ -511,6 +595,10 @@ public class KafkaIO {
abstract boolean isCommitOffsetsInFinalizeEnabled();
+ abstract boolean isDynamicRead();
+
+ abstract @Nullable Duration getWatchTopicPartitionDuration();
+
abstract TimestampPolicyFactory<K, V> getTimestampPolicyFactory();
abstract @Nullable Map<String, Object> getOffsetConsumerConfig();
@@ -550,6 +638,10 @@ public class KafkaIO {
abstract Builder<K, V> setCommitOffsetsInFinalizeEnabled(boolean
commitOffsetInFinalize);
+ abstract Builder<K, V> setDynamicRead(boolean dynamicRead);
+
+ abstract Builder<K, V> setWatchTopicPartitionDuration(Duration duration);
+
abstract Builder<K, V> setTimestampPolicyFactory(
TimestampPolicyFactory<K, V> timestampPolicyFactory);
@@ -616,6 +708,11 @@ public class KafkaIO {
if (config.startReadTime != null) {
setStartReadTime(Instant.ofEpochMilli(config.startReadTime));
}
+
+ // We can expose dynamic read to external build when ReadFromKafkaDoFn
is the default
+ // implementation.
+ setDynamicRead(false);
+
// We do not include Metadata until we can encode KafkaRecords
cross-language
return build().withoutMetadata();
}
@@ -998,6 +1095,16 @@ public class KafkaIO {
}
/**
+ * Configure the KafkaIO to use {@link WatchKafkaTopicPartitionDoFn} to
detect and emit any new
+ * available {@link TopicPartition} for {@link ReadFromKafkaDoFn} to
consume during pipeline
+ * execution time. The KafkaIO will regularly check the availability based
on the given
+ * duration. If the duration is not specified as {@code null}, the default
duration is 1 hour.
+ */
+ public Read<K, V> withDynamicRead(Duration duration) {
+ return
toBuilder().setDynamicRead(true).setWatchTopicPartitionDuration(duration).build();
+ }
+
+ /**
* Set additional configuration for the backend offset consumer. It may be
required for a
* secured Kafka cluster, especially when you see similar WARN log message
'exception while
* fetching latest offset for partition {}. will be retried'.
@@ -1052,9 +1159,20 @@ public class KafkaIO {
checkArgument(
getConsumerConfig().get(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG) !=
null,
"withBootstrapServers() is required");
- checkArgument(
- getTopics().size() > 0 || getTopicPartitions().size() > 0,
- "Either withTopic(), withTopics() or withTopicPartitions() is
required");
+ // With dynamic read, we no longer require providing topic/partition
during pipeline
+ // construction time. But dynamic read requires enabling
use_sdf_kafka_read.
+ if (!isDynamicRead()) {
+ checkArgument(
+ getTopics().size() > 0 || getTopicPartitions().size() > 0,
+ "Either withTopic(), withTopics() or withTopicPartitions() is
required");
+ } else {
+ checkArgument(
+ ExperimentalOptions.hasExperiment(
+ input.getPipeline().getOptions(), "use_sdf_kafka_read")
+ && ExperimentalOptions.hasExperiment(
+ input.getPipeline().getOptions(), "beam_fn_api"),
+ "Kafka Dynamic Read requires enabling experiment
use_sdf_kafka_read.");
+ }
checkArgument(getKeyDeserializerProvider() != null,
"withKeyDeserializer() is required");
checkArgument(getValueDeserializerProvider() != null,
"withValueDeserializer() is required");
@@ -1129,11 +1247,33 @@ public class KafkaIO {
if (isCommitOffsetsInFinalizeEnabled()) {
readTransform = readTransform.commitOffsets();
}
- PCollection<KafkaSourceDescriptor> output =
- input
- .getPipeline()
- .apply(Impulse.create())
- .apply(ParDo.of(new GenerateKafkaSourceDescriptor(this)));
+ PCollection<KafkaSourceDescriptor> output;
+ if (isDynamicRead()) {
+ output =
+ input
+ .getPipeline()
+ .apply(Impulse.create())
+ .apply(
+ MapElements.into(
+ TypeDescriptors.kvs(
+ new TypeDescriptor<byte[]>() {}, new
TypeDescriptor<byte[]>() {}))
+ .via(element -> KV.of(element, element)))
+ .apply(
+ ParDo.of(
+ new WatchKafkaTopicPartitionDoFn(
+ getWatchTopicPartitionDuration(),
+ getConsumerFactoryFn(),
+ getCheckStopReadingFn(),
+ getConsumerConfig(),
+ getStartReadTime())));
+
+ } else {
+ output =
+ input
+ .getPipeline()
+ .apply(Impulse.create())
+ .apply(ParDo.of(new GenerateKafkaSourceDescriptor(this)));
+ }
return
output.apply(readTransform).setCoder(KafkaRecordCoder.of(keyCoder, valueCoder));
}
diff --git
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/TopicPartitionCoder.java
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/TopicPartitionCoder.java
new file mode 100644
index 0000000..f11e8ca
--- /dev/null
+++
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/TopicPartitionCoder.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.List;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.StructuredCoder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.kafka.common.TopicPartition;
+
+/** The {@link Coder} for encoding and decoding {@link TopicPartition} in
Beam. */
+@SuppressWarnings({"nullness"})
+public class TopicPartitionCoder extends StructuredCoder<TopicPartition> {
+
+ @Override
+ public void encode(TopicPartition value, OutputStream outStream)
+ throws CoderException, IOException {
+ StringUtf8Coder.of().encode(value.topic(), outStream);
+ VarIntCoder.of().encode(value.partition(), outStream);
+ }
+
+ @Override
+ public TopicPartition decode(InputStream inStream) throws CoderException,
IOException {
+ String topic = StringUtf8Coder.of().decode(inStream);
+ int partition = VarIntCoder.of().decode(inStream);
+ return new TopicPartition(topic, partition);
+ }
+
+ @Override
+ public List<? extends Coder<?>> getCoderArguments() {
+ return null;
+ }
+
+ @Override
+ public void verifyDeterministic() throws NonDeterministicException {}
+}
diff --git
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/WatchKafkaTopicPartitionDoFn.java
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/WatchKafkaTopicPartitionDoFn.java
new file mode 100644
index 0000000..d82bfcf
--- /dev/null
+++
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/WatchKafkaTopicPartitionDoFn.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.state.Timer;
+import org.apache.beam.sdk.state.TimerSpec;
+import org.apache.beam.sdk.state.TimerSpecs;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.values.KV;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+
+/**
+ * A stateful {@linkl DoFn} that emits new available {@link TopicPartition}
regularly.
+ *
+ * <p>Please refer to
+ *
https://docs.google.com/document/d/1FU3GxVRetHPLVizP3Mdv6mP5tpjZ3fd99qNjUI5DT5k/edit#
for more
+ * details.
+ */
+@SuppressWarnings({"nullness"})
+@Experimental
+class WatchKafkaTopicPartitionDoFn extends DoFn<KV<byte[], byte[]>,
KafkaSourceDescriptor> {
+
+ private static final Duration DEFAULT_CHECK_DURATION =
Duration.standardHours(1);
+ private static final String TIMER_ID = "watch_timer";
+ private static final String STATE_ID = "topic_partition_set";
+ private final Duration checkDuration;
+
+ private final SerializableFunction<Map<String, Object>, Consumer<byte[],
byte[]>>
+ kafkaConsumerFactoryFn;
+ private final SerializableFunction<TopicPartition, Boolean>
checkStopReadingFn;
+ private final Map<String, Object> kafkaConsumerConfig;
+ private final Instant startReadTime;
+
+ WatchKafkaTopicPartitionDoFn(
+ Duration checkDuration,
+ SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>>
kafkaConsumerFactoryFn,
+ SerializableFunction<TopicPartition, Boolean> checkStopReadingFn,
+ Map<String, Object> kafkaConsumerConfig,
+ Instant startReadTime) {
+ this.checkDuration = checkDuration == null ? DEFAULT_CHECK_DURATION :
checkDuration;
+ this.kafkaConsumerFactoryFn = kafkaConsumerFactoryFn;
+ this.checkStopReadingFn = checkStopReadingFn;
+ this.kafkaConsumerConfig = kafkaConsumerConfig;
+ this.startReadTime = startReadTime;
+ }
+
+ @TimerId(TIMER_ID)
+ private final TimerSpec timerSpec =
TimerSpecs.timer(TimeDomain.PROCESSING_TIME);
+
+ @StateId(STATE_ID)
+ private final StateSpec<BagState<TopicPartition>> bagStateSpec =
+ StateSpecs.bag(new TopicPartitionCoder());
+
+ @VisibleForTesting
+ Set<TopicPartition> getAllTopicPartitions() {
+ Set<TopicPartition> current = new HashSet<>();
+ try (Consumer<byte[], byte[]> kafkaConsumer =
+ kafkaConsumerFactoryFn.apply(kafkaConsumerConfig)) {
+ for (Map.Entry<String, List<PartitionInfo>> topicInfo :
+ kafkaConsumer.listTopics().entrySet()) {
+ for (PartitionInfo partition : topicInfo.getValue()) {
+ current.add(new TopicPartition(topicInfo.getKey(),
partition.partition()));
+ }
+ }
+ }
+ return current;
+ }
+
+ @ProcessElement
+ public void processElement(
+ @TimerId(TIMER_ID) Timer timer,
+ @StateId(STATE_ID) BagState<TopicPartition> existingTopicPartitions,
+ OutputReceiver<KafkaSourceDescriptor> outputReceiver) {
+ // For the first time, we emit all available TopicPartition and write them
into State.
+ Set<TopicPartition> current = getAllTopicPartitions();
+ current.forEach(
+ topicPartition -> {
+ if (checkStopReadingFn == null ||
!checkStopReadingFn.apply(topicPartition)) {
+ existingTopicPartitions.add(topicPartition);
+ outputReceiver.output(
+ KafkaSourceDescriptor.of(topicPartition, null, startReadTime,
null));
+ }
+ });
+
+ timer.set(Instant.now().plus(checkDuration.getMillis()));
+ }
+
+ @OnTimer(TIMER_ID)
+ public void onTimer(
+ @TimerId(TIMER_ID) Timer timer,
+ @StateId(STATE_ID) BagState<TopicPartition> existingTopicPartitions,
+ OutputReceiver<KafkaSourceDescriptor> outputReceiver) {
+ Set<TopicPartition> readingTopicPartitions = new HashSet<>();
+ existingTopicPartitions
+ .read()
+ .forEach(
+ topicPartition -> {
+ readingTopicPartitions.add(topicPartition);
+ });
+ existingTopicPartitions.clear();
+
+ Set<TopicPartition> currentAll = getAllTopicPartitions();
+
+ // Emit new added TopicPartitions.
+ Set<TopicPartition> newAdded = Sets.difference(currentAll,
readingTopicPartitions);
+ newAdded.forEach(
+ topicPartition -> {
+ if (checkStopReadingFn == null ||
!checkStopReadingFn.apply(topicPartition)) {
+ outputReceiver.output(
+ KafkaSourceDescriptor.of(topicPartition, null, startReadTime,
null));
+ }
+ });
+
+ // Update the State.
+ currentAll.forEach(
+ topicPartition -> {
+ if (checkStopReadingFn == null ||
!checkStopReadingFn.apply(topicPartition)) {
+ existingTopicPartitions.add(topicPartition);
+ }
+ });
+
+ // Reset the timer.
+ timer.set(Instant.now().plus(checkDuration.getMillis()));
+ }
+}
diff --git
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/TopicPartitionCoderTest.java
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/TopicPartitionCoderTest.java
new file mode 100644
index 0000000..01c5acd
--- /dev/null
+++
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/TopicPartitionCoderTest.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import static org.junit.Assert.assertEquals;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import org.apache.kafka.common.TopicPartition;
+import org.junit.Test;
+
+@SuppressWarnings({"nullness"})
+public class TopicPartitionCoderTest {
+
+ @Test
+ public void testEncodeDecodeRoundTrip() throws Exception {
+ TopicPartitionCoder coder = new TopicPartitionCoder();
+ TopicPartition topicPartition = new TopicPartition("topic", 1);
+ ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+ coder.encode(topicPartition, outputStream);
+ assertEquals(
+ topicPartition, coder.decode(new
ByteArrayInputStream(outputStream.toByteArray())));
+ }
+}
diff --git
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/WatchKafkaTopicPartitionDoFnTest.java
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/WatchKafkaTopicPartitionDoFnTest.java
new file mode 100644
index 0000000..14460d6
--- /dev/null
+++
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/WatchKafkaTopicPartitionDoFnTest.java
@@ -0,0 +1,422 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+import static org.powermock.api.mockito.PowerMockito.mockStatic;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.ReadableState;
+import org.apache.beam.sdk.state.Timer;
+import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+
+@RunWith(PowerMockRunner.class)
+@PrepareForTest({Instant.class})
+@SuppressWarnings({"nullness"})
+public class WatchKafkaTopicPartitionDoFnTest {
+
+ @Mock Consumer<byte[], byte[]> mockConsumer;
+ @Mock Timer timer;
+
+ private final SerializableFunction<Map<String, Object>, Consumer<byte[],
byte[]>> consumerFn =
+ new SerializableFunction<Map<String, Object>, Consumer<byte[],
byte[]>>() {
+ @Override
+ public Consumer<byte[], byte[]> apply(Map<String, Object> input) {
+ return mockConsumer;
+ }
+ };
+
+ @Test
+ public void testGetAllTopicPartitions() throws Exception {
+ when(mockConsumer.listTopics())
+ .thenReturn(
+ ImmutableMap.of(
+ "topic1",
+ ImmutableList.of(
+ new PartitionInfo("topic1", 0, null, null, null),
+ new PartitionInfo("topic1", 1, null, null, null)),
+ "topic2",
+ ImmutableList.of(
+ new PartitionInfo("topic2", 0, null, null, null),
+ new PartitionInfo("topic2", 1, null, null, null))));
+ WatchKafkaTopicPartitionDoFn dofnInstance =
+ new WatchKafkaTopicPartitionDoFn(
+ Duration.millis(1L), consumerFn, null, ImmutableMap.of(), null);
+ assertEquals(
+ ImmutableSet.of(
+ new TopicPartition("topic1", 0),
+ new TopicPartition("topic1", 1),
+ new TopicPartition("topic2", 0),
+ new TopicPartition("topic2", 1)),
+ dofnInstance.getAllTopicPartitions());
+ }
+
+ @Test
+ public void testProcessElementWhenNoAvailableTopicPartition() throws
Exception {
+ WatchKafkaTopicPartitionDoFn dofnInstance =
+ new WatchKafkaTopicPartitionDoFn(
+ Duration.millis(600L), consumerFn, null, ImmutableMap.of(), null);
+ MockOutputReceiver outputReceiver = new MockOutputReceiver();
+
+ when(mockConsumer.listTopics()).thenReturn(ImmutableMap.of());
+ MockBagState bagState = new MockBagState(ImmutableList.of());
+ Instant now = Instant.EPOCH;
+ mockStatic(Instant.class);
+ when(Instant.now()).thenReturn(now);
+
+ dofnInstance.processElement(timer, bagState, outputReceiver);
+ verify(timer, times(1)).set(now.plus(600L));
+ assertTrue(outputReceiver.getOutputs().isEmpty());
+ assertTrue(bagState.getCurrentStates().isEmpty());
+ }
+
+ @Test
+ public void testProcessElementWithAvailableTopicPartitions() throws
Exception {
+ Instant startReadTime = Instant.ofEpochMilli(1L);
+ WatchKafkaTopicPartitionDoFn dofnInstance =
+ new WatchKafkaTopicPartitionDoFn(
+ Duration.millis(600L), consumerFn, null, ImmutableMap.of(),
startReadTime);
+ MockOutputReceiver outputReceiver = new MockOutputReceiver();
+
+ when(mockConsumer.listTopics())
+ .thenReturn(
+ ImmutableMap.of(
+ "topic1",
+ ImmutableList.of(
+ new PartitionInfo("topic1", 0, null, null, null),
+ new PartitionInfo("topic1", 1, null, null, null)),
+ "topic2",
+ ImmutableList.of(
+ new PartitionInfo("topic2", 0, null, null, null),
+ new PartitionInfo("topic2", 1, null, null, null))));
+ MockBagState bagState = new MockBagState(ImmutableList.of());
+ Instant now = Instant.EPOCH;
+ mockStatic(Instant.class);
+ when(Instant.now()).thenReturn(now);
+
+ dofnInstance.processElement(timer, bagState, outputReceiver);
+
+ verify(timer, times(1)).set(now.plus(600L));
+ Set<TopicPartition> expectedOutputTopicPartitions =
+ ImmutableSet.of(
+ new TopicPartition("topic1", 0),
+ new TopicPartition("topic1", 1),
+ new TopicPartition("topic2", 0),
+ new TopicPartition("topic2", 1));
+ Set<KafkaSourceDescriptor> expectedOutputDescriptor =
+ generateDescriptorsFromTopicPartitions(expectedOutputTopicPartitions,
startReadTime);
+ assertEquals(expectedOutputDescriptor, new
HashSet<>(outputReceiver.getOutputs()));
+ assertEquals(expectedOutputTopicPartitions, bagState.getCurrentStates());
+ }
+
+ @Test
+ public void testProcessElementWithStoppingReadingTopicPartition() throws
Exception {
+ Instant startReadTime = Instant.ofEpochMilli(1L);
+ SerializableFunction<TopicPartition, Boolean> checkStopReadingFn =
+ new SerializableFunction<TopicPartition, Boolean>() {
+ @Override
+ public Boolean apply(TopicPartition input) {
+ if (input.equals(new TopicPartition("topic1", 1))) {
+ return true;
+ }
+ return false;
+ }
+ };
+ WatchKafkaTopicPartitionDoFn dofnInstance =
+ new WatchKafkaTopicPartitionDoFn(
+ Duration.millis(600L),
+ consumerFn,
+ checkStopReadingFn,
+ ImmutableMap.of(),
+ startReadTime);
+ MockOutputReceiver outputReceiver = new MockOutputReceiver();
+
+ when(mockConsumer.listTopics())
+ .thenReturn(
+ ImmutableMap.of(
+ "topic1",
+ ImmutableList.of(
+ new PartitionInfo("topic1", 0, null, null, null),
+ new PartitionInfo("topic1", 1, null, null, null)),
+ "topic2",
+ ImmutableList.of(
+ new PartitionInfo("topic2", 0, null, null, null),
+ new PartitionInfo("topic2", 1, null, null, null))));
+ MockBagState bagState = new MockBagState(ImmutableList.of());
+ Instant now = Instant.EPOCH;
+ mockStatic(Instant.class);
+ when(Instant.now()).thenReturn(now);
+
+ dofnInstance.processElement(timer, bagState, outputReceiver);
+
+ verify(timer, times(1)).set(now.plus(600L));
+ Set<TopicPartition> expectedOutputTopicPartitions =
+ ImmutableSet.of(
+ new TopicPartition("topic1", 0),
+ new TopicPartition("topic2", 0),
+ new TopicPartition("topic2", 1));
+ Set<KafkaSourceDescriptor> expectedOutputDescriptor =
+ generateDescriptorsFromTopicPartitions(expectedOutputTopicPartitions,
startReadTime);
+ assertEquals(expectedOutputDescriptor, new
HashSet<>(outputReceiver.getOutputs()));
+ assertEquals(expectedOutputTopicPartitions, bagState.getCurrentStates());
+ }
+
+ @Test
+ public void testOnTimerWithNoAvailableTopicPartition() throws Exception {
+ WatchKafkaTopicPartitionDoFn dofnInstance =
+ new WatchKafkaTopicPartitionDoFn(
+ Duration.millis(600L), consumerFn, null, ImmutableMap.of(), null);
+ MockOutputReceiver outputReceiver = new MockOutputReceiver();
+
+ when(mockConsumer.listTopics()).thenReturn(ImmutableMap.of());
+ MockBagState bagState = new MockBagState(ImmutableList.of(new
TopicPartition("topic1", 0)));
+ Instant now = Instant.EPOCH;
+ mockStatic(Instant.class);
+ when(Instant.now()).thenReturn(now);
+
+ dofnInstance.onTimer(timer, bagState, outputReceiver);
+
+ verify(timer, times(1)).set(now.plus(600L));
+ assertTrue(outputReceiver.getOutputs().isEmpty());
+ assertTrue(bagState.getCurrentStates().isEmpty());
+ }
+
+ @Test
+ public void testOnTimerWithAdditionOnly() throws Exception {
+ Instant startReadTime = Instant.ofEpochMilli(1L);
+ WatchKafkaTopicPartitionDoFn dofnInstance =
+ new WatchKafkaTopicPartitionDoFn(
+ Duration.millis(600L), consumerFn, null, ImmutableMap.of(),
startReadTime);
+ MockOutputReceiver outputReceiver = new MockOutputReceiver();
+
+ when(mockConsumer.listTopics())
+ .thenReturn(
+ ImmutableMap.of(
+ "topic1",
+ ImmutableList.of(
+ new PartitionInfo("topic1", 0, null, null, null),
+ new PartitionInfo("topic1", 1, null, null, null)),
+ "topic2",
+ ImmutableList.of(
+ new PartitionInfo("topic2", 0, null, null, null),
+ new PartitionInfo("topic2", 1, null, null, null))));
+ MockBagState bagState =
+ new MockBagState(
+ ImmutableList.of(new TopicPartition("topic1", 0), new
TopicPartition("topic1", 1)));
+ Instant now = Instant.EPOCH;
+ mockStatic(Instant.class);
+ when(Instant.now()).thenReturn(now);
+
+ dofnInstance.onTimer(timer, bagState, outputReceiver);
+
+ verify(timer, times(1)).set(now.plus(600L));
+ Set<TopicPartition> expectedOutputTopicPartitions =
+ ImmutableSet.of(new TopicPartition("topic2", 0), new
TopicPartition("topic2", 1));
+ Set<TopicPartition> expectedCurrentTopicPartitions =
+ ImmutableSet.of(
+ new TopicPartition("topic1", 0),
+ new TopicPartition("topic1", 1),
+ new TopicPartition("topic2", 0),
+ new TopicPartition("topic2", 1));
+ Set<KafkaSourceDescriptor> expectedOutputDescriptor =
+ generateDescriptorsFromTopicPartitions(expectedOutputTopicPartitions,
startReadTime);
+ assertEquals(expectedOutputDescriptor, new
HashSet<>(outputReceiver.getOutputs()));
+ assertEquals(expectedCurrentTopicPartitions, bagState.getCurrentStates());
+ }
+
+ @Test
+ public void testOnTimerWithRemovalOnly() throws Exception {
+ Instant startReadTime = Instant.ofEpochMilli(1L);
+ WatchKafkaTopicPartitionDoFn dofnInstance =
+ new WatchKafkaTopicPartitionDoFn(
+ Duration.millis(600L), consumerFn, null, ImmutableMap.of(),
startReadTime);
+ MockOutputReceiver outputReceiver = new MockOutputReceiver();
+
+ when(mockConsumer.listTopics())
+ .thenReturn(
+ ImmutableMap.of(
+ "topic1",
+ ImmutableList.of(new PartitionInfo("topic1", 0, null, null,
null)),
+ "topic2",
+ ImmutableList.of(
+ new PartitionInfo("topic2", 0, null, null, null),
+ new PartitionInfo("topic2", 1, null, null, null))));
+ MockBagState bagState =
+ new MockBagState(
+ ImmutableList.of(
+ new TopicPartition("topic1", 0),
+ new TopicPartition("topic1", 1),
+ new TopicPartition("topic2", 0),
+ new TopicPartition("topic2", 1)));
+ Instant now = Instant.EPOCH;
+ mockStatic(Instant.class);
+ when(Instant.now()).thenReturn(now);
+
+ dofnInstance.onTimer(timer, bagState, outputReceiver);
+
+ verify(timer, times(1)).set(now.plus(600L));
+ Set<TopicPartition> expectedCurrentTopicPartitions =
+ ImmutableSet.of(
+ new TopicPartition("topic1", 0),
+ new TopicPartition("topic2", 0),
+ new TopicPartition("topic2", 1));
+ assertTrue(outputReceiver.getOutputs().isEmpty());
+ assertEquals(expectedCurrentTopicPartitions, bagState.getCurrentStates());
+ }
+
+ @Test
+ public void testOnTimerWithStoppedTopicPartitions() throws Exception {
+ Instant startReadTime = Instant.ofEpochMilli(1L);
+ SerializableFunction<TopicPartition, Boolean> checkStopReadingFn =
+ new SerializableFunction<TopicPartition, Boolean>() {
+ @Override
+ public Boolean apply(TopicPartition input) {
+ if (input.equals(new TopicPartition("topic1", 1))) {
+ return true;
+ }
+ return false;
+ }
+ };
+ WatchKafkaTopicPartitionDoFn dofnInstance =
+ new WatchKafkaTopicPartitionDoFn(
+ Duration.millis(600L),
+ consumerFn,
+ checkStopReadingFn,
+ ImmutableMap.of(),
+ startReadTime);
+ MockOutputReceiver outputReceiver = new MockOutputReceiver();
+
+ when(mockConsumer.listTopics())
+ .thenReturn(
+ ImmutableMap.of(
+ "topic1",
+ ImmutableList.of(
+ new PartitionInfo("topic1", 0, null, null, null),
+ new PartitionInfo("topic1", 1, null, null, null)),
+ "topic2",
+ ImmutableList.of(
+ new PartitionInfo("topic2", 0, null, null, null),
+ new PartitionInfo("topic2", 1, null, null, null))));
+ MockBagState bagState =
+ new MockBagState(
+ ImmutableList.of(
+ new TopicPartition("topic1", 0),
+ new TopicPartition("topic2", 0),
+ new TopicPartition("topic2", 1)));
+ Instant now = Instant.EPOCH;
+ mockStatic(Instant.class);
+ when(Instant.now()).thenReturn(now);
+
+ dofnInstance.onTimer(timer, bagState, outputReceiver);
+
+ Set<TopicPartition> expectedCurrentTopicPartitions =
+ ImmutableSet.of(
+ new TopicPartition("topic1", 0),
+ new TopicPartition("topic2", 0),
+ new TopicPartition("topic2", 1));
+
+ verify(timer, times(1)).set(now.plus(600L));
+ assertTrue(outputReceiver.getOutputs().isEmpty());
+ assertEquals(expectedCurrentTopicPartitions, bagState.getCurrentStates());
+ }
+
+ private static class MockOutputReceiver implements
OutputReceiver<KafkaSourceDescriptor> {
+
+ private List<KafkaSourceDescriptor> outputs = new ArrayList<>();
+
+ @Override
+ public void output(KafkaSourceDescriptor output) {
+ outputs.add(output);
+ }
+
+ @Override
+ public void outputWithTimestamp(KafkaSourceDescriptor output, Instant
timestamp) {}
+
+ public List<KafkaSourceDescriptor> getOutputs() {
+ return outputs;
+ }
+ }
+
+ private static class MockBagState implements BagState<TopicPartition> {
+ private Set<TopicPartition> topicPartitions = new HashSet<>();
+
+ MockBagState(List<TopicPartition> readReturn) {
+ topicPartitions.addAll(readReturn);
+ }
+
+ @Override
+ public Iterable<TopicPartition> read() {
+ return topicPartitions;
+ }
+
+ @Override
+ public void add(TopicPartition value) {
+ topicPartitions.add(value);
+ }
+
+ @Override
+ public ReadableState<Boolean> isEmpty() {
+ return null;
+ }
+
+ @Override
+ public BagState<TopicPartition> readLater() {
+ return null;
+ }
+
+ @Override
+ public void clear() {
+ topicPartitions.clear();
+ }
+
+ public Set<TopicPartition> getCurrentStates() {
+ return topicPartitions;
+ }
+ }
+
+ private Set<KafkaSourceDescriptor> generateDescriptorsFromTopicPartitions(
+ Set<TopicPartition> topicPartitions, Instant startReadTime) {
+ return topicPartitions.stream()
+ .map(topicPartition -> KafkaSourceDescriptor.of(topicPartition, null,
startReadTime, null))
+ .collect(Collectors.toSet());
+ }
+}