This is an automated email from the ASF dual-hosted git repository.
ahmedabualsaud pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 921e40a12f4 Dynamic SchemaTransform wrapper provider (#29561)
921e40a12f4 is described below
commit 921e40a12f467c51161bf33a0144fb8a1d4ca334
Author: Ahmed Abualsaud <[email protected]>
AuthorDate: Thu Dec 14 05:58:29 2023 +0300
Dynamic SchemaTransform wrapper provider (#29561)
* wrapper provider
* add typehints
* add GenerateSequence schematransform and expansion service for java core
* address comments; add description property
* add description test
* add experimental note
---
.../GenerateSequenceSchemaTransformProvider.java | 201 +++++++++++++++
.../apache/beam/sdk/providers/package-info.java | 23 ++
...enerateSequenceSchemaTransformProviderTest.java | 61 +++++
.../io/external/xlang_kafkaio_it_test.py | 4 +-
sdks/python/apache_beam/transforms/external.py | 1 +
.../external_schematransform_provider.py | 277 +++++++++++++++++++++
.../external_schematransform_provider_test.py | 140 +++++++++++
sdks/python/apache_beam/typehints/schemas.py | 3 +
sdks/python/pytest.ini | 2 +-
sdks/python/test-suites/dataflow/build.gradle | 4 +-
sdks/python/test-suites/direct/build.gradle | 8 +-
sdks/python/test-suites/gradle.properties | 4 +-
sdks/python/test-suites/xlang/build.gradle | 9 +-
13 files changed, 723 insertions(+), 14 deletions(-)
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/providers/GenerateSequenceSchemaTransformProvider.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/providers/GenerateSequenceSchemaTransformProvider.java
new file mode 100644
index 00000000000..f4cada661b0
--- /dev/null
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/providers/GenerateSequenceSchemaTransformProvider.java
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.providers;
+
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
+import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
+
+import com.google.auto.service.AutoService;
+import com.google.auto.value.AutoValue;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.io.GenerateSequence;
+import
org.apache.beam.sdk.providers.GenerateSequenceSchemaTransformProvider.GenerateSequenceConfiguration;
+import org.apache.beam.sdk.schemas.AutoValueSchema;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
+import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
+import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.sdk.values.TypeDescriptors;
+import org.joda.time.Duration;
+
+@AutoService(SchemaTransformProvider.class)
+public class GenerateSequenceSchemaTransformProvider
+ extends TypedSchemaTransformProvider<GenerateSequenceConfiguration> {
+ public static final String OUTPUT_ROWS_TAG = "output";
+ public static final Schema OUTPUT_SCHEMA =
Schema.builder().addInt64Field("value").build();
+
+ @Override
+ public String identifier() {
+ return "beam:schematransform:org.apache.beam:generate_sequence:v1";
+ }
+
+ @Override
+ public List<String> inputCollectionNames() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public List<String> outputCollectionNames() {
+ return Collections.singletonList(OUTPUT_ROWS_TAG);
+ }
+
+ @Override
+ public String description() {
+ return String.format(
+ "Outputs a PCollection of Beam Rows, each containing a single INT64 "
+ + "number called \"value\". The count is produced from the given
\"start\""
+ + "value and either up to the given \"end\" or until 2^63 - 1.\n"
+ + "To produce an unbounded PCollection, simply do not specify an
\"end\" value. "
+ + "Unbounded sequences can specify a \"rate\" for output
elements.\n"
+ + "In all cases, the sequence of numbers is generated in parallel,
so there is no "
+ + "inherent ordering between the generated values");
+ }
+
+ @Override
+ public Class<GenerateSequenceConfiguration> configurationClass() {
+ return GenerateSequenceConfiguration.class;
+ }
+
+ @Override
+ public SchemaTransform from(GenerateSequenceConfiguration configuration) {
+ return new GenerateSequenceSchemaTransform(configuration);
+ }
+
+ @DefaultSchema(AutoValueSchema.class)
+ @AutoValue
+ public abstract static class GenerateSequenceConfiguration {
+ @AutoValue
+ public abstract static class Rate {
+ @SchemaFieldDescription("Number of elements component of the rate.")
+ public abstract Long getElements();
+
+ @SchemaFieldDescription("Number of seconds component of the rate.")
+ @Nullable
+ public abstract Long getSeconds();
+
+ public static Builder builder() {
+ return new
AutoValue_GenerateSequenceSchemaTransformProvider_GenerateSequenceConfiguration_Rate
+ .Builder();
+ }
+
+ @AutoValue.Builder
+ public abstract static class Builder {
+ public abstract Builder setElements(Long elements);
+
+ public abstract Builder setSeconds(Long seconds);
+
+ public abstract Rate build();
+ }
+ }
+
+ public static Builder builder() {
+ return new
AutoValue_GenerateSequenceSchemaTransformProvider_GenerateSequenceConfiguration
+ .Builder();
+ }
+
+ @SchemaFieldDescription("The minimum number to generate (inclusive).")
+ public abstract Long getStart();
+
+ @SchemaFieldDescription(
+ "The maximum number to generate (exclusive). Will be an unbounded
sequence if left unspecified.")
+ @Nullable
+ public abstract Long getEnd();
+
+ @SchemaFieldDescription(
+ "Specifies the rate to generate a given number of elements per a given
number of seconds. "
+ + "Applicable only to unbounded sequences.")
+ @Nullable
+ public abstract Rate getRate();
+
+ @AutoValue.Builder
+ public abstract static class Builder {
+
+ public abstract Builder setStart(Long start);
+
+ public abstract Builder setEnd(Long end);
+
+ public abstract Builder setRate(Rate rate);
+
+ public abstract GenerateSequenceConfiguration build();
+ }
+
+ public void validate() {
+ checkNotNull(this.getStart(), "Must specify a starting point
\"start\".");
+ Long start = this.getStart();
+ Long end = this.getEnd();
+ if (end != null) {
+ checkArgument(end == -1 || end >= start, "Invalid range [%s, %s)",
start, end);
+ }
+ Rate rate = this.getRate();
+ if (rate != null) {
+ checkArgument(
+ rate.getElements() > 0,
+ "Invalid rate specification. Expected positive elements component
but received %s.",
+ rate.getElements());
+ checkArgument(
+ Optional.ofNullable(rate.getSeconds()).orElse(1L) > 0,
+ "Invalid rate specification. Expected positive seconds component
but received %s.",
+ rate.getSeconds());
+ }
+ }
+ }
+
+ protected static class GenerateSequenceSchemaTransform extends
SchemaTransform {
+ private final GenerateSequenceConfiguration configuration;
+
+ GenerateSequenceSchemaTransform(GenerateSequenceConfiguration
configuration) {
+ configuration.validate();
+ this.configuration = configuration;
+ }
+
+ @Override
+ public PCollectionRowTuple expand(PCollectionRowTuple input) {
+ checkArgument(
+ input.getAll().isEmpty(), "Expected no inputs but got: %s",
input.getAll().keySet());
+
+ Long end = Optional.ofNullable(configuration.getEnd()).orElse(-1L);
+ GenerateSequenceConfiguration.Rate rate = configuration.getRate();
+
+ GenerateSequence sequence =
GenerateSequence.from(configuration.getStart()).to(end);
+ if (rate != null) {
+ sequence =
+ sequence.withRate(
+ rate.getElements(),
+
Duration.standardSeconds(Optional.ofNullable(rate.getSeconds()).orElse(1L)));
+ }
+
+ return PCollectionRowTuple.of(
+ OUTPUT_ROWS_TAG,
+ input
+ .getPipeline()
+ .apply(sequence)
+ .apply(
+ MapElements.into(TypeDescriptors.rows())
+ .via(l ->
Row.withSchema(OUTPUT_SCHEMA).withFieldValue("value", l).build()))
+ .setRowSchema(OUTPUT_SCHEMA));
+ }
+ }
+}
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/providers/package-info.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/providers/package-info.java
new file mode 100644
index 00000000000..6d90b7d018a
--- /dev/null
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/providers/package-info.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Defines {@link
org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider}s for transforms
in
+ * the core module.
+ */
+package org.apache.beam.sdk.providers;
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/providers/GenerateSequenceSchemaTransformProviderTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/providers/GenerateSequenceSchemaTransformProviderTest.java
new file mode 100644
index 00000000000..dcff3dedb84
--- /dev/null
+++
b/sdks/java/core/src/test/java/org/apache/beam/sdk/providers/GenerateSequenceSchemaTransformProviderTest.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.providers;
+
+import java.util.ArrayList;
+import java.util.List;
+import
org.apache.beam.sdk.providers.GenerateSequenceSchemaTransformProvider.GenerateSequenceConfiguration;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.Row;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class GenerateSequenceSchemaTransformProviderTest {
+ @Rule public transient TestPipeline p = TestPipeline.create();
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testGenerateSequence() {
+ GenerateSequenceConfiguration config =
+
GenerateSequenceConfiguration.builder().setStart(0L).setEnd(10L).build();
+ SchemaTransform sequence = new
GenerateSequenceSchemaTransformProvider().from(config);
+
+ List<Row> expected = new ArrayList<>(10);
+ for (long i = 0L; i < 10L; i++) {
+ expected.add(
+ Row.withSchema(GenerateSequenceSchemaTransformProvider.OUTPUT_SCHEMA)
+ .withFieldValue("value", i)
+ .build());
+ }
+
+ PAssert.that(
+ PCollectionRowTuple.empty(p)
+ .apply(sequence)
+ .get(GenerateSequenceSchemaTransformProvider.OUTPUT_ROWS_TAG))
+ .containsInAnyOrder(expected);
+ p.run();
+ }
+}
diff --git a/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py
b/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py
index a2f350e8cb7..a7bf686d064 100644
--- a/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py
+++ b/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py
@@ -143,7 +143,7 @@ class CrossLanguageKafkaIOTest(unittest.TestCase):
self.run_kafka_write(pipeline_creator)
self.run_kafka_read(pipeline_creator, None)
- @pytest.mark.uses_io_expansion_service
+ @pytest.mark.uses_io_java_expansion_service
@unittest.skipUnless(
os.environ.get('EXPANSION_PORT'),
"EXPANSION_PORT environment var is not provided.")
@@ -162,7 +162,7 @@ class CrossLanguageKafkaIOTest(unittest.TestCase):
self.run_kafka_write(pipeline_creator)
self.run_kafka_read(pipeline_creator, b'key')
- @pytest.mark.uses_io_expansion_service
+ @pytest.mark.uses_io_java_expansion_service
@unittest.skipUnless(
os.environ.get('EXPANSION_PORT'),
"EXPANSION_PORT environment var is not provided.")
diff --git a/sdks/python/apache_beam/transforms/external.py
b/sdks/python/apache_beam/transforms/external.py
index 997cea347d3..71dfc545204 100644
--- a/sdks/python/apache_beam/transforms/external.py
+++ b/sdks/python/apache_beam/transforms/external.py
@@ -1067,6 +1067,7 @@ class BeamJarExpansionService(JavaJarExpansionService):
append_args=None):
path_to_jar = subprocess_server.JavaJarServer.path_to_beam_jar(
gradle_target, gradle_appendix)
+ self.gradle_target = gradle_target
super().__init__(
path_to_jar, extra_args, classpath=classpath, append_args=append_args)
diff --git
a/sdks/python/apache_beam/transforms/external_schematransform_provider.py
b/sdks/python/apache_beam/transforms/external_schematransform_provider.py
new file mode 100644
index 00000000000..fd650087893
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/external_schematransform_provider.py
@@ -0,0 +1,277 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import re
+from collections import namedtuple
+from typing import Dict
+from typing import List
+from typing import Tuple
+
+from apache_beam.transforms import PTransform
+from apache_beam.transforms.external import BeamJarExpansionService
+from apache_beam.transforms.external import SchemaAwareExternalTransform
+from apache_beam.transforms.external import SchemaTransformsConfig
+from apache_beam.typehints.schemas import named_tuple_to_schema
+from apache_beam.typehints.schemas import typing_from_runner_api
+
+__all__ = ['ExternalSchemaTransform', 'ExternalSchemaTransformProvider']
+
+
+def snake_case_to_upper_camel_case(string):
+ """Convert snake_case to UpperCamelCase"""
+ components = string.split('_')
+ output = ''.join(n.capitalize() for n in components)
+ return output
+
+
+def snake_case_to_lower_camel_case(string):
+ """Convert snake_case to lowerCamelCase"""
+ if len(string) <= 1:
+ return string.lower()
+ upper = snake_case_to_upper_camel_case(string)
+ return upper[0].lower() + upper[1:]
+
+
+def camel_case_to_snake_case(string):
+ """Convert camelCase to snake_case"""
+ arr = []
+ word = []
+ for i, n in enumerate(string):
+ # If seeing an upper letter after a lower letter, we just witnessed a word
+ # If seeing an upper letter and the next letter is lower, we may have just
+ # witnessed an all caps word
+ if n.isupper() and ((i > 0 and string[i - 1].islower()) or
+ (i + 1 < len(string) and string[i + 1].islower())):
+ arr.append(''.join(word))
+ word = [n.lower()]
+ else:
+ word.append(n.lower())
+ arr.append(''.join(word))
+ return '_'.join(arr).strip('_')
+
+
+# Information regarding a Wrapper parameter.
+ParamInfo = namedtuple('ParamInfo', ['type', 'description', 'original_name'])
+
+
+def get_config_with_descriptions(
+ schematransform: SchemaTransformsConfig) -> Dict[str, ParamInfo]:
+ # Prepare a configuration schema that includes types and descriptions
+ schema = named_tuple_to_schema(schematransform.configuration_schema)
+ descriptions = schematransform.configuration_schema._field_descriptions
+ fields_with_descriptions = {}
+ for field in schema.fields:
+ fields_with_descriptions[camel_case_to_snake_case(field.name)] = ParamInfo(
+ typing_from_runner_api(field.type),
+ descriptions[field.name],
+ field.name)
+
+ return fields_with_descriptions
+
+
+class ExternalSchemaTransform(PTransform):
+ """Template for a wrapper class of an external SchemaTransform
+
+ This is a superclass for dynamically generated SchemaTransform wrappers and
+ is not meant to be manually instantiated.
+
+ Experimental; no backwards compatibility guarantees."""
+
+ # These attributes need to be set when
+ # creating an ExternalSchemaTransform type
+ default_expansion_service = None
+ description: str = ""
+ identifier: str = ""
+ configuration_schema: Dict[str, ParamInfo] = {}
+
+ def __init__(self, expansion_service=None, **kwargs):
+ self._kwargs = kwargs
+ self._expansion_service = \
+ expansion_service or self.default_expansion_service
+
+ def expand(self, input):
+ camel_case_kwargs = {
+ snake_case_to_lower_camel_case(k): v
+ for k, v in self._kwargs.items()
+ }
+
+ external_schematransform = SchemaAwareExternalTransform(
+ identifier=self.identifier,
+ expansion_service=self._expansion_service,
+ rearrange_based_on_discovery=True,
+ **camel_case_kwargs)
+
+ return input | external_schematransform
+
+
+STANDARD_URN_PATTERN = r"^beam:schematransform:org.apache.beam:([\w-]+):(\w+)$"
+
+
+def infer_name_from_identifier(identifier: str, pattern: str):
+ """Infer a class name from an identifier, adhering to the input pattern"""
+ match = re.match(pattern, identifier)
+ if not match:
+ return None
+ groups = match.groups()
+
+ components = [snake_case_to_upper_camel_case(n) for n in groups]
+ # Special handling for standard SchemaTransform identifiers:
+ # We don't include the version number if it's the first version
+ if (pattern == STANDARD_URN_PATTERN and components[1].lower() == 'v1'):
+ return components[0]
+ else:
+ return ''.join(components)
+
+
+class ExternalSchemaTransformProvider:
+ """Dynamically discovers Schema-aware external transforms from a given list
+ of expansion services and provides them as ready PTransforms.
+
+ A :class:`ExternalSchemaTransform` subclass is generated for each external
+ transform, and is named based on what can be inferred from the URN
+ (see :param urn_pattern).
+
+ These classes are generated when :class:`ExternalSchemaTransformProvider` is
+ initialized. We need to give it one or more expansion service addresses that
+ are already up and running:
+ >>> provider = ExternalSchemaTransformProvider(["localhost:12345",
+ ... "localhost:12121"])
+ We can also give it the gradle target of a standard Beam expansion service:
+ >>> provider = ExternalSchemaTransform(BeamJarExpansionService(
+ ... "sdks:java:io:google-cloud-platform:expansion-service:shadowJar"))
+ Let's take a look at the output of :func:`get_available()` to know the
+ available transforms in the expansion service(s) we provided:
+ >>> provider.get_available()
+ [('JdbcWrite', 'beam:schematransform:org.apache.beam:jdbc_write:v1'),
+ ('BigtableRead', 'beam:schematransform:org.apache.beam:bigtable_read:v1'),
+ ...]
+
+ Then retrieve a transform by :func:`get()`, :func:`get_urn()`, or by directly
+ accessing it as an attribute of :class:`ExternalSchemaTransformProvider`.
+ All of the following commands do the same thing:
+ >>> provider.get('BigqueryStorageRead')
+ >>> provider.get_urn(
+ ... 'beam:schematransform:org.apache.beam:bigquery_storage_read:v1')
+ >>> provider.BigqueryStorageRead
+
+ To know more about the usage of a given transform, take a look at the
+ `description` attribute. This returns some documentation IF the underlying
+ SchemaTransform provides any.
+ >>> provider.BigqueryStorageRead.description
+
+ Similarly, the `configuration_schema` attribute returns information about the
+ parameters, including their names, types, and any documentation that the
+ underlying SchemaTransform may provide:
+ >>> provider.BigqueryStorageRead.configuration_schema
+ {'query': ParamInfo(type=typing.Optional[str], description='The SQL query to
+ be executed to read from the BigQuery table.', original_name='query'),
+ 'row_restriction': ParamInfo(type=typing.Optional[str]...}
+
+ The retrieved external transform can be used as a normal PTransform like so::
+
+ with Pipeline() as p:
+ _ = (p
+ | 'Read from BigQuery` >> provider.BigqueryStorageRead(
+ query=query,
+ row_restriction=restriction)
+ | 'Some processing' >> beam.Map(...))
+
+ Experimental; no backwards compatibility guarantees.
+ """
+ def __init__(self, expansion_services, urn_pattern=STANDARD_URN_PATTERN):
+ f"""Initialize an ExternalSchemaTransformProvider
+
+ :param expansion_services:
+ A list of expansion services to discover transforms from.
+ Supported forms:
+ * a string representing the expansion service address
+ * a :attr:`BeamJarExpansionService` pointing to a gradle target
+ :param urn_pattern:
+ The regular expression used to match valid transforms. In addition to
+ validating, the captured groups are used to infer a name for each class.
+ By default, the following pattern is used: [{STANDARD_URN_PATTERN}]
+ """
+ self._urn_pattern = urn_pattern
+ self._transforms: Dict[str, type(ExternalSchemaTransform)] = {}
+ self._name_to_urn: Dict[str, str] = {}
+
+ if isinstance(expansion_services, set):
+ expansion_services = list(expansion_services)
+ if not isinstance(expansion_services, list):
+ expansion_services = [expansion_services]
+ self.expansion_services = expansion_services
+ self._create_wrappers()
+
+ def _create_wrappers(self):
+ # multiple services can overlap and include the same URNs. If this happens,
+ # we prioritize by the order of services in the list
+ identifiers = set()
+ for service in self.expansion_services:
+ target = service
+ if isinstance(service, BeamJarExpansionService):
+ target = service.gradle_target
+ try:
+ schematransform_configs =
SchemaAwareExternalTransform.discover(service)
+ except Exception as e:
+ logging.exception(
+ "Encountered an error while discovering expansion service %s:\n%s",
+ target,
+ e)
+ continue
+ skipped_urns = []
+ for config in schematransform_configs:
+ identifier = config.identifier
+ if identifier not in identifiers:
+ identifiers.add(identifier)
+
+ name = infer_name_from_identifier(identifier, self._urn_pattern)
+ if name is None:
+ skipped_urns.append(identifier)
+ continue
+
+ self._transforms[identifier] = type(
+ name, (ExternalSchemaTransform, ),
+ dict(
+ identifier=identifier,
+ default_expansion_service=service,
+ schematransform=config,
+ description=config.description,
+ configuration_schema=get_config_with_descriptions(config)))
+ self._name_to_urn[name] = identifier
+
+ if skipped_urns:
+ logging.info(
+ "Skipped URN(s) in %s that don't follow the pattern [%s]: %s",
+ target,
+ self._urn_pattern,
+ skipped_urns)
+
+ for transform in self._transforms.values():
+ setattr(self, transform.__name__, transform)
+
+ def get_available(self) -> List[Tuple[str, str]]:
+ """Get a list of available ExternalSchemaTransform names and identifiers"""
+ return list(self._name_to_urn.items())
+
+ def get(self, name) -> ExternalSchemaTransform:
+ """Get an ExternalSchemaTransform by its inferred class name"""
+ return self._transforms[self._name_to_urn[name]]
+
+ def get_urn(self, identifier) -> ExternalSchemaTransform:
+ """Get an ExternalSchemaTransform by its SchemaTransform identifier"""
+ return self._transforms[identifier]
diff --git
a/sdks/python/apache_beam/transforms/external_schematransform_provider_test.py
b/sdks/python/apache_beam/transforms/external_schematransform_provider_test.py
new file mode 100644
index 00000000000..bf951e671c2
--- /dev/null
+++
b/sdks/python/apache_beam/transforms/external_schematransform_provider_test.py
@@ -0,0 +1,140 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import logging
+import os
+import unittest
+
+import pytest
+
+import apache_beam as beam
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.transforms.external import BeamJarExpansionService
+from apache_beam.transforms.external_schematransform_provider import
STANDARD_URN_PATTERN
+from apache_beam.transforms.external_schematransform_provider import
ExternalSchemaTransformProvider
+from apache_beam.transforms.external_schematransform_provider import
camel_case_to_snake_case
+from apache_beam.transforms.external_schematransform_provider import
infer_name_from_identifier
+from apache_beam.transforms.external_schematransform_provider import
snake_case_to_lower_camel_case
+from apache_beam.transforms.external_schematransform_provider import
snake_case_to_upper_camel_case
+
+
+class NameUtilsTest(unittest.TestCase):
+ def test_snake_case_to_upper_camel_case(self):
+ test_cases = [("", ""), ("test", "Test"), ("test_name", "TestName"),
+ ("test_double_underscore", "TestDoubleUnderscore"),
+ ("TEST_CAPITALIZED", "TestCapitalized"),
+ ("_prepended_underscore", "PrependedUnderscore"),
+ ("appended_underscore_", "AppendedUnderscore")]
+ for case in test_cases:
+ self.assertEqual(case[1], snake_case_to_upper_camel_case(case[0]))
+
+ def test_snake_case_to_lower_camel_case(self):
+ test_cases = [("", ""), ("test", "test"), ("test_name", "testName"),
+ ("test_double_underscore", "testDoubleUnderscore"),
+ ("TEST_CAPITALIZED", "testCapitalized"),
+ ("_prepended_underscore", "prependedUnderscore"),
+ ("appended_underscore_", "appendedUnderscore")]
+ for case in test_cases:
+ self.assertEqual(case[1], snake_case_to_lower_camel_case(case[0]))
+
+ def test_camel_case_to_snake_case(self):
+ test_cases = [("", ""), ("Test", "test"), ("TestName", "test_name"),
+ ("TestDoubleUnderscore",
+ "test_double_underscore"), ("MyToLoFo", "my_to_lo_fo"),
+ ("BEGINNINGAllCaps",
+ "beginning_all_caps"), ("AllCapsENDING", "all_caps_ending"),
+ ("AllCapsMIDDLEWord", "all_caps_middle_word"),
+ ("lowerCamelCase", "lower_camel_case")]
+ for case in test_cases:
+ self.assertEqual(case[1], camel_case_to_snake_case(case[0]))
+
+ def test_infer_name_from_identifier(self):
+ standard_test_cases = [
+ ("beam:schematransform:org.apache.beam:transform:v1", "Transform"),
+ ("beam:schematransform:org.apache.beam:my_transform:v1",
+ "MyTransform"), (
+ "beam:schematransform:org.apache.beam:my_transform:v2",
+ "MyTransformV2"),
+ ("beam:schematransform:org.apache.beam:fe_fi_fo_fum:v2",
"FeFiFoFumV2"),
+ ("beam:schematransform:bad_match:my_transform:v1", None)
+ ]
+ for case in standard_test_cases:
+ self.assertEqual(
+ case[1], infer_name_from_identifier(case[0], STANDARD_URN_PATTERN))
+
+ custom_pattern_cases = [
+ # (<pattern>, <urn>, <expected output>)
+ (
+ r"^custom:transform:([\w-]+):(\w+)$",
+ "custom:transform:my_transform:v1",
+ "MyTransformV1"),
+ (
+ r"^org.user:([\w-]+):([\w-]+):([\w-]+):external$",
+ "org.user:some:custom_transform:we_made:external",
+ "SomeCustomTransformWeMade"),
+ (
+ r"^([\w-]+):user.transforms",
+ "my_eXTErnal:user.transforms",
+ "MyExternal"),
+ (r"^([\w-]+):user.transforms", "my_external:badinput.transforms",
None),
+ ]
+ for case in custom_pattern_cases:
+ self.assertEqual(case[2], infer_name_from_identifier(case[1], case[0]))
+
+
[email protected]_io_java_expansion_service
[email protected](
+ os.environ.get('EXPANSION_PORT'),
+ "EXPANSION_PORT environment var is not provided.")
+class ExternalSchemaTransformProviderTest(unittest.TestCase):
+ def setUp(self):
+ self.test_pipeline = TestPipeline(is_integration_test=True)
+
+ def test_generate_sequence_config_schema_and_description(self):
+ provider = ExternalSchemaTransformProvider(
+ BeamJarExpansionService(":sdks:java:io:expansion-service:shadowJar"))
+
+ self.assertTrue((
+ 'GenerateSequence',
+ 'beam:schematransform:org.apache.beam:generate_sequence:v1'
+ ) in provider.get_available())
+
+ GenerateSequence = provider.get('GenerateSequence')
+ config_schema = GenerateSequence.configuration_schema
+ for param in ['start', 'end', 'rate']:
+ self.assertTrue(param in config_schema)
+
+ description_substring = (
+ "Outputs a PCollection of Beam Rows, each "
+ "containing a single INT64")
+ self.assertTrue(description_substring in GenerateSequence.description)
+
+ def test_run_generate_sequence(self):
+ provider = ExternalSchemaTransformProvider(
+ BeamJarExpansionService(":sdks:java:io:expansion-service:shadowJar"))
+
+ with beam.Pipeline() as p:
+ numbers = p | provider.GenerateSequence(
+ start=0, end=10) | beam.Map(lambda row: row.value)
+
+ assert_that(numbers, equal_to([i for i in range(10)]))
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
diff --git a/sdks/python/apache_beam/typehints/schemas.py
b/sdks/python/apache_beam/typehints/schemas.py
index b8176dccb8e..147a46f0bea 100644
--- a/sdks/python/apache_beam/typehints/schemas.py
+++ b/sdks/python/apache_beam/typehints/schemas.py
@@ -540,6 +540,7 @@ class SchemaTranslation(object):
type_name = 'BeamSchema_{}'.format(schema.id.replace('-', '_'))
subfields = []
+ descriptions = {}
for field in schema.fields:
try:
field_py_type = self.typing_from_runner_api(field.type)
@@ -550,6 +551,7 @@ class SchemaTranslation(object):
"Failed to decode schema due to an issue with Field proto:\n\n"
f"{text_format.MessageToString(field)}") from e
+ descriptions[field.name] = field.description
subfields.append((field.name, field_py_type))
user_type = NamedTuple(type_name, subfields)
@@ -560,6 +562,7 @@ class SchemaTranslation(object):
user_type,
'__reduce__',
_named_tuple_reduce_method(schema.SerializeToString()))
+ setattr(user_type, "_field_descriptions", descriptions)
setattr(user_type, row_type._BEAM_SCHEMA_ID, schema.id)
self.schema_registry.add(user_type, schema)
diff --git a/sdks/python/pytest.ini b/sdks/python/pytest.ini
index 140476b29e5..4ffbb4524c0 100644
--- a/sdks/python/pytest.ini
+++ b/sdks/python/pytest.ini
@@ -33,7 +33,7 @@ markers =
uses_gcp_java_expansion_service: collect Cross Language GCP Java
transforms test runs
uses_java_expansion_service: collect Cross Language Java transforms test
runs
uses_python_expansion_service: collect Cross Language Python transforms
test runs
- uses_io_expansion_service: collect Cross Language transform test runs
(with Kafka bootstrap server)
+ uses_io_java_expansion_service: collect Cross Language IO Java transform
test runs (with Kafka bootstrap server)
uses_transform_service: collect Cross Language test runs that uses the
Transform Service
xlang_sql_expansion_service: collect for Cross Language with SQL expansion
service test runs
it_postcommit: collect for post-commit integration test runs
diff --git a/sdks/python/test-suites/dataflow/build.gradle
b/sdks/python/test-suites/dataflow/build.gradle
index b55716a42df..04a79683fd3 100644
--- a/sdks/python/test-suites/dataflow/build.gradle
+++ b/sdks/python/test-suites/dataflow/build.gradle
@@ -67,13 +67,13 @@ task examplesPostCommit {
}
task gcpCrossLanguagePostCommit {
- getVersionsAsList('cross_language_validates_gcp_py_versions').each {
+ getVersionsAsList('cross_language_validates_py_versions').each {
dependsOn.add(":sdks:python:test-suites:dataflow:py${getVersionSuffix(it)}:gcpCrossLanguagePythonUsingJava")
}
}
task ioCrossLanguagePostCommit {
- getVersionsAsList('cross_language_validates_gcp_py_versions').each {
+ getVersionsAsList('cross_language_validates_py_versions').each {
dependsOn.add(":sdks:python:test-suites:dataflow:py${getVersionSuffix(it)}:ioCrossLanguagePythonUsingJava")
}
}
diff --git a/sdks/python/test-suites/direct/build.gradle
b/sdks/python/test-suites/direct/build.gradle
index fc408411ec2..ea643c3303a 100644
--- a/sdks/python/test-suites/direct/build.gradle
+++ b/sdks/python/test-suites/direct/build.gradle
@@ -32,7 +32,13 @@ tasks.register("examplesPostCommit") {
}
task gcpCrossLanguagePostCommit {
- getVersionsAsList('cross_language_validates_gcp_py_versions').each {
+ getVersionsAsList('cross_language_validates_py_versions').each {
dependsOn.add(":sdks:python:test-suites:direct:py${getVersionSuffix(it)}:gcpCrossLanguagePythonUsingJava")
}
}
+
+task ioCrossLanguagePostCommit {
+ getVersionsAsList('cross_language_validates_py_versions').each {
+
dependsOn.add(":sdks:python:test-suites:direct:py${getVersionSuffix(it)}:ioCrossLanguagePythonUsingJava")
+ }
+}
diff --git a/sdks/python/test-suites/gradle.properties
b/sdks/python/test-suites/gradle.properties
index 72fc651733d..c4e94bed4d8 100644
--- a/sdks/python/test-suites/gradle.properties
+++ b/sdks/python/test-suites/gradle.properties
@@ -47,5 +47,5 @@ samza_validates_runner_postcommit_py_versions=3.8,3.11
# spark runner test-suites
spark_examples_postcommit_py_versions=3.8,3.11
-# cross language gcp io postcommit python test suites
-cross_language_validates_gcp_py_versions=3.8,3.11
+# cross language postcommit python test suites
+cross_language_validates_py_versions=3.8,3.11
diff --git a/sdks/python/test-suites/xlang/build.gradle
b/sdks/python/test-suites/xlang/build.gradle
index df3ebdd1582..5a124ac20ce 100644
--- a/sdks/python/test-suites/xlang/build.gradle
+++ b/sdks/python/test-suites/xlang/build.gradle
@@ -56,20 +56,17 @@ def gcpXlangCommon = new CrossLanguageTaskCommon().tap {
startJobServer = setupTask
cleanupJobServer = cleanupTask
}
-xlangTasks.add(gcpXlangCommon)
-
-def ioExpansionProject = project.project(':sdks:java:io:expansion-service')
def ioXlangCommon = new CrossLanguageTaskCommon().tap {
name = "ioCrossLanguage"
- expansionProjectPath = ioExpansionProject.getPath()
- collectMarker = "uses_io_expansion_service"
+ expansionProjectPath =
project.project(':sdks:java:io:expansion-service').getPath()
+ collectMarker = "uses_io_java_expansion_service"
startJobServer = setupTask
cleanupJobServer = cleanupTask
//See .test-infra/kafka/bitnami/README.md for setup instructions
additionalEnvs =
["KAFKA_BOOTSTRAP_SERVER":project.findProperty('kafkaBootstrapServer')]
}
-xlangTasks.add(ioXlangCommon)
+xlangTasks.addAll(gcpXlangCommon, ioXlangCommon)
ext.xlangTasks = xlangTasks
\ No newline at end of file