This is an automated email from the ASF dual-hosted git repository.
ahmedabualsaud pushed a commit to branch release-2.57.0
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/release-2.57.0 by this push:
new a36857f10a3 Cherrypick (#31362) kafka schematransform translation
(#31518)
a36857f10a3 is described below
commit a36857f10a3314b7dac1445373a6d98dd3c85a31
Author: Ahmed Abualsaud <[email protected]>
AuthorDate: Wed Jun 5 16:37:24 2024 -0400
Cherrypick (#31362) kafka schematransform translation (#31518)
* kafka schematransform translation and tests
* cleanup
* spotless
* address failing tests
* switch existing schematransform tests to use Managed API
* fix nullness
* add some more mappings
* fix mapping
* typo
* more accurate test name
* cleanup after merging snake_case PR
* spotless
---
.../java/org/apache/beam/sdk/io/kafka/KafkaIO.java | 2 +-
.../KafkaReadSchemaTransformConfiguration.java | 6 +
.../io/kafka/KafkaReadSchemaTransformProvider.java | 260 +++++++++++----------
.../io/kafka/KafkaSchemaTransformTranslation.java | 93 ++++++++
.../kafka/KafkaWriteSchemaTransformProvider.java | 16 ++
.../org/apache/beam/sdk/io/kafka/KafkaIOIT.java | 45 ++--
.../KafkaReadSchemaTransformProviderTest.java | 49 ++--
.../kafka/KafkaSchemaTransformTranslationTest.java | 216 +++++++++++++++++
8 files changed, 522 insertions(+), 165 deletions(-)
diff --git
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
index 8f995a63a10..35aabbbfd97 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
@@ -2665,7 +2665,7 @@ public class KafkaIO {
abstract Builder<K, V> setProducerFactoryFn(
@Nullable SerializableFunction<Map<String, Object>, Producer<K, V>>
fn);
- abstract Builder<K, V> setKeySerializer(Class<? extends Serializer<K>>
serializer);
+ abstract Builder<K, V> setKeySerializer(@Nullable Class<? extends
Serializer<K>> serializer);
abstract Builder<K, V> setValueSerializer(Class<? extends Serializer<V>>
serializer);
diff --git
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java
index 13f5249a6c3..693c1371f78 100644
---
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java
+++
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java
@@ -149,6 +149,10 @@ public abstract class
KafkaReadSchemaTransformConfiguration {
/** Sets the topic from which to read. */
public abstract String getTopic();
+ @SchemaFieldDescription("Upper bound of how long to read from Kafka.")
+ @Nullable
+ public abstract Integer getMaxReadTimeSeconds();
+
@SchemaFieldDescription("This option specifies whether and where to output
unwritable rows.")
@Nullable
public abstract ErrorHandling getErrorHandling();
@@ -179,6 +183,8 @@ public abstract class KafkaReadSchemaTransformConfiguration
{
/** Sets the topic from which to read. */
public abstract Builder setTopic(String value);
+ public abstract Builder setMaxReadTimeSeconds(Integer maxReadTimeSeconds);
+
public abstract Builder setErrorHandling(ErrorHandling errorHandling);
/** Builds a {@link KafkaReadSchemaTransformConfiguration} instance. */
diff --git
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java
index 13240ea9dc4..b2eeb1a54d1 100644
---
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java
+++
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java
@@ -17,6 +17,8 @@
*/
package org.apache.beam.sdk.io.kafka;
+import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
+
import com.google.auto.service.AutoService;
import java.io.FileOutputStream;
import java.io.IOException;
@@ -38,7 +40,9 @@ import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils;
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
+import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.transforms.Convert;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
@@ -56,7 +60,6 @@ import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
-import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
@@ -76,19 +79,6 @@ public class KafkaReadSchemaTransformProvider
public static final TupleTag<Row> OUTPUT_TAG = new TupleTag<Row>() {};
public static final TupleTag<Row> ERROR_TAG = new TupleTag<Row>() {};
- final Boolean isTest;
- final Integer testTimeoutSecs;
-
- public KafkaReadSchemaTransformProvider() {
- this(false, 0);
- }
-
- @VisibleForTesting
- KafkaReadSchemaTransformProvider(Boolean isTest, Integer testTimeoutSecs) {
- this.isTest = isTest;
- this.testTimeoutSecs = testTimeoutSecs;
- }
-
@Override
protected Class<KafkaReadSchemaTransformConfiguration> configurationClass() {
return KafkaReadSchemaTransformConfiguration.class;
@@ -99,113 +89,7 @@ public class KafkaReadSchemaTransformProvider
})
@Override
protected SchemaTransform from(KafkaReadSchemaTransformConfiguration
configuration) {
- configuration.validate();
-
- final String inputSchema = configuration.getSchema();
- final int groupId = configuration.hashCode() % Integer.MAX_VALUE;
- final String autoOffsetReset =
- MoreObjects.firstNonNull(configuration.getAutoOffsetResetConfig(),
"latest");
-
- Map<String, Object> consumerConfigs =
- new HashMap<>(
- MoreObjects.firstNonNull(configuration.getConsumerConfigUpdates(),
new HashMap<>()));
- consumerConfigs.put(ConsumerConfig.GROUP_ID_CONFIG, "kafka-read-provider-"
+ groupId);
- consumerConfigs.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, true);
- consumerConfigs.put(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG, 100);
- consumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG,
autoOffsetReset);
-
- String format = configuration.getFormat();
- boolean handleErrors =
ErrorHandling.hasOutput(configuration.getErrorHandling());
-
- SerializableFunction<byte[], Row> valueMapper;
- Schema beamSchema;
-
- String confluentSchemaRegUrl =
configuration.getConfluentSchemaRegistryUrl();
- if (confluentSchemaRegUrl != null) {
- return new SchemaTransform() {
- @Override
- public PCollectionRowTuple expand(PCollectionRowTuple input) {
- final String confluentSchemaRegSubject =
- configuration.getConfluentSchemaRegistrySubject();
- KafkaIO.Read<byte[], GenericRecord> kafkaRead =
- KafkaIO.<byte[], GenericRecord>read()
- .withTopic(configuration.getTopic())
- .withConsumerFactoryFn(new
ConsumerFactoryWithGcsTrustStores())
- .withBootstrapServers(configuration.getBootstrapServers())
- .withConsumerConfigUpdates(consumerConfigs)
- .withKeyDeserializer(ByteArrayDeserializer.class)
- .withValueDeserializer(
- ConfluentSchemaRegistryDeserializerProvider.of(
- confluentSchemaRegUrl, confluentSchemaRegSubject));
- if (isTest) {
- kafkaRead =
kafkaRead.withMaxReadTime(Duration.standardSeconds(testTimeoutSecs));
- }
-
- PCollection<GenericRecord> kafkaValues =
-
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());
-
- assert kafkaValues.getCoder().getClass() == AvroCoder.class;
- AvroCoder<GenericRecord> coder = (AvroCoder<GenericRecord>)
kafkaValues.getCoder();
- kafkaValues =
kafkaValues.setCoder(AvroUtils.schemaCoder(coder.getSchema()));
- return PCollectionRowTuple.of("output",
kafkaValues.apply(Convert.toRows()));
- }
- };
- }
- if ("RAW".equals(format)) {
- beamSchema = Schema.builder().addField("payload",
Schema.FieldType.BYTES).build();
- valueMapper = getRawBytesToRowFunction(beamSchema);
- } else if ("PROTO".equals(format)) {
- String fileDescriptorPath = configuration.getFileDescriptorPath();
- String messageName = configuration.getMessageName();
- if (fileDescriptorPath != null) {
- beamSchema = ProtoByteUtils.getBeamSchemaFromProto(fileDescriptorPath,
messageName);
- valueMapper =
ProtoByteUtils.getProtoBytesToRowFunction(fileDescriptorPath, messageName);
- } else {
- beamSchema = ProtoByteUtils.getBeamSchemaFromProtoSchema(inputSchema,
messageName);
- valueMapper =
ProtoByteUtils.getProtoBytesToRowFromSchemaFunction(inputSchema, messageName);
- }
- } else if ("JSON".equals(format)) {
- beamSchema = JsonUtils.beamSchemaFromJsonSchema(inputSchema);
- valueMapper = JsonUtils.getJsonBytesToRowFunction(beamSchema);
- } else {
- beamSchema = AvroUtils.toBeamSchema(new
org.apache.avro.Schema.Parser().parse(inputSchema));
- valueMapper = AvroUtils.getAvroBytesToRowFunction(beamSchema);
- }
-
- return new SchemaTransform() {
- @Override
- public PCollectionRowTuple expand(PCollectionRowTuple input) {
- KafkaIO.Read<byte[], byte[]> kafkaRead =
- KafkaIO.readBytes()
- .withConsumerConfigUpdates(consumerConfigs)
- .withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores())
- .withTopic(configuration.getTopic())
- .withBootstrapServers(configuration.getBootstrapServers());
- if (isTest) {
- kafkaRead =
kafkaRead.withMaxReadTime(Duration.standardSeconds(testTimeoutSecs));
- }
-
- PCollection<byte[]> kafkaValues =
-
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());
-
- Schema errorSchema = ErrorHandling.errorSchemaBytes();
- PCollectionTuple outputTuple =
- kafkaValues.apply(
- ParDo.of(
- new ErrorFn(
- "Kafka-read-error-counter", valueMapper,
errorSchema, handleErrors))
- .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));
-
- PCollectionRowTuple outputRows =
- PCollectionRowTuple.of("output",
outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema));
-
- PCollection<Row> errorOutput =
outputTuple.get(ERROR_TAG).setRowSchema(errorSchema);
- if (handleErrors) {
- outputRows =
outputRows.and(configuration.getErrorHandling().getOutput(), errorOutput);
- }
- return outputRows;
- }
- };
+ return new KafkaReadSchemaTransform(configuration);
}
public static SerializableFunction<byte[], Row>
getRawBytesToRowFunction(Schema rawSchema) {
@@ -232,6 +116,140 @@ public class KafkaReadSchemaTransformProvider
return Arrays.asList("output", "errors");
}
+ static class KafkaReadSchemaTransform extends SchemaTransform {
+ private final KafkaReadSchemaTransformConfiguration configuration;
+
+ KafkaReadSchemaTransform(KafkaReadSchemaTransformConfiguration
configuration) {
+ this.configuration = configuration;
+ }
+
+ Row getConfigurationRow() {
+ try {
+ // To stay consistent with our SchemaTransform configuration naming
conventions,
+ // we sort lexicographically
+ return SchemaRegistry.createDefault()
+ .getToRowFunction(KafkaReadSchemaTransformConfiguration.class)
+ .apply(configuration)
+ .sorted()
+ .toSnakeCase();
+ } catch (NoSuchSchemaException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public PCollectionRowTuple expand(PCollectionRowTuple input) {
+ configuration.validate();
+
+ final String inputSchema = configuration.getSchema();
+ final int groupId = configuration.hashCode() % Integer.MAX_VALUE;
+ final String autoOffsetReset =
+ MoreObjects.firstNonNull(configuration.getAutoOffsetResetConfig(),
"latest");
+
+ Map<String, Object> consumerConfigs =
+ new HashMap<>(
+
MoreObjects.firstNonNull(configuration.getConsumerConfigUpdates(), new
HashMap<>()));
+ consumerConfigs.put(ConsumerConfig.GROUP_ID_CONFIG,
"kafka-read-provider-" + groupId);
+ consumerConfigs.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, true);
+ consumerConfigs.put(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG, 100);
+ consumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG,
autoOffsetReset);
+
+ String format = configuration.getFormat();
+ boolean handleErrors =
ErrorHandling.hasOutput(configuration.getErrorHandling());
+
+ SerializableFunction<byte[], Row> valueMapper;
+ Schema beamSchema;
+
+ String confluentSchemaRegUrl =
configuration.getConfluentSchemaRegistryUrl();
+ if (confluentSchemaRegUrl != null) {
+ final String confluentSchemaRegSubject =
+
checkArgumentNotNull(configuration.getConfluentSchemaRegistrySubject());
+ KafkaIO.Read<byte[], GenericRecord> kafkaRead =
+ KafkaIO.<byte[], GenericRecord>read()
+ .withTopic(configuration.getTopic())
+ .withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores())
+ .withBootstrapServers(configuration.getBootstrapServers())
+ .withConsumerConfigUpdates(consumerConfigs)
+ .withKeyDeserializer(ByteArrayDeserializer.class)
+ .withValueDeserializer(
+ ConfluentSchemaRegistryDeserializerProvider.of(
+ confluentSchemaRegUrl, confluentSchemaRegSubject));
+ Integer maxReadTimeSeconds = configuration.getMaxReadTimeSeconds();
+ if (maxReadTimeSeconds != null) {
+ kafkaRead =
kafkaRead.withMaxReadTime(Duration.standardSeconds(maxReadTimeSeconds));
+ }
+
+ PCollection<GenericRecord> kafkaValues =
+
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());
+
+ assert kafkaValues.getCoder().getClass() == AvroCoder.class;
+ AvroCoder<GenericRecord> coder = (AvroCoder<GenericRecord>)
kafkaValues.getCoder();
+ kafkaValues =
kafkaValues.setCoder(AvroUtils.schemaCoder(coder.getSchema()));
+ return PCollectionRowTuple.of("output",
kafkaValues.apply(Convert.toRows()));
+ }
+
+ if ("RAW".equals(format)) {
+ beamSchema = Schema.builder().addField("payload",
Schema.FieldType.BYTES).build();
+ valueMapper = getRawBytesToRowFunction(beamSchema);
+ } else if ("PROTO".equals(format)) {
+ String fileDescriptorPath = configuration.getFileDescriptorPath();
+ String messageName =
checkArgumentNotNull(configuration.getMessageName());
+ if (fileDescriptorPath != null) {
+ beamSchema =
ProtoByteUtils.getBeamSchemaFromProto(fileDescriptorPath, messageName);
+ valueMapper =
ProtoByteUtils.getProtoBytesToRowFunction(fileDescriptorPath, messageName);
+ } else {
+ beamSchema =
+ ProtoByteUtils.getBeamSchemaFromProtoSchema(
+ checkArgumentNotNull(inputSchema), messageName);
+ valueMapper =
+ ProtoByteUtils.getProtoBytesToRowFromSchemaFunction(
+ checkArgumentNotNull(inputSchema), messageName);
+ }
+ } else if ("JSON".equals(format)) {
+ beamSchema =
JsonUtils.beamSchemaFromJsonSchema(checkArgumentNotNull(inputSchema));
+ valueMapper = JsonUtils.getJsonBytesToRowFunction(beamSchema);
+ } else {
+ beamSchema =
+ AvroUtils.toBeamSchema(
+ new
org.apache.avro.Schema.Parser().parse(checkArgumentNotNull(inputSchema)));
+ valueMapper = AvroUtils.getAvroBytesToRowFunction(beamSchema);
+ }
+
+ KafkaIO.Read<byte[], byte[]> kafkaRead =
+ KafkaIO.readBytes()
+ .withConsumerConfigUpdates(consumerConfigs)
+ .withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores())
+ .withTopic(configuration.getTopic())
+ .withBootstrapServers(configuration.getBootstrapServers());
+ Integer maxReadTimeSeconds = configuration.getMaxReadTimeSeconds();
+ if (maxReadTimeSeconds != null) {
+ kafkaRead =
kafkaRead.withMaxReadTime(Duration.standardSeconds(maxReadTimeSeconds));
+ }
+
+ PCollection<byte[]> kafkaValues =
+
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());
+
+ Schema errorSchema = ErrorHandling.errorSchemaBytes();
+ PCollectionTuple outputTuple =
+ kafkaValues.apply(
+ ParDo.of(
+ new ErrorFn(
+ "Kafka-read-error-counter", valueMapper,
errorSchema, handleErrors))
+ .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));
+
+ PCollectionRowTuple outputRows =
+ PCollectionRowTuple.of("output",
outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema));
+
+ PCollection<Row> errorOutput =
outputTuple.get(ERROR_TAG).setRowSchema(errorSchema);
+ if (handleErrors) {
+ outputRows =
+ outputRows.and(
+
checkArgumentNotNull(configuration.getErrorHandling()).getOutput(),
errorOutput);
+ }
+ return outputRows;
+ }
+ }
+
public static class ErrorFn extends DoFn<byte[], Row> {
private final SerializableFunction<byte[], Row> valueMapper;
private final Counter errorCounter;
diff --git
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslation.java
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslation.java
new file mode 100644
index 00000000000..4b83e2b6f55
--- /dev/null
+++
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslation.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import static
org.apache.beam.sdk.io.kafka.KafkaReadSchemaTransformProvider.KafkaReadSchemaTransform;
+import static
org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransform;
+import static
org.apache.beam.sdk.schemas.transforms.SchemaTransformTranslation.SchemaTransformPayloadTranslator;
+
+import com.google.auto.service.AutoService;
+import java.util.Map;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.util.construction.PTransformTranslation;
+import
org.apache.beam.sdk.util.construction.TransformPayloadTranslatorRegistrar;
+import org.apache.beam.sdk.values.Row;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+
+public class KafkaSchemaTransformTranslation {
+ static class KafkaReadSchemaTransformTranslator
+ extends SchemaTransformPayloadTranslator<KafkaReadSchemaTransform> {
+ @Override
+ public SchemaTransformProvider provider() {
+ return new KafkaReadSchemaTransformProvider();
+ }
+
+ @Override
+ public Row toConfigRow(KafkaReadSchemaTransform transform) {
+ return transform.getConfigurationRow();
+ }
+ }
+
+ @AutoService(TransformPayloadTranslatorRegistrar.class)
+ public static class ReadRegistrar implements
TransformPayloadTranslatorRegistrar {
+ @Override
+ @SuppressWarnings({
+ "rawtypes",
+ })
+ public Map<
+ ? extends Class<? extends PTransform>,
+ ? extends PTransformTranslation.TransformPayloadTranslator>
+ getTransformPayloadTranslators() {
+ return ImmutableMap
+ .<Class<? extends PTransform>,
PTransformTranslation.TransformPayloadTranslator>builder()
+ .put(KafkaReadSchemaTransform.class, new
KafkaReadSchemaTransformTranslator())
+ .build();
+ }
+ }
+
+ static class KafkaWriteSchemaTransformTranslator
+ extends SchemaTransformPayloadTranslator<KafkaWriteSchemaTransform> {
+ @Override
+ public SchemaTransformProvider provider() {
+ return new KafkaWriteSchemaTransformProvider();
+ }
+
+ @Override
+ public Row toConfigRow(KafkaWriteSchemaTransform transform) {
+ return transform.getConfigurationRow();
+ }
+ }
+
+ @AutoService(TransformPayloadTranslatorRegistrar.class)
+ public static class WriteRegistrar implements
TransformPayloadTranslatorRegistrar {
+ @Override
+ @SuppressWarnings({
+ "rawtypes",
+ })
+ public Map<
+ ? extends Class<? extends PTransform>,
+ ? extends PTransformTranslation.TransformPayloadTranslator>
+ getTransformPayloadTranslators() {
+ return ImmutableMap
+ .<Class<? extends PTransform>,
PTransformTranslation.TransformPayloadTranslator>builder()
+ .put(KafkaWriteSchemaTransform.class, new
KafkaWriteSchemaTransformTranslator())
+ .build();
+ }
+ }
+}
diff --git
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
index 26f37b790ef..09b338492b4 100644
---
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
+++
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
@@ -31,7 +31,9 @@ import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.schemas.AutoValueSchema;
+import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
@@ -99,6 +101,20 @@ public class KafkaWriteSchemaTransformProvider
this.configuration = configuration;
}
+ Row getConfigurationRow() {
+ try {
+ // To stay consistent with our SchemaTransform configuration naming
conventions,
+ // we sort lexicographically
+ return SchemaRegistry.createDefault()
+ .getToRowFunction(KafkaWriteSchemaTransformConfiguration.class)
+ .apply(configuration)
+ .sorted()
+ .toSnakeCase();
+ } catch (NoSuchSchemaException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
public static class ErrorCounterFn extends DoFn<Row, KV<byte[], byte[]>> {
private final SerializableFunction<Row, byte[]> toBytesFn;
private final Counter errorCounter;
diff --git
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java
index ab6ac52e318..4d38636892c 100644
---
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java
+++
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java
@@ -48,6 +48,7 @@ import
org.apache.beam.sdk.io.kafka.KafkaIOTest.FailingLongSerializer;
import org.apache.beam.sdk.io.kafka.ReadFromKafkaDoFnTest.FailingDeserializer;
import org.apache.beam.sdk.io.synthetic.SyntheticBoundedSource;
import org.apache.beam.sdk.io.synthetic.SyntheticSourceOptions;
+import org.apache.beam.sdk.managed.Managed;
import org.apache.beam.sdk.options.Default;
import org.apache.beam.sdk.options.Description;
import org.apache.beam.sdk.options.ExperimentalOptions;
@@ -607,18 +608,18 @@ public class KafkaIOIT {
private static final int FIVE_MINUTES_IN_MS = 5 * 60 * 1000;
@Test(timeout = FIVE_MINUTES_IN_MS)
- public void testKafkaViaSchemaTransformJson() {
- runReadWriteKafkaViaSchemaTransforms(
+ public void testKafkaViaManagedSchemaTransformJson() {
+ runReadWriteKafkaViaManagedSchemaTransforms(
"JSON", SCHEMA_IN_JSON,
JsonUtils.beamSchemaFromJsonSchema(SCHEMA_IN_JSON));
}
@Test(timeout = FIVE_MINUTES_IN_MS)
- public void testKafkaViaSchemaTransformAvro() {
- runReadWriteKafkaViaSchemaTransforms(
+ public void testKafkaViaManagedSchemaTransformAvro() {
+ runReadWriteKafkaViaManagedSchemaTransforms(
"AVRO", AvroUtils.toAvroSchema(KAFKA_TOPIC_SCHEMA).toString(),
KAFKA_TOPIC_SCHEMA);
}
- public void runReadWriteKafkaViaSchemaTransforms(
+ public void runReadWriteKafkaViaManagedSchemaTransforms(
String format, String schemaDefinition, Schema beamSchema) {
String topicName = options.getKafkaTopic() + "-schema-transform" +
UUID.randomUUID();
PCollectionRowTuple.of(
@@ -646,13 +647,12 @@ public class KafkaIOIT {
.setRowSchema(beamSchema))
.apply(
"Write to Kafka",
- new KafkaWriteSchemaTransformProvider()
- .from(
-
KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransformConfiguration
- .builder()
- .setTopic(topicName)
-
.setBootstrapServers(options.getKafkaBootstrapServerAddresses())
- .setFormat(format)
+ Managed.write(Managed.KAFKA)
+ .withConfig(
+ ImmutableMap.<String, Object>builder()
+ .put("topic", topicName)
+ .put("bootstrap_servers",
options.getKafkaBootstrapServerAddresses())
+ .put("format", format)
.build()));
PAssert.that(
@@ -661,15 +661,18 @@ public class KafkaIOIT {
"Read from unbounded Kafka",
// A timeout of 30s for local, container-based tests, and
2 minutes for
// real-kafka tests.
- new KafkaReadSchemaTransformProvider(
- true, options.isWithTestcontainers() ? 30 : 120)
- .from(
- KafkaReadSchemaTransformConfiguration.builder()
- .setFormat(format)
- .setAutoOffsetResetConfig("earliest")
- .setSchema(schemaDefinition)
- .setTopic(topicName)
-
.setBootstrapServers(options.getKafkaBootstrapServerAddresses())
+ Managed.read(Managed.KAFKA)
+ .withConfig(
+ ImmutableMap.<String, Object>builder()
+ .put("format", format)
+ .put("auto_offset_reset_config", "earliest")
+ .put("schema", schemaDefinition)
+ .put("topic", topicName)
+ .put(
+ "bootstrap_servers",
options.getKafkaBootstrapServerAddresses())
+ .put(
+ "max_read_time_seconds",
+ options.isWithTestcontainers() ? 30 : 120)
.build()))
.get("output"))
.containsInAnyOrder(
diff --git
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java
index dfe062e1eef..19c336e1d24 100644
---
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java
+++
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java
@@ -34,9 +34,11 @@ import java.util.stream.StreamSupport;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.managed.Managed;
import org.apache.beam.sdk.managed.ManagedTransformConstants;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.utils.YamlUtils;
import org.apache.beam.sdk.values.PBegin;
+import org.apache.beam.sdk.values.PCollectionRowTuple;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
@@ -131,7 +133,8 @@ public class KafkaReadSchemaTransformProviderTest {
"confluent_schema_registry_url",
"error_handling",
"file_descriptor_path",
- "message_name"),
+ "message_name",
+ "max_read_time_seconds"),
kafkaProvider.configurationSchema().getFields().stream()
.map(field -> field.getName())
.collect(Collectors.toSet()));
@@ -232,22 +235,23 @@ public class KafkaReadSchemaTransformProviderTest {
.collect(Collectors.toList());
KafkaReadSchemaTransformProvider kafkaProvider =
(KafkaReadSchemaTransformProvider) providers.get(0);
+ SchemaTransform transform =
+ kafkaProvider.from(
+ KafkaReadSchemaTransformConfiguration.builder()
+ .setTopic("anytopic")
+ .setBootstrapServers("anybootstrap")
+ .setFormat("PROTO")
+ .setMessageName("MyOtherMessage")
+ .setFileDescriptorPath(
+ Objects.requireNonNull(
+ getClass()
+
.getResource("/proto_byte/file_descriptor/proto_byte_utils.pb"))
+ .getPath())
+ .build());
assertThrows(
NullPointerException.class,
- () ->
- kafkaProvider.from(
- KafkaReadSchemaTransformConfiguration.builder()
- .setTopic("anytopic")
- .setBootstrapServers("anybootstrap")
- .setFormat("PROTO")
- .setMessageName("MyOtherMessage")
- .setFileDescriptorPath(
- Objects.requireNonNull(
- getClass()
-
.getResource("/proto_byte/file_descriptor/proto_byte_utils.pb"))
- .getPath())
- .build()));
+ () -> transform.expand(PCollectionRowTuple.empty(Pipeline.create())));
}
@Test
@@ -281,17 +285,18 @@ public class KafkaReadSchemaTransformProviderTest {
.collect(Collectors.toList());
KafkaReadSchemaTransformProvider kafkaProvider =
(KafkaReadSchemaTransformProvider) providers.get(0);
+ SchemaTransform transform =
+ kafkaProvider.from(
+ KafkaReadSchemaTransformConfiguration.builder()
+ .setTopic("anytopic")
+ .setBootstrapServers("anybootstrap")
+ .setFormat("PROTO")
+ .setMessageName("MyMessage")
+ .build());
assertThrows(
IllegalArgumentException.class,
- () ->
- kafkaProvider.from(
- KafkaReadSchemaTransformConfiguration.builder()
- .setTopic("anytopic")
- .setBootstrapServers("anybootstrap")
- .setFormat("PROTO")
- .setMessageName("MyMessage")
- .build()));
+ () -> transform.expand(PCollectionRowTuple.empty(Pipeline.create())));
}
@Test
diff --git
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslationTest.java
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslationTest.java
new file mode 100644
index 00000000000..b297227bb7a
--- /dev/null
+++
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslationTest.java
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import static
org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM;
+import static
org.apache.beam.sdk.io.kafka.KafkaReadSchemaTransformProvider.KafkaReadSchemaTransform;
+import static
org.apache.beam.sdk.io.kafka.KafkaSchemaTransformTranslation.KafkaReadSchemaTransformTranslator;
+import static
org.apache.beam.sdk.io.kafka.KafkaSchemaTransformTranslation.KafkaWriteSchemaTransformTranslator;
+import static
org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransform;
+import static org.junit.Assert.assertEquals;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+import
org.apache.beam.model.pipeline.v1.ExternalTransforms.SchemaTransformPayload;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.RowCoder;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.SchemaTranslation;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.util.construction.BeamUrns;
+import org.apache.beam.sdk.util.construction.PipelineTranslation;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.Row;
+import
org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.InvalidProtocolBufferException;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import org.junit.ClassRule;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.rules.TemporaryFolder;
+
+public class KafkaSchemaTransformTranslationTest {
+ @ClassRule public static final TemporaryFolder TEMPORARY_FOLDER = new
TemporaryFolder();
+
+ @Rule public transient ExpectedException thrown = ExpectedException.none();
+
+ static final KafkaWriteSchemaTransformProvider WRITE_PROVIDER =
+ new KafkaWriteSchemaTransformProvider();
+ static final KafkaReadSchemaTransformProvider READ_PROVIDER =
+ new KafkaReadSchemaTransformProvider();
+
+ static final Row READ_CONFIG =
+ Row.withSchema(READ_PROVIDER.configurationSchema())
+ .withFieldValue("format", "RAW")
+ .withFieldValue("topic", "test_topic")
+ .withFieldValue("bootstrap_servers", "host:port")
+ .withFieldValue("confluent_schema_registry_url", null)
+ .withFieldValue("confluent_schema_registry_subject", null)
+ .withFieldValue("schema", null)
+ .withFieldValue("file_descriptor_path", "testPath")
+ .withFieldValue("message_name", "test_message")
+ .withFieldValue("auto_offset_reset_config", "earliest")
+ .withFieldValue("consumer_config_updates", ImmutableMap.<String,
String>builder().build())
+ .withFieldValue("error_handling", null)
+ .build();
+
+ static final Row WRITE_CONFIG =
+ Row.withSchema(WRITE_PROVIDER.configurationSchema())
+ .withFieldValue("format", "RAW")
+ .withFieldValue("topic", "test_topic")
+ .withFieldValue("bootstrap_servers", "host:port")
+ .withFieldValue("producer_config_updates", ImmutableMap.<String,
String>builder().build())
+ .withFieldValue("error_handling", null)
+ .withFieldValue("file_descriptor_path", "testPath")
+ .withFieldValue("message_name", "test_message")
+ .withFieldValue("schema", "test_schema")
+ .build();
+
+ @Test
+ public void testRecreateWriteTransformFromRow() {
+ KafkaWriteSchemaTransform writeTransform =
+ (KafkaWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG);
+
+ KafkaWriteSchemaTransformTranslator translator = new
KafkaWriteSchemaTransformTranslator();
+ Row translatedRow = translator.toConfigRow(writeTransform);
+
+ KafkaWriteSchemaTransform writeTransformFromRow =
+ translator.fromConfigRow(translatedRow,
PipelineOptionsFactory.create());
+
+ assertEquals(WRITE_CONFIG, writeTransformFromRow.getConfigurationRow());
+ }
+
+ @Test
+ public void testWriteTransformProtoTranslation()
+ throws InvalidProtocolBufferException, IOException {
+ // First build a pipeline
+ Pipeline p = Pipeline.create();
+ Schema inputSchema = Schema.builder().addByteArrayField("b").build();
+ PCollection<Row> input =
+ p.apply(
+ Create.of(
+ Collections.singletonList(
+ Row.withSchema(inputSchema).addValue(new byte[] {1, 2,
3}).build())))
+ .setRowSchema(inputSchema);
+
+ KafkaWriteSchemaTransform writeTransform =
+ (KafkaWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG);
+ PCollectionRowTuple.of("input", input).apply(writeTransform);
+
+ // Then translate the pipeline to a proto and extract
KafkaWriteSchemaTransform proto
+ RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+ List<RunnerApi.PTransform> writeTransformProto =
+ pipelineProto.getComponents().getTransformsMap().values().stream()
+ .filter(
+ tr -> {
+ RunnerApi.FunctionSpec spec = tr.getSpec();
+ try {
+ return
spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM))
+ && SchemaTransformPayload.parseFrom(spec.getPayload())
+ .getIdentifier()
+ .equals(WRITE_PROVIDER.identifier());
+ } catch (InvalidProtocolBufferException e) {
+ throw new RuntimeException(e);
+ }
+ })
+ .collect(Collectors.toList());
+ assertEquals(1, writeTransformProto.size());
+ RunnerApi.FunctionSpec spec = writeTransformProto.get(0).getSpec();
+
+ // Check that the proto contains correct values
+ SchemaTransformPayload payload =
SchemaTransformPayload.parseFrom(spec.getPayload());
+ Schema schemaFromSpec =
SchemaTranslation.schemaFromProto(payload.getConfigurationSchema());
+ assertEquals(WRITE_PROVIDER.configurationSchema(), schemaFromSpec);
+ Row rowFromSpec =
RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput());
+
+ assertEquals(WRITE_CONFIG, rowFromSpec);
+
+ // Use the information in the proto to recreate the
KafkaWriteSchemaTransform
+ KafkaWriteSchemaTransformTranslator translator = new
KafkaWriteSchemaTransformTranslator();
+ KafkaWriteSchemaTransform writeTransformFromSpec =
+ translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create());
+
+ assertEquals(WRITE_CONFIG, writeTransformFromSpec.getConfigurationRow());
+ }
+
+ @Test
+ public void testReCreateReadTransformFromRow() {
+ // setting a subset of fields here.
+ KafkaReadSchemaTransform readTransform =
+ (KafkaReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG);
+
+ KafkaReadSchemaTransformTranslator translator = new
KafkaReadSchemaTransformTranslator();
+ Row row = translator.toConfigRow(readTransform);
+
+ KafkaReadSchemaTransform readTransformFromRow =
+ translator.fromConfigRow(row, PipelineOptionsFactory.create());
+
+ assertEquals(READ_CONFIG, readTransformFromRow.getConfigurationRow());
+ }
+
+ @Test
+ public void testReadTransformProtoTranslation()
+ throws InvalidProtocolBufferException, IOException {
+ // First build a pipeline
+ Pipeline p = Pipeline.create();
+
+ KafkaReadSchemaTransform readTransform =
+ (KafkaReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG);
+
+ PCollectionRowTuple.empty(p).apply(readTransform);
+
+ // Then translate the pipeline to a proto and extract
KafkaReadSchemaTransform proto
+ RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+ List<RunnerApi.PTransform> readTransformProto =
+ pipelineProto.getComponents().getTransformsMap().values().stream()
+ .filter(
+ tr -> {
+ RunnerApi.FunctionSpec spec = tr.getSpec();
+ try {
+ return
spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM))
+ && SchemaTransformPayload.parseFrom(spec.getPayload())
+ .getIdentifier()
+ .equals(READ_PROVIDER.identifier());
+ } catch (InvalidProtocolBufferException e) {
+ throw new RuntimeException(e);
+ }
+ })
+ .collect(Collectors.toList());
+ assertEquals(1, readTransformProto.size());
+ RunnerApi.FunctionSpec spec = readTransformProto.get(0).getSpec();
+
+ // Check that the proto contains correct values
+ SchemaTransformPayload payload =
SchemaTransformPayload.parseFrom(spec.getPayload());
+ Schema schemaFromSpec =
SchemaTranslation.schemaFromProto(payload.getConfigurationSchema());
+ assertEquals(READ_PROVIDER.configurationSchema(), schemaFromSpec);
+ Row rowFromSpec =
RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput());
+ assertEquals(READ_CONFIG, rowFromSpec);
+
+ // Use the information in the proto to recreate the
KafkaReadSchemaTransform
+ KafkaReadSchemaTransformTranslator translator = new
KafkaReadSchemaTransformTranslator();
+ KafkaReadSchemaTransform readTransformFromSpec =
+ translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create());
+
+ assertEquals(READ_CONFIG, readTransformFromSpec.getConfigurationRow());
+ }
+}