This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new cf0cf3b746c Add an annotation to expose transforms to yaml. (#28208)
cf0cf3b746c is described below

commit cf0cf3b746c55921ebe4a95d2e06052cb96be74d
Author: Robert Bradshaw <[email protected]>
AuthorDate: Tue Sep 12 17:06:29 2023 -0700

    Add an annotation to expose transforms to yaml. (#28208)
    
    We should add this to all transforms that are simply parameterized.
---
 sdks/python/apache_beam/transforms/ptransform.py   | 52 ++++++++++++++++++++++
 .../python/apache_beam/yaml/yaml_transform_test.py | 30 +++++++++++++
 2 files changed, 82 insertions(+)

diff --git a/sdks/python/apache_beam/transforms/ptransform.py 
b/sdks/python/apache_beam/transforms/ptransform.py
index c7eaa152ae0..28614c6561c 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -38,11 +38,13 @@ FlatMap processing functions.
 
 import copy
 import itertools
+import json
 import logging
 import operator
 import os
 import sys
 import threading
+import warnings
 from functools import reduce
 from functools import wraps
 from typing import TYPE_CHECKING
@@ -83,6 +85,7 @@ from apache_beam.typehints.decorators import 
getcallargs_forhints
 from apache_beam.typehints.trivial_inference import instance_to_type
 from apache_beam.typehints.typehints import validate_composite_type_param
 from apache_beam.utils import proto_utils
+from apache_beam.utils import python_callable
 
 if TYPE_CHECKING:
   from apache_beam import coders
@@ -95,6 +98,7 @@ __all__ = [
     'PTransform',
     'ptransform_fn',
     'label_from_callable',
+    'annotate_yaml',
 ]
 
 _LOGGER = logging.getLogger(__name__)
@@ -1096,3 +1100,51 @@ class _NamedPTransform(PTransform):
 
   def expand(self, pvalue):
     raise RuntimeError("Should never be expanded directly.")
+
+
+# Defined here to avoid circular import issues for Beam library transforms.
+def annotate_yaml(constructor):
+  """Causes instances of this transform to be annotated with their yaml syntax.
+
+  Should only be used for transforms that are fully defined by their 
constructor
+  arguments.
+  """
+  @wraps(constructor)
+  def wrapper(*args, **kwargs):
+    transform = constructor(*args, **kwargs)
+
+    fully_qualified_name = (
+        f'{constructor.__module__}.{constructor.__qualname__}')
+    try:
+      imported_constructor = (
+          python_callable.PythonCallableWithSource.
+          load_from_fully_qualified_name(fully_qualified_name))
+      if imported_constructor != wrapper:
+        raise ImportError('Different object.')
+    except ImportError:
+      warnings.warn(f'Cannot import {constructor} as {fully_qualified_name}.')
+      return transform
+
+    try:
+      config = json.dumps({
+          'constructor': fully_qualified_name,
+          'args': args,
+          'kwargs': kwargs,
+      })
+    except TypeError as exn:
+      warnings.warn(
+          f'Cannot serialize arguments for {constructor} as json: {exn}')
+      return transform
+
+    original_annotations = transform.annotations
+    transform.annotations = lambda: {
+        **original_annotations(),
+        # These override whatever may have been provided earlier.
+        # The outermost call is expected to be the most specific.
+        'yaml_provider': 'python',
+        'yaml_type': 'PyTransform',
+        'yaml_args': config,
+    }
+    return transform
+
+  return wrapper
diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py 
b/sdks/python/apache_beam/yaml/yaml_transform_test.py
index f969761092e..26baebec86e 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform_test.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py
@@ -250,6 +250,23 @@ class YamlTransformE2ETest(unittest.TestCase):
             output: AnotherFilter
             ''')
 
+  def test_annotations(self):
+    t = LinearTransform(5, b=100)
+    annotations = t.annotations()
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      result = p | YamlTransform(
+          '''
+          type: chain
+          transforms:
+            - type: Create
+              config:
+                elements: [0, 1, 2, 3]
+            - type: %r
+              config: %s
+          ''' % (annotations['yaml_type'], annotations['yaml_args']))
+      assert_that(result, equal_to([100, 105, 110, 115]))
+
 
 class CreateTimestamped(beam.PTransform):
   def __init__(self, elements):
@@ -631,6 +648,19 @@ class ProviderAffinityTest(unittest.TestCase):
           label='StartWith3')
 
 
[email protected]_yaml
+class LinearTransform(beam.PTransform):
+  """A transform used for testing annotate_yaml."""
+  def __init__(self, a, b):
+    self._a = a
+    self._b = b
+
+  def expand(self, pcoll):
+    a = self._a
+    b = self._b
+    return pcoll | beam.Map(lambda x: a * x + b)
+
+
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
   unittest.main()

Reply via email to