This is an automated email from the ASF dual-hosted git repository.

ahmedabualsaud pushed a commit to branch release-2.57.0
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/release-2.57.0 by this push:
     new a36857f10a3 Cherrypick (#31362) kafka schematransform translation  
(#31518)
a36857f10a3 is described below

commit a36857f10a3314b7dac1445373a6d98dd3c85a31
Author: Ahmed Abualsaud <[email protected]>
AuthorDate: Wed Jun 5 16:37:24 2024 -0400

    Cherrypick (#31362) kafka schematransform translation  (#31518)
    
    * kafka schematransform translation and tests
    
    * cleanup
    
    * spotless
    
    * address failing tests
    
    * switch existing schematransform tests to use Managed API
    
    * fix nullness
    
    * add some more mappings
    
    * fix mapping
    
    * typo
    
    * more accurate test name
    
    * cleanup after merging snake_case PR
    
    * spotless
---
 .../java/org/apache/beam/sdk/io/kafka/KafkaIO.java |   2 +-
 .../KafkaReadSchemaTransformConfiguration.java     |   6 +
 .../io/kafka/KafkaReadSchemaTransformProvider.java | 260 +++++++++++----------
 .../io/kafka/KafkaSchemaTransformTranslation.java  |  93 ++++++++
 .../kafka/KafkaWriteSchemaTransformProvider.java   |  16 ++
 .../org/apache/beam/sdk/io/kafka/KafkaIOIT.java    |  45 ++--
 .../KafkaReadSchemaTransformProviderTest.java      |  49 ++--
 .../kafka/KafkaSchemaTransformTranslationTest.java | 216 +++++++++++++++++
 8 files changed, 522 insertions(+), 165 deletions(-)

diff --git 
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
index 8f995a63a10..35aabbbfd97 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
@@ -2665,7 +2665,7 @@ public class KafkaIO {
       abstract Builder<K, V> setProducerFactoryFn(
           @Nullable SerializableFunction<Map<String, Object>, Producer<K, V>> 
fn);
 
-      abstract Builder<K, V> setKeySerializer(Class<? extends Serializer<K>> 
serializer);
+      abstract Builder<K, V> setKeySerializer(@Nullable Class<? extends 
Serializer<K>> serializer);
 
       abstract Builder<K, V> setValueSerializer(Class<? extends Serializer<V>> 
serializer);
 
diff --git 
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java
 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java
index 13f5249a6c3..693c1371f78 100644
--- 
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java
+++ 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformConfiguration.java
@@ -149,6 +149,10 @@ public abstract class 
KafkaReadSchemaTransformConfiguration {
   /** Sets the topic from which to read. */
   public abstract String getTopic();
 
+  @SchemaFieldDescription("Upper bound of how long to read from Kafka.")
+  @Nullable
+  public abstract Integer getMaxReadTimeSeconds();
+
   @SchemaFieldDescription("This option specifies whether and where to output 
unwritable rows.")
   @Nullable
   public abstract ErrorHandling getErrorHandling();
@@ -179,6 +183,8 @@ public abstract class KafkaReadSchemaTransformConfiguration 
{
     /** Sets the topic from which to read. */
     public abstract Builder setTopic(String value);
 
+    public abstract Builder setMaxReadTimeSeconds(Integer maxReadTimeSeconds);
+
     public abstract Builder setErrorHandling(ErrorHandling errorHandling);
 
     /** Builds a {@link KafkaReadSchemaTransformConfiguration} instance. */
diff --git 
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java
 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java
index 13240ea9dc4..b2eeb1a54d1 100644
--- 
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java
+++ 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProvider.java
@@ -17,6 +17,8 @@
  */
 package org.apache.beam.sdk.io.kafka;
 
+import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
+
 import com.google.auto.service.AutoService;
 import java.io.FileOutputStream;
 import java.io.IOException;
@@ -38,7 +40,9 @@ import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils;
 import org.apache.beam.sdk.io.FileSystems;
 import org.apache.beam.sdk.metrics.Counter;
 import org.apache.beam.sdk.metrics.Metrics;
+import org.apache.beam.sdk.schemas.NoSuchSchemaException;
 import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.SchemaRegistry;
 import org.apache.beam.sdk.schemas.transforms.Convert;
 import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
 import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
@@ -56,7 +60,6 @@ import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.Row;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TupleTagList;
-import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
@@ -76,19 +79,6 @@ public class KafkaReadSchemaTransformProvider
   public static final TupleTag<Row> OUTPUT_TAG = new TupleTag<Row>() {};
   public static final TupleTag<Row> ERROR_TAG = new TupleTag<Row>() {};
 
-  final Boolean isTest;
-  final Integer testTimeoutSecs;
-
-  public KafkaReadSchemaTransformProvider() {
-    this(false, 0);
-  }
-
-  @VisibleForTesting
-  KafkaReadSchemaTransformProvider(Boolean isTest, Integer testTimeoutSecs) {
-    this.isTest = isTest;
-    this.testTimeoutSecs = testTimeoutSecs;
-  }
-
   @Override
   protected Class<KafkaReadSchemaTransformConfiguration> configurationClass() {
     return KafkaReadSchemaTransformConfiguration.class;
@@ -99,113 +89,7 @@ public class KafkaReadSchemaTransformProvider
   })
   @Override
   protected SchemaTransform from(KafkaReadSchemaTransformConfiguration 
configuration) {
-    configuration.validate();
-
-    final String inputSchema = configuration.getSchema();
-    final int groupId = configuration.hashCode() % Integer.MAX_VALUE;
-    final String autoOffsetReset =
-        MoreObjects.firstNonNull(configuration.getAutoOffsetResetConfig(), 
"latest");
-
-    Map<String, Object> consumerConfigs =
-        new HashMap<>(
-            MoreObjects.firstNonNull(configuration.getConsumerConfigUpdates(), 
new HashMap<>()));
-    consumerConfigs.put(ConsumerConfig.GROUP_ID_CONFIG, "kafka-read-provider-" 
+ groupId);
-    consumerConfigs.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, true);
-    consumerConfigs.put(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG, 100);
-    consumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, 
autoOffsetReset);
-
-    String format = configuration.getFormat();
-    boolean handleErrors = 
ErrorHandling.hasOutput(configuration.getErrorHandling());
-
-    SerializableFunction<byte[], Row> valueMapper;
-    Schema beamSchema;
-
-    String confluentSchemaRegUrl = 
configuration.getConfluentSchemaRegistryUrl();
-    if (confluentSchemaRegUrl != null) {
-      return new SchemaTransform() {
-        @Override
-        public PCollectionRowTuple expand(PCollectionRowTuple input) {
-          final String confluentSchemaRegSubject =
-              configuration.getConfluentSchemaRegistrySubject();
-          KafkaIO.Read<byte[], GenericRecord> kafkaRead =
-              KafkaIO.<byte[], GenericRecord>read()
-                  .withTopic(configuration.getTopic())
-                  .withConsumerFactoryFn(new 
ConsumerFactoryWithGcsTrustStores())
-                  .withBootstrapServers(configuration.getBootstrapServers())
-                  .withConsumerConfigUpdates(consumerConfigs)
-                  .withKeyDeserializer(ByteArrayDeserializer.class)
-                  .withValueDeserializer(
-                      ConfluentSchemaRegistryDeserializerProvider.of(
-                          confluentSchemaRegUrl, confluentSchemaRegSubject));
-          if (isTest) {
-            kafkaRead = 
kafkaRead.withMaxReadTime(Duration.standardSeconds(testTimeoutSecs));
-          }
-
-          PCollection<GenericRecord> kafkaValues =
-              
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());
-
-          assert kafkaValues.getCoder().getClass() == AvroCoder.class;
-          AvroCoder<GenericRecord> coder = (AvroCoder<GenericRecord>) 
kafkaValues.getCoder();
-          kafkaValues = 
kafkaValues.setCoder(AvroUtils.schemaCoder(coder.getSchema()));
-          return PCollectionRowTuple.of("output", 
kafkaValues.apply(Convert.toRows()));
-        }
-      };
-    }
-    if ("RAW".equals(format)) {
-      beamSchema = Schema.builder().addField("payload", 
Schema.FieldType.BYTES).build();
-      valueMapper = getRawBytesToRowFunction(beamSchema);
-    } else if ("PROTO".equals(format)) {
-      String fileDescriptorPath = configuration.getFileDescriptorPath();
-      String messageName = configuration.getMessageName();
-      if (fileDescriptorPath != null) {
-        beamSchema = ProtoByteUtils.getBeamSchemaFromProto(fileDescriptorPath, 
messageName);
-        valueMapper = 
ProtoByteUtils.getProtoBytesToRowFunction(fileDescriptorPath, messageName);
-      } else {
-        beamSchema = ProtoByteUtils.getBeamSchemaFromProtoSchema(inputSchema, 
messageName);
-        valueMapper = 
ProtoByteUtils.getProtoBytesToRowFromSchemaFunction(inputSchema, messageName);
-      }
-    } else if ("JSON".equals(format)) {
-      beamSchema = JsonUtils.beamSchemaFromJsonSchema(inputSchema);
-      valueMapper = JsonUtils.getJsonBytesToRowFunction(beamSchema);
-    } else {
-      beamSchema = AvroUtils.toBeamSchema(new 
org.apache.avro.Schema.Parser().parse(inputSchema));
-      valueMapper = AvroUtils.getAvroBytesToRowFunction(beamSchema);
-    }
-
-    return new SchemaTransform() {
-      @Override
-      public PCollectionRowTuple expand(PCollectionRowTuple input) {
-        KafkaIO.Read<byte[], byte[]> kafkaRead =
-            KafkaIO.readBytes()
-                .withConsumerConfigUpdates(consumerConfigs)
-                .withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores())
-                .withTopic(configuration.getTopic())
-                .withBootstrapServers(configuration.getBootstrapServers());
-        if (isTest) {
-          kafkaRead = 
kafkaRead.withMaxReadTime(Duration.standardSeconds(testTimeoutSecs));
-        }
-
-        PCollection<byte[]> kafkaValues =
-            
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());
-
-        Schema errorSchema = ErrorHandling.errorSchemaBytes();
-        PCollectionTuple outputTuple =
-            kafkaValues.apply(
-                ParDo.of(
-                        new ErrorFn(
-                            "Kafka-read-error-counter", valueMapper, 
errorSchema, handleErrors))
-                    .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));
-
-        PCollectionRowTuple outputRows =
-            PCollectionRowTuple.of("output", 
outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema));
-
-        PCollection<Row> errorOutput = 
outputTuple.get(ERROR_TAG).setRowSchema(errorSchema);
-        if (handleErrors) {
-          outputRows = 
outputRows.and(configuration.getErrorHandling().getOutput(), errorOutput);
-        }
-        return outputRows;
-      }
-    };
+    return new KafkaReadSchemaTransform(configuration);
   }
 
   public static SerializableFunction<byte[], Row> 
getRawBytesToRowFunction(Schema rawSchema) {
@@ -232,6 +116,140 @@ public class KafkaReadSchemaTransformProvider
     return Arrays.asList("output", "errors");
   }
 
+  static class KafkaReadSchemaTransform extends SchemaTransform {
+    private final KafkaReadSchemaTransformConfiguration configuration;
+
+    KafkaReadSchemaTransform(KafkaReadSchemaTransformConfiguration 
configuration) {
+      this.configuration = configuration;
+    }
+
+    Row getConfigurationRow() {
+      try {
+        // To stay consistent with our SchemaTransform configuration naming 
conventions,
+        // we sort lexicographically
+        return SchemaRegistry.createDefault()
+            .getToRowFunction(KafkaReadSchemaTransformConfiguration.class)
+            .apply(configuration)
+            .sorted()
+            .toSnakeCase();
+      } catch (NoSuchSchemaException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public PCollectionRowTuple expand(PCollectionRowTuple input) {
+      configuration.validate();
+
+      final String inputSchema = configuration.getSchema();
+      final int groupId = configuration.hashCode() % Integer.MAX_VALUE;
+      final String autoOffsetReset =
+          MoreObjects.firstNonNull(configuration.getAutoOffsetResetConfig(), 
"latest");
+
+      Map<String, Object> consumerConfigs =
+          new HashMap<>(
+              
MoreObjects.firstNonNull(configuration.getConsumerConfigUpdates(), new 
HashMap<>()));
+      consumerConfigs.put(ConsumerConfig.GROUP_ID_CONFIG, 
"kafka-read-provider-" + groupId);
+      consumerConfigs.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, true);
+      consumerConfigs.put(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG, 100);
+      consumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, 
autoOffsetReset);
+
+      String format = configuration.getFormat();
+      boolean handleErrors = 
ErrorHandling.hasOutput(configuration.getErrorHandling());
+
+      SerializableFunction<byte[], Row> valueMapper;
+      Schema beamSchema;
+
+      String confluentSchemaRegUrl = 
configuration.getConfluentSchemaRegistryUrl();
+      if (confluentSchemaRegUrl != null) {
+        final String confluentSchemaRegSubject =
+            
checkArgumentNotNull(configuration.getConfluentSchemaRegistrySubject());
+        KafkaIO.Read<byte[], GenericRecord> kafkaRead =
+            KafkaIO.<byte[], GenericRecord>read()
+                .withTopic(configuration.getTopic())
+                .withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores())
+                .withBootstrapServers(configuration.getBootstrapServers())
+                .withConsumerConfigUpdates(consumerConfigs)
+                .withKeyDeserializer(ByteArrayDeserializer.class)
+                .withValueDeserializer(
+                    ConfluentSchemaRegistryDeserializerProvider.of(
+                        confluentSchemaRegUrl, confluentSchemaRegSubject));
+        Integer maxReadTimeSeconds = configuration.getMaxReadTimeSeconds();
+        if (maxReadTimeSeconds != null) {
+          kafkaRead = 
kafkaRead.withMaxReadTime(Duration.standardSeconds(maxReadTimeSeconds));
+        }
+
+        PCollection<GenericRecord> kafkaValues =
+            
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());
+
+        assert kafkaValues.getCoder().getClass() == AvroCoder.class;
+        AvroCoder<GenericRecord> coder = (AvroCoder<GenericRecord>) 
kafkaValues.getCoder();
+        kafkaValues = 
kafkaValues.setCoder(AvroUtils.schemaCoder(coder.getSchema()));
+        return PCollectionRowTuple.of("output", 
kafkaValues.apply(Convert.toRows()));
+      }
+
+      if ("RAW".equals(format)) {
+        beamSchema = Schema.builder().addField("payload", 
Schema.FieldType.BYTES).build();
+        valueMapper = getRawBytesToRowFunction(beamSchema);
+      } else if ("PROTO".equals(format)) {
+        String fileDescriptorPath = configuration.getFileDescriptorPath();
+        String messageName = 
checkArgumentNotNull(configuration.getMessageName());
+        if (fileDescriptorPath != null) {
+          beamSchema = 
ProtoByteUtils.getBeamSchemaFromProto(fileDescriptorPath, messageName);
+          valueMapper = 
ProtoByteUtils.getProtoBytesToRowFunction(fileDescriptorPath, messageName);
+        } else {
+          beamSchema =
+              ProtoByteUtils.getBeamSchemaFromProtoSchema(
+                  checkArgumentNotNull(inputSchema), messageName);
+          valueMapper =
+              ProtoByteUtils.getProtoBytesToRowFromSchemaFunction(
+                  checkArgumentNotNull(inputSchema), messageName);
+        }
+      } else if ("JSON".equals(format)) {
+        beamSchema = 
JsonUtils.beamSchemaFromJsonSchema(checkArgumentNotNull(inputSchema));
+        valueMapper = JsonUtils.getJsonBytesToRowFunction(beamSchema);
+      } else {
+        beamSchema =
+            AvroUtils.toBeamSchema(
+                new 
org.apache.avro.Schema.Parser().parse(checkArgumentNotNull(inputSchema)));
+        valueMapper = AvroUtils.getAvroBytesToRowFunction(beamSchema);
+      }
+
+      KafkaIO.Read<byte[], byte[]> kafkaRead =
+          KafkaIO.readBytes()
+              .withConsumerConfigUpdates(consumerConfigs)
+              .withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores())
+              .withTopic(configuration.getTopic())
+              .withBootstrapServers(configuration.getBootstrapServers());
+      Integer maxReadTimeSeconds = configuration.getMaxReadTimeSeconds();
+      if (maxReadTimeSeconds != null) {
+        kafkaRead = 
kafkaRead.withMaxReadTime(Duration.standardSeconds(maxReadTimeSeconds));
+      }
+
+      PCollection<byte[]> kafkaValues =
+          
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());
+
+      Schema errorSchema = ErrorHandling.errorSchemaBytes();
+      PCollectionTuple outputTuple =
+          kafkaValues.apply(
+              ParDo.of(
+                      new ErrorFn(
+                          "Kafka-read-error-counter", valueMapper, 
errorSchema, handleErrors))
+                  .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));
+
+      PCollectionRowTuple outputRows =
+          PCollectionRowTuple.of("output", 
outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema));
+
+      PCollection<Row> errorOutput = 
outputTuple.get(ERROR_TAG).setRowSchema(errorSchema);
+      if (handleErrors) {
+        outputRows =
+            outputRows.and(
+                
checkArgumentNotNull(configuration.getErrorHandling()).getOutput(), 
errorOutput);
+      }
+      return outputRows;
+    }
+  }
+
   public static class ErrorFn extends DoFn<byte[], Row> {
     private final SerializableFunction<byte[], Row> valueMapper;
     private final Counter errorCounter;
diff --git 
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslation.java
 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslation.java
new file mode 100644
index 00000000000..4b83e2b6f55
--- /dev/null
+++ 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslation.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import static 
org.apache.beam.sdk.io.kafka.KafkaReadSchemaTransformProvider.KafkaReadSchemaTransform;
+import static 
org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransform;
+import static 
org.apache.beam.sdk.schemas.transforms.SchemaTransformTranslation.SchemaTransformPayloadTranslator;
+
+import com.google.auto.service.AutoService;
+import java.util.Map;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.util.construction.PTransformTranslation;
+import 
org.apache.beam.sdk.util.construction.TransformPayloadTranslatorRegistrar;
+import org.apache.beam.sdk.values.Row;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+
+public class KafkaSchemaTransformTranslation {
+  static class KafkaReadSchemaTransformTranslator
+      extends SchemaTransformPayloadTranslator<KafkaReadSchemaTransform> {
+    @Override
+    public SchemaTransformProvider provider() {
+      return new KafkaReadSchemaTransformProvider();
+    }
+
+    @Override
+    public Row toConfigRow(KafkaReadSchemaTransform transform) {
+      return transform.getConfigurationRow();
+    }
+  }
+
+  @AutoService(TransformPayloadTranslatorRegistrar.class)
+  public static class ReadRegistrar implements 
TransformPayloadTranslatorRegistrar {
+    @Override
+    @SuppressWarnings({
+      "rawtypes",
+    })
+    public Map<
+            ? extends Class<? extends PTransform>,
+            ? extends PTransformTranslation.TransformPayloadTranslator>
+        getTransformPayloadTranslators() {
+      return ImmutableMap
+          .<Class<? extends PTransform>, 
PTransformTranslation.TransformPayloadTranslator>builder()
+          .put(KafkaReadSchemaTransform.class, new 
KafkaReadSchemaTransformTranslator())
+          .build();
+    }
+  }
+
+  static class KafkaWriteSchemaTransformTranslator
+      extends SchemaTransformPayloadTranslator<KafkaWriteSchemaTransform> {
+    @Override
+    public SchemaTransformProvider provider() {
+      return new KafkaWriteSchemaTransformProvider();
+    }
+
+    @Override
+    public Row toConfigRow(KafkaWriteSchemaTransform transform) {
+      return transform.getConfigurationRow();
+    }
+  }
+
+  @AutoService(TransformPayloadTranslatorRegistrar.class)
+  public static class WriteRegistrar implements 
TransformPayloadTranslatorRegistrar {
+    @Override
+    @SuppressWarnings({
+      "rawtypes",
+    })
+    public Map<
+            ? extends Class<? extends PTransform>,
+            ? extends PTransformTranslation.TransformPayloadTranslator>
+        getTransformPayloadTranslators() {
+      return ImmutableMap
+          .<Class<? extends PTransform>, 
PTransformTranslation.TransformPayloadTranslator>builder()
+          .put(KafkaWriteSchemaTransform.class, new 
KafkaWriteSchemaTransformTranslator())
+          .build();
+    }
+  }
+}
diff --git 
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
index 26f37b790ef..09b338492b4 100644
--- 
a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
+++ 
b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProvider.java
@@ -31,7 +31,9 @@ import org.apache.beam.sdk.extensions.protobuf.ProtoByteUtils;
 import org.apache.beam.sdk.metrics.Counter;
 import org.apache.beam.sdk.metrics.Metrics;
 import org.apache.beam.sdk.schemas.AutoValueSchema;
+import org.apache.beam.sdk.schemas.NoSuchSchemaException;
 import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.SchemaRegistry;
 import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
 import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
 import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
@@ -99,6 +101,20 @@ public class KafkaWriteSchemaTransformProvider
       this.configuration = configuration;
     }
 
+    Row getConfigurationRow() {
+      try {
+        // To stay consistent with our SchemaTransform configuration naming 
conventions,
+        // we sort lexicographically
+        return SchemaRegistry.createDefault()
+            .getToRowFunction(KafkaWriteSchemaTransformConfiguration.class)
+            .apply(configuration)
+            .sorted()
+            .toSnakeCase();
+      } catch (NoSuchSchemaException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
     public static class ErrorCounterFn extends DoFn<Row, KV<byte[], byte[]>> {
       private final SerializableFunction<Row, byte[]> toBytesFn;
       private final Counter errorCounter;
diff --git 
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java 
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java
index ab6ac52e318..4d38636892c 100644
--- 
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java
+++ 
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOIT.java
@@ -48,6 +48,7 @@ import 
org.apache.beam.sdk.io.kafka.KafkaIOTest.FailingLongSerializer;
 import org.apache.beam.sdk.io.kafka.ReadFromKafkaDoFnTest.FailingDeserializer;
 import org.apache.beam.sdk.io.synthetic.SyntheticBoundedSource;
 import org.apache.beam.sdk.io.synthetic.SyntheticSourceOptions;
+import org.apache.beam.sdk.managed.Managed;
 import org.apache.beam.sdk.options.Default;
 import org.apache.beam.sdk.options.Description;
 import org.apache.beam.sdk.options.ExperimentalOptions;
@@ -607,18 +608,18 @@ public class KafkaIOIT {
   private static final int FIVE_MINUTES_IN_MS = 5 * 60 * 1000;
 
   @Test(timeout = FIVE_MINUTES_IN_MS)
-  public void testKafkaViaSchemaTransformJson() {
-    runReadWriteKafkaViaSchemaTransforms(
+  public void testKafkaViaManagedSchemaTransformJson() {
+    runReadWriteKafkaViaManagedSchemaTransforms(
         "JSON", SCHEMA_IN_JSON, 
JsonUtils.beamSchemaFromJsonSchema(SCHEMA_IN_JSON));
   }
 
   @Test(timeout = FIVE_MINUTES_IN_MS)
-  public void testKafkaViaSchemaTransformAvro() {
-    runReadWriteKafkaViaSchemaTransforms(
+  public void testKafkaViaManagedSchemaTransformAvro() {
+    runReadWriteKafkaViaManagedSchemaTransforms(
         "AVRO", AvroUtils.toAvroSchema(KAFKA_TOPIC_SCHEMA).toString(), 
KAFKA_TOPIC_SCHEMA);
   }
 
-  public void runReadWriteKafkaViaSchemaTransforms(
+  public void runReadWriteKafkaViaManagedSchemaTransforms(
       String format, String schemaDefinition, Schema beamSchema) {
     String topicName = options.getKafkaTopic() + "-schema-transform" + 
UUID.randomUUID();
     PCollectionRowTuple.of(
@@ -646,13 +647,12 @@ public class KafkaIOIT {
                 .setRowSchema(beamSchema))
         .apply(
             "Write to Kafka",
-            new KafkaWriteSchemaTransformProvider()
-                .from(
-                    
KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransformConfiguration
-                        .builder()
-                        .setTopic(topicName)
-                        
.setBootstrapServers(options.getKafkaBootstrapServerAddresses())
-                        .setFormat(format)
+            Managed.write(Managed.KAFKA)
+                .withConfig(
+                    ImmutableMap.<String, Object>builder()
+                        .put("topic", topicName)
+                        .put("bootstrap_servers", 
options.getKafkaBootstrapServerAddresses())
+                        .put("format", format)
                         .build()));
 
     PAssert.that(
@@ -661,15 +661,18 @@ public class KafkaIOIT {
                     "Read from unbounded Kafka",
                     // A timeout of 30s for local, container-based tests, and 
2 minutes for
                     // real-kafka tests.
-                    new KafkaReadSchemaTransformProvider(
-                            true, options.isWithTestcontainers() ? 30 : 120)
-                        .from(
-                            KafkaReadSchemaTransformConfiguration.builder()
-                                .setFormat(format)
-                                .setAutoOffsetResetConfig("earliest")
-                                .setSchema(schemaDefinition)
-                                .setTopic(topicName)
-                                
.setBootstrapServers(options.getKafkaBootstrapServerAddresses())
+                    Managed.read(Managed.KAFKA)
+                        .withConfig(
+                            ImmutableMap.<String, Object>builder()
+                                .put("format", format)
+                                .put("auto_offset_reset_config", "earliest")
+                                .put("schema", schemaDefinition)
+                                .put("topic", topicName)
+                                .put(
+                                    "bootstrap_servers", 
options.getKafkaBootstrapServerAddresses())
+                                .put(
+                                    "max_read_time_seconds",
+                                    options.isWithTestcontainers() ? 30 : 120)
                                 .build()))
                 .get("output"))
         .containsInAnyOrder(
diff --git 
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java
 
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java
index dfe062e1eef..19c336e1d24 100644
--- 
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java
+++ 
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java
@@ -34,9 +34,11 @@ import java.util.stream.StreamSupport;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.managed.Managed;
 import org.apache.beam.sdk.managed.ManagedTransformConstants;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
 import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
 import org.apache.beam.sdk.schemas.utils.YamlUtils;
 import org.apache.beam.sdk.values.PBegin;
+import org.apache.beam.sdk.values.PCollectionRowTuple;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
@@ -131,7 +133,8 @@ public class KafkaReadSchemaTransformProviderTest {
             "confluent_schema_registry_url",
             "error_handling",
             "file_descriptor_path",
-            "message_name"),
+            "message_name",
+            "max_read_time_seconds"),
         kafkaProvider.configurationSchema().getFields().stream()
             .map(field -> field.getName())
             .collect(Collectors.toSet()));
@@ -232,22 +235,23 @@ public class KafkaReadSchemaTransformProviderTest {
             .collect(Collectors.toList());
     KafkaReadSchemaTransformProvider kafkaProvider =
         (KafkaReadSchemaTransformProvider) providers.get(0);
+    SchemaTransform transform =
+        kafkaProvider.from(
+            KafkaReadSchemaTransformConfiguration.builder()
+                .setTopic("anytopic")
+                .setBootstrapServers("anybootstrap")
+                .setFormat("PROTO")
+                .setMessageName("MyOtherMessage")
+                .setFileDescriptorPath(
+                    Objects.requireNonNull(
+                            getClass()
+                                
.getResource("/proto_byte/file_descriptor/proto_byte_utils.pb"))
+                        .getPath())
+                .build());
 
     assertThrows(
         NullPointerException.class,
-        () ->
-            kafkaProvider.from(
-                KafkaReadSchemaTransformConfiguration.builder()
-                    .setTopic("anytopic")
-                    .setBootstrapServers("anybootstrap")
-                    .setFormat("PROTO")
-                    .setMessageName("MyOtherMessage")
-                    .setFileDescriptorPath(
-                        Objects.requireNonNull(
-                                getClass()
-                                    
.getResource("/proto_byte/file_descriptor/proto_byte_utils.pb"))
-                            .getPath())
-                    .build()));
+        () -> transform.expand(PCollectionRowTuple.empty(Pipeline.create())));
   }
 
   @Test
@@ -281,17 +285,18 @@ public class KafkaReadSchemaTransformProviderTest {
             .collect(Collectors.toList());
     KafkaReadSchemaTransformProvider kafkaProvider =
         (KafkaReadSchemaTransformProvider) providers.get(0);
+    SchemaTransform transform =
+        kafkaProvider.from(
+            KafkaReadSchemaTransformConfiguration.builder()
+                .setTopic("anytopic")
+                .setBootstrapServers("anybootstrap")
+                .setFormat("PROTO")
+                .setMessageName("MyMessage")
+                .build());
 
     assertThrows(
         IllegalArgumentException.class,
-        () ->
-            kafkaProvider.from(
-                KafkaReadSchemaTransformConfiguration.builder()
-                    .setTopic("anytopic")
-                    .setBootstrapServers("anybootstrap")
-                    .setFormat("PROTO")
-                    .setMessageName("MyMessage")
-                    .build()));
+        () -> transform.expand(PCollectionRowTuple.empty(Pipeline.create())));
   }
 
   @Test
diff --git 
a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslationTest.java
 
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslationTest.java
new file mode 100644
index 00000000000..b297227bb7a
--- /dev/null
+++ 
b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaSchemaTransformTranslationTest.java
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.kafka;
+
+import static 
org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM;
+import static 
org.apache.beam.sdk.io.kafka.KafkaReadSchemaTransformProvider.KafkaReadSchemaTransform;
+import static 
org.apache.beam.sdk.io.kafka.KafkaSchemaTransformTranslation.KafkaReadSchemaTransformTranslator;
+import static 
org.apache.beam.sdk.io.kafka.KafkaSchemaTransformTranslation.KafkaWriteSchemaTransformTranslator;
+import static 
org.apache.beam.sdk.io.kafka.KafkaWriteSchemaTransformProvider.KafkaWriteSchemaTransform;
+import static org.junit.Assert.assertEquals;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+import 
org.apache.beam.model.pipeline.v1.ExternalTransforms.SchemaTransformPayload;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.RowCoder;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.SchemaTranslation;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.util.construction.BeamUrns;
+import org.apache.beam.sdk.util.construction.PipelineTranslation;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.Row;
+import 
org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.InvalidProtocolBufferException;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import org.junit.ClassRule;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.rules.TemporaryFolder;
+
+public class KafkaSchemaTransformTranslationTest {
+  @ClassRule public static final TemporaryFolder TEMPORARY_FOLDER = new 
TemporaryFolder();
+
+  @Rule public transient ExpectedException thrown = ExpectedException.none();
+
+  static final KafkaWriteSchemaTransformProvider WRITE_PROVIDER =
+      new KafkaWriteSchemaTransformProvider();
+  static final KafkaReadSchemaTransformProvider READ_PROVIDER =
+      new KafkaReadSchemaTransformProvider();
+
+  static final Row READ_CONFIG =
+      Row.withSchema(READ_PROVIDER.configurationSchema())
+          .withFieldValue("format", "RAW")
+          .withFieldValue("topic", "test_topic")
+          .withFieldValue("bootstrap_servers", "host:port")
+          .withFieldValue("confluent_schema_registry_url", null)
+          .withFieldValue("confluent_schema_registry_subject", null)
+          .withFieldValue("schema", null)
+          .withFieldValue("file_descriptor_path", "testPath")
+          .withFieldValue("message_name", "test_message")
+          .withFieldValue("auto_offset_reset_config", "earliest")
+          .withFieldValue("consumer_config_updates", ImmutableMap.<String, 
String>builder().build())
+          .withFieldValue("error_handling", null)
+          .build();
+
+  static final Row WRITE_CONFIG =
+      Row.withSchema(WRITE_PROVIDER.configurationSchema())
+          .withFieldValue("format", "RAW")
+          .withFieldValue("topic", "test_topic")
+          .withFieldValue("bootstrap_servers", "host:port")
+          .withFieldValue("producer_config_updates", ImmutableMap.<String, 
String>builder().build())
+          .withFieldValue("error_handling", null)
+          .withFieldValue("file_descriptor_path", "testPath")
+          .withFieldValue("message_name", "test_message")
+          .withFieldValue("schema", "test_schema")
+          .build();
+
+  @Test
+  public void testRecreateWriteTransformFromRow() {
+    KafkaWriteSchemaTransform writeTransform =
+        (KafkaWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG);
+
+    KafkaWriteSchemaTransformTranslator translator = new 
KafkaWriteSchemaTransformTranslator();
+    Row translatedRow = translator.toConfigRow(writeTransform);
+
+    KafkaWriteSchemaTransform writeTransformFromRow =
+        translator.fromConfigRow(translatedRow, 
PipelineOptionsFactory.create());
+
+    assertEquals(WRITE_CONFIG, writeTransformFromRow.getConfigurationRow());
+  }
+
+  @Test
+  public void testWriteTransformProtoTranslation()
+      throws InvalidProtocolBufferException, IOException {
+    // First build a pipeline
+    Pipeline p = Pipeline.create();
+    Schema inputSchema = Schema.builder().addByteArrayField("b").build();
+    PCollection<Row> input =
+        p.apply(
+                Create.of(
+                    Collections.singletonList(
+                        Row.withSchema(inputSchema).addValue(new byte[] {1, 2, 
3}).build())))
+            .setRowSchema(inputSchema);
+
+    KafkaWriteSchemaTransform writeTransform =
+        (KafkaWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG);
+    PCollectionRowTuple.of("input", input).apply(writeTransform);
+
+    // Then translate the pipeline to a proto and extract 
KafkaWriteSchemaTransform proto
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+    List<RunnerApi.PTransform> writeTransformProto =
+        pipelineProto.getComponents().getTransformsMap().values().stream()
+            .filter(
+                tr -> {
+                  RunnerApi.FunctionSpec spec = tr.getSpec();
+                  try {
+                    return 
spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM))
+                        && SchemaTransformPayload.parseFrom(spec.getPayload())
+                            .getIdentifier()
+                            .equals(WRITE_PROVIDER.identifier());
+                  } catch (InvalidProtocolBufferException e) {
+                    throw new RuntimeException(e);
+                  }
+                })
+            .collect(Collectors.toList());
+    assertEquals(1, writeTransformProto.size());
+    RunnerApi.FunctionSpec spec = writeTransformProto.get(0).getSpec();
+
+    // Check that the proto contains correct values
+    SchemaTransformPayload payload = 
SchemaTransformPayload.parseFrom(spec.getPayload());
+    Schema schemaFromSpec = 
SchemaTranslation.schemaFromProto(payload.getConfigurationSchema());
+    assertEquals(WRITE_PROVIDER.configurationSchema(), schemaFromSpec);
+    Row rowFromSpec = 
RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput());
+
+    assertEquals(WRITE_CONFIG, rowFromSpec);
+
+    // Use the information in the proto to recreate the 
KafkaWriteSchemaTransform
+    KafkaWriteSchemaTransformTranslator translator = new 
KafkaWriteSchemaTransformTranslator();
+    KafkaWriteSchemaTransform writeTransformFromSpec =
+        translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create());
+
+    assertEquals(WRITE_CONFIG, writeTransformFromSpec.getConfigurationRow());
+  }
+
+  @Test
+  public void testReCreateReadTransformFromRow() {
+    // setting a subset of fields here.
+    KafkaReadSchemaTransform readTransform =
+        (KafkaReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG);
+
+    KafkaReadSchemaTransformTranslator translator = new 
KafkaReadSchemaTransformTranslator();
+    Row row = translator.toConfigRow(readTransform);
+
+    KafkaReadSchemaTransform readTransformFromRow =
+        translator.fromConfigRow(row, PipelineOptionsFactory.create());
+
+    assertEquals(READ_CONFIG, readTransformFromRow.getConfigurationRow());
+  }
+
+  @Test
+  public void testReadTransformProtoTranslation()
+      throws InvalidProtocolBufferException, IOException {
+    // First build a pipeline
+    Pipeline p = Pipeline.create();
+
+    KafkaReadSchemaTransform readTransform =
+        (KafkaReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG);
+
+    PCollectionRowTuple.empty(p).apply(readTransform);
+
+    // Then translate the pipeline to a proto and extract 
KafkaReadSchemaTransform proto
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+    List<RunnerApi.PTransform> readTransformProto =
+        pipelineProto.getComponents().getTransformsMap().values().stream()
+            .filter(
+                tr -> {
+                  RunnerApi.FunctionSpec spec = tr.getSpec();
+                  try {
+                    return 
spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM))
+                        && SchemaTransformPayload.parseFrom(spec.getPayload())
+                            .getIdentifier()
+                            .equals(READ_PROVIDER.identifier());
+                  } catch (InvalidProtocolBufferException e) {
+                    throw new RuntimeException(e);
+                  }
+                })
+            .collect(Collectors.toList());
+    assertEquals(1, readTransformProto.size());
+    RunnerApi.FunctionSpec spec = readTransformProto.get(0).getSpec();
+
+    // Check that the proto contains correct values
+    SchemaTransformPayload payload = 
SchemaTransformPayload.parseFrom(spec.getPayload());
+    Schema schemaFromSpec = 
SchemaTranslation.schemaFromProto(payload.getConfigurationSchema());
+    assertEquals(READ_PROVIDER.configurationSchema(), schemaFromSpec);
+    Row rowFromSpec = 
RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput());
+    assertEquals(READ_CONFIG, rowFromSpec);
+
+    // Use the information in the proto to recreate the 
KafkaReadSchemaTransform
+    KafkaReadSchemaTransformTranslator translator = new 
KafkaReadSchemaTransformTranslator();
+    KafkaReadSchemaTransform readTransformFromSpec =
+        translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create());
+
+    assertEquals(READ_CONFIG, readTransformFromSpec.getConfigurationRow());
+  }
+}


Reply via email to