This is an automated email from the ASF dual-hosted git repository.

chamikara pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 23e59afce97 Updates ExpansionService to support dynamically 
discovering and expanding SchemaTransforms (#23413)
23e59afce97 is described below

commit 23e59afce976d40f3e881094deee7e42c42a0e11
Author: Chamikara Jayalath <[email protected]>
AuthorDate: Wed Nov 23 16:52:50 2022 -0800

    Updates ExpansionService to support dynamically discovering and expanding 
SchemaTransforms (#23413)
    
    * Updates ExpansionService to support dynamically discovering and expanding 
SchemaTransforms
    
    * Fixing checker framework errors.
    
    * Address reviewer comments
    
    * Addressing reviewer comments
    
    * Addressing reviewer comments
---
 .../job_management/v1/beam_expansion_api.proto     |  28 ++
 .../model/pipeline/v1/external_transforms.proto    |  17 +
 .../sdk/expansion/service/ExpansionService.java    |  75 +++-
 .../ExpansionServiceSchemaTransformProvider.java   | 144 ++++++
 ...xpansionServiceSchemaTransformProviderTest.java | 486 +++++++++++++++++++++
 sdks/python/apache_beam/portability/common_urns.py |   1 +
 sdks/python/apache_beam/transforms/external.py     | 158 +++++--
 .../python/apache_beam/transforms/external_test.py |  30 ++
 8 files changed, 898 insertions(+), 41 deletions(-)

diff --git 
a/model/job-management/src/main/proto/org/apache/beam/model/job_management/v1/beam_expansion_api.proto
 
b/model/job-management/src/main/proto/org/apache/beam/model/job_management/v1/beam_expansion_api.proto
index f3ab890005d..568f9c87741 100644
--- 
a/model/job-management/src/main/proto/org/apache/beam/model/job_management/v1/beam_expansion_api.proto
+++ 
b/model/job-management/src/main/proto/org/apache/beam/model/job_management/v1/beam_expansion_api.proto
@@ -30,6 +30,7 @@ option java_package = "org.apache.beam.model.expansion.v1";
 option java_outer_classname = "ExpansionApi";
 
 import "org/apache/beam/model/pipeline/v1/beam_runner_api.proto";
+import "org/apache/beam/model/pipeline/v1/schema.proto";
 
 message ExpansionRequest {
   // Set of components needed to interpret the transform, or which
@@ -72,7 +73,34 @@ message ExpansionResponse {
   string error = 10;
 }
 
+message DiscoverSchemaTransformRequest  {
+}
+
+message SchemaTransformConfig {
+  // Config schema of the SchemaTransform
+  org.apache.beam.model.pipeline.v1.Schema config_schema = 1;
+
+  // Names of input PCollections
+  repeated string input_pcollection_names = 2;
+
+  // Names of output PCollections
+  repeated string output_pcollection_names = 3;
+}
+
+message DiscoverSchemaTransformResponse {
+  // A mapping from SchemaTransform ID to schema transform config of discovered
+  // SchemaTransforms
+  map <string, SchemaTransformConfig> schema_transform_configs = 1;
+
+  // If list of identifies are empty, this may contain an error.
+  string error = 2;
+}
+
 // Job Service for constructing pipelines
 service ExpansionService {
   rpc Expand (ExpansionRequest) returns (ExpansionResponse);
+
+  //A RPC to discover already registered SchemaTransformProviders.
+  // See https://s.apache.org/easy-multi-language for more details.
+  rpc DiscoverSchemaTransform (DiscoverSchemaTransformRequest) returns 
(DiscoverSchemaTransformResponse);
 }
diff --git 
a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto
 
b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto
index baff2c0436f..18cd02e3942 100644
--- 
a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto
+++ 
b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto
@@ -51,6 +51,11 @@ message ExpansionMethods {
     // Transform payload will be of type JavaClassLookupPayload.
     JAVA_CLASS_LOOKUP = 0 [(org.apache.beam.model.pipeline.v1.beam_urn) =
       "beam:expansion:payload:java_class_lookup:v1"];
+
+    // Expanding a SchemaTransform identified by the expansion service.
+    // Transform payload will be of type  SchemaTransformPayload.
+    SCHEMA_TRANSFORM = 1 [(org.apache.beam.model.pipeline.v1.beam_urn) =
+      "beam:expansion:payload:schematransform:v1"];
   }
 }
 
@@ -106,4 +111,16 @@ message BuilderMethod {
   bytes payload = 3;
 }
 
+message SchemaTransformPayload {
+  // The identifier of the SchemaTransform (typically a URN).
+  string identifier = 1;
+
+  // The configuration schema of the SchemaTransform.
+  Schema configuration_schema = 2;
 
+  // The configuration of the SchemaTransform.
+  // Should be decodable via beam:coder:row:v1.
+  // The schema of the Row should be compatible with the schema of the
+  // SchemaTransform.
+  bytes configuration_row = 3;
+}
diff --git 
a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java
 
b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java
index fed01d2576e..221c40f7920 100644
--- 
a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java
+++ 
b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java
@@ -35,6 +35,9 @@ import java.util.ServiceLoader;
 import java.util.Set;
 import java.util.stream.Collectors;
 import org.apache.beam.model.expansion.v1.ExpansionApi;
+import 
org.apache.beam.model.expansion.v1.ExpansionApi.DiscoverSchemaTransformRequest;
+import 
org.apache.beam.model.expansion.v1.ExpansionApi.DiscoverSchemaTransformResponse;
+import org.apache.beam.model.expansion.v1.ExpansionApi.SchemaTransformConfig;
 import org.apache.beam.model.expansion.v1.ExpansionServiceGrpc;
 import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods;
 import 
org.apache.beam.model.pipeline.v1.ExternalTransforms.ExternalConfigurationPayload;
@@ -63,6 +66,7 @@ import org.apache.beam.sdk.schemas.Schema.Field;
 import org.apache.beam.sdk.schemas.SchemaCoder;
 import org.apache.beam.sdk.schemas.SchemaRegistry;
 import org.apache.beam.sdk.schemas.SchemaTranslation;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
 import org.apache.beam.sdk.transforms.ExternalTransformBuilder;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.SerializableFunction;
@@ -436,6 +440,10 @@ public class ExpansionService extends 
ExpansionServiceGrpc.ExpansionServiceImplB
     return registeredTransforms;
   }
 
+  private Iterable<SchemaTransformProvider> getRegisteredSchemaTransforms() {
+    return ExpansionServiceSchemaTransformProvider.of().getAllProviders();
+  }
+
   private Map<String, TransformProvider> loadRegisteredTransforms() {
     ImmutableMap.Builder<String, TransformProvider> 
registeredTransformsBuilder =
         ImmutableMap.builder();
@@ -500,6 +508,8 @@ public class ExpansionService extends 
ExpansionServiceGrpc.ExpansionServiceImplB
           
pipelineOptions.as(ExpansionServiceOptions.class).getJavaClassLookupAllowlist();
       assert allowList != null;
       transformProvider = new JavaClassLookupTransformProvider(allowList);
+    } else if (getUrn(ExpansionMethods.Enum.SCHEMA_TRANSFORM).equals(urn)) {
+      transformProvider = ExpansionServiceSchemaTransformProvider.of();
     } else {
       transformProvider = getRegisteredTransforms().get(urn);
       if (transformProvider == null) {
@@ -604,6 +614,42 @@ public class ExpansionService extends 
ExpansionServiceGrpc.ExpansionServiceImplB
     }
   }
 
+  DiscoverSchemaTransformResponse discover(DiscoverSchemaTransformRequest 
request) {
+    ExpansionServiceSchemaTransformProvider transformProvider =
+        ExpansionServiceSchemaTransformProvider.of();
+    DiscoverSchemaTransformResponse.Builder responseBuilder =
+        DiscoverSchemaTransformResponse.newBuilder();
+    for (org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider 
provider :
+        transformProvider.getAllProviders()) {
+      SchemaTransformConfig.Builder schemaTransformConfigBuider =
+          SchemaTransformConfig.newBuilder();
+      schemaTransformConfigBuider.setConfigSchema(
+          SchemaTranslation.schemaToProto(provider.configurationSchema(), 
true));
+      
schemaTransformConfigBuider.addAllInputPcollectionNames(provider.inputCollectionNames());
+      
schemaTransformConfigBuider.addAllOutputPcollectionNames(provider.outputCollectionNames());
+      responseBuilder.putSchemaTransformConfigs(
+          provider.identifier(), schemaTransformConfigBuider.build());
+    }
+
+    return responseBuilder.build();
+  }
+
+  @Override
+  public void discoverSchemaTransform(
+      DiscoverSchemaTransformRequest request,
+      StreamObserver<DiscoverSchemaTransformResponse> responseObserver) {
+    try {
+      responseObserver.onNext(discover(request));
+      responseObserver.onCompleted();
+    } catch (RuntimeException exn) {
+      responseObserver.onNext(
+          ExpansionApi.DiscoverSchemaTransformResponse.newBuilder()
+              .setError(Throwables.getStackTraceAsString(exn))
+              .build());
+      responseObserver.onCompleted();
+    }
+  }
+
   @Override
   public void close() throws Exception {
     // Nothing to do because the expansion service is stateless.
@@ -618,9 +664,36 @@ public class ExpansionService extends 
ExpansionServiceGrpc.ExpansionServiceImplB
 
     @SuppressWarnings("nullness")
     ExpansionService service = new ExpansionService(Arrays.copyOfRange(args, 
1, args.length));
+
+    StringBuilder registeredTransformsLog = new StringBuilder();
+    boolean registeredTransformsFound = false;
+    registeredTransformsLog.append("\n");
+    registeredTransformsLog.append("Registered transforms:");
+
     for (Map.Entry<String, TransformProvider> entry :
         service.getRegisteredTransforms().entrySet()) {
-      System.out.println("\t" + entry.getKey() + ": " + entry.getValue());
+      registeredTransformsFound = true;
+      registeredTransformsLog.append("\n\t" + entry.getKey() + ": " + 
entry.getValue());
+    }
+
+    StringBuilder registeredSchemaTransformProvidersLog = new StringBuilder();
+    boolean registeredSchemaTransformProvidersFound = false;
+    registeredSchemaTransformProvidersLog.append("\n");
+    registeredSchemaTransformProvidersLog.append("Registered 
SchemaTransformProviders:");
+
+    for (SchemaTransformProvider provider : 
service.getRegisteredSchemaTransforms()) {
+      registeredSchemaTransformProvidersFound = true;
+      registeredSchemaTransformProvidersLog.append("\n\t" + 
provider.identifier());
+    }
+
+    if (registeredTransformsFound) {
+      System.out.println(registeredTransformsLog.toString());
+    }
+    if (registeredSchemaTransformProvidersFound) {
+      System.out.println(registeredSchemaTransformProvidersLog.toString());
+    }
+    if (!registeredTransformsFound && 
!registeredSchemaTransformProvidersFound) {
+      System.out.println("\nDid not find any registered transforms or 
SchemaTransforms.\n");
     }
 
     Server server =
diff --git 
a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProvider.java
 
b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProvider.java
new file mode 100644
index 00000000000..4657e052402
--- /dev/null
+++ 
b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProvider.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.expansion.service;
+
+import static org.apache.beam.runners.core.construction.BeamUrns.getUrn;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.ServiceLoader;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods;
+import 
org.apache.beam.model.pipeline.v1.ExternalTransforms.SchemaTransformPayload;
+import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.RowCoder;
+import 
org.apache.beam.sdk.expansion.service.ExpansionService.TransformProvider;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.SchemaTranslation;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.Row;
+import 
org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.InvalidProtocolBufferException;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+@SuppressWarnings({"rawtypes"})
+public class ExpansionServiceSchemaTransformProvider
+    implements TransformProvider<PCollectionRowTuple, PCollectionRowTuple> {
+
+  private Map<String, 
org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider>
+      schemaTransformProviders = new HashMap<>();
+  private static @Nullable ExpansionServiceSchemaTransformProvider 
transformProvider = null;
+
+  private ExpansionServiceSchemaTransformProvider() {
+    try {
+      for (org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider 
schemaTransformProvider :
+          ServiceLoader.load(
+              
org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider.class)) {
+        if 
(schemaTransformProviders.containsKey(schemaTransformProvider.identifier())) {
+          throw new IllegalArgumentException(
+              "Found multiple SchemaTransformProvider implementations with the 
same identifier "
+                  + schemaTransformProvider.identifier());
+        }
+        schemaTransformProviders.put(schemaTransformProvider.identifier(), 
schemaTransformProvider);
+      }
+    } catch (Exception e) {
+      throw new RuntimeException(e.getMessage());
+    }
+  }
+
+  public static ExpansionServiceSchemaTransformProvider of() {
+    if (transformProvider == null) {
+      transformProvider = new ExpansionServiceSchemaTransformProvider();
+    }
+
+    return transformProvider;
+  }
+
+  @Override
+  public PCollectionRowTuple createInput(Pipeline p, Map<String, 
PCollection<?>> inputs) {
+    PCollectionRowTuple inputRowTuple = PCollectionRowTuple.empty(p);
+    for (Map.Entry<String, PCollection<?>> entry : inputs.entrySet()) {
+      inputRowTuple = inputRowTuple.and(entry.getKey(), (PCollection<Row>) 
entry.getValue());
+    }
+    return inputRowTuple;
+  }
+
+  @Override
+  public Map<String, PCollection<?>> extractOutputs(PCollectionRowTuple 
output) {
+    ImmutableMap.Builder<String, PCollection<?>> pCollectionMap = 
ImmutableMap.builder();
+    for (String key : output.getAll().keySet()) {
+      pCollectionMap.put(key, output.get(key));
+    }
+    return pCollectionMap.build();
+  }
+
+  @Override
+  public PTransform getTransform(FunctionSpec spec) {
+    SchemaTransformPayload payload;
+    try {
+      payload = SchemaTransformPayload.parseFrom(spec.getPayload());
+      String identifier = payload.getIdentifier();
+      if (!schemaTransformProviders.containsKey(identifier)) {
+        throw new RuntimeException(
+            "Did not find a SchemaTransformProvider with the identifier " + 
identifier);
+      }
+
+    } catch (InvalidProtocolBufferException e) {
+      throw new IllegalArgumentException(
+          "Invalid payload type for URN " + 
getUrn(ExpansionMethods.Enum.SCHEMA_TRANSFORM), e);
+    }
+
+    String identifier = payload.getIdentifier();
+    org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider provider =
+        schemaTransformProviders.get(identifier);
+    if (provider == null) {
+      throw new IllegalArgumentException(
+          "Could not find a SchemaTransform with identifier " + identifier);
+    }
+
+    Schema configSchemaFromRequest =
+        SchemaTranslation.schemaFromProto((payload.getConfigurationSchema()));
+    Schema configSchemaFromProvider = provider.configurationSchema();
+
+    if (!configSchemaFromRequest.assignableTo(configSchemaFromProvider)) {
+      throw new IllegalArgumentException(
+          String.format(
+              "Config schema provided with the expansion request %s is not 
compatible with the "
+                  + "config of the Schema transform %s.",
+              configSchemaFromRequest, configSchemaFromProvider));
+    }
+
+    Row configRow;
+    try {
+      configRow =
+          RowCoder.of(provider.configurationSchema())
+              .decode(payload.getConfigurationRow().newInput());
+    } catch (IOException e) {
+      throw new RuntimeException("Error decoding payload", e);
+    }
+
+    return provider.from(configRow).buildTransform();
+  }
+
+  Iterable<org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider> 
getAllProviders() {
+    return schemaTransformProviders.values();
+  }
+}
diff --git 
a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java
 
b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java
new file mode 100644
index 00000000000..5b9b50b248a
--- /dev/null
+++ 
b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceSchemaTransformProviderTest.java
@@ -0,0 +1,486 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.expansion.service;
+
+import static org.apache.beam.runners.core.construction.BeamUrns.getUrn;
+import static org.junit.Assert.assertEquals;
+
+import com.google.auto.service.AutoService;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.beam.model.expansion.v1.ExpansionApi;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms;
+import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods;
+import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.ParDoTranslation;
+import org.apache.beam.runners.core.construction.PipelineTranslation;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.schemas.JavaFieldSchema;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.Field;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.schemas.SchemaCoder;
+import org.apache.beam.sdk.schemas.SchemaTranslation;
+import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
+import org.apache.beam.sdk.schemas.annotations.SchemaCreate;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
+import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Impulse;
+import org.apache.beam.sdk.transforms.InferableFunction;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.util.ByteStringOutputStream;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.Row;
+import 
org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.InvalidProtocolBufferException;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.junit.Test;
+
+/** Tests for {@link ExpansionServiceSchemaTransformProvider}. */
+@SuppressWarnings({
+  "rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
+})
+public class ExpansionServiceSchemaTransformProviderTest {
+
+  private static final String TEST_NAME = "TestName";
+
+  private static final String TEST_NAMESPACE = "namespace";
+
+  private static final Schema TEST_SCHEMATRANSFORM_CONFIG_SCHEMA =
+      Schema.of(
+          Field.of("str1", FieldType.STRING),
+          Field.of("str2", FieldType.STRING),
+          Field.of("int1", FieldType.INT32),
+          Field.of("int2", FieldType.INT32));
+
+  private ExpansionService expansionService = new ExpansionService();
+
+  @DefaultSchema(JavaFieldSchema.class)
+  public static class TestSchemaTransformConfiguration {
+
+    public final String str1;
+    public final String str2;
+    public final Integer int1;
+    public final Integer int2;
+
+    @SchemaCreate
+    public TestSchemaTransformConfiguration(String str1, String str2, Integer 
int1, Integer int2) {
+      this.str1 = str1;
+      this.str2 = str2;
+      this.int1 = int1;
+      this.int2 = int2;
+    }
+  }
+
+  /** Registers a SchemaTransform. */
+  @AutoService(SchemaTransformProvider.class)
+  public static class TestSchemaTransformProvider
+      extends TypedSchemaTransformProvider<TestSchemaTransformConfiguration> {
+
+    @Override
+    protected Class<TestSchemaTransformConfiguration> configurationClass() {
+      return TestSchemaTransformConfiguration.class;
+    }
+
+    @Override
+    protected SchemaTransform from(TestSchemaTransformConfiguration 
configuration) {
+      return new TestSchemaTransform(
+          configuration.str1, configuration.str2, configuration.int1, 
configuration.int2);
+    }
+
+    @Override
+    public String identifier() {
+      return "dummy_id";
+    }
+
+    @Override
+    public List<String> inputCollectionNames() {
+      return ImmutableList.of("input1");
+    }
+
+    @Override
+    public List<String> outputCollectionNames() {
+      return ImmutableList.of("output1");
+    }
+  }
+
+  public static class TestSchemaTransform implements SchemaTransform {
+
+    private String str1;
+    private String str2;
+    private Integer int1;
+    private Integer int2;
+
+    public TestSchemaTransform(String str1, String str2, Integer int1, Integer 
int2) {
+      this.str1 = str1;
+      this.str2 = str2;
+      this.int1 = int1;
+      this.int2 = int2;
+    }
+
+    @Override
+    public PTransform<PCollectionRowTuple, PCollectionRowTuple> 
buildTransform() {
+      return new TestTransform(str1, str2, int1, int2);
+    }
+  }
+
+  public static class TestDoFn extends DoFn<String, String> {
+
+    public String str1;
+    public String str2;
+    public int int1;
+    public int int2;
+
+    public TestDoFn(String str1, String str2, Integer int1, Integer int2) {
+      this.str1 = str1;
+      this.str2 = str2;
+      this.int1 = int1;
+      this.int2 = int2;
+    }
+
+    @ProcessElement
+    public void processElement(@Element String element, OutputReceiver<String> 
receiver) {
+      receiver.output(element);
+    }
+  }
+
+  public static class TestTransform extends PTransform<PCollectionRowTuple, 
PCollectionRowTuple> {
+
+    private String str1;
+    private String str2;
+    private Integer int1;
+    private Integer int2;
+
+    public TestTransform(String str1, String str2, Integer int1, Integer int2) 
{
+      this.str1 = str1;
+      this.str2 = str2;
+      this.int1 = int1;
+      this.int2 = int2;
+    }
+
+    @Override
+    public PCollectionRowTuple expand(PCollectionRowTuple input) {
+      PCollection<Row> outputPC =
+          input
+              .getAll()
+              .values()
+              .iterator()
+              .next()
+              .apply(
+                  MapElements.via(
+                      new InferableFunction<Row, String>() {
+                        @Override
+                        public String apply(Row input) throws Exception {
+                          return input.getString("in_str");
+                        }
+                      }))
+              .apply(ParDo.of(new TestDoFn(this.str1, this.str2, this.int1, 
this.int2)))
+              .apply(
+                  MapElements.via(
+                      new InferableFunction<String, Row>() {
+                        @Override
+                        public Row apply(String input) throws Exception {
+                          return Row.withSchema(Schema.of(Field.of("out_str", 
FieldType.STRING)))
+                              .withFieldValue("out_str", input)
+                              .build();
+                        }
+                      }))
+              .setRowSchema(Schema.of(Field.of("out_str", FieldType.STRING)));
+      return PCollectionRowTuple.of("output1", outputPC);
+    }
+  }
+
+  /** Registers a SchemaTransform. */
+  @AutoService(SchemaTransformProvider.class)
+  public static class TestSchemaTransformProviderMultiInputMultiOutput
+      extends TypedSchemaTransformProvider<TestSchemaTransformConfiguration> {
+
+    @Override
+    protected Class<TestSchemaTransformConfiguration> configurationClass() {
+      return TestSchemaTransformConfiguration.class;
+    }
+
+    @Override
+    protected SchemaTransform from(TestSchemaTransformConfiguration 
configuration) {
+      return new TestSchemaTransformMultiInputOutput(
+          configuration.str1, configuration.str2, configuration.int1, 
configuration.int2);
+    }
+
+    @Override
+    public String identifier() {
+      return "dummy_id_multi_input_multi_output";
+    }
+
+    @Override
+    public List<String> inputCollectionNames() {
+      return ImmutableList.of("input1", "input2");
+    }
+
+    @Override
+    public List<String> outputCollectionNames() {
+      return ImmutableList.of("output1", "output2");
+    }
+  }
+
+  public static class TestSchemaTransformMultiInputOutput implements 
SchemaTransform {
+
+    private String str1;
+    private String str2;
+    private Integer int1;
+    private Integer int2;
+
+    public TestSchemaTransformMultiInputOutput(
+        String str1, String str2, Integer int1, Integer int2) {
+      this.str1 = str1;
+      this.str2 = str2;
+      this.int1 = int1;
+      this.int2 = int2;
+    }
+
+    @Override
+    public PTransform<PCollectionRowTuple, PCollectionRowTuple> 
buildTransform() {
+      return new TestTransformMultiInputMultiOutput(str1, str2, int1, int2);
+    }
+  }
+
+  public static class TestTransformMultiInputMultiOutput
+      extends PTransform<PCollectionRowTuple, PCollectionRowTuple> {
+
+    private String str1;
+    private String str2;
+    private Integer int1;
+    private Integer int2;
+
+    public TestTransformMultiInputMultiOutput(
+        String str1, String str2, Integer int1, Integer int2) {
+      this.str1 = str1;
+      this.str2 = str2;
+      this.int1 = int1;
+      this.int2 = int2;
+    }
+
+    @Override
+    public PCollectionRowTuple expand(PCollectionRowTuple input) {
+      PCollection<Row> outputPC1 =
+          input
+              .get("input1")
+              .apply(
+                  MapElements.via(
+                      new InferableFunction<Row, String>() {
+                        @Override
+                        public String apply(Row input) throws Exception {
+                          return input.getString("in_str");
+                        }
+                      }))
+              .apply(ParDo.of(new TestDoFn(this.str1, this.str2, this.int1, 
this.int2)))
+              .apply(
+                  MapElements.via(
+                      new InferableFunction<String, Row>() {
+                        @Override
+                        public Row apply(String input) throws Exception {
+                          return Row.withSchema(Schema.of(Field.of("out_str", 
FieldType.STRING)))
+                              .withFieldValue("out_str", input)
+                              .build();
+                        }
+                      }))
+              .setRowSchema(Schema.of(Field.of("out_str", FieldType.STRING)));
+      PCollection<Row> outputPC2 =
+          input
+              .get("input2")
+              .apply(
+                  MapElements.via(
+                      new InferableFunction<Row, String>() {
+                        @Override
+                        public String apply(Row input) throws Exception {
+                          return input.getString("in_str");
+                        }
+                      }))
+              .apply(ParDo.of(new TestDoFn(this.str1, this.str2, this.int1, 
this.int2)))
+              .apply(
+                  MapElements.via(
+                      new InferableFunction<String, Row>() {
+                        @Override
+                        public Row apply(String input) throws Exception {
+                          return Row.withSchema(Schema.of(Field.of("out_str", 
FieldType.STRING)))
+                              .withFieldValue("out_str", input)
+                              .build();
+                        }
+                      }))
+              .setRowSchema(Schema.of(Field.of("out_str", FieldType.STRING)));
+      return PCollectionRowTuple.of("output1", outputPC1, "output2", 
outputPC2);
+    }
+  }
+
+  @Test
+  public void testSchemaTransformDiscovery() {
+    ExpansionApi.DiscoverSchemaTransformRequest discoverRequest =
+        ExpansionApi.DiscoverSchemaTransformRequest.newBuilder().build();
+    ExpansionApi.DiscoverSchemaTransformResponse response =
+        expansionService.discover(discoverRequest);
+    assertEquals(2, response.getSchemaTransformConfigsCount());
+  }
+
+  private void verifyLeafTransforms(ExpansionApi.ExpansionResponse response, 
int count) {
+
+    int leafTransformCount = 0;
+    for (RunnerApi.PTransform transform : 
response.getComponents().getTransformsMap().values()) {
+      if 
(transform.getSpec().getUrn().equals(PTransformTranslation.PAR_DO_TRANSFORM_URN))
 {
+        RunnerApi.ParDoPayload parDoPayload;
+        try {
+          parDoPayload = 
RunnerApi.ParDoPayload.parseFrom(transform.getSpec().getPayload());
+          DoFn doFn = ParDoTranslation.getDoFn(parDoPayload);
+          if (!(doFn instanceof TestDoFn)) {
+            continue;
+          }
+          TestDoFn testDoFn = (TestDoFn) doFn;
+          assertEquals("aaa", testDoFn.str1);
+          assertEquals("bbb", testDoFn.str2);
+          assertEquals(111, testDoFn.int1);
+          assertEquals(222, testDoFn.int2);
+          leafTransformCount++;
+        } catch (InvalidProtocolBufferException exc) {
+          throw new RuntimeException(exc);
+        }
+      }
+    }
+    assertEquals(count, leafTransformCount);
+  }
+
+  @Test
+  public void testSchemaTransformExpansion() {
+    Pipeline p = Pipeline.create();
+    p.apply(Impulse.create());
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+
+    String inputPcollId =
+        Iterables.getOnlyElement(
+            
Iterables.getOnlyElement(pipelineProto.getComponents().getTransformsMap().values())
+                .getOutputsMap()
+                .values());
+    Row configRow =
+        Row.withSchema(TEST_SCHEMATRANSFORM_CONFIG_SCHEMA)
+            .withFieldValue("str1", "aaa")
+            .withFieldValue("str2", "bbb")
+            .withFieldValue("int1", 111)
+            .withFieldValue("int2", 222)
+            .build();
+
+    ByteStringOutputStream outputStream = new ByteStringOutputStream();
+    try {
+      SchemaCoder.of(configRow.getSchema()).encode(configRow, outputStream);
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
+
+    ExternalTransforms.SchemaTransformPayload payload =
+        ExternalTransforms.SchemaTransformPayload.newBuilder()
+            .setIdentifier("dummy_id")
+            .setConfigurationRow(outputStream.toByteString())
+            .setConfigurationSchema(
+                
SchemaTranslation.schemaToProto(TEST_SCHEMATRANSFORM_CONFIG_SCHEMA, true))
+            .build();
+
+    ExpansionApi.ExpansionRequest request =
+        ExpansionApi.ExpansionRequest.newBuilder()
+            .setComponents(pipelineProto.getComponents())
+            .setTransform(
+                RunnerApi.PTransform.newBuilder()
+                    .setUniqueName(TEST_NAME)
+                    .setSpec(
+                        RunnerApi.FunctionSpec.newBuilder()
+                            
.setUrn(getUrn(ExpansionMethods.Enum.SCHEMA_TRANSFORM))
+                            .setPayload(payload.toByteString()))
+                    .putInputs("input1", inputPcollId))
+            .setNamespace(TEST_NAMESPACE)
+            .build();
+    ExpansionApi.ExpansionResponse response = expansionService.expand(request);
+    RunnerApi.PTransform expandedTransform = response.getTransform();
+
+    assertEquals(3, expandedTransform.getSubtransformsCount());
+    assertEquals(1, expandedTransform.getInputsCount());
+    assertEquals(1, expandedTransform.getOutputsCount());
+    verifyLeafTransforms(response, 1);
+  }
+
+  @Test
+  public void testSchemaTransformExpansionMultiInputMultiOutput() {
+    Pipeline p = Pipeline.create();
+    p.apply(Impulse.create());
+    p.apply(Impulse.create());
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+
+    List<String> inputPcollIds = new ArrayList<>();
+    for (RunnerApi.PTransform transform :
+        pipelineProto.getComponents().getTransformsMap().values()) {
+      
inputPcollIds.add(Iterables.getOnlyElement(transform.getOutputsMap().values()));
+    }
+    assertEquals(2, inputPcollIds.size());
+
+    Row configRow =
+        Row.withSchema(TEST_SCHEMATRANSFORM_CONFIG_SCHEMA)
+            .withFieldValue("str1", "aaa")
+            .withFieldValue("str2", "bbb")
+            .withFieldValue("int1", 111)
+            .withFieldValue("int2", 222)
+            .build();
+
+    ByteStringOutputStream outputStream = new ByteStringOutputStream();
+    try {
+      SchemaCoder.of(configRow.getSchema()).encode(configRow, outputStream);
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
+
+    ExternalTransforms.SchemaTransformPayload payload =
+        ExternalTransforms.SchemaTransformPayload.newBuilder()
+            .setIdentifier("dummy_id_multi_input_multi_output")
+            .setConfigurationRow(outputStream.toByteString())
+            .setConfigurationSchema(
+                
SchemaTranslation.schemaToProto(TEST_SCHEMATRANSFORM_CONFIG_SCHEMA, true))
+            .build();
+
+    ExpansionApi.ExpansionRequest request =
+        ExpansionApi.ExpansionRequest.newBuilder()
+            .setComponents(pipelineProto.getComponents())
+            .setTransform(
+                RunnerApi.PTransform.newBuilder()
+                    .setUniqueName(TEST_NAME)
+                    .setSpec(
+                        RunnerApi.FunctionSpec.newBuilder()
+                            
.setUrn(getUrn(ExpansionMethods.Enum.SCHEMA_TRANSFORM))
+                            .setPayload(payload.toByteString()))
+                    .putInputs("input1", inputPcollIds.get(0))
+                    .putInputs("input2", inputPcollIds.get(1)))
+            .setNamespace(TEST_NAMESPACE)
+            .build();
+    ExpansionApi.ExpansionResponse response = expansionService.expand(request);
+    RunnerApi.PTransform expandedTransform = response.getTransform();
+
+    assertEquals(6, expandedTransform.getSubtransformsCount());
+    assertEquals(2, expandedTransform.getInputsCount());
+    assertEquals(2, expandedTransform.getOutputsCount());
+    verifyLeafTransforms(response, 2);
+  }
+}
diff --git a/sdks/python/apache_beam/portability/common_urns.py 
b/sdks/python/apache_beam/portability/common_urns.py
index 3b47f1ab1e4..3799af5d2e1 100644
--- a/sdks/python/apache_beam/portability/common_urns.py
+++ b/sdks/python/apache_beam/portability/common_urns.py
@@ -78,6 +78,7 @@ requirements = StandardRequirements.Enum
 displayData = StandardDisplayData.DisplayData
 
 java_class_lookup = ExpansionMethods.Enum.JAVA_CLASS_LOOKUP
+schematransform_based_expand = ExpansionMethods.Enum.SCHEMA_TRANSFORM
 
 decimal = LogicalTypes.Enum.DECIMAL
 micros_instant = LogicalTypes.Enum.MICROS_INSTANT
diff --git a/sdks/python/apache_beam/transforms/external.py 
b/sdks/python/apache_beam/transforms/external.py
index 58b5182593e..7a51379a0e3 100644
--- a/sdks/python/apache_beam/transforms/external.py
+++ b/sdks/python/apache_beam/transforms/external.py
@@ -28,6 +28,7 @@ import glob
 import logging
 import threading
 from collections import OrderedDict
+from collections import namedtuple
 from typing import Dict
 
 import grpc
@@ -104,6 +105,28 @@ class PayloadBuilder(object):
     """
     return self.build().SerializeToString()
 
+  def _get_schema_proto_and_payload(self, **kwargs):
+    named_fields = []
+    fields_to_values = OrderedDict()
+
+    for key, value in kwargs.items():
+      if not key:
+        raise ValueError('Parameter name cannot be empty')
+      if value is None:
+        raise ValueError(
+            'Received value None for key %s. None values are currently not '
+            'supported' % key)
+      named_fields.append(
+          (key, convert_to_typing_type(instance_to_type(value))))
+      fields_to_values[key] = value
+
+    schema_proto = named_fields_to_schema(named_fields)
+    row = named_tuple_from_schema(schema_proto)(**fields_to_values)
+    schema = named_tuple_to_schema(type(row))
+
+    payload = RowCoder(schema).encode(row)
+    return (schema_proto, payload)
+
 
 class SchemaBasedPayloadBuilder(PayloadBuilder):
   """
@@ -156,6 +179,20 @@ class 
NamedTupleBasedPayloadBuilder(SchemaBasedPayloadBuilder):
     return self._tuple_instance
 
 
+class SchemaTransformPayloadBuilder(PayloadBuilder):
+  def __init__(self, identifier, **kwargs):
+    self._identifier = identifier
+    self._kwargs = kwargs
+
+  def build(self):
+    schema_proto, payload = self._get_schema_proto_and_payload(**self._kwargs)
+    payload = external_transforms_pb2.SchemaTransformPayload(
+        identifier=self._identifier,
+        configuration_schema=schema_proto,
+        configuration_row=payload)
+    return payload
+
+
 class JavaClassLookupPayloadBuilder(PayloadBuilder):
   """
   Builds a payload for directly instantiating a Java transform using a
@@ -177,45 +214,26 @@ class JavaClassLookupPayloadBuilder(PayloadBuilder):
     self._constructor_param_kwargs = None
     self._builder_methods_and_params = OrderedDict()
 
-  def _get_schema_proto_and_payload(self, *args, **kwargs):
-    named_fields = []
-    fields_to_values = OrderedDict()
+  def _args_to_named_fields(self, args):
     next_field_id = 0
+    named_fields = OrderedDict()
     for value in args:
       if value is None:
         raise ValueError(
             'Received value None. None values are currently not supported')
-      named_fields.append(
-          ((JavaClassLookupPayloadBuilder.IGNORED_ARG_FORMAT % next_field_id),
-           convert_to_typing_type(instance_to_type(value))))
-      fields_to_values[(
+      named_fields[(
           JavaClassLookupPayloadBuilder.IGNORED_ARG_FORMAT %
           next_field_id)] = value
       next_field_id += 1
-    for key, value in kwargs.items():
-      if not key:
-        raise ValueError('Parameter name cannot be empty')
-      if value is None:
-        raise ValueError(
-            'Received value None for key %s. None values are currently not '
-            'supported' % key)
-      named_fields.append(
-          (key, convert_to_typing_type(instance_to_type(value))))
-      fields_to_values[key] = value
-
-    schema_proto = named_fields_to_schema(named_fields)
-    row = named_tuple_from_schema(schema_proto)(**fields_to_values)
-    schema = named_tuple_to_schema(type(row))
-
-    payload = RowCoder(schema).encode(row)
-    return (schema_proto, payload)
+    return named_fields
 
   def build(self):
-    constructor_param_args = self._constructor_param_args or []
-    constructor_param_kwargs = self._constructor_param_kwargs or {}
+    all_constructor_param_kwargs = self._args_to_named_fields(
+        self._constructor_param_args)
+    if self._constructor_param_kwargs:
+      all_constructor_param_kwargs.update(self._constructor_param_kwargs)
     constructor_schema, constructor_payload = (
-        self._get_schema_proto_and_payload(
-            *constructor_param_args, **constructor_param_kwargs))
+      self._get_schema_proto_and_payload(**all_constructor_param_kwargs))
     payload = external_transforms_pb2.JavaClassLookupPayload(
         class_name=self._class_name,
         constructor_schema=constructor_schema,
@@ -225,9 +243,12 @@ class JavaClassLookupPayloadBuilder(PayloadBuilder):
 
     for builder_method_name, params in 
self._builder_methods_and_params.items():
       builder_method_args, builder_method_kwargs = params
+      all_builder_method_kwargs = self._args_to_named_fields(
+          builder_method_args)
+      if builder_method_kwargs:
+        all_builder_method_kwargs.update(builder_method_kwargs)
       builder_method_schema, builder_method_payload = (
-          self._get_schema_proto_and_payload(
-              *builder_method_args, **builder_method_kwargs))
+        self._get_schema_proto_and_payload(**all_builder_method_kwargs))
       builder_method = external_transforms_pb2.BuilderMethod(
           name=builder_method_name,
           schema=builder_method_schema,
@@ -289,6 +310,64 @@ class JavaClassLookupPayloadBuilder(PayloadBuilder):
         self._constructor_param_kwargs)
 
 
+# Information regarding a SchemaTransform available in an external SDK.
+SchemaTransformsConfig = namedtuple(
+    'SchemaTransformsConfig',
+    ['identifier', 'configuration_schema', 'inputs', 'outputs'])
+
+
+class SchemaAwareExternalTransform(ptransform.PTransform):
+  """A proxy transform for SchemaTransforms implemented in external SDKs.
+
+  This allows Python pipelines to directly use existing SchemaTransforms
+  available to the expansion service without adding additional code in external
+  SDKs.
+
+  :param identifier: unique identifier of the SchemaTransform.
+  :param expansion_service: an expansion service to use. This should already be
+      available and the Schema-aware transforms to be used must already be
+      deployed.
+  :param classpath: (Optional) A list paths to additional jars to place on the
+      expansion service classpath.
+  :kwargs: field name to value mapping for configuring the schema transform.
+      keys map to the field names of the schema of the SchemaTransform
+      (in-order).
+  """
+  def __init__(self, identifier, expansion_service, classpath=None, **kwargs):
+    self._expansion_service = expansion_service
+    self._payload_builder = SchemaTransformPayloadBuilder(identifier, **kwargs)
+    self._classpath = classpath
+
+  def expand(self, pcolls):
+    # Expand the transform using the expansion service.
+    return pcolls | ExternalTransform(
+        common_urns.schematransform_based_expand.urn,
+        self._payload_builder,
+        self._expansion_service)
+
+  @staticmethod
+  def discover(expansion_service):
+    """Discover all SchemaTransforms available to the given expansion service.
+
+    :return: a list of SchemaTransformsConfigs that represent the discovered
+        SchemaTransforms.
+    """
+
+    with ExternalTransform.service(expansion_service) as service:
+      discover_response = service.DiscoverSchemaTransform(
+          beam_expansion_api_pb2.DiscoverSchemaTransformRequest())
+
+      for identifier in discover_response.schema_transform_configs:
+        proto_config = discover_response.schema_transform_configs[identifier]
+        schema = named_tuple_from_schema(proto_config.config_schema)
+
+        yield SchemaTransformsConfig(
+            identifier=identifier,
+            configuration_schema=schema,
+            inputs=proto_config.input_pcollection_names,
+            outputs=proto_config.output_pcollection_names)
+
+
 class JavaExternalTransform(ptransform.PTransform):
   """A proxy for Java-implemented external transforms.
 
@@ -520,7 +599,7 @@ class ExternalTransform(ptransform.PTransform):
         transform=transform_proto,
         output_coder_requests=output_coders)
 
-    with self._service() as service:
+    with ExternalTransform.service(self._expansion_service) as service:
       response = service.Expand(request)
       if response.error:
         raise RuntimeError(response.error)
@@ -549,9 +628,10 @@ class ExternalTransform(ptransform.PTransform):
 
     return self._output_to_pvalueish(self._outputs)
 
+  @staticmethod
   @contextlib.contextmanager
-  def _service(self):
-    if isinstance(self._expansion_service, str):
+  def service(expansion_service):
+    if isinstance(expansion_service, str):
       channel_options = [("grpc.max_receive_message_length", -1),
                          ("grpc.max_send_message_length", -1)]
       if hasattr(grpc, 'local_channel_credentials'):
@@ -560,7 +640,7 @@ class ExternalTransform(ptransform.PTransform):
         # TODO: update this to support secure non-local channels.
         channel_factory_fn = functools.partial(
             grpc.secure_channel,
-            self._expansion_service,
+            expansion_service,
             grpc.local_channel_credentials(),
             options=channel_options)
       else:
@@ -568,15 +648,13 @@ class ExternalTransform(ptransform.PTransform):
         # by older versions of grpc which may be pulled in due to other project
         # dependencies.
         channel_factory_fn = functools.partial(
-            grpc.insecure_channel,
-            self._expansion_service,
-            options=channel_options)
+            grpc.insecure_channel, expansion_service, options=channel_options)
       with channel_factory_fn() as channel:
         yield ExpansionAndArtifactRetrievalStub(channel)
-    elif hasattr(self._expansion_service, 'Expand'):
-      yield self._expansion_service
+    elif hasattr(expansion_service, 'Expand'):
+      yield expansion_service
     else:
-      with self._expansion_service as stub:
+      with expansion_service as stub:
         yield stub
 
   def _resolve_artifacts(self, components, service, dest):
diff --git a/sdks/python/apache_beam/transforms/external_test.py 
b/sdks/python/apache_beam/transforms/external_test.py
index c567f34330d..f38876367c3 100644
--- a/sdks/python/apache_beam/transforms/external_test.py
+++ b/sdks/python/apache_beam/transforms/external_test.py
@@ -44,6 +44,7 @@ from apache_beam.transforms.external import 
JavaClassLookupPayloadBuilder
 from apache_beam.transforms.external import JavaExternalTransform
 from apache_beam.transforms.external import JavaJarExpansionService
 from apache_beam.transforms.external import NamedTupleBasedPayloadBuilder
+from apache_beam.transforms.external import SchemaTransformPayloadBuilder
 from apache_beam.typehints import typehints
 from apache_beam.typehints.native_type_compatibility import 
convert_to_beam_type
 from apache_beam.utils import proto_utils
@@ -445,6 +446,35 @@ class ExternalDataclassesPayloadTest(PayloadBase, 
unittest.TestCase):
     return get_payload(DataclassTransform(**values))
 
 
+class SchemaTransformPayloadBuilderTest(unittest.TestCase):
+  def test_build_payload(self):
+    ComplexType = typing.NamedTuple(
+        "ComplexType", [
+            ("str_sub_field", str),
+            ("int_sub_field", int),
+        ])
+
+    payload_builder = SchemaTransformPayloadBuilder(
+        identifier='dummy_id',
+        str_field='aaa',
+        int_field=123,
+        object_field=ComplexType(str_sub_field="bbb", int_sub_field=456))
+    payload_bytes = payload_builder.payload()
+    payload_from_bytes = proto_utils.parse_Bytes(
+        payload_bytes, external_transforms_pb2.SchemaTransformPayload)
+
+    self.assertEqual('dummy_id', payload_from_bytes.identifier)
+
+    expected_coder = RowCoder(payload_from_bytes.configuration_schema)
+    schema_transform_config = expected_coder.decode(
+        payload_from_bytes.configuration_row)
+
+    self.assertEqual('aaa', schema_transform_config.str_field)
+    self.assertEqual(123, schema_transform_config.int_field)
+    self.assertEqual('bbb', schema_transform_config.object_field.str_sub_field)
+    self.assertEqual(456, schema_transform_config.object_field.int_sub_field)
+
+
 class JavaClassLookupPayloadBuilderTest(unittest.TestCase):
   def _verify_row(self, schema, row_payload, expected_values):
     row = RowCoder(schema).decode(row_payload)

Reply via email to