This is an automated email from the ASF dual-hosted git repository.
robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new cf0cf3b746c Add an annotation to expose transforms to yaml. (#28208)
cf0cf3b746c is described below
commit cf0cf3b746c55921ebe4a95d2e06052cb96be74d
Author: Robert Bradshaw <[email protected]>
AuthorDate: Tue Sep 12 17:06:29 2023 -0700
Add an annotation to expose transforms to yaml. (#28208)
We should add this to all transforms that are simply parameterized.
---
sdks/python/apache_beam/transforms/ptransform.py | 52 ++++++++++++++++++++++
.../python/apache_beam/yaml/yaml_transform_test.py | 30 +++++++++++++
2 files changed, 82 insertions(+)
diff --git a/sdks/python/apache_beam/transforms/ptransform.py
b/sdks/python/apache_beam/transforms/ptransform.py
index c7eaa152ae0..28614c6561c 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -38,11 +38,13 @@ FlatMap processing functions.
import copy
import itertools
+import json
import logging
import operator
import os
import sys
import threading
+import warnings
from functools import reduce
from functools import wraps
from typing import TYPE_CHECKING
@@ -83,6 +85,7 @@ from apache_beam.typehints.decorators import
getcallargs_forhints
from apache_beam.typehints.trivial_inference import instance_to_type
from apache_beam.typehints.typehints import validate_composite_type_param
from apache_beam.utils import proto_utils
+from apache_beam.utils import python_callable
if TYPE_CHECKING:
from apache_beam import coders
@@ -95,6 +98,7 @@ __all__ = [
'PTransform',
'ptransform_fn',
'label_from_callable',
+ 'annotate_yaml',
]
_LOGGER = logging.getLogger(__name__)
@@ -1096,3 +1100,51 @@ class _NamedPTransform(PTransform):
def expand(self, pvalue):
raise RuntimeError("Should never be expanded directly.")
+
+
+# Defined here to avoid circular import issues for Beam library transforms.
+def annotate_yaml(constructor):
+ """Causes instances of this transform to be annotated with their yaml syntax.
+
+ Should only be used for transforms that are fully defined by their
constructor
+ arguments.
+ """
+ @wraps(constructor)
+ def wrapper(*args, **kwargs):
+ transform = constructor(*args, **kwargs)
+
+ fully_qualified_name = (
+ f'{constructor.__module__}.{constructor.__qualname__}')
+ try:
+ imported_constructor = (
+ python_callable.PythonCallableWithSource.
+ load_from_fully_qualified_name(fully_qualified_name))
+ if imported_constructor != wrapper:
+ raise ImportError('Different object.')
+ except ImportError:
+ warnings.warn(f'Cannot import {constructor} as {fully_qualified_name}.')
+ return transform
+
+ try:
+ config = json.dumps({
+ 'constructor': fully_qualified_name,
+ 'args': args,
+ 'kwargs': kwargs,
+ })
+ except TypeError as exn:
+ warnings.warn(
+ f'Cannot serialize arguments for {constructor} as json: {exn}')
+ return transform
+
+ original_annotations = transform.annotations
+ transform.annotations = lambda: {
+ **original_annotations(),
+ # These override whatever may have been provided earlier.
+ # The outermost call is expected to be the most specific.
+ 'yaml_provider': 'python',
+ 'yaml_type': 'PyTransform',
+ 'yaml_args': config,
+ }
+ return transform
+
+ return wrapper
diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py
b/sdks/python/apache_beam/yaml/yaml_transform_test.py
index f969761092e..26baebec86e 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform_test.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py
@@ -250,6 +250,23 @@ class YamlTransformE2ETest(unittest.TestCase):
output: AnotherFilter
''')
+ def test_annotations(self):
+ t = LinearTransform(5, b=100)
+ annotations = t.annotations()
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle')) as p:
+ result = p | YamlTransform(
+ '''
+ type: chain
+ transforms:
+ - type: Create
+ config:
+ elements: [0, 1, 2, 3]
+ - type: %r
+ config: %s
+ ''' % (annotations['yaml_type'], annotations['yaml_args']))
+ assert_that(result, equal_to([100, 105, 110, 115]))
+
class CreateTimestamped(beam.PTransform):
def __init__(self, elements):
@@ -631,6 +648,19 @@ class ProviderAffinityTest(unittest.TestCase):
label='StartWith3')
[email protected]_yaml
+class LinearTransform(beam.PTransform):
+ """A transform used for testing annotate_yaml."""
+ def __init__(self, a, b):
+ self._a = a
+ self._b = b
+
+ def expand(self, pcoll):
+ a = self._a
+ b = self._b
+ return pcoll | beam.Map(lambda x: a * x + b)
+
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()