This is an automated email from the ASF dual-hosted git repository.
anton pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 9d131d4 [BEAM-7896] Implementing RateEstimation for KafkaTable with
Unit and Integration Tests
new cd2ab9e Merge pull request #9298 from riazela/KafkaRateEstimation2
9d131d4 is described below
commit 9d131d490dfa1b4838d0303a3f17f36202c0874b
Author: Alireza Samadian <[email protected]>
AuthorDate: Tue Aug 6 16:56:03 2019 -0700
[BEAM-7896] Implementing RateEstimation for KafkaTable with Unit and
Integration Tests
---
sdks/java/extensions/sql/build.gradle | 1 +
.../sql/meta/provider/kafka/BeamKafkaTable.java | 147 +++++++++--
.../meta/provider/kafka/BeamKafkaCSVTableTest.java | 118 ++++++++-
.../sql/meta/provider/kafka/KafkaCSVTableIT.java | 292 +++++++++++++++++++++
.../sql/meta/provider/kafka/KafkaCSVTestTable.java | 197 ++++++++++++++
.../sql/meta/provider/kafka/KafkaTestRecord.java | 39 +++
6 files changed, 777 insertions(+), 17 deletions(-)
diff --git a/sdks/java/extensions/sql/build.gradle
b/sdks/java/extensions/sql/build.gradle
index b4a7079..fe07bfe 100644
--- a/sdks/java/extensions/sql/build.gradle
+++ b/sdks/java/extensions/sql/build.gradle
@@ -203,6 +203,7 @@ task integrationTest(type: Test) {
systemProperty "beamTestPipelineOptions", JsonOutput.toJson(pipelineOptions)
include '**/*IT.class'
+ exclude '**/KafkaCSVTableIT.java'
maxParallelForks 4
classpath = project(":sdks:java:extensions:sql")
.sourceSets
diff --git
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaTable.java
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaTable.java
index 0e1dab3..11c12f6 100644
---
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaTable.java
+++
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaTable.java
@@ -19,9 +19,13 @@ package
org.apache.beam.sdk.extensions.sql.meta.provider.kafka;
import static
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Properties;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.extensions.sql.impl.BeamTableStatistics;
import org.apache.beam.sdk.extensions.sql.impl.schema.BaseBeamTable;
@@ -34,9 +38,15 @@ import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.Row;
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.ConsumerRecords;
+import org.apache.kafka.clients.consumer.KafkaConsumer;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.serialization.ByteArrayDeserializer;
import org.apache.kafka.common.serialization.ByteArraySerializer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
* {@code BeamKafkaTable} represent a Kafka topic, as source or target. Need
to extend to convert
@@ -47,6 +57,10 @@ public abstract class BeamKafkaTable extends BaseBeamTable {
private List<String> topics;
private List<TopicPartition> topicPartitions;
private Map<String, Object> configUpdates;
+ private BeamTableStatistics rowCountStatistics = null;
+ private static final Logger LOGGER =
LoggerFactory.getLogger(BeamKafkaTable.class);
+ // This is the number of records looked from each partition when the rate is
estimated
+ protected int numberOfRecordsForRate = 50;
protected BeamKafkaTable(Schema beamSchema) {
super(beamSchema);
@@ -84,7 +98,14 @@ public abstract class BeamKafkaTable extends BaseBeamTable {
@Override
public PCollection<Row> buildIOReader(PBegin begin) {
- KafkaIO.Read<byte[], byte[]> kafkaRead = null;
+ return begin
+ .apply("read", createKafkaRead().withoutMetadata())
+ .apply("in_format", getPTransformForInput())
+ .setRowSchema(getSchema());
+ }
+
+ KafkaIO.Read<byte[], byte[]> createKafkaRead() {
+ KafkaIO.Read<byte[], byte[]> kafkaRead;
if (topics != null) {
kafkaRead =
KafkaIO.<byte[], byte[]>read()
@@ -104,28 +125,25 @@ public abstract class BeamKafkaTable extends
BaseBeamTable {
} else {
throw new IllegalArgumentException("One of topics and topicPartitions
must be configurated.");
}
-
- return begin
- .apply("read", kafkaRead.withoutMetadata())
- .apply("in_format", getPTransformForInput())
- .setRowSchema(getSchema());
+ return kafkaRead;
}
@Override
public POutput buildIOWriter(PCollection<Row> input) {
checkArgument(
topics != null && topics.size() == 1, "Only one topic can be
acceptable as output.");
- assert topics != null;
return input
.apply("out_reformat", getPTransformForOutput())
- .apply(
- "persistent",
- KafkaIO.<byte[], byte[]>write()
- .withBootstrapServers(bootstrapServers)
- .withTopic(topics.get(0))
- .withKeySerializer(ByteArraySerializer.class)
- .withValueSerializer(ByteArraySerializer.class));
+ .apply("persistent", createKafkaWrite());
+ }
+
+ private KafkaIO.Write<byte[], byte[]> createKafkaWrite() {
+ return KafkaIO.<byte[], byte[]>write()
+ .withBootstrapServers(bootstrapServers)
+ .withTopic(topics.get(0))
+ .withKeySerializer(ByteArraySerializer.class)
+ .withValueSerializer(ByteArraySerializer.class);
}
public String getBootstrapServers() {
@@ -138,6 +156,105 @@ public abstract class BeamKafkaTable extends
BaseBeamTable {
@Override
public BeamTableStatistics getTableStatistics(PipelineOptions options) {
- return BeamTableStatistics.UNBOUNDED_UNKNOWN;
+ if (rowCountStatistics == null) {
+ try {
+ rowCountStatistics =
+ BeamTableStatistics.createUnboundedTableStatistics(
+ this.computeRate(numberOfRecordsForRate));
+ } catch (Exception e) {
+ LOGGER.warn("Could not get the row count for the topics " +
getTopics(), e);
+ rowCountStatistics = BeamTableStatistics.UNBOUNDED_UNKNOWN;
+ }
+ }
+
+ return rowCountStatistics;
+ }
+
+ /**
+ * This method returns the estimate of the computeRate for this table using
last numberOfRecords
+ * tuples in each partition.
+ */
+ double computeRate(int numberOfRecords) throws NoEstimationException {
+ Properties props = new Properties();
+
+ props.put("bootstrap.servers", bootstrapServers);
+ props.put("session.timeout.ms", "30000");
+ props.put("key.deserializer",
"org.apache.kafka.common.serialization.ByteArrayDeserializer");
+ props.put("value.deserializer",
"org.apache.kafka.common.serialization.ByteArrayDeserializer");
+
+ KafkaConsumer<String, String> consumer = new KafkaConsumer<String,
String>(props);
+
+ return computeRate(consumer, numberOfRecords);
+ }
+
+ <T> double computeRate(Consumer<T, T> consumer, int numberOfRecordsToCheck)
+ throws NoEstimationException {
+
+ Stream<TopicPartition> c =
+ getTopics().stream()
+ .map(consumer::partitionsFor)
+ .flatMap(Collection::stream)
+ .map(parInf -> new TopicPartition(parInf.topic(),
parInf.partition()));
+ List<TopicPartition> topicPartitions = c.collect(Collectors.toList());
+
+ consumer.assign(topicPartitions);
+ // This will return current offset of all the partitions that are assigned
to the consumer. (It
+ // will be the last record in those partitions). Note that each topic can
have multiple
+ // partitions. Since the consumer is not assigned to any consumer group,
changing the offset or
+ // consuming messages does not have any effect on the other consumers (and
the data that our
+ // table is receiving)
+ Map<TopicPartition, Long> offsets = consumer.endOffsets(topicPartitions);
+ long nParsSeen = 0;
+ for (TopicPartition par : topicPartitions) {
+ long offset = offsets.get(par);
+ nParsSeen = (offset == 0) ? nParsSeen : nParsSeen + 1;
+ consumer.seek(par, Math.max(0L, offset - numberOfRecordsToCheck));
+ }
+
+ if (nParsSeen == 0) {
+ throw new NoEstimationException("There is no partition with messages in
it.");
+ }
+
+ ConsumerRecords<T, T> records = consumer.poll(1000);
+
+ // Kafka guarantees the delivery of messages in order they arrive to each
partition.
+ // Therefore the first message seen from each partition is the first
message arrived to that.
+ // We pick all the first messages of the partitions, and then consider the
latest one as the
+ // starting point
+ // and discard all the messages that have arrived sooner than that in the
rate estimation.
+ Map<Integer, Long> minTimeStamps = new HashMap<>();
+ long maxMinTimeStamp = 0;
+ for (ConsumerRecord<T, T> record : records) {
+ if (!minTimeStamps.containsKey(record.partition())) {
+ minTimeStamps.put(record.partition(), record.timestamp());
+
+ nParsSeen--;
+ maxMinTimeStamp = Math.max(record.timestamp(), maxMinTimeStamp);
+ if (nParsSeen == 0) {
+ break;
+ }
+ }
+ }
+
+ int numberOfRecords = 0;
+ long maxTimeStamp = 0;
+ for (ConsumerRecord<T, T> record : records) {
+ maxTimeStamp = Math.max(maxTimeStamp, record.timestamp());
+ numberOfRecords =
+ record.timestamp() > maxMinTimeStamp ? numberOfRecords + 1 :
numberOfRecords;
+ }
+
+ if (maxTimeStamp == maxMinTimeStamp) {
+ throw new NoEstimationException("Arrival time of all records are the
same.");
+ }
+
+ return (numberOfRecords * 1000.) / ((double) maxTimeStamp -
maxMinTimeStamp);
+ }
+
+ /** Will be thrown if we cannot estimate the rate for kafka table. */
+ static class NoEstimationException extends Exception {
+ NoEstimationException(String message) {
+ super(message);
+ }
}
}
diff --git
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java
index 710a1a5..c407ff4 100644
---
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java
+++
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java
@@ -20,7 +20,13 @@ package
org.apache.beam.sdk.extensions.sql.meta.provider.kafka;
import static java.nio.charset.StandardCharsets.UTF_8;
import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.beam.sdk.extensions.sql.BeamSqlTable;
+import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv;
+import org.apache.beam.sdk.extensions.sql.impl.BeamTableStatistics;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
+import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestTableUtils;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
@@ -30,11 +36,13 @@ import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.calcite.adapter.java.JavaTypeFactory;
import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.commons.csv.CSVFormat;
+import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
@@ -46,8 +54,101 @@ public class BeamKafkaCSVTableTest {
private static final Row ROW2 = Row.withSchema(genSchema()).addValues(2L, 2,
2.0).build();
+ private static Map<String, BeamSqlTable> tables = new HashMap<>();
+ protected static BeamSqlEnv env = BeamSqlEnv.readOnly("test", tables);
+
+ @Test
+ public void testOrderedArrivalSinglePartitionRate() {
+ KafkaCSVTestTable table = getTable(1);
+ for (int i = 0; i < 100; i++) {
+ table.addRecord(KafkaTestRecord.create("key1", i + ",1,2", "topic1", 500
* i));
+ }
+
+ BeamTableStatistics stats = table.getTableStatistics(null);
+ Assert.assertEquals(2d, stats.getRate(), 0.001);
+ }
+
+ @Test
+ public void testOrderedArrivalMultiplePartitionsRate() {
+ KafkaCSVTestTable table = getTable(3);
+ for (int i = 0; i < 100; i++) {
+ table.addRecord(KafkaTestRecord.create("key" + i, i + ",1,2", "topic1",
500 * i));
+ }
+
+ BeamTableStatistics stats = table.getTableStatistics(null);
+ Assert.assertEquals(2d, stats.getRate(), 0.001);
+ }
+
+ @Test
+ public void testOnePartitionAheadRate() {
+ KafkaCSVTestTable table = getTable(3);
+ for (int i = 0; i < 100; i++) {
+ table.addRecord(KafkaTestRecord.create("1", i + ",1,2", "topic1", 1000 *
i));
+ table.addRecord(KafkaTestRecord.create("2", i + ",1,2", "topic1", 500 *
i));
+ }
+
+ table.setNumberOfRecordsForRate(20);
+ BeamTableStatistics stats = table.getTableStatistics(null);
+ Assert.assertEquals(1d, stats.getRate(), 0.001);
+ }
+
+ @Test
+ public void testLateRecords() {
+ KafkaCSVTestTable table = getTable(3);
+
+ table.addRecord(KafkaTestRecord.create("1", 132 + ",1,2", "topic1", 1000));
+ for (int i = 0; i < 98; i++) {
+ table.addRecord(KafkaTestRecord.create("1", i + ",1,2", "topic1", 500));
+ }
+ table.addRecord(KafkaTestRecord.create("1", 133 + ",1,2", "topic1", 2000));
+
+ table.setNumberOfRecordsForRate(200);
+ BeamTableStatistics stats = table.getTableStatistics(null);
+ Assert.assertEquals(1d, stats.getRate(), 0.001);
+ }
+
@Test
- public void testCsvRecorderDecoder() throws Exception {
+ public void testAllLate() {
+ KafkaCSVTestTable table = getTable(3);
+
+ table.addRecord(KafkaTestRecord.create("1", 132 + ",1,2", "topic1", 1000));
+ for (int i = 0; i < 98; i++) {
+ table.addRecord(KafkaTestRecord.create("1", i + ",1,2", "topic1", 500));
+ }
+
+ table.setNumberOfRecordsForRate(200);
+ BeamTableStatistics stats = table.getTableStatistics(null);
+ Assert.assertTrue(stats.isUnknown());
+ }
+
+ @Test
+ public void testEmptyPartitionsRate() {
+ KafkaCSVTestTable table = getTable(3);
+ BeamTableStatistics stats = table.getTableStatistics(null);
+ Assert.assertTrue(stats.isUnknown());
+ }
+
+ @Test
+ public void allTheRecordsSameTimeRate() {
+ KafkaCSVTestTable table = getTable(3);
+ for (int i = 0; i < 100; i++) {
+ table.addRecord(KafkaTestRecord.create("key" + i, i + ",1,2", "topic1",
1000));
+ }
+ BeamTableStatistics stats = table.getTableStatistics(null);
+ Assert.assertTrue(stats.isUnknown());
+ }
+
+ private static class PrintDoFn extends DoFn<Row, Row> {
+
+ @ProcessElement
+ public void process(ProcessContext c) {
+ System.out.println("we are here");
+ System.out.println(c.element().getValues());
+ }
+ }
+
+ @Test
+ public void testCsvRecorderDecoder() {
PCollection<Row> result =
pipeline
.apply(Create.of("1,\"1\",1.0", "2,2,2.0"))
@@ -60,7 +161,7 @@ public class BeamKafkaCSVTableTest {
}
@Test
- public void testCsvRecorderEncoder() throws Exception {
+ public void testCsvRecorderEncoder() {
PCollection<Row> result =
pipeline
.apply(Create.of(ROW1, ROW2))
@@ -90,4 +191,17 @@ public class BeamKafkaCSVTableTest {
ctx.output(KV.of(new byte[] {}, ctx.element().getBytes(UTF_8)));
}
}
+
+ private KafkaCSVTestTable getTable(int numberOfPartitions) {
+ return new KafkaCSVTestTable(
+ TestTableUtils.buildBeamSqlSchema(
+ Schema.FieldType.INT32,
+ "order_id",
+ Schema.FieldType.INT32,
+ "site_id",
+ Schema.FieldType.INT32,
+ "price"),
+ ImmutableList.of("topic1", "topic2"),
+ numberOfPartitions);
+ }
}
diff --git
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTableIT.java
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTableIT.java
new file mode 100644
index 0000000..201a1df
--- /dev/null
+++
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTableIT.java
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.meta.provider.kafka;
+
+import static org.apache.beam.sdk.schemas.Schema.FieldType.INT32;
+import static org.apache.beam.sdk.schemas.Schema.toSchema;
+
+import com.alibaba.fastjson.JSON;
+import java.util.Map;
+import java.util.Properties;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import java.util.stream.StreamSupport;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.direct.DirectOptions;
+import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv;
+import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils;
+import org.apache.beam.sdk.extensions.sql.meta.Table;
+import org.apache.beam.sdk.extensions.sql.meta.provider.TableProvider;
+import org.apache.beam.sdk.options.Default;
+import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.options.Validation;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.StateSpec;
+import org.apache.beam.sdk.state.StateSpecs;
+import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.vendor.grpc.v1p21p0.com.google.common.base.MoreObjects;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
+import org.apache.kafka.clients.producer.KafkaProducer;
+import org.apache.kafka.clients.producer.Producer;
+import org.apache.kafka.clients.producer.ProducerRecord;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+
+/**
+ * This is an integration test for KafkaCSVTable. There should be a kafka
server running and the
+ * address should be passed to it.
(https://issues.apache.org/jira/projects/BEAM/issues/BEAM-7523)
+ */
+public class KafkaCSVTableIT {
+
+ @Rule public transient TestPipeline pipeline = TestPipeline.create();
+
+ private static final Schema TEST_TABLE_SCHEMA =
+ Schema.builder()
+ .addNullableField("order_id", Schema.FieldType.INT32)
+ .addNullableField("member_id", Schema.FieldType.INT32)
+ .addNullableField("item_name", Schema.FieldType.INT32)
+ .build();
+
+ @BeforeClass
+ public static void prepare() {
+ PipelineOptionsFactory.register(KafkaOptions.class);
+ }
+
+ @Test
+ @SuppressWarnings("FutureReturnValueIgnored")
+ public void testFake2() throws BeamKafkaTable.NoEstimationException {
+ KafkaOptions kafkaOptions = pipeline.getOptions().as(KafkaOptions.class);
+ Table table =
+ Table.builder()
+ .name("kafka_table")
+ .comment("kafka" + " table")
+ .location("")
+ .schema(
+ Stream.of(
+ Schema.Field.nullable("order_id", INT32),
+ Schema.Field.nullable("member_id", INT32),
+ Schema.Field.nullable("item_name", INT32))
+ .collect(toSchema()))
+ .type("kafka")
+
.properties(JSON.parseObject(getKafkaPropertiesString(kafkaOptions)))
+ .build();
+ BeamKafkaTable kafkaTable = (BeamKafkaTable) new
KafkaTableProvider().buildBeamSqlTable(table);
+ produceSomeRecordsWithDelay(100, 20);
+ double rate1 = kafkaTable.computeRate(20);
+ produceSomeRecordsWithDelay(100, 10);
+ double rate2 = kafkaTable.computeRate(20);
+ Assert.assertTrue(rate2 > rate1);
+ }
+
+ private String getKafkaPropertiesString(KafkaOptions kafkaOptions) {
+ return "{ \"bootstrap.servers\" : \""
+ + kafkaOptions.getKafkaBootstrapServerAddress()
+ + "\",\"topics\":[\""
+ + kafkaOptions.getKafkaTopic()
+ + "\"] }";
+ }
+
+ static final transient Map<Long, Boolean> FLAG = new ConcurrentHashMap<>();
+
+ @Test
+ public void testFake() throws InterruptedException {
+ KafkaOptions kafkaOptions = pipeline.getOptions().as(KafkaOptions.class);
+ pipeline.getOptions().as(DirectOptions.class).setBlockOnRun(false);
+ String createTableString =
+ "CREATE EXTERNAL TABLE kafka_table(\n"
+ + "order_id INTEGER, \n"
+ + "member_id INTEGER, \n"
+ + "item_name INTEGER \n"
+ + ") \n"
+ + "TYPE 'kafka' \n"
+ + "LOCATION '"
+ + "'\n"
+ + "TBLPROPERTIES '"
+ + getKafkaPropertiesString(kafkaOptions)
+ + "'";
+ TableProvider tb = new KafkaTableProvider();
+ BeamSqlEnv env = BeamSqlEnv.inMemory(tb);
+
+ env.executeDdl(createTableString);
+
+ PCollection<Row> queryOutput =
+ BeamSqlRelUtils.toPCollection(pipeline, env.parseQuery("SELECT * FROM
kafka_table"));
+
+ queryOutput
+ .apply(ParDo.of(new FakeKvPair()))
+ .apply(
+ "waitForSuccess",
+ ParDo.of(
+ new StreamAssertEqual(
+ ImmutableSet.of(
+ row(TEST_TABLE_SCHEMA, 0, 1, 0),
+ row(TEST_TABLE_SCHEMA, 1, 2, 1),
+ row(TEST_TABLE_SCHEMA, 2, 3, 2)))));
+ queryOutput.apply(logRecords(""));
+ pipeline.run();
+ TimeUnit.MILLISECONDS.sleep(3000);
+ produceSomeRecords(3);
+
+ for (int i = 0; i < 200; i++) {
+ if (FLAG.getOrDefault(pipeline.getOptions().getOptionsId(), false)) {
+ return;
+ }
+ TimeUnit.MILLISECONDS.sleep(60);
+ }
+ Assert.fail();
+ }
+
+ private static MapElements<Row, Void> logRecords(String suffix) {
+ return MapElements.via(
+ new SimpleFunction<Row, Void>() {
+ @Override
+ public @Nullable Void apply(Row input) {
+ System.out.println(input.getValues() + suffix);
+ return null;
+ }
+ });
+ }
+
+ /** This is made because DoFn with states should get KV as input. */
+ public static class FakeKvPair extends DoFn<Row, KV<String, Row>> {
+ @ProcessElement
+ public void processElement(ProcessContext c) {
+ c.output(KV.of("fake_key", c.element()));
+ }
+ }
+
+ /** This DoFn will set a flag if all the elements are seen. */
+ public static class StreamAssertEqual extends DoFn<KV<String, Row>, Void> {
+ private final Set<Row> expected;
+
+ StreamAssertEqual(Set<Row> expected) {
+ super();
+ this.expected = expected;
+ }
+
+ @DoFn.StateId("seenValues")
+ private final StateSpec<BagState<Row>> seenRows = StateSpecs.bag();
+
+ @StateId("count")
+ private final StateSpec<ValueState<Integer>> countState =
StateSpecs.value();
+
+ @ProcessElement
+ public void process(
+ ProcessContext context,
+ @StateId("seenValues") BagState<Row> seenValues,
+ @StateId("count") ValueState<Integer> countState) {
+ // I don't think doing this will be safe in parallel
+ int count = MoreObjects.firstNonNull(countState.read(), 0);
+ count = count + 1;
+ countState.write(count);
+ seenValues.add(context.element().getValue());
+
+ if (count >= expected.size()) {
+ if (StreamSupport.stream(seenValues.read().spliterator(), false)
+ .collect(Collectors.toSet())
+ .containsAll(expected)) {
+ System.out.println("in second if");
+ FLAG.put(context.getPipelineOptions().getOptionsId(), true);
+ }
+ }
+ }
+ }
+
+ private Row row(Schema schema, Object... values) {
+ return Row.withSchema(schema).addValues(values).build();
+ }
+
+ @SuppressWarnings("FutureReturnValueIgnored")
+ private void produceSomeRecords(int num) {
+ Producer<String, String> producer = new KafkaProducer<String,
String>(producerProps());
+ String topicName =
pipeline.getOptions().as(KafkaOptions.class).getKafkaTopic();
+ for (int i = 0; i < num; i++) {
+ producer.send(
+ new ProducerRecord<String, String>(
+ topicName, "k" + i, i + "," + ((i % 3) + 1) + "," + i));
+ }
+ producer.flush();
+ producer.close();
+ }
+
+ @SuppressWarnings("FutureReturnValueIgnored")
+ private void produceSomeRecordsWithDelay(int num, int delayMilis) {
+ Producer<String, String> producer = new KafkaProducer<String,
String>(producerProps());
+ String topicName =
pipeline.getOptions().as(KafkaOptions.class).getKafkaTopic();
+ for (int i = 0; i < num; i++) {
+ producer.send(
+ new ProducerRecord<String, String>(
+ topicName, "k" + i, i + "," + ((i % 3) + 1) + "," + i));
+ try {
+ TimeUnit.MILLISECONDS.sleep(delayMilis);
+ } catch (InterruptedException e) {
+ throw new RuntimeException("Could not wait for producing", e);
+ }
+ }
+ producer.flush();
+ producer.close();
+ }
+
+ private Properties producerProps() {
+ KafkaOptions options = pipeline.getOptions().as(KafkaOptions.class);
+ Properties props = new Properties();
+ props.put("bootstrap.servers", options.getKafkaBootstrapServerAddress());
+ props.put("key.serializer",
"org.apache.kafka.common.serialization.StringSerializer");
+ props.put("value.serializer",
"org.apache.kafka.common.serialization.StringSerializer");
+ props.put("buffer.memory", 33554432);
+ props.put("acks", "all");
+ props.put("request.required.acks", "1");
+ props.put("retries", 0);
+ props.put("linger.ms", 1);
+ return props;
+ }
+
+ /** Pipeline options specific for this test. */
+ public interface KafkaOptions extends PipelineOptions {
+
+ @Description("Kafka server address")
+ @Validation.Required
+ @Default.String("localhost:9092")
+ String getKafkaBootstrapServerAddress();
+
+ void setKafkaBootstrapServerAddress(String address);
+
+ @Description("Kafka topic")
+ @Validation.Required
+ @Default.String("test")
+ String getKafkaTopic();
+
+ void setKafkaTopic(String topic);
+ }
+}
diff --git
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTestTable.java
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTestTable.java
new file mode 100644
index 0000000..749adea
--- /dev/null
+++
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTestTable.java
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.meta.provider.kafka;
+
+import static java.nio.charset.StandardCharsets.UTF_8;
+
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.io.kafka.KafkaIO;
+import org.apache.beam.sdk.schemas.Schema;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.MockConsumer;
+import org.apache.kafka.clients.consumer.OffsetAndTimestamp;
+import org.apache.kafka.clients.consumer.OffsetResetStrategy;
+import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.record.TimestampType;
+
+/** This is a MockKafkaCSVTestTable. It will use a Mock Consumer. */
+public class KafkaCSVTestTable extends BeamKafkaCSVTable {
+ private int partitionsPerTopic;
+ private List<KafkaTestRecord> records;
+ private static final String TIMESTAMP_TYPE_CONFIG = "test.timestamp.type";
+
+ public KafkaCSVTestTable(Schema beamSchema, List<String> topics, int
partitionsPerTopic) {
+ super(beamSchema, "server:123", topics);
+ this.partitionsPerTopic = partitionsPerTopic;
+ this.records = new ArrayList<>();
+ }
+
+ @Override
+ KafkaIO.Read<byte[], byte[]> createKafkaRead() {
+ return super.createKafkaRead().withConsumerFactoryFn(this::mkMockConsumer);
+ }
+
+ public void addRecord(KafkaTestRecord record) {
+ records.add(record);
+ }
+
+ @Override
+ double computeRate(int numberOfRecords) throws NoEstimationException {
+ return super.computeRate(mkMockConsumer(new HashMap<>()), numberOfRecords);
+ }
+
+ public void setNumberOfRecordsForRate(int numberOfRecordsForRate) {
+ this.numberOfRecordsForRate = numberOfRecordsForRate;
+ }
+
+ private MockConsumer<byte[], byte[]> mkMockConsumer(Map<String, Object>
config) {
+ OffsetResetStrategy offsetResetStrategy = OffsetResetStrategy.EARLIEST;
+ final Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>>
kafkaRecords = new HashMap<>();
+ Map<String, List<PartitionInfo>> partitionInfoMap = new HashMap<>();
+ Map<String, List<TopicPartition>> partitionMap = new HashMap<>();
+
+ // Create Topic Paritions
+ for (String topic : this.getTopics()) {
+ List<PartitionInfo> partIds = new ArrayList<>(partitionsPerTopic);
+ List<TopicPartition> topicParitions = new
ArrayList<>(partitionsPerTopic);
+ for (int i = 0; i < partitionsPerTopic; i++) {
+ TopicPartition tp = new TopicPartition(topic, i);
+ topicParitions.add(tp);
+ partIds.add(new PartitionInfo(topic, i, null, null, null));
+ kafkaRecords.put(tp, new ArrayList<>());
+ }
+ partitionInfoMap.put(topic, partIds);
+ partitionMap.put(topic, topicParitions);
+ }
+
+ TimestampType timestampType =
+ TimestampType.forName(
+ (String)
+ config.getOrDefault(
+ TIMESTAMP_TYPE_CONFIG,
TimestampType.LOG_APPEND_TIME.toString()));
+
+ for (KafkaTestRecord record : this.records) {
+ int partitionIndex = record.getKey().hashCode() % partitionsPerTopic;
+ TopicPartition tp =
partitionMap.get(record.getTopic()).get(partitionIndex);
+ byte[] key = record.getKey().getBytes(UTF_8);
+ byte[] value = record.getValue().getBytes(UTF_8);
+ kafkaRecords
+ .get(tp)
+ .add(
+ new ConsumerRecord<>(
+ tp.topic(),
+ tp.partition(),
+ kafkaRecords.get(tp).size(),
+ record.getTimeStamp(),
+ timestampType,
+ 0,
+ key.length,
+ value.length,
+ key,
+ value));
+ }
+
+ // This is updated when reader assigns partitions.
+ final AtomicReference<List<TopicPartition>> assignedPartitions =
+ new AtomicReference<>(Collections.<TopicPartition>emptyList());
+ final MockConsumer<byte[], byte[]> consumer =
+ new MockConsumer<byte[], byte[]>(offsetResetStrategy) {
+ @Override
+ public synchronized void assign(final Collection<TopicPartition>
assigned) {
+ Collection<TopicPartition> realPartitions =
+ assigned.stream()
+ .map(part ->
partitionMap.get(part.topic()).get(part.partition()))
+ .collect(Collectors.toList());
+ super.assign(realPartitions);
+ assignedPartitions.set(ImmutableList.copyOf(realPartitions));
+ for (TopicPartition tp : realPartitions) {
+ updateBeginningOffsets(ImmutableMap.of(tp, 0L));
+ updateEndOffsets(ImmutableMap.of(tp, (long)
kafkaRecords.get(tp).size()));
+ }
+ }
+ // Override offsetsForTimes() in order to look up the offsets by
timestamp.
+ @Override
+ public synchronized Map<TopicPartition, OffsetAndTimestamp>
offsetsForTimes(
+ Map<TopicPartition, Long> timestampsToSearch) {
+ return timestampsToSearch.entrySet().stream()
+ .map(
+ e -> {
+ // In test scope, timestamp == offset. ????
+ long maxOffset = kafkaRecords.get(e.getKey()).size();
+ long offset = e.getValue();
+ OffsetAndTimestamp value =
+ (offset >= maxOffset) ? null : new
OffsetAndTimestamp(offset, offset);
+ return new AbstractMap.SimpleEntry<>(e.getKey(), value);
+ })
+ .collect(
+ Collectors.toMap(
+ AbstractMap.SimpleEntry::getKey,
AbstractMap.SimpleEntry::getValue));
+ }
+ };
+
+ for (String topic : getTopics()) {
+ consumer.updatePartitions(topic, partitionInfoMap.get(topic));
+ }
+
+ Runnable recordEnqueueTask =
+ new Runnable() {
+ @Override
+ public void run() {
+ // add all the records with offset >= current partition position.
+ int recordsAdded = 0;
+ for (TopicPartition tp : assignedPartitions.get()) {
+ long curPos = consumer.position(tp);
+ for (ConsumerRecord<byte[], byte[]> r : kafkaRecords.get(tp)) {
+ if (r.offset() >= curPos) {
+ consumer.addRecord(r);
+ recordsAdded++;
+ }
+ }
+ }
+ if (recordsAdded == 0) {
+ if (config.get("inject.error.at.eof") != null) {
+ consumer.setException(new KafkaException("Injected error in
consumer.poll()"));
+ }
+ // MockConsumer.poll(timeout) does not actually wait even when
there aren't any
+ // records.
+ // Add a small wait here in order to avoid busy looping in the
reader.
+ Uninterruptibles.sleepUninterruptibly(10, TimeUnit.MILLISECONDS);
+ }
+ consumer.schedulePollTask(this);
+ }
+ };
+
+ consumer.schedulePollTask(recordEnqueueTask);
+
+ return consumer;
+ }
+}
diff --git
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestRecord.java
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestRecord.java
new file mode 100644
index 0000000..015ac8b
--- /dev/null
+++
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestRecord.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.meta.provider.kafka;
+
+import com.google.auto.value.AutoValue;
+import java.io.Serializable;
+
+/** This class is created because Kafka Consumer Records are not serializable.
*/
+@AutoValue
+public abstract class KafkaTestRecord implements Serializable {
+
+ public abstract String getKey();
+
+ public abstract String getValue();
+
+ public abstract String getTopic();
+
+ public abstract long getTimeStamp();
+
+ public static KafkaTestRecord create(
+ String newKey, String newValue, String newTopic, long newTimeStamp) {
+ return new AutoValue_KafkaTestRecord(newKey, newValue, newTopic,
newTimeStamp);
+ }
+}