This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c5e6c7962e6 Refactor and cleanup yaml MapToFields. (#28462)
c5e6c7962e6 is described below

commit c5e6c7962e60ee8366bfc92edf0812341e940020
Author: Robert Bradshaw <[email protected]>
AuthorDate: Wed Sep 20 16:37:17 2023 -0700

    Refactor and cleanup yaml MapToFields. (#28462)
    
    * Avoid the use of MetaProviders, which was always kind of hacky.
      We may want to remove this infrastructure altogether as it
      does not play nicely with provider inference.
    
    * Split MapToFields into separate mapping, filtering, and exploding
      operations.
    
    * Allow MapToFields to act on non-schema'd PCollections.
    
    The various langauge flavors of these UDFs are now handled by a 
preprocessing
    step.  This will make it easier to extend to other langauges, including
    in particular possible multiple (equivalent) implementations of javascript 
to
    minimize cross-langauge boundary crossings.
    
    ---------
    
    Co-authored-by: Danny McCormick <[email protected]>
---
 sdks/python/apache_beam/transforms/core.py         |   8 +
 sdks/python/apache_beam/yaml/readme_test.py        |  23 +-
 sdks/python/apache_beam/yaml/yaml_mapping.md       |  35 +-
 sdks/python/apache_beam/yaml/yaml_mapping.py       | 367 ++++++++++++---------
 sdks/python/apache_beam/yaml/yaml_mapping_test.py  |  32 +-
 sdks/python/apache_beam/yaml/yaml_provider.py      |   7 +-
 sdks/python/apache_beam/yaml/yaml_transform.py     |  15 +
 .../python/apache_beam/yaml/yaml_transform_test.py |  14 +-
 sdks/python/apache_beam/yaml/yaml_udf_test.py      |  38 ++-
 9 files changed, 311 insertions(+), 228 deletions(-)

diff --git a/sdks/python/apache_beam/transforms/core.py 
b/sdks/python/apache_beam/transforms/core.py
index 66ac8fbad96..671af54e47b 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -2258,6 +2258,10 @@ class _PValueWithErrors(object):
     self._exception_handling_args = exception_handling_args
     self._upstream_errors = upstream_errors
 
+  @property
+  def element_type(self):
+    return self._pcoll.element_type
+
   def main_output_tag(self):
     return self._exception_handling_args.get('main_tag', 'good')
 
@@ -2309,6 +2313,10 @@ class _MaybePValueWithErrors(object):
     else:
       self._pvalue = _PValueWithErrors(pvalue, exception_handling_args)
 
+  @property
+  def element_type(self):
+    return self._pvalue.element_type
+
   def __or__(self, transform):
     return self.apply(transform)
 
diff --git a/sdks/python/apache_beam/yaml/readme_test.py 
b/sdks/python/apache_beam/yaml/readme_test.py
index 958d9cb5783..d918d18e11d 100644
--- a/sdks/python/apache_beam/yaml/readme_test.py
+++ b/sdks/python/apache_beam/yaml/readme_test.py
@@ -32,6 +32,7 @@ from yaml.loader import SafeLoader
 import apache_beam as beam
 from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.typehints import trivial_inference
+from apache_beam.yaml import yaml_mapping
 from apache_beam.yaml import yaml_provider
 from apache_beam.yaml import yaml_transform
 
@@ -85,13 +86,16 @@ class FakeSql(beam.PTransform):
             typ, = [t for t in typ.__args__ if t is not type(None)]
       return name, typ
 
-    output_schema = [
-        guess_name_and_type(expr) for expr in m.group(1).split(',')
-    ]
-    output_element = beam.Row(**{name: typ() for name, typ in output_schema})
-    return next(iter(inputs.values())) | beam.Map(
-        lambda _: output_element).with_output_types(
-            trivial_inference.instance_to_type(output_element))
+    if m.group(1) == '*':
+      return inputs['PCOLLECTION'] | beam.Filter(lambda _: True)
+    else:
+      output_schema = [
+          guess_name_and_type(expr) for expr in m.group(1).split(',')
+      ]
+      output_element = beam.Row(**{name: typ() for name, typ in output_schema})
+      return next(iter(inputs.values())) | beam.Map(
+          lambda _: output_element).with_output_types(
+              trivial_inference.instance_to_type(output_element))
 
 
 class FakeReadFromPubSub(beam.PTransform):
@@ -204,12 +208,13 @@ def create_test_method(test_type, test_name, test_yaml):
         ]
         options['render_leaf_composite_nodes'] = ['.*']
       test_provider = TestProvider(TEST_TRANSFORMS)
+      test_sql_mapping_provider = 
yaml_mapping.SqlMappingProvider(test_provider)
       p = beam.Pipeline(options=PipelineOptions(**options))
       yaml_transform.expand_pipeline(
           p,
           modified_yaml,
-          {t: test_provider
-           for t in test_provider.provided_transforms()})
+          yaml_provider.merge_providers(
+              [test_provider, test_sql_mapping_provider]))
       if test_type == 'BUILD':
         return
       p.run().wait_until_finish()
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.md 
b/sdks/python/apache_beam/yaml/yaml_mapping.md
index b5e84e1a805..653b4abe8b8 100644
--- a/sdks/python/apache_beam/yaml/yaml_mapping.md
+++ b/sdks/python/apache_beam/yaml/yaml_mapping.md
@@ -131,7 +131,7 @@ Currently, in addition to Python, SQL expressions are 
supported as well
 
 Sometimes it may be desirable to emit more (or less) than one record for each
 input record.  This can be accomplished by mapping to an iterable type and
-noting that the specific field should be exploded, e.g.
+following the mapping with an Explode operation, e.g.
 
 ```
 - type: MapToFields
@@ -140,7 +140,9 @@ noting that the specific field should be exploded, e.g.
     fields:
       new_col: "[col1.upper(), col1.lower(), col1.title()]"
       another_col: "col2 + col3"
-    explode: new_col
+- type: Explode
+  config:
+    fields: new_col
 ```
 
 will result in three output records for every input record.
@@ -155,7 +157,9 @@ product over all fields should be taken. For example
     fields:
       new_col: "[col1.upper(), col1.lower(), col1.title()]"
       another_col: "[col2 - 1, col2, col2 + 1]"
-    explode: [new_col, another_col]
+- type: Explode
+  config:
+    fields: [new_col, another_col]
     cross_product: true
 ```
 
@@ -168,38 +172,27 @@ will emit nine records whereas
     fields:
       new_col: "[col1.upper(), col1.lower(), col1.title()]"
       another_col: "[col2 - 1, col2, col2 + 1]"
-    explode: [new_col, another_col]
+- type: Explode
+  config:
+    fields: [new_col, another_col]
     cross_product: false
 ```
 
 will only emit three.
 
-If one is only exploding existing fields, a simpler `Explode` transform may be
-used instead
+The `Explode` operation can be used on its own if the field in question is
+already an iterable type.
 
 ```
 - type: Explode
   config:
-    explode: [col1]
+    fields: [col1]
 ```
 
 ## Filtering
 
 Sometimes it can be desirable to only keep records that satisfy a certain
-criteria. This can be accomplished by specifying a keep parameter, e.g.
-
-```
-- type: MapToFields
-  config:
-    language: python
-    fields:
-      new_col: "col1.upper()"
-      another_col: "col2 + col3"
-    keep: "col2 > 0"
-```
-
-Like explode, there is a simpler `Filter` transform useful when no mapping is
-being done
+criteria. This can be accomplished with a `Filter` transform, e.g.
 
 ```
 - type: Filter
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py 
b/sdks/python/apache_beam/yaml/yaml_mapping.py
index b6dea894b3e..221c6f018d6 100644
--- a/sdks/python/apache_beam/yaml/yaml_mapping.py
+++ b/sdks/python/apache_beam/yaml/yaml_mapping.py
@@ -17,6 +17,14 @@
 
 """This module defines the basic MapToFields operation."""
 import itertools
+from typing import Any
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import Iterable
+from typing import Mapping
+from typing import Optional
+from typing import Union
 
 import js2py
 
@@ -139,18 +147,73 @@ def _as_callable(original_fields, expr, transform_name, 
language):
         'Supported languages are "javascript" and "python."')
 
 
+def exception_handling_args(error_handling_spec):
+  if error_handling_spec:
+    return {
+        'dead_letter_tag' if k == 'output' else k: v
+        for (k, v) in error_handling_spec.items()
+    }
+  else:
+    return None
+
+
+def _map_errors_to_standard_format():
+  # TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple.
+  return beam.Map(
+      lambda x: beam.Row(element=x[0], msg=str(x[1][1]), stack=str(x[1][2])))
+
+
+def maybe_with_exception_handling(inner_expand):
+  def expand(self, pcoll):
+    wrapped_pcoll = beam.core._MaybePValueWithErrors(
+        pcoll, self._exception_handling_args)
+    return inner_expand(self, wrapped_pcoll).as_result(
+        _map_errors_to_standard_format())
+
+  return expand
+
+
+def maybe_with_exception_handling_transform_fn(transform_fn):
+  def expand(pcoll, error_handling=None, **kwargs):
+    wrapped_pcoll = beam.core._MaybePValueWithErrors(
+        pcoll, exception_handling_args(error_handling))
+    return transform_fn(wrapped_pcoll,
+                        **kwargs).as_result(_map_errors_to_standard_format())
+
+  return expand
+
+
 # TODO(yaml): This should be available in all environments, in which case
 # we choose the one that matches best.
 class _Explode(beam.PTransform):
-  def __init__(self, fields, cross_product):
+  def __init__(
+      self,
+      fields: Union[str, Collection[str]],
+      cross_product: Optional[bool] = None,
+      error_handling: Optional[Mapping[str, Any]] = None):
+    if isinstance(fields, str):
+      fields = [fields]
+    if cross_product is None:
+      if len(fields) > 1:
+        raise ValueError(
+            'cross_product must be specified true or false '
+            'when exploding multiple fields')
+      else:
+        # Doesn't matter.
+        cross_product = True
     self._fields = fields
     self._cross_product = cross_product
-    self._exception_handling_args = None
+    # TODO(yaml): Support standard error handling argument.
+    self._exception_handling_args = exception_handling_args(error_handling)
 
+  @maybe_with_exception_handling
   def expand(self, pcoll):
     all_fields = [
         x for x, _ in named_fields_from_element_type(pcoll.element_type)
     ]
+    for field in self._fields:
+      if field not in all_fields:
+        raise ValueError(f'Exploding unknown field "{field}"')
     to_explode = self._fields
 
     def explode_cross_product(base, fields):
@@ -171,12 +234,12 @@ class _Explode(beam.PTransform):
         yield beam.Row(**copy)
 
     return (
-        beam.core._MaybePValueWithErrors(pcoll, self._exception_handling_args)
+        pcoll
         | beam.FlatMap(
             lambda row:
             (explode_cross_product if self._cross_product else explode_zip)
             ({name: getattr(row, name)
-              for name in all_fields}, to_explode))).as_result()
+              for name in all_fields}, to_explode)))
 
   def infer_output_type(self, input_type):
     return row_type.RowTypeConstraint.from_fields([(
@@ -190,189 +253,171 @@ class _Explode(beam.PTransform):
     return self
 
 
-# TODO(yaml): Should Filter and Explode be distinct operations from Project?
-# We'll want these per-language.
 @beam.ptransform.ptransform_fn
-def _PythonProjectionTransform(
-    pcoll,
-    *,
-    fields,
-    transform_name,
-    language,
-    keep=None,
-    explode=(),
-    cross_product=True,
-    error_handling=None):
-  original_fields = [
-      name for (name, _) in named_fields_from_element_type(pcoll.element_type)
-  ]
+@maybe_with_exception_handling_transform_fn
+def _PyJsFilter(
+    pcoll, keep: Union[str, Dict[str, str]], language: Optional[str] = None):
 
-  if error_handling is None:
-    error_handling_args = None
+  input_schema = dict(named_fields_from_element_type(pcoll.element_type))
+  if isinstance(keep, str) and keep in input_schema:
+    keep_fn = lambda row: getattr(row, keep)
   else:
-    error_handling_args = {
-        'dead_letter_tag' if k == 'output' else k: v
-        for (k, v) in error_handling.items()
-    }
+    keep_fn = _as_callable(list(input_schema.keys()), keep, "keep", language)
+  return pcoll | beam.Filter(keep_fn)
 
-  pcoll = beam.core._MaybePValueWithErrors(pcoll, error_handling_args)
 
-  if keep:
-    if isinstance(keep, str) and keep in original_fields:
-      keep_fn = lambda row: getattr(row, keep)
-    else:
-      keep_fn = _as_callable(original_fields, keep, transform_name, language)
-    filtered = pcoll | beam.Filter(keep_fn)
-  else:
-    filtered = pcoll
+def is_expr(v):
+  return isinstance(v, str) or (isinstance(v, dict) and 'expression' in v)
 
-  projected = filtered | beam.Select(
-      **{
-          name: _as_callable(original_fields, expr, transform_name, language)
-          for (name, expr) in fields.items()
-      })
 
-  if explode:
-    result = projected | _Explode(explode, cross_product=cross_product)
-  else:
-    result = projected
-
-  return result.as_result(
-      # TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple.
-      beam.Map(
-          lambda x: beam.Row(
-              element=x[0], msg=str(x[1][1]), stack=str(x[1][2]))))
-
-
[email protected]_fn
-def MapToFields(
-    pcoll,
-    yaml_create_transform,
-    *,
-    fields,
-    keep=None,
-    explode=(),
-    cross_product=None,
-    append=False,
-    drop=(),
-    language=None,
-    error_handling=None,
-    transform_name="MapToFields",
-    **language_keywords):
-  if isinstance(explode, str):
-    explode = [explode]
-  if cross_product is None:
-    if len(explode) > 1:
-      # TODO(robertwb): Consider if true is an OK default.
-      raise ValueError(
-          'cross_product must be specified true or false '
-          'when exploding multiple fields')
-    else:
-      # Doesn't matter.
-      cross_product = True
+def normalize_fields(pcoll, fields, drop=(), append=False, language='generic'):
+  try:
+    input_schema = dict(named_fields_from_element_type(pcoll.element_type))
+  except ValueError as exn:
+    if drop:
+      raise ValueError("Can only drop fields on a schema'd input.") from exn
+    if append:
+      raise ValueError("Can only append fields on a schema'd input.") from exn
+    elif any(is_expr(x) for x in fields.values()):
+      raise ValueError("Can only use expressions on a schema'd input.") from 
exn
+    input_schema = {}
 
-  input_schema = dict(named_fields_from_element_type(pcoll.element_type))
+  if isinstance(drop, str):
+    drop = [drop]
   if drop and not append:
     raise ValueError("Can only drop fields if append is true.")
   for name in drop:
     if name not in input_schema:
       raise ValueError(f'Dropping unknown field "{name}"')
-  for name in explode:
-    if not (name in fields or (append and name in input_schema)):
-      raise ValueError(f'Exploding unknown field "{name}"')
   if append:
     for name in fields:
       if name in input_schema and name not in drop:
-        raise ValueError(f'Redefinition of field "{name}"')
+        raise ValueError(
+            f'Redefinition of field "{name}". '
+            'Cannot append a field that already exists in original input.')
+
+  if language == 'generic':
+    for expr in fields.values():
+      if not isinstance(expr, str):
+        raise ValueError(
+            "Missing language specification. "
+            "Must specify a language when using a map with custom logic.")
+    missing = set(fields.values()) - set(input_schema.keys())
+    if missing:
+      raise ValueError(
+          f"Missing language specification or unknown input fields: {missing}")
 
   if append:
-    fields = {
+    return input_schema, {
         **{name: name
            for name in input_schema.keys() if name not in drop},
         **fields
     }
+  else:
+    return input_schema, fields
 
-  if language is None:
-    for name, expr in fields.items():
-      if not isinstance(expr, str) or expr not in input_schema:
-        # TODO(robertw): Could consider defaulting to SQL, or another
-        # lowest-common-denominator expression language.
-        raise ValueError("Missing language specification.")
-
-    # We should support this for all languages.
-    language = "python"
-
-  if language in ("sql", "calcite"):
-    if error_handling:
-      raise ValueError('Error handling unsupported for sql.')
-    selects = [f'{expr} AS {name}' for (name, expr) in fields.items()]
-    query = "SELECT " + ", ".join(selects) + " FROM PCOLLECTION"
-    if keep:
-      query += " WHERE " + keep
-
-    result = pcoll | yaml_create_transform({
-        'type': 'Sql',
-        'config': {
-            'query': query, **language_keywords
-        },
-    }, [pcoll])
-    if explode:
-      # TODO(yaml): Implement via unnest.
-      result = result | _Explode(explode, cross_product)
-
-    return result
-
-  elif language == 'python' or language == 'javascript':
-    return pcoll | yaml_create_transform({
-        'type': 'PyTransform',
-        'config': {
-            'constructor': __name__ + '._PythonProjectionTransform',
-            'kwargs': {
-                'fields': fields,
-                'transform_name': transform_name,
-                'language': language,
-                'keep': keep,
-                'explode': explode,
-                'cross_product': cross_product,
-                'error_handling': error_handling,
-            },
-            **language_keywords
-        },
-    }, [pcoll])
 
-  else:
-    # TODO(yaml): Support javascript expressions and UDFs.
-    # TODO(yaml): Support java by fully qualified name.
-    # TODO(yaml): Maybe support java lambdas?
-    raise ValueError(
-        f'Unknown language: {language}. '
-        'Supported languages are "sql" (alias calcite) and "python."')
[email protected]_fn
+@maybe_with_exception_handling_transform_fn
+def _PyJsMapToFields(pcoll, language='generic', **mapping_args):
+  input_schema, fields = normalize_fields(
+      pcoll, language=language, **mapping_args)
+  original_fields = list(input_schema.keys())
+
+  return pcoll | beam.Select(
+      **{
+          name: _as_callable(original_fields, expr, name, language)
+          for (name, expr) in fields.items()
+      })
+
+
+class SqlMappingProvider(yaml_provider.Provider):
+  def __init__(self, sql_provider=None):
+    if sql_provider is None:
+      sql_provider = yaml_provider.beam_jar(
+          urns={'Sql': 'beam:external:java:sql:v1'},
+          gradle_target='sdks:java:extensions:sql:expansion-service:shadowJar')
+    self._sql_provider = sql_provider
+
+  def available(self):
+    return self._sql_provider.available()
+
+  def cache_artifacts(self):
+    return self._sql_provider.cache_artifacts()
+
+  def provided_transforms(self) -> Iterable[str]:
+    return [
+        'Filter-sql',
+        'Filter-calcite',
+        'MapToFields-sql',
+        'MapToFields-calcite'
+    ]
+
+  def create_transform(
+      self,
+      typ: str,
+      args: Mapping[str, Any],
+      yaml_create_transform: Callable[
+          [Mapping[str, Any], Iterable[beam.PCollection]], beam.PTransform]
+  ) -> beam.PTransform:
+    if typ.startswith('Filter-'):
+      return _SqlFilterTransform(
+          self._sql_provider, yaml_create_transform, **args)
+    if typ.startswith('MapToFields-'):
+      return _SqlMapToFieldsTransform(
+          self._sql_provider, yaml_create_transform, **args)
+    else:
+      raise NotImplementedError(typ)
+
+  def underlying_provider(self):
+    return self._sql_provider
+
+  def to_json(self):
+    return {'type': "SqlMappingProvider"}
+
+
[email protected]_fn
+def _SqlFilterTransform(
+    pcoll, sql_provider, yaml_create_transform, keep, language):
+  return pcoll | sql_provider.create_transform(
+      'Sql', {'query': f'SELECT * FROM PCOLLECTION WHERE {keep}'},
+      yaml_create_transform)
 
 
-def create_mapping_provider():
[email protected]_fn
+def _SqlMapToFieldsTransform(
+    pcoll, sql_provider, yaml_create_transform, **mapping_args):
+  _, fields = normalize_fields(pcoll, **mapping_args)
+
+  def extract_expr(name, v):
+    if isinstance(v, str):
+      return v
+    elif 'expression' in v:
+      return v['expression']
+    else:
+      raise ValueError("Only expressions allowed in SQL at {name}.")
+
+  selects = [
+      f'({extract_expr(name, expr)}) AS {name}'
+      for (name, expr) in fields.items()
+  ]
+  query = "SELECT " + ", ".join(selects) + " FROM PCOLLECTION"
+  return pcoll | sql_provider.create_transform(
+      'Sql', {'query': query}, yaml_create_transform)
+
+
+def create_mapping_providers():
   # These are MetaInlineProviders because their expansion is in terms of other
   # YamlTransforms, but in a way that needs to be deferred until the input
   # schema is known.
-  return yaml_provider.MetaInlineProvider({
-      'MapToFields': MapToFields,
-      'Filter': (
-          lambda yaml_create_transform,
-          keep,
-          **kwargs: MapToFields(
-              yaml_create_transform,
-              keep=keep,
-              fields={},
-              append=True,
-              transform_name='Filter',
-              **kwargs)),
-      'Explode': (
-          lambda yaml_create_transform,
-          explode,
-          **kwargs: MapToFields(
-              yaml_create_transform,
-              explode=explode,
-              fields={},
-              append=True,
-              transform_name='Explode',
-              **kwargs)),
-  })
+  return [
+      yaml_provider.InlineProvider({
+          'Explode': _Explode,
+          'Filter-python': _PyJsFilter,
+          'Filter-javascript': _PyJsFilter,
+          'MapToFields-python': _PyJsMapToFields,
+          'MapToFields-javascript': _PyJsMapToFields,
+          'MapToFields-generic': _PyJsMapToFields,
+      }),
+      SqlMappingProvider(),
+  ]
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping_test.py 
b/sdks/python/apache_beam/yaml/yaml_mapping_test.py
index 728476b1fd5..55032aeae52 100644
--- a/sdks/python/apache_beam/yaml/yaml_mapping_test.py
+++ b/sdks/python/apache_beam/yaml/yaml_mapping_test.py
@@ -82,18 +82,18 @@ class YamlMappingTest(unittest.TestCase):
       elements = p | beam.Create(DATA)
       result = elements | YamlTransform(
           '''
-          type: MapToFields
+          type: Filter
           input: input
           config:
               language: python
-              fields:
-                label: label
               keep: "rank > 0"
           ''')
       assert_that(
-          result, equal_to([
-              beam.Row(label='37a'),
-              beam.Row(label='389a'),
+          result
+          | beam.Map(lambda named_tuple: beam.Row(**named_tuple._asdict())),
+          equal_to([
+              beam.Row(label='37a', conductor=37, rank=1),
+              beam.Row(label='389a', conductor=389, rank=2),
           ]))
 
   def test_explode(self):
@@ -105,15 +105,19 @@ class YamlMappingTest(unittest.TestCase):
       ])
       result = elements | YamlTransform(
           '''
-          type: MapToFields
+          type: chain
           input: input
-          config:
-              language: python
-              append: true
-              fields:
-                range: "range(a)"
-              explode: [range, b]
-              cross_product: true
+          transforms:
+            - type: MapToFields
+              config:
+                  language: python
+                  append: true
+                  fields:
+                    range: "range(a)"
+            - type: Explode
+              config:
+                  fields: [range, b]
+                  cross_product: true
           ''')
       assert_that(
           result,
diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py 
b/sdks/python/apache_beam/yaml/yaml_provider.py
index d01852a69c3..0cd9bdcadcc 100644
--- a/sdks/python/apache_beam/yaml/yaml_provider.py
+++ b/sdks/python/apache_beam/yaml/yaml_provider.py
@@ -209,6 +209,7 @@ class ExternalProvider(Provider):
   def register_provider_type(cls, type_name):
     def apply(constructor):
       cls._provider_types[type_name] = constructor
+      return constructor
 
     return apply
 
@@ -709,19 +710,21 @@ def merge_providers(*provider_sets):
           transform_type: [provider]
           for transform_type in provider.provided_transforms()
       }
+    elif isinstance(provider_set, list):
+      provider_set = merge_providers(*provider_set)
     for transform_type, providers in provider_set.items():
       result[transform_type].extend(providers)
   return result
 
 
 def standard_providers():
-  from apache_beam.yaml.yaml_mapping import create_mapping_provider
+  from apache_beam.yaml.yaml_mapping import create_mapping_providers
   from apache_beam.yaml.yaml_io import io_providers
   with open(os.path.join(os.path.dirname(__file__),
                          'standard_providers.yaml')) as fin:
     standard_providers = yaml.load(fin, Loader=SafeLoader)
   return merge_providers(
       create_builtin_provider(),
-      create_mapping_provider(),
+      create_mapping_providers(),
       io_providers(),
       parse_providers(standard_providers))
diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py 
b/sdks/python/apache_beam/yaml/yaml_transform.py
index da9bf526cd5..78546aa28cb 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform.py
@@ -879,8 +879,23 @@ def preprocess(spec, verbose=False, known_transforms=None):
             f'Unknown type or missing provider for {identify_object(spec)}')
     return spec
 
+  def preprocess_langauges(spec):
+    if spec['type'] in ('Filter', 'MapToFields'):
+      language = spec.get('config', {}).get('language', 'generic')
+      new_type = spec['type'] + '-' + language
+      if known_transforms and new_type not in known_transforms:
+        if language == 'generic':
+          raise ValueError(f'Missing language for {identify_object(spec)}')
+        else:
+          raise ValueError(
+              f'Unknown language {language} for {identify_object(spec)}')
+      return dict(spec, type=new_type, name=spec.get('name', spec['type']))
+    else:
+      return spec
+
   for phase in [
       ensure_transforms_have_types,
+      preprocess_langauges,
       ensure_transforms_have_providers,
       preprocess_source_sink,
       preprocess_chain,
diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py 
b/sdks/python/apache_beam/yaml/yaml_transform_test.py
index 993f9ea6639..ebf12710d3f 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform_test.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py
@@ -419,21 +419,27 @@ class ErrorHandlingTest(unittest.TestCase):
               input: Create
               config:
                   fn: "lambda x: beam.Row(num=x, str='a' * x or 'bbb')"
+            - type: Filter
+              input: ToRow
+              config:
+                  language: python
+                  keep:
+                    str[1] >= 'a'
+                  error_handling:
+                    output: errors
             - type: MapToFields
               name: MapWithErrorHandling
-              input: ToRow
+              input: Filter
               config:
                   language: python
                   fields:
                     num: num
                     inverse: float(1 / num)
-                  keep:
-                    str[1] >= 'a'
                   error_handling:
                     output: errors
             - type: PyMap
               name: TrimErrors
-              input: MapWithErrorHandling.errors
+              input: [MapWithErrorHandling.errors, Filter.errors]
               config:
                   fn: "lambda x: x.msg"
             - type: MapToFields
diff --git a/sdks/python/apache_beam/yaml/yaml_udf_test.py 
b/sdks/python/apache_beam/yaml/yaml_udf_test.py
index bb15cd49475..5e9faa08253 100644
--- a/sdks/python/apache_beam/yaml/yaml_udf_test.py
+++ b/sdks/python/apache_beam/yaml/yaml_udf_test.py
@@ -28,6 +28,10 @@ from apache_beam.testing.util import equal_to
 from apache_beam.yaml.yaml_transform import YamlTransform
 
 
+def AsRows():
+  return beam.Map(lambda named_tuple: beam.Row(**named_tuple._asdict()))
+
+
 class YamlUDFMappingTest(unittest.TestCase):
   def __init__(self, method_name='runYamlMappingTest'):
     super().__init__(method_name)
@@ -59,12 +63,11 @@ class YamlUDFMappingTest(unittest.TestCase):
             callable: "function label_map(x) {return x.label + 'x'}"
           conductor:
             callable: "function conductor_map(x) {return x.conductor + 1}"
-        keep:
-          callable: "function filter(x) {return x.rank > 0}"
       ''')
       assert_that(
           result,
           equal_to([
+              beam.Row(label='11ax', conductor=12),
               beam.Row(label='37ax', conductor=38),
               beam.Row(label='389ax', conductor=390),
           ]))
@@ -84,12 +87,11 @@ class YamlUDFMappingTest(unittest.TestCase):
             callable: "lambda x: x.label + 'x'"
           conductor:
             callable: "lambda x: x.conductor + 1"
-        keep: 
-          callable: "lambda x: x.rank > 0"
       ''')
       assert_that(
           result,
           equal_to([
+              beam.Row(label='11ax', conductor=12),
               beam.Row(label='37ax', conductor=38),
               beam.Row(label='389ax', conductor=390),
           ]))
@@ -104,11 +106,11 @@ class YamlUDFMappingTest(unittest.TestCase):
       input: input
       config:
         language: javascript
-        keep: 
+        keep:
           callable: "function filter(x) {return x.rank > 0}"
       ''')
       assert_that(
-          result,
+          result | AsRows(),
           equal_to([
               beam.Row(label='37a', conductor=37, rank=1),
               beam.Row(label='389a', conductor=389, rank=2),
@@ -124,11 +126,11 @@ class YamlUDFMappingTest(unittest.TestCase):
       input: input
       config:
         language: python
-        keep: 
+        keep:
           callable: "lambda x: x.rank > 0"
       ''')
       assert_that(
-          result,
+          result | AsRows(),
           equal_to([
               beam.Row(label='37a', conductor=37, rank=1),
               beam.Row(label='389a', conductor=389, rank=2),
@@ -144,11 +146,12 @@ class YamlUDFMappingTest(unittest.TestCase):
       input: input
       config:
         language: javascript
-        keep: 
+        keep:
           expression: "label.toUpperCase().indexOf('3') == -1 && conductor"
       ''')
       assert_that(
-          result, equal_to([
+          result | AsRows(),
+          equal_to([
               beam.Row(label='11a', conductor=11, rank=0),
           ]))
 
@@ -162,11 +165,12 @@ class YamlUDFMappingTest(unittest.TestCase):
       input: input
       config:
         language: python
-        keep: 
+        keep:
           expression: "'3' not in label"
       ''')
       assert_that(
-          result, equal_to([
+          result | AsRows(),
+          equal_to([
               beam.Row(label='11a', conductor=11, rank=0),
           ]))
 
@@ -175,7 +179,7 @@ class YamlUDFMappingTest(unittest.TestCase):
     function f(x) {
       return x.rank > 0
     }
-    
+
     function g(x) {
       return x.rank > 1
     }
@@ -193,12 +197,12 @@ class YamlUDFMappingTest(unittest.TestCase):
         input: input
         config:
           language: javascript
-          keep: 
+          keep:
             path: {path}
             name: "f"
         ''')
       assert_that(
-          result,
+          result | AsRows(),
           equal_to([
               beam.Row(label='37a', conductor=37, rank=1),
               beam.Row(label='389a', conductor=389, rank=2),
@@ -225,12 +229,12 @@ class YamlUDFMappingTest(unittest.TestCase):
         input: input
         config:
           language: python
-          keep: 
+          keep:
             path: {path}
             name: "f"
         ''')
       assert_that(
-          result,
+          result | AsRows(),
           equal_to([
               beam.Row(label='37a', conductor=37, rank=1),
               beam.Row(label='389a', conductor=389, rank=2),

Reply via email to