robertwb commented on code in PR #27096:
URL: https://github.com/apache/beam/pull/27096#discussion_r1227256961


##########
sdks/python/apache_beam/yaml/yaml_mapping.py:
##########
@@ -0,0 +1,249 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""This module defines the basic MapToFields operation."""
+
+import itertools
+
+import apache_beam as beam
+from apache_beam.typehints import row_type
+from apache_beam.typehints import trivial_inference
+from apache_beam.typehints.schemas import named_fields_from_element_type
+from apache_beam.utils import python_callable
+from apache_beam.yaml import yaml_provider
+
+
+def _as_callable(original_fields, expr):
+  if expr in original_fields:
+    return expr
+  else:
+    # TODO(yaml): support a type parameter
+    # TODO(yaml): support an imports parameter
+    # TODO(yaml): support a requirements parameter (possibly at a higher level)
+    if isinstance(expr, str):
+      expr = {'expression': expr}
+    if not isinstance(expr, dict):
+      raise ValueError(
+          f"Ambiguous expression type (perhaps missing quoting?): {expr}")
+    elif len(expr) != 1:
+      raise ValueError(f"Ambiguous expression type: {list(expr.keys())}")
+    if 'expression' in expr:
+      # TODO(robertwb): Consider constructing a single callable that takes
+      # the row and returns the new row, rather than invoking (and unpacking)
+      # for each field individually.
+      source = '\n'.join(['def fn(__row__):'] + [
+          f'  {name} = __row__.{name}'
+          for name in original_fields if name in expr['expression']
+      ] + ['  return (' + expr['expression'] + ')'])
+    elif 'callable' in expr:
+      source = expr['callable']
+    else:
+      raise ValueError(f"Unknown expression type: {list(expr.keys())}")
+    return python_callable.PythonCallableWithSource(source)
+
+
+# TODO(yaml): This should be available in all environments, in which case
+# we choose the one that matches best.
+class _Explode(beam.PTransform):
+  def __init__(self, fields, cross_product):
+    self._fields = fields
+    self._cross_product = cross_product
+
+  def expand(self, pcoll):
+    all_fields = [
+        x for x, _ in named_fields_from_element_type(pcoll.element_type)
+    ]
+    to_explode = self._fields
+
+    def explode_cross_product(base, fields):
+      if fields:
+        copy = dict(base)
+        for value in base[fields[0]]:
+          copy[fields[0]] = value
+          yield from explode_cross_product(copy, fields[1:])
+      else:
+        yield beam.Row(**base)
+
+    def explode_zip(base, fields):
+      to_zip = [base[field] for field in fields]
+      copy = dict(base)
+      for values in itertools.zip_longest(*to_zip, fillvalue=None):
+        for ix, field in enumerate(fields):
+          copy[field] = values[ix]
+        yield beam.Row(**copy)
+
+    return pcoll | beam.FlatMap(
+        lambda row: (
+        explode_cross_product if self._cross_product else explode_zip)(
+            {name: getattr(row, name) for name in all_fields},  # yapf break
+            to_explode))
+
+  def infer_output_type(self, input_type):
+    return row_type.RowTypeConstraint.from_fields([(
+        name,
+        trivial_inference.element_type(typ) if name in self._fields else
+        typ) for (name, typ) in named_fields_from_element_type(input_type)])
+
+
+# TODO(yaml): Should Filter and Explode be distinct operations from Project?
+# We'll want these per-language.
[email protected]_fn
+def _PythonProjectionTransform(
+    pcoll, *, fields, keep=None, explode=(), cross_product=True):
+  original_fields = [
+      name for (name, _) in named_fields_from_element_type(pcoll.element_type)
+  ]
+
+  if keep:
+    if isinstance(keep, str) and keep in original_fields:
+      keep_fn = lambda row: getattr(row, keep)
+    else:
+      keep_fn = _as_callable(original_fields, keep)
+    filtered = pcoll | beam.Filter(keep_fn)
+  else:
+    filtered = pcoll
+
+  projected = filtered | beam.Select(
+      **{
+          name: _as_callable(original_fields, expr)
+          for (name, expr) in fields.items()
+      })
+
+  if explode:
+    result = projected | _Explode(explode, cross_product=cross_product)
+  else:
+    result = projected
+
+  return result
+
+
[email protected]_fn
+def MapToFields(
+    pcoll,
+    yaml_create_transform,
+    *,
+    fields,
+    keep=None,
+    explode=(),
+    cross_product=None,
+    append=False,
+    drop=(),
+    language=None,
+    **language_keywords):
+
+  if isinstance(explode, str):
+    explode = [explode]
+  if cross_product is None:
+    if len(explode) > 1:
+      # TODO(robertwb): Consider if true is an OK default.
+      raise ValueError(
+          'cross_product must be specified true or false '
+          'when exploding multiple fields')
+    else:
+      # Doesn't matter.
+      cross_product = True
+
+  input_schema = dict(named_fields_from_element_type(pcoll.element_type))
+  if drop and not append:
+    raise ValueError("Can only drop fields if append is true.")
+  for name in drop:
+    if name not in input_schema:
+      raise ValueError(f'Dropping unknown field "{name}"')
+  for name in explode:
+    if not (name in fields or (append and name in input_schema)):
+      raise ValueError(f'Exploding unknown field "{name}"')
+  if append:
+    for name in fields:
+      if name in input_schema and name not in drop:
+        raise ValueError(f'Redefinition of field "{name}"')
+
+  if append:
+    fields = {
+        **{name: name
+           for name in input_schema.keys() if name not in drop},
+        **fields
+    }
+
+  if language is None:
+    for name, expr in fields.items():
+      if not isinstance(expr, str) or expr not in input_schema:
+        # TODO(robertw): Could consider defaulting to SQL, or another
+        # lowest-common-denominator expression language.
+        raise ValueError("Missing language specification.")
+
+    # We should support this for all languages.
+    language = "python"
+
+  if language in ("sql", "calcite"):
+    selects = [f'{expr} AS {name}' for (name, expr) in fields.items()]
+    query = "SELECT " + ", ".join(selects) + " FROM PCOLLECTION"
+    if keep:
+      query += " WHERE " + keep
+    print("QUERY", query)

Review Comment:
   Leftover debugging. Removed.



##########
sdks/python/apache_beam/yaml/yaml_mapping.py:
##########
@@ -0,0 +1,249 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""This module defines the basic MapToFields operation."""
+
+import itertools
+
+import apache_beam as beam
+from apache_beam.typehints import row_type
+from apache_beam.typehints import trivial_inference
+from apache_beam.typehints.schemas import named_fields_from_element_type
+from apache_beam.utils import python_callable
+from apache_beam.yaml import yaml_provider
+
+
+def _as_callable(original_fields, expr):
+  if expr in original_fields:
+    return expr
+  else:
+    # TODO(yaml): support a type parameter
+    # TODO(yaml): support an imports parameter
+    # TODO(yaml): support a requirements parameter (possibly at a higher level)
+    if isinstance(expr, str):
+      expr = {'expression': expr}
+    if not isinstance(expr, dict):
+      raise ValueError(
+          f"Ambiguous expression type (perhaps missing quoting?): {expr}")
+    elif len(expr) != 1:
+      raise ValueError(f"Ambiguous expression type: {list(expr.keys())}")
+    if 'expression' in expr:
+      # TODO(robertwb): Consider constructing a single callable that takes
+      # the row and returns the new row, rather than invoking (and unpacking)
+      # for each field individually.
+      source = '\n'.join(['def fn(__row__):'] + [
+          f'  {name} = __row__.{name}'
+          for name in original_fields if name in expr['expression']
+      ] + ['  return (' + expr['expression'] + ')'])
+    elif 'callable' in expr:
+      source = expr['callable']
+    else:
+      raise ValueError(f"Unknown expression type: {list(expr.keys())}")
+    return python_callable.PythonCallableWithSource(source)
+
+
+# TODO(yaml): This should be available in all environments, in which case
+# we choose the one that matches best.
+class _Explode(beam.PTransform):
+  def __init__(self, fields, cross_product):
+    self._fields = fields
+    self._cross_product = cross_product
+
+  def expand(self, pcoll):
+    all_fields = [
+        x for x, _ in named_fields_from_element_type(pcoll.element_type)
+    ]
+    to_explode = self._fields
+
+    def explode_cross_product(base, fields):
+      if fields:
+        copy = dict(base)
+        for value in base[fields[0]]:
+          copy[fields[0]] = value
+          yield from explode_cross_product(copy, fields[1:])
+      else:
+        yield beam.Row(**base)
+
+    def explode_zip(base, fields):
+      to_zip = [base[field] for field in fields]
+      copy = dict(base)
+      for values in itertools.zip_longest(*to_zip, fillvalue=None):
+        for ix, field in enumerate(fields):
+          copy[field] = values[ix]
+        yield beam.Row(**copy)
+
+    return pcoll | beam.FlatMap(
+        lambda row: (
+        explode_cross_product if self._cross_product else explode_zip)(
+            {name: getattr(row, name) for name in all_fields},  # yapf break
+            to_explode))
+
+  def infer_output_type(self, input_type):
+    return row_type.RowTypeConstraint.from_fields([(
+        name,
+        trivial_inference.element_type(typ) if name in self._fields else
+        typ) for (name, typ) in named_fields_from_element_type(input_type)])
+
+
+# TODO(yaml): Should Filter and Explode be distinct operations from Project?
+# We'll want these per-language.
[email protected]_fn
+def _PythonProjectionTransform(
+    pcoll, *, fields, keep=None, explode=(), cross_product=True):
+  original_fields = [
+      name for (name, _) in named_fields_from_element_type(pcoll.element_type)
+  ]
+
+  if keep:
+    if isinstance(keep, str) and keep in original_fields:
+      keep_fn = lambda row: getattr(row, keep)
+    else:
+      keep_fn = _as_callable(original_fields, keep)
+    filtered = pcoll | beam.Filter(keep_fn)
+  else:
+    filtered = pcoll
+
+  projected = filtered | beam.Select(
+      **{
+          name: _as_callable(original_fields, expr)
+          for (name, expr) in fields.items()
+      })
+
+  if explode:
+    result = projected | _Explode(explode, cross_product=cross_product)
+  else:
+    result = projected
+
+  return result
+
+
[email protected]_fn
+def MapToFields(
+    pcoll,
+    yaml_create_transform,
+    *,
+    fields,
+    keep=None,
+    explode=(),
+    cross_product=None,
+    append=False,
+    drop=(),
+    language=None,
+    **language_keywords):
+
+  if isinstance(explode, str):
+    explode = [explode]
+  if cross_product is None:
+    if len(explode) > 1:
+      # TODO(robertwb): Consider if true is an OK default.
+      raise ValueError(
+          'cross_product must be specified true or false '
+          'when exploding multiple fields')
+    else:
+      # Doesn't matter.
+      cross_product = True
+
+  input_schema = dict(named_fields_from_element_type(pcoll.element_type))
+  if drop and not append:
+    raise ValueError("Can only drop fields if append is true.")
+  for name in drop:
+    if name not in input_schema:
+      raise ValueError(f'Dropping unknown field "{name}"')
+  for name in explode:
+    if not (name in fields or (append and name in input_schema)):
+      raise ValueError(f'Exploding unknown field "{name}"')
+  if append:
+    for name in fields:
+      if name in input_schema and name not in drop:
+        raise ValueError(f'Redefinition of field "{name}"')
+
+  if append:
+    fields = {
+        **{name: name
+           for name in input_schema.keys() if name not in drop},
+        **fields
+    }
+
+  if language is None:
+    for name, expr in fields.items():
+      if not isinstance(expr, str) or expr not in input_schema:
+        # TODO(robertw): Could consider defaulting to SQL, or another
+        # lowest-common-denominator expression language.
+        raise ValueError("Missing language specification.")
+
+    # We should support this for all languages.
+    language = "python"
+
+  if language in ("sql", "calcite"):
+    selects = [f'{expr} AS {name}' for (name, expr) in fields.items()]
+    query = "SELECT " + ", ".join(selects) + " FROM PCOLLECTION"
+    if keep:
+      query += " WHERE " + keep
+    print("QUERY", query)
+
+    result = pcoll | yaml_create_transform({
+        'type': 'Sql', 'query': query, **language_keywords
+    })
+    if explode:
+      # TODO(yaml): Implement via unnest.
+      result = result | _PythonProjectionTransform(
+          {}, append=True, explode=explode, cross_product=cross_product)
+
+    return result
+
+  elif language == 'python':
+    return pcoll | yaml_create_transform({
+        'type': 'PyTransform',
+        'constructor': __name__ + '._PythonProjectionTransform',
+        'kwargs': {
+            'fields': fields,
+            'keep': keep,
+            'explode': explode,
+            'cross_product': cross_product,
+        },
+        **language_keywords
+    })
+
+  else:
+    # TODO(yaml): Support javascript expressions and UDFs.
+    # TODO(yaml): Support java by fully qualified name.
+    # TODO(yaml): Maybe support java lambdas?
+    raise ValueError(f'Unknown language: {language}')

Review Comment:
   Done.



##########
sdks/python/apache_beam/yaml/yaml_mapping_test.py:
##########
@@ -0,0 +1,137 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import unittest
+
+import apache_beam as beam
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.yaml.yaml_transform import YamlTransform
+from apache_beam.yaml.readme_test import createTestSuite
+
+DATA = [
+    beam.Row(label='11a', conductor=11, rank=0),
+    beam.Row(label='37a', conductor=37, rank=1),
+    beam.Row(label='389a', conductor=389, rank=2),
+]
+
+
+class YamlMappingTest(unittest.TestCase):
+  def test_basic(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create(DATA)
+      result = elements | YamlTransform(
+          '''
+          type: MapToFields
+          input: input
+          language: python
+          fields:
+            label: label
+            isogeny: "label[-1]"
+          ''')
+      assert_that(
+          result,
+          equal_to([
+              beam.Row(label='11a', isogeny='a'),
+              beam.Row(label='37a', isogeny='a'),
+              beam.Row(label='389a', isogeny='a'),
+          ]))
+
+  def test_drop(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create(DATA)
+      result = elements | YamlTransform(
+          '''
+          type: MapToFields
+          input: input
+          fields: {}
+          append: true
+          drop: [conductor]
+          ''')
+      assert_that(
+          result,
+          equal_to([
+              beam.Row(label='11a', rank=0),
+              beam.Row(label='37a', rank=1),
+              beam.Row(label='389a', rank=2),
+          ]))
+
+  def test_filter(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create(DATA)
+      result = elements | YamlTransform(
+          '''
+          type: MapToFields
+          input: input
+          language: python
+          fields:
+            label: label
+          keep: "rank > 0"
+          ''')
+      assert_that(
+          result, equal_to([
+              beam.Row(label='37a'),
+              beam.Row(label='389a'),
+          ]))
+
+  def test_explode(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create([
+          beam.Row(a=2, b='abc', c=.25),
+          beam.Row(a=3, b='xy', c=.125),
+      ])
+      result = elements | YamlTransform(
+          '''
+          type: MapToFields
+          input: input
+          language: python
+          append: true
+          fields:
+            range: "range(a)"
+          explode: [range, b]
+          cross_product: true
+          ''')
+      assert_that(
+          result,
+          equal_to([
+              beam.Row(a=2, b='a', c=.25, range=0),
+              beam.Row(a=2, b='a', c=.25, range=1),
+              beam.Row(a=2, b='b', c=.25, range=0),
+              beam.Row(a=2, b='b', c=.25, range=1),
+              beam.Row(a=2, b='c', c=.25, range=0),
+              beam.Row(a=2, b='c', c=.25, range=1),
+              beam.Row(a=3, b='x', c=.125, range=0),
+              beam.Row(a=3, b='x', c=.125, range=1),
+              beam.Row(a=3, b='x', c=.125, range=2),
+              beam.Row(a=3, b='y', c=.125, range=0),
+              beam.Row(a=3, b='y', c=.125, range=1),
+              beam.Row(a=3, b='y', c=.125, range=2),
+          ]))
+
+
+YamlMappingDocTest = createTestSuite(
+    'YamlMappingDocTest',
+    os.path.join(os.path.dirname(__file__), 'yaml_mapping.md'))
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)

Review Comment:
   Fixed.



##########
sdks/python/apache_beam/yaml/readme_test.py:
##########
@@ -211,28 +215,39 @@ def parse_test_methods(markdown_lines):
         test_type = 'RUN'
         test_name = f'test_line_{ix + 2}'
       else:
-        if code_lines and code_lines[0] == 'pipeline:':
-          yaml_pipeline = '\n'.join(code_lines)
-          if 'providers:' in yaml_pipeline:
-            test_type = 'PARSE'
-          yield test_name, create_test_method(
-              test_type,
-              test_name,
-              yaml_pipeline)
+        if code_lines:
+          if code_lines[0].startswith('- type:'):
+            # Treat this as a fragment of a larger pipeline.
+            code_lines = [
+                'pipeline:',
+                '  type: chain',
+                '  transforms:',
+                '    - type: ReadFromCsv',
+                '      path: whatever',
+            ] + ['    ' + line for line in code_lines]
+          if code_lines[0] == 'pipeline:':
+            yaml_pipeline = '\n'.join(code_lines)
+            if 'providers:' in yaml_pipeline:
+              test_type = 'PARSE'
+            yield test_name, create_test_method(
+                test_type,
+                test_name,
+                yaml_pipeline)
         code_lines = None
     elif code_lines is not None:
       code_lines.append(line)
 
 
-def createTestSuite():
-  with open(os.path.join(os.path.dirname(__file__), 'README.md')) as readme:
-    return type(
-        'ReadMeTest', (unittest.TestCase, ), dict(parse_test_methods(readme)))
+def createTestSuite(name, path):
+  with open(path) as readme:
+    return type(name, (unittest.TestCase, ), dict(parse_test_methods(readme)))
 
 
-ReadMeTest = createTestSuite()
+ReadMeTest = createTestSuite(
+    'ReadMeTest', os.path.join(os.path.dirname(__file__), 'README.md'))

Review Comment:
   I have such a block in yaml_mapping_test. I'm open to putting it here 
instead if you think that'd be preferable. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to