This is an automated email from the ASF dual-hosted git repository.

ahmedabualsaud pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 921e40a12f4 Dynamic SchemaTransform wrapper provider (#29561)
921e40a12f4 is described below

commit 921e40a12f467c51161bf33a0144fb8a1d4ca334
Author: Ahmed Abualsaud <[email protected]>
AuthorDate: Thu Dec 14 05:58:29 2023 +0300

    Dynamic SchemaTransform wrapper provider (#29561)
    
    * wrapper provider
    
    * add typehints
    
    * add GenerateSequence schematransform and expansion service for java core
    
    * address comments; add description property
    
    * add description test
    
    * add experimental note
---
 .../GenerateSequenceSchemaTransformProvider.java   | 201 +++++++++++++++
 .../apache/beam/sdk/providers/package-info.java    |  23 ++
 ...enerateSequenceSchemaTransformProviderTest.java |  61 +++++
 .../io/external/xlang_kafkaio_it_test.py           |   4 +-
 sdks/python/apache_beam/transforms/external.py     |   1 +
 .../external_schematransform_provider.py           | 277 +++++++++++++++++++++
 .../external_schematransform_provider_test.py      | 140 +++++++++++
 sdks/python/apache_beam/typehints/schemas.py       |   3 +
 sdks/python/pytest.ini                             |   2 +-
 sdks/python/test-suites/dataflow/build.gradle      |   4 +-
 sdks/python/test-suites/direct/build.gradle        |   8 +-
 sdks/python/test-suites/gradle.properties          |   4 +-
 sdks/python/test-suites/xlang/build.gradle         |   9 +-
 13 files changed, 723 insertions(+), 14 deletions(-)

diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/providers/GenerateSequenceSchemaTransformProvider.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/providers/GenerateSequenceSchemaTransformProvider.java
new file mode 100644
index 00000000000..f4cada661b0
--- /dev/null
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/providers/GenerateSequenceSchemaTransformProvider.java
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.providers;
+
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+
+import com.google.auto.service.AutoService;
+import com.google.auto.value.AutoValue;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.io.GenerateSequence;
+import 
org.apache.beam.sdk.providers.GenerateSequenceSchemaTransformProvider.GenerateSequenceConfiguration;
+import org.apache.beam.sdk.schemas.AutoValueSchema;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
+import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
+import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.sdk.values.TypeDescriptors;
+import org.joda.time.Duration;
+
+@AutoService(SchemaTransformProvider.class)
+public class GenerateSequenceSchemaTransformProvider
+    extends TypedSchemaTransformProvider<GenerateSequenceConfiguration> {
+  public static final String OUTPUT_ROWS_TAG = "output";
+  public static final Schema OUTPUT_SCHEMA = 
Schema.builder().addInt64Field("value").build();
+
+  @Override
+  public String identifier() {
+    return "beam:schematransform:org.apache.beam:generate_sequence:v1";
+  }
+
+  @Override
+  public List<String> inputCollectionNames() {
+    return Collections.emptyList();
+  }
+
+  @Override
+  public List<String> outputCollectionNames() {
+    return Collections.singletonList(OUTPUT_ROWS_TAG);
+  }
+
+  @Override
+  public String description() {
+    return String.format(
+        "Outputs a PCollection of Beam Rows, each containing a single INT64 "
+            + "number called \"value\". The count is produced from the given 
\"start\""
+            + "value and either up to the given \"end\" or until 2^63 - 1.\n"
+            + "To produce an unbounded PCollection, simply do not specify an 
\"end\" value. "
+            + "Unbounded sequences can specify a \"rate\" for output 
elements.\n"
+            + "In all cases, the sequence of numbers is generated in parallel, 
so there is no "
+            + "inherent ordering between the generated values");
+  }
+
+  @Override
+  public Class<GenerateSequenceConfiguration> configurationClass() {
+    return GenerateSequenceConfiguration.class;
+  }
+
+  @Override
+  public SchemaTransform from(GenerateSequenceConfiguration configuration) {
+    return new GenerateSequenceSchemaTransform(configuration);
+  }
+
+  @DefaultSchema(AutoValueSchema.class)
+  @AutoValue
+  public abstract static class GenerateSequenceConfiguration {
+    @AutoValue
+    public abstract static class Rate {
+      @SchemaFieldDescription("Number of elements component of the rate.")
+      public abstract Long getElements();
+
+      @SchemaFieldDescription("Number of seconds component of the rate.")
+      @Nullable
+      public abstract Long getSeconds();
+
+      public static Builder builder() {
+        return new 
AutoValue_GenerateSequenceSchemaTransformProvider_GenerateSequenceConfiguration_Rate
+            .Builder();
+      }
+
+      @AutoValue.Builder
+      public abstract static class Builder {
+        public abstract Builder setElements(Long elements);
+
+        public abstract Builder setSeconds(Long seconds);
+
+        public abstract Rate build();
+      }
+    }
+
+    public static Builder builder() {
+      return new 
AutoValue_GenerateSequenceSchemaTransformProvider_GenerateSequenceConfiguration
+          .Builder();
+    }
+
+    @SchemaFieldDescription("The minimum number to generate (inclusive).")
+    public abstract Long getStart();
+
+    @SchemaFieldDescription(
+        "The maximum number to generate (exclusive). Will be an unbounded 
sequence if left unspecified.")
+    @Nullable
+    public abstract Long getEnd();
+
+    @SchemaFieldDescription(
+        "Specifies the rate to generate a given number of elements per a given 
number of seconds. "
+            + "Applicable only to unbounded sequences.")
+    @Nullable
+    public abstract Rate getRate();
+
+    @AutoValue.Builder
+    public abstract static class Builder {
+
+      public abstract Builder setStart(Long start);
+
+      public abstract Builder setEnd(Long end);
+
+      public abstract Builder setRate(Rate rate);
+
+      public abstract GenerateSequenceConfiguration build();
+    }
+
+    public void validate() {
+      checkNotNull(this.getStart(), "Must specify a starting point 
\"start\".");
+      Long start = this.getStart();
+      Long end = this.getEnd();
+      if (end != null) {
+        checkArgument(end == -1 || end >= start, "Invalid range [%s, %s)", 
start, end);
+      }
+      Rate rate = this.getRate();
+      if (rate != null) {
+        checkArgument(
+            rate.getElements() > 0,
+            "Invalid rate specification. Expected positive elements component 
but received %s.",
+            rate.getElements());
+        checkArgument(
+            Optional.ofNullable(rate.getSeconds()).orElse(1L) > 0,
+            "Invalid rate specification. Expected positive seconds component 
but received %s.",
+            rate.getSeconds());
+      }
+    }
+  }
+
+  protected static class GenerateSequenceSchemaTransform extends 
SchemaTransform {
+    private final GenerateSequenceConfiguration configuration;
+
+    GenerateSequenceSchemaTransform(GenerateSequenceConfiguration 
configuration) {
+      configuration.validate();
+      this.configuration = configuration;
+    }
+
+    @Override
+    public PCollectionRowTuple expand(PCollectionRowTuple input) {
+      checkArgument(
+          input.getAll().isEmpty(), "Expected no inputs but got: %s", 
input.getAll().keySet());
+
+      Long end = Optional.ofNullable(configuration.getEnd()).orElse(-1L);
+      GenerateSequenceConfiguration.Rate rate = configuration.getRate();
+
+      GenerateSequence sequence = 
GenerateSequence.from(configuration.getStart()).to(end);
+      if (rate != null) {
+        sequence =
+            sequence.withRate(
+                rate.getElements(),
+                
Duration.standardSeconds(Optional.ofNullable(rate.getSeconds()).orElse(1L)));
+      }
+
+      return PCollectionRowTuple.of(
+          OUTPUT_ROWS_TAG,
+          input
+              .getPipeline()
+              .apply(sequence)
+              .apply(
+                  MapElements.into(TypeDescriptors.rows())
+                      .via(l -> 
Row.withSchema(OUTPUT_SCHEMA).withFieldValue("value", l).build()))
+              .setRowSchema(OUTPUT_SCHEMA));
+    }
+  }
+}
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/providers/package-info.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/providers/package-info.java
new file mode 100644
index 00000000000..6d90b7d018a
--- /dev/null
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/providers/package-info.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Defines {@link 
org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider}s for transforms 
in
+ * the core module.
+ */
+package org.apache.beam.sdk.providers;
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/providers/GenerateSequenceSchemaTransformProviderTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/providers/GenerateSequenceSchemaTransformProviderTest.java
new file mode 100644
index 00000000000..dcff3dedb84
--- /dev/null
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/providers/GenerateSequenceSchemaTransformProviderTest.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.providers;
+
+import java.util.ArrayList;
+import java.util.List;
+import 
org.apache.beam.sdk.providers.GenerateSequenceSchemaTransformProvider.GenerateSequenceConfiguration;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.Row;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GenerateSequenceSchemaTransformProviderTest {
+  @Rule public transient TestPipeline p = TestPipeline.create();
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testGenerateSequence() {
+    GenerateSequenceConfiguration config =
+        
GenerateSequenceConfiguration.builder().setStart(0L).setEnd(10L).build();
+    SchemaTransform sequence = new 
GenerateSequenceSchemaTransformProvider().from(config);
+
+    List<Row> expected = new ArrayList<>(10);
+    for (long i = 0L; i < 10L; i++) {
+      expected.add(
+          Row.withSchema(GenerateSequenceSchemaTransformProvider.OUTPUT_SCHEMA)
+              .withFieldValue("value", i)
+              .build());
+    }
+
+    PAssert.that(
+            PCollectionRowTuple.empty(p)
+                .apply(sequence)
+                .get(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG))
+        .containsInAnyOrder(expected);
+    p.run();
+  }
+}
diff --git a/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py 
b/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py
index a2f350e8cb7..a7bf686d064 100644
--- a/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py
+++ b/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py
@@ -143,7 +143,7 @@ class CrossLanguageKafkaIOTest(unittest.TestCase):
       self.run_kafka_write(pipeline_creator)
       self.run_kafka_read(pipeline_creator, None)
 
-  @pytest.mark.uses_io_expansion_service
+  @pytest.mark.uses_io_java_expansion_service
   @unittest.skipUnless(
       os.environ.get('EXPANSION_PORT'),
       "EXPANSION_PORT environment var is not provided.")
@@ -162,7 +162,7 @@ class CrossLanguageKafkaIOTest(unittest.TestCase):
     self.run_kafka_write(pipeline_creator)
     self.run_kafka_read(pipeline_creator, b'key')
 
-  @pytest.mark.uses_io_expansion_service
+  @pytest.mark.uses_io_java_expansion_service
   @unittest.skipUnless(
       os.environ.get('EXPANSION_PORT'),
       "EXPANSION_PORT environment var is not provided.")
diff --git a/sdks/python/apache_beam/transforms/external.py 
b/sdks/python/apache_beam/transforms/external.py
index 997cea347d3..71dfc545204 100644
--- a/sdks/python/apache_beam/transforms/external.py
+++ b/sdks/python/apache_beam/transforms/external.py
@@ -1067,6 +1067,7 @@ class BeamJarExpansionService(JavaJarExpansionService):
       append_args=None):
     path_to_jar = subprocess_server.JavaJarServer.path_to_beam_jar(
         gradle_target, gradle_appendix)
+    self.gradle_target = gradle_target
     super().__init__(
         path_to_jar, extra_args, classpath=classpath, append_args=append_args)
 
diff --git 
a/sdks/python/apache_beam/transforms/external_schematransform_provider.py 
b/sdks/python/apache_beam/transforms/external_schematransform_provider.py
new file mode 100644
index 00000000000..fd650087893
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/external_schematransform_provider.py
@@ -0,0 +1,277 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import re
+from collections import namedtuple
+from typing import Dict
+from typing import List
+from typing import Tuple
+
+from apache_beam.transforms import PTransform
+from apache_beam.transforms.external import BeamJarExpansionService
+from apache_beam.transforms.external import SchemaAwareExternalTransform
+from apache_beam.transforms.external import SchemaTransformsConfig
+from apache_beam.typehints.schemas import named_tuple_to_schema
+from apache_beam.typehints.schemas import typing_from_runner_api
+
+__all__ = ['ExternalSchemaTransform', 'ExternalSchemaTransformProvider']
+
+
+def snake_case_to_upper_camel_case(string):
+  """Convert snake_case to UpperCamelCase"""
+  components = string.split('_')
+  output = ''.join(n.capitalize() for n in components)
+  return output
+
+
+def snake_case_to_lower_camel_case(string):
+  """Convert snake_case to lowerCamelCase"""
+  if len(string) <= 1:
+    return string.lower()
+  upper = snake_case_to_upper_camel_case(string)
+  return upper[0].lower() + upper[1:]
+
+
+def camel_case_to_snake_case(string):
+  """Convert camelCase to snake_case"""
+  arr = []
+  word = []
+  for i, n in enumerate(string):
+    # If seeing an upper letter after a lower letter, we just witnessed a word
+    # If seeing an upper letter and the next letter is lower, we may have just
+    # witnessed an all caps word
+    if n.isupper() and ((i > 0 and string[i - 1].islower()) or
+                        (i + 1 < len(string) and string[i + 1].islower())):
+      arr.append(''.join(word))
+      word = [n.lower()]
+    else:
+      word.append(n.lower())
+  arr.append(''.join(word))
+  return '_'.join(arr).strip('_')
+
+
+# Information regarding a Wrapper parameter.
+ParamInfo = namedtuple('ParamInfo', ['type', 'description', 'original_name'])
+
+
+def get_config_with_descriptions(
+    schematransform: SchemaTransformsConfig) -> Dict[str, ParamInfo]:
+  # Prepare a configuration schema that includes types and descriptions
+  schema = named_tuple_to_schema(schematransform.configuration_schema)
+  descriptions = schematransform.configuration_schema._field_descriptions
+  fields_with_descriptions = {}
+  for field in schema.fields:
+    fields_with_descriptions[camel_case_to_snake_case(field.name)] = ParamInfo(
+        typing_from_runner_api(field.type),
+        descriptions[field.name],
+        field.name)
+
+  return fields_with_descriptions
+
+
+class ExternalSchemaTransform(PTransform):
+  """Template for a wrapper class of an external SchemaTransform
+
+  This is a superclass for dynamically generated SchemaTransform wrappers and
+  is not meant to be manually instantiated.
+
+  Experimental; no backwards compatibility guarantees."""
+
+  # These attributes need to be set when
+  # creating an ExternalSchemaTransform type
+  default_expansion_service = None
+  description: str = ""
+  identifier: str = ""
+  configuration_schema: Dict[str, ParamInfo] = {}
+
+  def __init__(self, expansion_service=None, **kwargs):
+    self._kwargs = kwargs
+    self._expansion_service = \
+        expansion_service or self.default_expansion_service
+
+  def expand(self, input):
+    camel_case_kwargs = {
+        snake_case_to_lower_camel_case(k): v
+        for k, v in self._kwargs.items()
+    }
+
+    external_schematransform = SchemaAwareExternalTransform(
+        identifier=self.identifier,
+        expansion_service=self._expansion_service,
+        rearrange_based_on_discovery=True,
+        **camel_case_kwargs)
+
+    return input | external_schematransform
+
+
+STANDARD_URN_PATTERN = r"^beam:schematransform:org.apache.beam:([\w-]+):(\w+)$"
+
+
+def infer_name_from_identifier(identifier: str, pattern: str):
+  """Infer a class name from an identifier, adhering to the input pattern"""
+  match = re.match(pattern, identifier)
+  if not match:
+    return None
+  groups = match.groups()
+
+  components = [snake_case_to_upper_camel_case(n) for n in groups]
+  # Special handling for standard SchemaTransform identifiers:
+  # We don't include the version number if it's the first version
+  if (pattern == STANDARD_URN_PATTERN and components[1].lower() == 'v1'):
+    return components[0]
+  else:
+    return ''.join(components)
+
+
+class ExternalSchemaTransformProvider:
+  """Dynamically discovers Schema-aware external transforms from a given list
+  of expansion services and provides them as ready PTransforms.
+
+  A :class:`ExternalSchemaTransform` subclass is generated for each external
+  transform, and is named based on what can be inferred from the URN
+  (see :param urn_pattern).
+
+  These classes are generated when :class:`ExternalSchemaTransformProvider` is
+  initialized. We need to give it one or more expansion service addresses that
+  are already up and running:
+  >>> provider = ExternalSchemaTransformProvider(["localhost:12345",
+  ...                                             "localhost:12121"])
+  We can also give it the gradle target of a standard Beam expansion service:
+  >>> provider = ExternalSchemaTransform(BeamJarExpansionService(
+  ...     "sdks:java:io:google-cloud-platform:expansion-service:shadowJar"))
+  Let's take a look at the output of :func:`get_available()` to know the
+  available transforms in the expansion service(s) we provided:
+  >>> provider.get_available()
+  [('JdbcWrite', 'beam:schematransform:org.apache.beam:jdbc_write:v1'),
+  ('BigtableRead', 'beam:schematransform:org.apache.beam:bigtable_read:v1'),
+  ...]
+
+  Then retrieve a transform by :func:`get()`, :func:`get_urn()`, or by directly
+  accessing it as an attribute of :class:`ExternalSchemaTransformProvider`.
+  All of the following commands do the same thing:
+  >>> provider.get('BigqueryStorageRead')
+  >>> provider.get_urn(
+  ...       'beam:schematransform:org.apache.beam:bigquery_storage_read:v1')
+  >>> provider.BigqueryStorageRead
+
+  To know more about the usage of a given transform, take a look at the
+  `description` attribute. This returns some documentation IF the underlying
+  SchemaTransform provides any.
+  >>> provider.BigqueryStorageRead.description
+
+  Similarly, the `configuration_schema` attribute returns information about the
+  parameters, including their names, types, and any documentation that the
+  underlying SchemaTransform may provide:
+  >>> provider.BigqueryStorageRead.configuration_schema
+  {'query': ParamInfo(type=typing.Optional[str], description='The SQL query to
+  be executed to read from the BigQuery table.', original_name='query'),
+  'row_restriction': ParamInfo(type=typing.Optional[str]...}
+
+  The retrieved external transform can be used as a normal PTransform like so::
+
+    with Pipeline() as p:
+      _ = (p
+        | 'Read from BigQuery` >> provider.BigqueryStorageRead(
+                query=query,
+                row_restriction=restriction)
+        | 'Some processing' >> beam.Map(...))
+
+  Experimental; no backwards compatibility guarantees.
+  """
+  def __init__(self, expansion_services, urn_pattern=STANDARD_URN_PATTERN):
+    f"""Initialize an ExternalSchemaTransformProvider
+
+    :param expansion_services:
+      A list of expansion services to discover transforms from.
+      Supported forms:
+      * a string representing the expansion service address
+      * a :attr:`BeamJarExpansionService` pointing to a gradle target
+    :param urn_pattern:
+      The regular expression used to match valid transforms. In addition to
+      validating, the captured groups are used to infer a name for each class.
+      By default, the following pattern is used: [{STANDARD_URN_PATTERN}]
+    """
+    self._urn_pattern = urn_pattern
+    self._transforms: Dict[str, type(ExternalSchemaTransform)] = {}
+    self._name_to_urn: Dict[str, str] = {}
+
+    if isinstance(expansion_services, set):
+      expansion_services = list(expansion_services)
+    if not isinstance(expansion_services, list):
+      expansion_services = [expansion_services]
+    self.expansion_services = expansion_services
+    self._create_wrappers()
+
+  def _create_wrappers(self):
+    # multiple services can overlap and include the same URNs. If this happens,
+    # we prioritize by the order of services in the list
+    identifiers = set()
+    for service in self.expansion_services:
+      target = service
+      if isinstance(service, BeamJarExpansionService):
+        target = service.gradle_target
+      try:
+        schematransform_configs = 
SchemaAwareExternalTransform.discover(service)
+      except Exception as e:
+        logging.exception(
+            "Encountered an error while discovering expansion service %s:\n%s",
+            target,
+            e)
+        continue
+      skipped_urns = []
+      for config in schematransform_configs:
+        identifier = config.identifier
+        if identifier not in identifiers:
+          identifiers.add(identifier)
+
+          name = infer_name_from_identifier(identifier, self._urn_pattern)
+          if name is None:
+            skipped_urns.append(identifier)
+            continue
+
+          self._transforms[identifier] = type(
+              name, (ExternalSchemaTransform, ),
+              dict(
+                  identifier=identifier,
+                  default_expansion_service=service,
+                  schematransform=config,
+                  description=config.description,
+                  configuration_schema=get_config_with_descriptions(config)))
+          self._name_to_urn[name] = identifier
+
+      if skipped_urns:
+        logging.info(
+            "Skipped URN(s) in %s that don't follow the pattern [%s]: %s",
+            target,
+            self._urn_pattern,
+            skipped_urns)
+
+    for transform in self._transforms.values():
+      setattr(self, transform.__name__, transform)
+
+  def get_available(self) -> List[Tuple[str, str]]:
+    """Get a list of available ExternalSchemaTransform names and identifiers"""
+    return list(self._name_to_urn.items())
+
+  def get(self, name) -> ExternalSchemaTransform:
+    """Get an ExternalSchemaTransform by its inferred class name"""
+    return self._transforms[self._name_to_urn[name]]
+
+  def get_urn(self, identifier) -> ExternalSchemaTransform:
+    """Get an ExternalSchemaTransform by its SchemaTransform identifier"""
+    return self._transforms[identifier]
diff --git 
a/sdks/python/apache_beam/transforms/external_schematransform_provider_test.py 
b/sdks/python/apache_beam/transforms/external_schematransform_provider_test.py
new file mode 100644
index 00000000000..bf951e671c2
--- /dev/null
+++ 
b/sdks/python/apache_beam/transforms/external_schematransform_provider_test.py
@@ -0,0 +1,140 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import logging
+import os
+import unittest
+
+import pytest
+
+import apache_beam as beam
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.transforms.external import BeamJarExpansionService
+from apache_beam.transforms.external_schematransform_provider import 
STANDARD_URN_PATTERN
+from apache_beam.transforms.external_schematransform_provider import 
ExternalSchemaTransformProvider
+from apache_beam.transforms.external_schematransform_provider import 
camel_case_to_snake_case
+from apache_beam.transforms.external_schematransform_provider import 
infer_name_from_identifier
+from apache_beam.transforms.external_schematransform_provider import 
snake_case_to_lower_camel_case
+from apache_beam.transforms.external_schematransform_provider import 
snake_case_to_upper_camel_case
+
+
+class NameUtilsTest(unittest.TestCase):
+  def test_snake_case_to_upper_camel_case(self):
+    test_cases = [("", ""), ("test", "Test"), ("test_name", "TestName"),
+                  ("test_double_underscore", "TestDoubleUnderscore"),
+                  ("TEST_CAPITALIZED", "TestCapitalized"),
+                  ("_prepended_underscore", "PrependedUnderscore"),
+                  ("appended_underscore_", "AppendedUnderscore")]
+    for case in test_cases:
+      self.assertEqual(case[1], snake_case_to_upper_camel_case(case[0]))
+
+  def test_snake_case_to_lower_camel_case(self):
+    test_cases = [("", ""), ("test", "test"), ("test_name", "testName"),
+                  ("test_double_underscore", "testDoubleUnderscore"),
+                  ("TEST_CAPITALIZED", "testCapitalized"),
+                  ("_prepended_underscore", "prependedUnderscore"),
+                  ("appended_underscore_", "appendedUnderscore")]
+    for case in test_cases:
+      self.assertEqual(case[1], snake_case_to_lower_camel_case(case[0]))
+
+  def test_camel_case_to_snake_case(self):
+    test_cases = [("", ""), ("Test", "test"), ("TestName", "test_name"),
+                  ("TestDoubleUnderscore",
+                   "test_double_underscore"), ("MyToLoFo", "my_to_lo_fo"),
+                  ("BEGINNINGAllCaps",
+                   "beginning_all_caps"), ("AllCapsENDING", "all_caps_ending"),
+                  ("AllCapsMIDDLEWord", "all_caps_middle_word"),
+                  ("lowerCamelCase", "lower_camel_case")]
+    for case in test_cases:
+      self.assertEqual(case[1], camel_case_to_snake_case(case[0]))
+
+  def test_infer_name_from_identifier(self):
+    standard_test_cases = [
+        ("beam:schematransform:org.apache.beam:transform:v1", "Transform"),
+        ("beam:schematransform:org.apache.beam:my_transform:v1",
+         "MyTransform"), (
+             "beam:schematransform:org.apache.beam:my_transform:v2",
+             "MyTransformV2"),
+        ("beam:schematransform:org.apache.beam:fe_fi_fo_fum:v2", 
"FeFiFoFumV2"),
+        ("beam:schematransform:bad_match:my_transform:v1", None)
+    ]
+    for case in standard_test_cases:
+      self.assertEqual(
+          case[1], infer_name_from_identifier(case[0], STANDARD_URN_PATTERN))
+
+    custom_pattern_cases = [
+        # (<pattern>, <urn>, <expected output>)
+        (
+            r"^custom:transform:([\w-]+):(\w+)$",
+            "custom:transform:my_transform:v1",
+            "MyTransformV1"),
+        (
+            r"^org.user:([\w-]+):([\w-]+):([\w-]+):external$",
+            "org.user:some:custom_transform:we_made:external",
+            "SomeCustomTransformWeMade"),
+        (
+            r"^([\w-]+):user.transforms",
+            "my_eXTErnal:user.transforms",
+            "MyExternal"),
+        (r"^([\w-]+):user.transforms", "my_external:badinput.transforms", 
None),
+    ]
+    for case in custom_pattern_cases:
+      self.assertEqual(case[2], infer_name_from_identifier(case[1], case[0]))
+
+
[email protected]_io_java_expansion_service
[email protected](
+    os.environ.get('EXPANSION_PORT'),
+    "EXPANSION_PORT environment var is not provided.")
+class ExternalSchemaTransformProviderTest(unittest.TestCase):
+  def setUp(self):
+    self.test_pipeline = TestPipeline(is_integration_test=True)
+
+  def test_generate_sequence_config_schema_and_description(self):
+    provider = ExternalSchemaTransformProvider(
+        BeamJarExpansionService(":sdks:java:io:expansion-service:shadowJar"))
+
+    self.assertTrue((
+        'GenerateSequence',
+        'beam:schematransform:org.apache.beam:generate_sequence:v1'
+    ) in provider.get_available())
+
+    GenerateSequence = provider.get('GenerateSequence')
+    config_schema = GenerateSequence.configuration_schema
+    for param in ['start', 'end', 'rate']:
+      self.assertTrue(param in config_schema)
+
+    description_substring = (
+        "Outputs a PCollection of Beam Rows, each "
+        "containing a single INT64")
+    self.assertTrue(description_substring in GenerateSequence.description)
+
+  def test_run_generate_sequence(self):
+    provider = ExternalSchemaTransformProvider(
+        BeamJarExpansionService(":sdks:java:io:expansion-service:shadowJar"))
+
+    with beam.Pipeline() as p:
+      numbers = p | provider.GenerateSequence(
+          start=0, end=10) | beam.Map(lambda row: row.value)
+
+      assert_that(numbers, equal_to([i for i in range(10)]))
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()
diff --git a/sdks/python/apache_beam/typehints/schemas.py 
b/sdks/python/apache_beam/typehints/schemas.py
index b8176dccb8e..147a46f0bea 100644
--- a/sdks/python/apache_beam/typehints/schemas.py
+++ b/sdks/python/apache_beam/typehints/schemas.py
@@ -540,6 +540,7 @@ class SchemaTranslation(object):
     type_name = 'BeamSchema_{}'.format(schema.id.replace('-', '_'))
 
     subfields = []
+    descriptions = {}
     for field in schema.fields:
       try:
         field_py_type = self.typing_from_runner_api(field.type)
@@ -550,6 +551,7 @@ class SchemaTranslation(object):
             "Failed to decode schema due to an issue with Field proto:\n\n"
             f"{text_format.MessageToString(field)}") from e
 
+      descriptions[field.name] = field.description
       subfields.append((field.name, field_py_type))
 
     user_type = NamedTuple(type_name, subfields)
@@ -560,6 +562,7 @@ class SchemaTranslation(object):
         user_type,
         '__reduce__',
         _named_tuple_reduce_method(schema.SerializeToString()))
+    setattr(user_type, "_field_descriptions", descriptions)
     setattr(user_type, row_type._BEAM_SCHEMA_ID, schema.id)
 
     self.schema_registry.add(user_type, schema)
diff --git a/sdks/python/pytest.ini b/sdks/python/pytest.ini
index 140476b29e5..4ffbb4524c0 100644
--- a/sdks/python/pytest.ini
+++ b/sdks/python/pytest.ini
@@ -33,7 +33,7 @@ markers =
     uses_gcp_java_expansion_service: collect Cross Language GCP Java 
transforms test runs
     uses_java_expansion_service: collect Cross Language Java transforms test 
runs
     uses_python_expansion_service: collect Cross Language Python transforms 
test runs
-    uses_io_expansion_service: collect Cross Language transform test runs 
(with Kafka bootstrap server)
+    uses_io_java_expansion_service: collect Cross Language IO Java transform 
test runs (with Kafka bootstrap server)
     uses_transform_service: collect Cross Language test runs that uses the 
Transform Service
     xlang_sql_expansion_service: collect for Cross Language with SQL expansion 
service test runs
     it_postcommit: collect for post-commit integration test runs
diff --git a/sdks/python/test-suites/dataflow/build.gradle 
b/sdks/python/test-suites/dataflow/build.gradle
index b55716a42df..04a79683fd3 100644
--- a/sdks/python/test-suites/dataflow/build.gradle
+++ b/sdks/python/test-suites/dataflow/build.gradle
@@ -67,13 +67,13 @@ task examplesPostCommit {
 }
 
 task gcpCrossLanguagePostCommit {
-  getVersionsAsList('cross_language_validates_gcp_py_versions').each {
+  getVersionsAsList('cross_language_validates_py_versions').each {
     
dependsOn.add(":sdks:python:test-suites:dataflow:py${getVersionSuffix(it)}:gcpCrossLanguagePythonUsingJava")
   }
 }
 
 task ioCrossLanguagePostCommit {
-  getVersionsAsList('cross_language_validates_gcp_py_versions').each {
+  getVersionsAsList('cross_language_validates_py_versions').each {
     
dependsOn.add(":sdks:python:test-suites:dataflow:py${getVersionSuffix(it)}:ioCrossLanguagePythonUsingJava")
   }
 }
diff --git a/sdks/python/test-suites/direct/build.gradle 
b/sdks/python/test-suites/direct/build.gradle
index fc408411ec2..ea643c3303a 100644
--- a/sdks/python/test-suites/direct/build.gradle
+++ b/sdks/python/test-suites/direct/build.gradle
@@ -32,7 +32,13 @@ tasks.register("examplesPostCommit") {
 }
 
 task gcpCrossLanguagePostCommit {
-  getVersionsAsList('cross_language_validates_gcp_py_versions').each {
+  getVersionsAsList('cross_language_validates_py_versions').each {
     
dependsOn.add(":sdks:python:test-suites:direct:py${getVersionSuffix(it)}:gcpCrossLanguagePythonUsingJava")
   }
 }
+
+task ioCrossLanguagePostCommit {
+  getVersionsAsList('cross_language_validates_py_versions').each {
+    
dependsOn.add(":sdks:python:test-suites:direct:py${getVersionSuffix(it)}:ioCrossLanguagePythonUsingJava")
+  }
+}
diff --git a/sdks/python/test-suites/gradle.properties 
b/sdks/python/test-suites/gradle.properties
index 72fc651733d..c4e94bed4d8 100644
--- a/sdks/python/test-suites/gradle.properties
+++ b/sdks/python/test-suites/gradle.properties
@@ -47,5 +47,5 @@ samza_validates_runner_postcommit_py_versions=3.8,3.11
 # spark runner test-suites
 spark_examples_postcommit_py_versions=3.8,3.11
 
-# cross language gcp io postcommit python test suites
-cross_language_validates_gcp_py_versions=3.8,3.11
+# cross language postcommit python test suites
+cross_language_validates_py_versions=3.8,3.11
diff --git a/sdks/python/test-suites/xlang/build.gradle 
b/sdks/python/test-suites/xlang/build.gradle
index df3ebdd1582..5a124ac20ce 100644
--- a/sdks/python/test-suites/xlang/build.gradle
+++ b/sdks/python/test-suites/xlang/build.gradle
@@ -56,20 +56,17 @@ def gcpXlangCommon = new CrossLanguageTaskCommon().tap {
     startJobServer = setupTask
     cleanupJobServer = cleanupTask
 }
-xlangTasks.add(gcpXlangCommon)
-
-def ioExpansionProject = project.project(':sdks:java:io:expansion-service')
 
 def ioXlangCommon = new CrossLanguageTaskCommon().tap {
     name = "ioCrossLanguage"
-    expansionProjectPath = ioExpansionProject.getPath()
-    collectMarker = "uses_io_expansion_service"
+    expansionProjectPath = 
project.project(':sdks:java:io:expansion-service').getPath()
+    collectMarker = "uses_io_java_expansion_service"
     startJobServer = setupTask
     cleanupJobServer = cleanupTask
     //See .test-infra/kafka/bitnami/README.md for setup instructions
     additionalEnvs = 
["KAFKA_BOOTSTRAP_SERVER":project.findProperty('kafkaBootstrapServer')]
 }
 
-xlangTasks.add(ioXlangCommon)
+xlangTasks.addAll(gcpXlangCommon, ioXlangCommon)
 
 ext.xlangTasks = xlangTasks
\ No newline at end of file


Reply via email to