This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new ada507dc3b3 Implement basic mapping capabilities for YAML. (#27096)
ada507dc3b3 is described below

commit ada507dc3b389121a58ca1b3765ef2853c786f4e
Author: Robert Bradshaw <[email protected]>
AuthorDate: Thu Jun 15 10:40:26 2023 -0700

    Implement basic mapping capabilities for YAML. (#27096)
---
 sdks/python/apache_beam/typehints/schemas.py      |   8 +-
 sdks/python/apache_beam/yaml/readme_test.py       |  43 ++--
 sdks/python/apache_beam/yaml/yaml_mapping.md      | 196 +++++++++++++++++
 sdks/python/apache_beam/yaml/yaml_mapping.py      | 255 ++++++++++++++++++++++
 sdks/python/apache_beam/yaml/yaml_mapping_test.py | 138 ++++++++++++
 sdks/python/apache_beam/yaml/yaml_provider.py     |  32 ++-
 sdks/python/apache_beam/yaml/yaml_transform.py    |   3 +-
 7 files changed, 652 insertions(+), 23 deletions(-)

diff --git a/sdks/python/apache_beam/typehints/schemas.py 
b/sdks/python/apache_beam/typehints/schemas.py
index 2f5c366b4cb..156b877d07e 100644
--- a/sdks/python/apache_beam/typehints/schemas.py
+++ b/sdks/python/apache_beam/typehints/schemas.py
@@ -291,7 +291,13 @@ class SchemaTranslation(object):
       result.nullable = True
       return result
 
-    elif _safe_issubclass(type_, Sequence):
+    elif type_ == range:
+      return schema_pb2.FieldType(
+          array_type=schema_pb2.ArrayType(
+              element_type=schema_pb2.FieldType(
+                  atomic_type=PRIMITIVE_TO_ATOMIC_TYPE[int])))
+
+    elif _safe_issubclass(type_, Sequence) and not _safe_issubclass(type_, 
str):
       element_type = self.typing_to_runner_api(_get_args(type_)[0])
       return schema_pb2.FieldType(
           array_type=schema_pb2.ArrayType(element_type=element_type))
diff --git a/sdks/python/apache_beam/yaml/readme_test.py 
b/sdks/python/apache_beam/yaml/readme_test.py
index 1c014a4280f..26df760f285 100644
--- a/sdks/python/apache_beam/yaml/readme_test.py
+++ b/sdks/python/apache_beam/yaml/readme_test.py
@@ -68,6 +68,8 @@ class FakeSql(beam.PTransform):
           typ = float
         else:
           typ = str
+      elif '+' in expr:
+        typ = float
       else:
         part = parts[0]
         if '.' in part:
@@ -172,6 +174,8 @@ def replace_recursive(spec, transform_type, arg_name, 
arg_value):
 
 
 def create_test_method(test_type, test_name, test_yaml):
+  test_yaml = test_yaml.replace('pkg.module.fn', 'str')
+
   def test(self):
     with TestEnvironment() as env:
       spec = yaml.load(test_yaml, Loader=SafeLoader)
@@ -202,6 +206,7 @@ def create_test_method(test_type, test_name, test_yaml):
 
 
 def parse_test_methods(markdown_lines):
+  # pylint: disable=too-many-nested-blocks
   code_lines = None
   for ix, line in enumerate(markdown_lines):
     line = line.rstrip()
@@ -211,26 +216,38 @@ def parse_test_methods(markdown_lines):
         test_type = 'RUN'
         test_name = f'test_line_{ix + 2}'
       else:
-        if code_lines and code_lines[0] == 'pipeline:':
-          yaml_pipeline = '\n'.join(code_lines)
-          if 'providers:' in yaml_pipeline:
-            test_type = 'PARSE'
-          yield test_name, create_test_method(
-              test_type,
-              test_name,
-              yaml_pipeline)
+        if code_lines:
+          if code_lines[0].startswith('- type:'):
+            # Treat this as a fragment of a larger pipeline.
+            code_lines = [
+                'pipeline:',
+                '  type: chain',
+                '  transforms:',
+                '    - type: ReadFromCsv',
+                '      path: whatever',
+            ] + [
+                '    ' + line for line in code_lines
+            ]  # pylint: disable=not-an-iterable
+          if code_lines[0] == 'pipeline:':
+            yaml_pipeline = '\n'.join(code_lines)
+            if 'providers:' in yaml_pipeline:
+              test_type = 'PARSE'
+            yield test_name, create_test_method(
+                test_type,
+                test_name,
+                yaml_pipeline)
         code_lines = None
     elif code_lines is not None:
       code_lines.append(line)
 
 
-def createTestSuite():
-  with open(os.path.join(os.path.dirname(__file__), 'README.md')) as readme:
-    return type(
-        'ReadMeTest', (unittest.TestCase, ), dict(parse_test_methods(readme)))
+def createTestSuite(name, path):
+  with open(path) as readme:
+    return type(name, (unittest.TestCase, ), dict(parse_test_methods(readme)))
 
 
-ReadMeTest = createTestSuite()
+ReadMeTest = createTestSuite(
+    'ReadMeTest', os.path.join(os.path.dirname(__file__), 'README.md'))
 
 if __name__ == '__main__':
   parser = argparse.ArgumentParser()
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.md 
b/sdks/python/apache_beam/yaml/yaml_mapping.md
new file mode 100644
index 00000000000..193abac610e
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_mapping.md
@@ -0,0 +1,196 @@
+<!--
+    Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+    Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+-->
+
+# Beam YAML mappings
+
+Beam YAML has the ability to do simple transformations which can be used to
+get data into the correct shape. The simplest of these is `MaptoFields`
+which creates records with new fields defined in terms of the input fields.
+
+## Field renames
+
+To rename fields one can write
+
+```
+- type: MapToFields
+  fields:
+    new_col1: col1
+    new_col2: col2
+```
+
+will result in an output where each record has two fields,
+`new_col1` and `new_col2`, whose values are those of `col1` and `col2`
+respectively.
+
+One can specify the append parameter which indicates the original fields should
+be retained similar to the use of `*` in an SQL select statement. For example
+
+```
+- type: MapToFields
+  append: true
+  fields:
+    new_col1: col1
+    new_col2: col2
+```
+
+will output records that have `new_col1` and `new_col2` as *additional*
+fields.  When the append field is specified, one can drop fields as well, e.g.
+
+```
+- type: MapToFields
+  append: true
+  drop:
+    - col3
+  fields:
+    new_col1: col1
+    new_col2: col2
+```
+
+which includes all original fiels *except* col3 in addition to outputting the
+two new ones.
+
+
+## Mapping functions
+
+Of course one may want to do transformations beyond just dropping and renaming
+fields.  Beam YAML has the ability to inline simple UDFs.
+This requires a language specification. For example
+
+```
+- type: MapToFields
+  language: python
+  fields:
+    new_col: "col1.upper()"
+    another_col: "col2 + col3"
+```
+
+In addition, one can provide a full Python callable that takes the row as an
+argument to do more complex mappings
+(see 
[PythonCallableSource](https://beam.apache.org/releases/pydoc/current/apache_beam.utils.python_callable.html#apache_beam.utils.python_callable.PythonCallableWithSource)
+for acceptable formats). Thus one can write
+
+```
+- type: MapToFields
+  language: python
+  fields:
+    new_col:
+      callable: |
+        import re
+        def my_mapping(row):
+          if re.match("[0-9]+", row.col1) and row.col2 > 0:
+            return "good"
+          else:
+            return "bad"
+```
+
+Once one reaches a certain level of complexity, it may be preferable to package
+this up as a dependency and simply refer to it by fully qualified name, e.g.
+
+```
+- type: MapToFields
+  language: python
+  fields:
+    new_col:
+      callable: pkg.module.fn
+```
+
+Currently, in addition to Python, SQL expressions are supported as well
+
+```
+- type: MapToFields
+  language: sql
+  fields:
+    new_col: "UPPER(col1)"
+    another_col: "col2 + col3"
+```
+
+## FlatMap
+
+Sometimes it may be desirable to emit more (or less) than one record for each
+input record.  This can be accomplished by mapping to an iterable type and
+noting that the specific field should be exploded, e.g.
+
+```
+- type: MapToFields
+  language: python
+  fields:
+    new_col: "[col1.upper(), col1.lower(), col1.title()]"
+    another_col: "col2 + col3"
+  explode: new_col
+```
+
+will result in three output records for every input record.
+
+If more than one record is to be exploded, one must specify whether the cross
+product over all fields should be taken. For example
+
+```
+- type: MapToFields
+  language: python
+  fields:
+    new_col: "[col1.upper(), col1.lower(), col1.title()]"
+    another_col: "[col2 - 1, col2, col2 + 1]"
+  explode: [new_col, another_col]
+  cross_product: true
+```
+
+will emit nine records whereas
+
+```
+- type: MapToFields
+  language: python
+  fields:
+    new_col: "[col1.upper(), col1.lower(), col1.title()]"
+    another_col: "[col2 - 1, col2, col2 + 1]"
+  explode: [new_col, another_col]
+  cross_product: false
+```
+
+will only emit three.
+
+If one is only exploding existing fields, a simpler `Explode` transform may be
+used instead
+
+```
+- type: Explode
+  explode: [col1]
+```
+
+## Filtering
+
+Sometimes it can be desirable to only keep records that satisfy a certain
+criteria. This can be accomplished by specifying a keep parameter, e.g.
+
+```
+- type: MapToFields
+  language: python
+  fields:
+    new_col: "col1.upper()"
+    another_col: "col2 + col3"
+  keep: "col2 > 0"
+```
+
+Like explode, there is a simpler `Filter` transform useful when no mapping is
+being done
+
+```
+- type: Filter
+  language: sql
+  keep: "col2 > 0"
+```
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py 
b/sdks/python/apache_beam/yaml/yaml_mapping.py
new file mode 100644
index 00000000000..7f959773320
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_mapping.py
@@ -0,0 +1,255 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""This module defines the basic MapToFields operation."""
+
+import itertools
+
+import apache_beam as beam
+from apache_beam.typehints import row_type
+from apache_beam.typehints import trivial_inference
+from apache_beam.typehints.schemas import named_fields_from_element_type
+from apache_beam.utils import python_callable
+from apache_beam.yaml import yaml_provider
+
+
+def _as_callable(original_fields, expr):
+  if expr in original_fields:
+    return expr
+  else:
+    # TODO(yaml): support a type parameter
+    # TODO(yaml): support an imports parameter
+    # TODO(yaml): support a requirements parameter (possibly at a higher level)
+    if isinstance(expr, str):
+      expr = {'expression': expr}
+    if not isinstance(expr, dict):
+      raise ValueError(
+          f"Ambiguous expression type (perhaps missing quoting?): {expr}")
+    elif len(expr) != 1:
+      raise ValueError(f"Ambiguous expression type: {list(expr.keys())}")
+    if 'expression' in expr:
+      # TODO(robertwb): Consider constructing a single callable that takes
+      # the row and returns the new row, rather than invoking (and unpacking)
+      # for each field individually.
+      source = '\n'.join(['def fn(__row__):'] + [
+          f'  {name} = __row__.{name}'
+          for name in original_fields if name in expr['expression']
+      ] + ['  return (' + expr['expression'] + ')'])
+    elif 'callable' in expr:
+      source = expr['callable']
+    else:
+      raise ValueError(f"Unknown expression type: {list(expr.keys())}")
+    return python_callable.PythonCallableWithSource(source)
+
+
+# TODO(yaml): This should be available in all environments, in which case
+# we choose the one that matches best.
+class _Explode(beam.PTransform):
+  def __init__(self, fields, cross_product):
+    self._fields = fields
+    self._cross_product = cross_product
+
+  def expand(self, pcoll):
+    all_fields = [
+        x for x, _ in named_fields_from_element_type(pcoll.element_type)
+    ]
+    to_explode = self._fields
+
+    def explode_cross_product(base, fields):
+      if fields:
+        copy = dict(base)
+        for value in base[fields[0]]:
+          copy[fields[0]] = value
+          yield from explode_cross_product(copy, fields[1:])
+      else:
+        yield beam.Row(**base)
+
+    def explode_zip(base, fields):
+      to_zip = [base[field] for field in fields]
+      copy = dict(base)
+      for values in itertools.zip_longest(*to_zip, fillvalue=None):
+        for ix, field in enumerate(fields):
+          copy[field] = values[ix]
+        yield beam.Row(**copy)
+
+    return pcoll | beam.FlatMap(
+        lambda row: (
+        explode_cross_product if self._cross_product else explode_zip)(
+            {name: getattr(row, name) for name in all_fields},  # yapf break
+            to_explode))
+
+  def infer_output_type(self, input_type):
+    return row_type.RowTypeConstraint.from_fields([(
+        name,
+        trivial_inference.element_type(typ) if name in self._fields else
+        typ) for (name, typ) in named_fields_from_element_type(input_type)])
+
+
+# TODO(yaml): Should Filter and Explode be distinct operations from Project?
+# We'll want these per-language.
[email protected]_fn
+def _PythonProjectionTransform(
+    pcoll, *, fields, keep=None, explode=(), cross_product=True):
+  original_fields = [
+      name for (name, _) in named_fields_from_element_type(pcoll.element_type)
+  ]
+
+  if keep:
+    if isinstance(keep, str) and keep in original_fields:
+      keep_fn = lambda row: getattr(row, keep)
+    else:
+      keep_fn = _as_callable(original_fields, keep)
+    filtered = pcoll | beam.Filter(keep_fn)
+  else:
+    filtered = pcoll
+
+  if list(fields.items()) == [(name, name) for name in original_fields]:
+    projected = filtered
+  else:
+    projected = filtered | beam.Select(
+        **{
+            name: _as_callable(original_fields, expr)
+            for (name, expr) in fields.items()
+        })
+
+  if explode:
+    result = projected | _Explode(explode, cross_product=cross_product)
+  else:
+    result = projected
+
+  return result
+
+
[email protected]_fn
+def MapToFields(
+    pcoll,
+    yaml_create_transform,
+    *,
+    fields,
+    keep=None,
+    explode=(),
+    cross_product=None,
+    append=False,
+    drop=(),
+    language=None,
+    **language_keywords):
+
+  if isinstance(explode, str):
+    explode = [explode]
+  if cross_product is None:
+    if len(explode) > 1:
+      # TODO(robertwb): Consider if true is an OK default.
+      raise ValueError(
+          'cross_product must be specified true or false '
+          'when exploding multiple fields')
+    else:
+      # Doesn't matter.
+      cross_product = True
+
+  input_schema = dict(named_fields_from_element_type(pcoll.element_type))
+  if drop and not append:
+    raise ValueError("Can only drop fields if append is true.")
+  for name in drop:
+    if name not in input_schema:
+      raise ValueError(f'Dropping unknown field "{name}"')
+  for name in explode:
+    if not (name in fields or (append and name in input_schema)):
+      raise ValueError(f'Exploding unknown field "{name}"')
+  if append:
+    for name in fields:
+      if name in input_schema and name not in drop:
+        raise ValueError(f'Redefinition of field "{name}"')
+
+  if append:
+    fields = {
+        **{name: name
+           for name in input_schema.keys() if name not in drop},
+        **fields
+    }
+
+  if language is None:
+    for name, expr in fields.items():
+      if not isinstance(expr, str) or expr not in input_schema:
+        # TODO(robertw): Could consider defaulting to SQL, or another
+        # lowest-common-denominator expression language.
+        raise ValueError("Missing language specification.")
+
+    # We should support this for all languages.
+    language = "python"
+
+  if language in ("sql", "calcite"):
+    selects = [f'{expr} AS {name}' for (name, expr) in fields.items()]
+    query = "SELECT " + ", ".join(selects) + " FROM PCOLLECTION"
+    if keep:
+      query += " WHERE " + keep
+
+    result = pcoll | yaml_create_transform({
+        'type': 'Sql', 'query': query, **language_keywords
+    })
+    if explode:
+      # TODO(yaml): Implement via unnest.
+      result = result | _Explode(explode, cross_product)
+
+    return result
+
+  elif language == 'python':
+    return pcoll | yaml_create_transform({
+        'type': 'PyTransform',
+        'constructor': __name__ + '._PythonProjectionTransform',
+        'kwargs': {
+            'fields': fields,
+            'keep': keep,
+            'explode': explode,
+            'cross_product': cross_product,
+        },
+        **language_keywords
+    })
+
+  else:
+    # TODO(yaml): Support javascript expressions and UDFs.
+    # TODO(yaml): Support java by fully qualified name.
+    # TODO(yaml): Maybe support java lambdas?
+    raise ValueError(
+        f'Unknown language: {language}. '
+        'Supported languages are "sql" (alias calcite) and "python."')
+
+
+def create_mapping_provider():
+  # These are MetaInlineProviders because their expansion is in terms of other
+  # YamlTransforms, but in a way that needs to be deferred until the input
+  # schema is known.
+  return yaml_provider.MetaInlineProvider({
+      'MapToFields': MapToFields,
+      'Filter': (
+          lambda yaml_create_transform,
+          keep,
+          **kwargs: MapToFields(
+              yaml_create_transform,
+              keep=keep,
+              fields={},
+              append=True,
+              **kwargs)),
+      'Explode': (
+          lambda yaml_create_transform,
+          explode,
+          **kwargs: MapToFields(
+              yaml_create_transform,
+              explode=explode,
+              fields={},
+              append=True,
+              **kwargs)),
+  })
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping_test.py 
b/sdks/python/apache_beam/yaml/yaml_mapping_test.py
new file mode 100644
index 00000000000..3305ff8b92a
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_mapping_test.py
@@ -0,0 +1,138 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import os
+import unittest
+
+import apache_beam as beam
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.yaml.readme_test import createTestSuite
+from apache_beam.yaml.yaml_transform import YamlTransform
+
+DATA = [
+    beam.Row(label='11a', conductor=11, rank=0),
+    beam.Row(label='37a', conductor=37, rank=1),
+    beam.Row(label='389a', conductor=389, rank=2),
+]
+
+
+class YamlMappingTest(unittest.TestCase):
+  def test_basic(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create(DATA)
+      result = elements | YamlTransform(
+          '''
+          type: MapToFields
+          input: input
+          language: python
+          fields:
+            label: label
+            isogeny: "label[-1]"
+          ''')
+      assert_that(
+          result,
+          equal_to([
+              beam.Row(label='11a', isogeny='a'),
+              beam.Row(label='37a', isogeny='a'),
+              beam.Row(label='389a', isogeny='a'),
+          ]))
+
+  def test_drop(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create(DATA)
+      result = elements | YamlTransform(
+          '''
+          type: MapToFields
+          input: input
+          fields: {}
+          append: true
+          drop: [conductor]
+          ''')
+      assert_that(
+          result,
+          equal_to([
+              beam.Row(label='11a', rank=0),
+              beam.Row(label='37a', rank=1),
+              beam.Row(label='389a', rank=2),
+          ]))
+
+  def test_filter(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create(DATA)
+      result = elements | YamlTransform(
+          '''
+          type: MapToFields
+          input: input
+          language: python
+          fields:
+            label: label
+          keep: "rank > 0"
+          ''')
+      assert_that(
+          result, equal_to([
+              beam.Row(label='37a'),
+              beam.Row(label='389a'),
+          ]))
+
+  def test_explode(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create([
+          beam.Row(a=2, b='abc', c=.25),
+          beam.Row(a=3, b='xy', c=.125),
+      ])
+      result = elements | YamlTransform(
+          '''
+          type: MapToFields
+          input: input
+          language: python
+          append: true
+          fields:
+            range: "range(a)"
+          explode: [range, b]
+          cross_product: true
+          ''')
+      assert_that(
+          result,
+          equal_to([
+              beam.Row(a=2, b='a', c=.25, range=0),
+              beam.Row(a=2, b='a', c=.25, range=1),
+              beam.Row(a=2, b='b', c=.25, range=0),
+              beam.Row(a=2, b='b', c=.25, range=1),
+              beam.Row(a=2, b='c', c=.25, range=0),
+              beam.Row(a=2, b='c', c=.25, range=1),
+              beam.Row(a=3, b='x', c=.125, range=0),
+              beam.Row(a=3, b='x', c=.125, range=1),
+              beam.Row(a=3, b='x', c=.125, range=2),
+              beam.Row(a=3, b='y', c=.125, range=0),
+              beam.Row(a=3, b='y', c=.125, range=1),
+              beam.Row(a=3, b='y', c=.125, range=2),
+          ]))
+
+
+YamlMappingDocTest = createTestSuite(
+    'YamlMappingDocTest',
+    os.path.join(os.path.dirname(__file__), 'yaml_mapping.md'))
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()
diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py 
b/sdks/python/apache_beam/yaml/yaml_provider.py
index e3d544781c3..209e2178b8e 100644
--- a/sdks/python/apache_beam/yaml/yaml_provider.py
+++ b/sdks/python/apache_beam/yaml/yaml_provider.py
@@ -27,6 +27,7 @@ import subprocess
 import sys
 import uuid
 from typing import Any
+from typing import Callable
 from typing import Iterable
 from typing import Mapping
 
@@ -59,7 +60,11 @@ class Provider:
     raise NotImplementedError(type(self))
 
   def create_transform(
-      self, typ: str, args: Mapping[str, Any]) -> beam.PTransform:
+      self,
+      typ: str,
+      args: Mapping[str, Any],
+      yaml_create_transform: Callable[[Mapping[str, Any]], beam.PTransform]
+  ) -> beam.PTransform:
     """Creates a PTransform instance for the given transform type and 
arguments.
     """
     raise NotImplementedError(type(self))
@@ -88,7 +93,7 @@ class ExternalProvider(Provider):
   def provided_transforms(self):
     return self._urns.keys()
 
-  def create_transform(self, type, args):
+  def create_transform(self, type, args, yaml_create_transform):
     if callable(self._service):
       self._service = self._service()
     if self._schema_transforms is None:
@@ -245,13 +250,18 @@ class InlineProvider(Provider):
   def provided_transforms(self):
     return self._transform_factories.keys()
 
-  def create_transform(self, type, args):
+  def create_transform(self, type, args, yaml_create_transform):
     return self._transform_factories[type](**args)
 
   def to_json(self):
     return {'type': "InlineProvider"}
 
 
+class MetaInlineProvider(InlineProvider):
+  def create_transform(self, type, args, yaml_create_transform):
+    return self._transform_factories[type](yaml_create_transform, **args)
+
+
 PRIMITIVE_NAMES_TO_ATOMIC_TYPE = {
     py_type.__name__: schema_type
     for (py_type, schema_type) in schemas.PRIMITIVE_TO_ATOMIC_TYPE.items()
@@ -446,17 +456,23 @@ def parse_providers(provider_specs):
 def merge_providers(*provider_sets):
   result = collections.defaultdict(list)
   for provider_set in provider_sets:
+    if isinstance(provider_set, Provider):
+      provider = provider_set
+      provider_set = {
+          transform_type: [provider]
+          for transform_type in provider.provided_transforms()
+      }
     for transform_type, providers in provider_set.items():
       result[transform_type].extend(providers)
   return result
 
 
 def standard_providers():
-  builtin_providers = collections.defaultdict(list)
-  builtin_provider = create_builtin_provider()
-  for transform_type in builtin_provider.provided_transforms():
-    builtin_providers[transform_type].append(builtin_provider)
+  from apache_beam.yaml.yaml_mapping import create_mapping_provider
   with open(os.path.join(os.path.dirname(__file__),
                          'standard_providers.yaml')) as fin:
     standard_providers = yaml.load(fin, Loader=SafeLoader)
-  return merge_providers(builtin_providers, 
parse_providers(standard_providers))
+  return merge_providers(
+      create_builtin_provider(),
+      create_mapping_provider(),
+      parse_providers(standard_providers))
diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py 
b/sdks/python/apache_beam/yaml/yaml_transform.py
index f3e605adf03..925aa0d85b9 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform.py
@@ -235,7 +235,8 @@ class Scope(LightweightScope):
     real_args = SafeLineLoader.strip_metadata(args)
     try:
       # pylint: disable=undefined-loop-variable
-      ptransform = provider.create_transform(spec['type'], real_args)
+      ptransform = provider.create_transform(
+          spec['type'], real_args, self.create_ptransform)
       # TODO(robertwb): Should we have a better API for adding annotations
       # than this?
       annotations = dict(

Reply via email to