This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new d5fc02479c2 [YAML] Add a basic aggregating transform to Beam Yaml. 
(#29167)
d5fc02479c2 is described below

commit d5fc02479c25cec4b98322495c3cb4fd866cc746
Author: Robert Bradshaw <[email protected]>
AuthorDate: Mon Oct 30 17:18:47 2023 -0700

    [YAML] Add a basic aggregating transform to Beam Yaml. (#29167)
    
    This follows the spirit of the schema-aware GroupBy operations in
    Java, Python, Typescript, etc. Syntactic sugar is provided for
    simplified forms.
    
    This is currently guarded as experimental until we've fully vetted
    the API.
---
 sdks/python/apache_beam/yaml/readme_test.py       |  23 ++-
 sdks/python/apache_beam/yaml/yaml_combine.md      | 166 ++++++++++++++++++
 sdks/python/apache_beam/yaml/yaml_combine.py      | 205 ++++++++++++++++++++++
 sdks/python/apache_beam/yaml/yaml_combine_test.py | 173 ++++++++++++++++++
 sdks/python/apache_beam/yaml/yaml_mapping.py      |  67 ++-----
 sdks/python/apache_beam/yaml/yaml_provider.py     |  41 +++++
 sdks/python/apache_beam/yaml/yaml_transform.py    |  12 +-
 7 files changed, 619 insertions(+), 68 deletions(-)

diff --git a/sdks/python/apache_beam/yaml/readme_test.py 
b/sdks/python/apache_beam/yaml/readme_test.py
index d918d18e11d..7f2d193bf35 100644
--- a/sdks/python/apache_beam/yaml/readme_test.py
+++ b/sdks/python/apache_beam/yaml/readme_test.py
@@ -26,13 +26,13 @@ import sys
 import tempfile
 import unittest
 
+import mock
 import yaml
 from yaml.loader import SafeLoader
 
 import apache_beam as beam
 from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.typehints import trivial_inference
-from apache_beam.yaml import yaml_mapping
 from apache_beam.yaml import yaml_provider
 from apache_beam.yaml import yaml_transform
 
@@ -200,7 +200,10 @@ def create_test_method(test_type, test_name, test_yaml):
         if write in test_yaml:
           spec = replace_recursive(spec, write, 'path', env.output_file())
       modified_yaml = yaml.dump(spec)
-      options = {'pickle_library': 'cloudpickle'}
+      options = {
+          'pickle_library': 'cloudpickle',
+          'yaml_experimental_features': ['Combine']
+      }
       if RENDER_DIR is not None:
         options['runner'] = 'apache_beam.runners.render.RenderRunner'
         options['render_output'] = [
@@ -208,13 +211,12 @@ def create_test_method(test_type, test_name, test_yaml):
         ]
         options['render_leaf_composite_nodes'] = ['.*']
       test_provider = TestProvider(TEST_TRANSFORMS)
-      test_sql_mapping_provider = 
yaml_mapping.SqlMappingProvider(test_provider)
-      p = beam.Pipeline(options=PipelineOptions(**options))
-      yaml_transform.expand_pipeline(
-          p,
-          modified_yaml,
-          yaml_provider.merge_providers(
-              [test_provider, test_sql_mapping_provider]))
+      with mock.patch(
+          'apache_beam.yaml.yaml_provider.SqlBackedProvider.sql_provider',
+          lambda self: test_provider):
+        p = beam.Pipeline(options=PipelineOptions(**options))
+        yaml_transform.expand_pipeline(
+            p, modified_yaml, yaml_provider.merge_providers([test_provider]))
       if test_type == 'BUILD':
         return
       p.run().wait_until_finish()
@@ -270,6 +272,9 @@ ErrorHandlingTest = createTestSuite(
     'ErrorHandlingTest',
     os.path.join(os.path.dirname(__file__), 'yaml_errors.md'))
 
+CombineTest = createTestSuite(
+    'CombineTest', os.path.join(os.path.dirname(__file__), 'yaml_combine.md'))
+
 if __name__ == '__main__':
   parser = argparse.ArgumentParser()
   parser.add_argument('--render_dir', default=None)
diff --git a/sdks/python/apache_beam/yaml/yaml_combine.md 
b/sdks/python/apache_beam/yaml/yaml_combine.md
new file mode 100644
index 00000000000..e2fef304fb0
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_combine.md
@@ -0,0 +1,166 @@
+<!--
+    Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+    Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+-->
+
+# Beam YAML Aggregations
+
+Beam YAML has EXPERIMENTAL ability to do aggregations to group and combine
+values across records. The is accomplished via the `Combine` transform type.
+Currently `Combine` needs to be in the `yaml_experimental_features`
+option to use this transform.
+
+For example, one can write
+
+```
+- type: Combine
+  config:
+    group_by: col1
+    combine:
+      total:
+        value: col2
+        fn:
+          type: sum
+```
+
+If the function has no configuration requirements, it can be provided directly
+as a string
+
+```
+- type: Combine
+  config:
+    group_by: col1
+    combine:
+      total:
+        value: col2
+        fn: sum
+```
+
+This can be simplified further if the output field name is the same as the 
input
+field name
+
+```
+- type: Combine
+  config:
+    group_by: col1
+    combine:
+      col2: sum
+```
+
+One can aggregate over may fields at once
+
+```
+- type: Combine
+  config:
+    group_by: col1
+    combine:
+      col2: sum
+      col3: max
+```
+
+and/or group by more than one field
+
+```
+- type: Combine
+  config:
+    group_by: [col1, col2]
+    combine:
+      col3: sum
+```
+
+or none at all (which will result in a global combine with a single output)
+
+```
+- type: Combine
+  config:
+    group_by: []
+    combine:
+      col2: sum
+      col3: max
+```
+
+## Windowed aggregation
+
+As with all transforms, `Combine` can take a windowing parameter
+
+```
+- type: Combine
+  windowing:
+    type: fixed
+    size: 60
+  config:
+    group_by: col1
+    combine:
+      col2: sum
+      col3: max
+```
+
+If no windowing specification is provided, it inherits the windowing
+parameters from upstream, e.g.
+
+```
+- type: WindowInto
+  windowing:
+    type: fixed
+    size: 60
+- type: Combine
+  config:
+    group_by: col1
+    combine:
+      col2: sum
+      col3: max
+```
+
+is equivalent to the previous example.
+
+
+## Custom aggregation functions
+
+One can use aggregation functions defined in Python by setting the language
+parameter.
+
+```
+- type: Combine
+  config:
+    language: python
+    group_by: col1
+    combine:
+      biggest:
+        value: "col2 + col2"
+        fn:
+          type: 'apache_beam.transforms.combiners.TopCombineFn'
+          config:
+            n: 10
+```
+
+## SQL-style aggregations
+
+By setting the language to SQL, one can provide full SQL snippets as the
+combine fn.
+
+```
+- type: Combine
+  config:
+    language: sql
+    group_by: col1
+    combine:
+      num_values: "count(*)"
+      total: "sum(col2)"
+```
+
+One can of course also use the `Sql` transform type and provide a query
+directly.
diff --git a/sdks/python/apache_beam/yaml/yaml_combine.py 
b/sdks/python/apache_beam/yaml/yaml_combine.py
new file mode 100644
index 00000000000..ef4974cff35
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_combine.py
@@ -0,0 +1,205 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""This module defines the basic Combine operation."""
+
+from typing import Any
+from typing import Iterable
+from typing import Mapping
+from typing import Optional
+
+import apache_beam as beam
+from apache_beam import typehints
+from apache_beam.typehints import row_type
+from apache_beam.typehints import trivial_inference
+from apache_beam.typehints.decorators import get_type_hints
+from apache_beam.typehints.schemas import named_fields_from_element_type
+from apache_beam.utils import python_callable
+from apache_beam.yaml import options
+from apache_beam.yaml import yaml_mapping
+from apache_beam.yaml import yaml_provider
+
+BUILTIN_COMBINE_FNS = {
+    'sum': sum,
+    'max': max,
+    'min': min,
+    'all': all,
+    'any': any,
+    'mean': beam.transforms.combiners.MeanCombineFn(),
+    'count': beam.transforms.combiners.CountCombineFn(),
+}
+
+
+def normalize_combine(spec):
+  """Expands various shorthand specs for combine (which can otherwise be quite
+  verbose for simple cases.)  We do this here so that it doesn't need to be 
done
+  per language.  The following are all equivalent::
+
+      dest: fn_type
+
+      dest:
+        value: dest
+        fn: fn_type
+
+      dest:
+        value: dest
+        fn:
+          type: fn_type
+  """
+  from apache_beam.yaml.yaml_transform import SafeLineLoader
+  if spec['type'] == 'Combine':
+    config = spec.get('config')
+    if isinstance(config.get('group_by'), str):
+      config['group_by'] = [config['group_by']]
+
+    def normalize_agg(dest, agg):
+      if isinstance(agg, str):
+        agg = {'fn': agg}
+      if 'value' not in agg and spec.get('language') != 'sql':
+        agg['value'] = dest
+      if isinstance(agg['fn'], str):
+        agg['fn'] = {'type': agg['fn']}
+      return agg
+
+    if 'combine' not in config:
+      raise ValueError('Missing combine parameter in Combine config.')
+    config['combine'] = {
+        dest: normalize_agg(dest, agg)
+        for (dest,
+             agg) in SafeLineLoader.strip_metadata(config['combine']).items()
+    }
+  return spec
+
+
+class PyJsYamlCombine(beam.PTransform):
+  def __init__(
+      self,
+      group_by: Iterable[str],
+      combine: Mapping[str, Mapping[str, Any]],
+      language: Optional[str] = None):
+    self._group_by = group_by
+    self._combine = combine
+    self._language = language
+
+  def expand(self, pcoll):
+    options.YamlOptions.check_enabled(pcoll.pipeline, 'Combine')
+    input_types = dict(named_fields_from_element_type(pcoll.element_type))
+    all_fields = list(input_types.keys())
+    unknown_keys = set(self._group_by) - set(all_fields)
+    if unknown_keys:
+      raise ValueError(f'Unknown grouping columns: {list(unknown_keys)}')
+
+    def create_combine_fn(fn_spec):
+      if 'type' not in fn_spec:
+        raise ValueError(f'CombineFn spec missing type: {fn_spec}')
+      elif fn_spec['type'] in BUILTIN_COMBINE_FNS:
+        return BUILTIN_COMBINE_FNS[fn_spec['type']]
+      elif self._language == 'python':
+        # TODO(yaml): Support output_type here as well.
+        fn = python_callable.PythonCallableWithSource.load_from_source(
+            fn_spec['type'])
+        if 'config' in fn_spec:
+          fn = fn(**fn_spec['config'])
+        return fn
+      else:
+        raise TypeError('Unknown CombineFn: {fn_spec}')
+
+    def extract_return_type(expr):
+      if isinstance(expr, str) and expr in input_types:
+        return input_types[expr]
+      expr_hints = get_type_hints(expr)
+      if (expr_hints and expr_hints.has_simple_output_type() and
+          expr_hints.simple_output_type(None) != typehints.Any):
+        return expr_hints.simple_output_type(None)
+      elif callable(expr):
+        return trivial_inference.infer_return_type(expr, [pcoll.element_type])
+      else:
+        return Any
+
+    # TODO(yaml): Support error handling.
+    transform = beam.GroupBy(*self._group_by)
+    output_types = [(k, input_types[k]) for k in self._group_by]
+
+    for output, agg in self._combine.items():
+      expr = yaml_mapping._as_callable(
+          all_fields, agg['value'], 'Combine', self._language)
+      fn = create_combine_fn(agg['fn'])
+      transform = transform.aggregate_field(expr, fn, output)
+
+      # TODO(yaml): See if this logic can be pushed into GroupBy itself.
+      expr_type = extract_return_type(expr)
+      print('expr', expr, 'expr_type', expr_type)
+      if isinstance(fn, beam.CombineFn):
+        # TODO(yaml): Better inference on CombineFns whose outputs types are
+        # functions of their input types
+        combined_type = extract_return_type(fn)
+      elif fn in (sum, min, max):
+        combined_type = expr_type
+      elif fn in (any, all):
+        combined_type = bool
+      else:
+        combined_type = Any
+      output_types.append((output, combined_type))
+
+    return pcoll | transform.with_output_types(
+        row_type.RowTypeConstraint.from_fields(output_types))
+
+
[email protected]_fn
+def _SqlCombineTransform(
+    pcoll, sql_transform_constructor, group_by, combine, language=None):
+  options.YamlOptions.check_enabled(pcoll.pipeline, 'Combine')
+  all_fields = [
+      x for x, _ in named_fields_from_element_type(pcoll.element_type)
+  ]
+  unknown_keys = set(group_by) - set(all_fields)
+  if unknown_keys:
+    raise ValueError(f'Unknown grouping columns: {list(unknown_keys)}')
+
+  def combine_col(dest, fn_spec):
+    if 'value' in fn_spec or 'config' in fn_spec['fn']:
+      expr = '%s(%s)' % (
+          fn_spec['fn']['type'],
+          ', '.join([fn_spec['value']] +
+                    list(fn_spec['fn'].get('config', {}).values())))
+    else:
+      expr = fn_spec['fn']['type']
+    return f'{expr} as {dest}'
+
+  return pcoll | sql_transform_constructor(
+      'SELECT %s FROM PCOLLECTION GROUP BY %s' % (
+          ', '.join(
+              list(group_by) +
+              [combine_col(dest, fn_spec)
+               for dest, fn_spec in combine.items()]),
+          ', '.join(group_by),
+      ))
+
+
+def create_combine_providers():
+  return [
+      yaml_provider.InlineProvider({
+          'Combine-generic': PyJsYamlCombine,
+          'Combine-python': PyJsYamlCombine,
+          'Combine-javascript': PyJsYamlCombine,
+      }),
+      yaml_provider.SqlBackedProvider({
+          'Combine-generic': _SqlCombineTransform,
+          'Combine-sql': _SqlCombineTransform,
+          'Combine-calcite': _SqlCombineTransform,
+      }),
+  ]
diff --git a/sdks/python/apache_beam/yaml/yaml_combine_test.py 
b/sdks/python/apache_beam/yaml/yaml_combine_test.py
new file mode 100644
index 00000000000..ef696c89379
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_combine_test.py
@@ -0,0 +1,173 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+
+import apache_beam as beam
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.yaml.yaml_transform import YamlTransform
+
+DATA = [
+    beam.Row(a='x', b=1, c=101),
+    beam.Row(a='x', b=1, c=102),
+    beam.Row(a='y', b=1, c=103),
+    beam.Row(a='y', b=2, c=104),
+]
+
+
+class YamlCombineTest(unittest.TestCase):
+  def test_multiple_aggregations(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle', yaml_experimental_features=['Combine'
+                                                                  ])) as p:
+      elements = p | beam.Create(DATA)
+      result = elements | YamlTransform(
+          '''
+          type: Combine
+          config:
+            group_by: a
+            combine:
+              b: sum
+              c: max
+          ''')
+      assert_that(
+          result | beam.Map(lambda x: beam.Row(**x._asdict())),
+          equal_to([
+              beam.Row(a='x', b=2, c=102),
+              beam.Row(a='y', b=3, c=104),
+          ]))
+
+  def test_multiple_keys(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle', yaml_experimental_features=['Combine'
+                                                                  ])) as p:
+      elements = p | beam.Create(DATA)
+      result = elements | YamlTransform(
+          '''
+          type: Combine
+          config:
+            group_by: [a, b]
+            combine:
+              c: sum
+          ''')
+      assert_that(
+          result | beam.Map(lambda x: beam.Row(**x._asdict())),
+          equal_to([
+              beam.Row(a='x', b=1, c=203),
+              beam.Row(a='y', b=1, c=103),
+              beam.Row(a='y', b=2, c=104),
+          ]))
+
+  def test_no_keys(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle', yaml_experimental_features=['Combine'
+                                                                  ])) as p:
+      elements = p | beam.Create(DATA)
+      result = elements | YamlTransform(
+          '''
+          type: Combine
+          config:
+            group_by: []
+            combine:
+              c: sum
+          ''')
+      assert_that(
+          result | beam.Map(lambda x: beam.Row(**x._asdict())),
+          equal_to([
+              beam.Row(c=410),
+          ]))
+
+  def test_multiple_combines(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle', yaml_experimental_features=['Combine'
+                                                                  ])) as p:
+      elements = p | beam.Create(DATA)
+      result = elements | YamlTransform(
+          '''
+          type: Combine
+          config:
+            group_by: a
+            combine:
+              min_c:
+                fn: min
+                value: c
+              max_c:
+                fn: max
+                value: c
+          ''')
+      assert_that(
+          result | beam.Map(lambda x: beam.Row(**x._asdict())),
+          equal_to([
+              beam.Row(a='x', min_c=101, max_c=102),
+              beam.Row(a='y', min_c=103, max_c=104),
+          ]))
+
+  def test_expression(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle', yaml_experimental_features=['Combine'
+                                                                  ])) as p:
+      elements = p | beam.Create(DATA)
+      result = elements | YamlTransform(
+          '''
+          type: Combine
+          config:
+            language: python
+            group_by: a
+            combine:
+              max:
+                fn: max
+                value: b + c
+          ''')
+      assert_that(
+          result | beam.Map(lambda x: beam.Row(**x._asdict())),
+          equal_to([
+              beam.Row(a='x', max=103),
+              beam.Row(a='y', max=106),
+          ]))
+
+  def test_config(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle', yaml_experimental_features=['Combine'
+                                                                  ])) as p:
+      elements = p | beam.Create(DATA)
+      result = elements | YamlTransform(
+          '''
+          type: Combine
+          config:
+            language: python
+            group_by: b
+            combine:
+              biggest:
+                fn:
+                  type: 'apache_beam.transforms.combiners.TopCombineFn'
+                  config:
+                    n: 2
+                value: c
+          ''')
+      assert_that(
+          result | beam.Map(lambda x: beam.Row(**x._asdict())),
+          equal_to([
+              beam.Row(b=1, biggest=[103, 102]),
+              beam.Row(b=2, biggest=[104]),
+          ]))
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py 
b/sdks/python/apache_beam/yaml/yaml_mapping.py
index e217ab28584..501c7a5c57b 100644
--- a/sdks/python/apache_beam/yaml/yaml_mapping.py
+++ b/sdks/python/apache_beam/yaml/yaml_mapping.py
@@ -21,7 +21,6 @@ from typing import Any
 from typing import Callable
 from typing import Collection
 from typing import Dict
-from typing import Iterable
 from typing import Mapping
 from typing import Optional
 from typing import Union
@@ -399,62 +398,14 @@ def _PyJsMapToFields(pcoll, language='generic', 
**mapping_args):
       })
 
 
-class SqlMappingProvider(yaml_provider.Provider):
-  def __init__(self, sql_provider=None):
-    if sql_provider is None:
-      sql_provider = yaml_provider.beam_jar(
-          urns={'Sql': 'beam:external:java:sql:v1'},
-          gradle_target='sdks:java:extensions:sql:expansion-service:shadowJar')
-    self._sql_provider = sql_provider
-
-  def available(self):
-    return self._sql_provider.available()
-
-  def cache_artifacts(self):
-    return self._sql_provider.cache_artifacts()
-
-  def provided_transforms(self) -> Iterable[str]:
-    return [
-        'Filter-sql',
-        'Filter-calcite',
-        'MapToFields-sql',
-        'MapToFields-calcite'
-    ]
-
-  def create_transform(
-      self,
-      typ: str,
-      args: Mapping[str, Any],
-      yaml_create_transform: Callable[
-          [Mapping[str, Any], Iterable[beam.PCollection]], beam.PTransform]
-  ) -> beam.PTransform:
-    if typ.startswith('Filter-'):
-      return _SqlFilterTransform(
-          self._sql_provider, yaml_create_transform, **args)
-    if typ.startswith('MapToFields-'):
-      return _SqlMapToFieldsTransform(
-          self._sql_provider, yaml_create_transform, **args)
-    else:
-      raise NotImplementedError(typ)
-
-  def underlying_provider(self):
-    return self._sql_provider
-
-  def to_json(self):
-    return {'type': "SqlMappingProvider"}
-
-
 @beam.ptransform.ptransform_fn
-def _SqlFilterTransform(
-    pcoll, sql_provider, yaml_create_transform, keep, language):
-  return pcoll | sql_provider.create_transform(
-      'Sql', {'query': f'SELECT * FROM PCOLLECTION WHERE {keep}'},
-      yaml_create_transform)
+def _SqlFilterTransform(pcoll, sql_transform_constructor, keep, language):
+  return pcoll | sql_transform_constructor(
+      f'SELECT * FROM PCOLLECTION WHERE {keep}')
 
 
 @beam.ptransform.ptransform_fn
-def _SqlMapToFieldsTransform(
-    pcoll, sql_provider, yaml_create_transform, **mapping_args):
+def _SqlMapToFieldsTransform(pcoll, sql_transform_constructor, **mapping_args):
   _, fields = normalize_fields(pcoll, **mapping_args)
 
   def extract_expr(name, v):
@@ -470,8 +421,7 @@ def _SqlMapToFieldsTransform(
       for (name, expr) in fields.items()
   ]
   query = "SELECT " + ", ".join(selects) + " FROM PCOLLECTION"
-  return pcoll | sql_provider.create_transform(
-      'Sql', {'query': query}, yaml_create_transform)
+  return pcoll | sql_transform_constructor(query)
 
 
 def create_mapping_providers():
@@ -487,5 +437,10 @@ def create_mapping_providers():
           'MapToFields-javascript': _PyJsMapToFields,
           'MapToFields-generic': _PyJsMapToFields,
       }),
-      SqlMappingProvider(),
+      yaml_provider.SqlBackedProvider({
+          'Filter-sql': _SqlFilterTransform,
+          'Filter-calcite': _SqlFilterTransform,
+          'MapToFields-sql': _SqlMapToFieldsTransform,
+          'MapToFields-calcite': _SqlMapToFieldsTransform,
+      }),
   ]
diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py 
b/sdks/python/apache_beam/yaml/yaml_provider.py
index 33c16380ece..a21db09b50c 100644
--- a/sdks/python/apache_beam/yaml/yaml_provider.py
+++ b/sdks/python/apache_beam/yaml/yaml_provider.py
@@ -453,6 +453,45 @@ class MetaInlineProvider(InlineProvider):
     return self._transform_factories[type](yaml_create_transform, **args)
 
 
+class SqlBackedProvider(Provider):
+  def __init__(
+      self,
+      transforms: Mapping[str, Callable[..., beam.PTransform]],
+      sql_provider: Optional[Provider] = None):
+    self._transforms = transforms
+    if sql_provider is None:
+      sql_provider = beam_jar(
+          urns={'Sql': 'beam:external:java:sql:v1'},
+          gradle_target='sdks:java:extensions:sql:expansion-service:shadowJar')
+    self._sql_provider = sql_provider
+
+  def sql_provider(self):
+    return self._sql_provider
+
+  def provided_transforms(self):
+    return self._transforms.keys()
+
+  def available(self):
+    return self.sql_provider().available()
+
+  def cache_artifacts(self):
+    return self.sql_provider().cache_artifacts()
+
+  def underlying_provider(self):
+    return self.sql_provider()
+
+  def to_json(self):
+    return {'type': "SqlBackedProvider"}
+
+  def create_transform(
+      self, typ: str, args: Mapping[str, Any],
+      yaml_create_transform: Any) -> beam.PTransform:
+    return self._transforms[typ](
+        lambda query: self.sql_provider().create_transform(
+            'Sql', {'query': query}, yaml_create_transform),
+        **args)
+
+
 PRIMITIVE_NAMES_TO_ATOMIC_TYPE = {
     py_type.__name__: schema_type
     for (py_type, schema_type) in schemas.PRIMITIVE_TO_ATOMIC_TYPE.items()
@@ -781,6 +820,7 @@ def merge_providers(*provider_sets):
 
 
 def standard_providers():
+  from apache_beam.yaml.yaml_combine import create_combine_providers
   from apache_beam.yaml.yaml_mapping import create_mapping_providers
   from apache_beam.yaml.yaml_io import io_providers
   with open(os.path.join(os.path.dirname(__file__),
@@ -790,6 +830,7 @@ def standard_providers():
   return merge_providers(
       create_builtin_provider(),
       create_mapping_providers(),
+      create_combine_providers(),
       io_providers(),
       parse_providers(standard_providers))
 
diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py 
b/sdks/python/apache_beam/yaml/yaml_transform.py
index 7ab8da33f1a..ca63834e283 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform.py
@@ -34,6 +34,7 @@ from yaml.loader import SafeLoader
 import apache_beam as beam
 from apache_beam.transforms.fully_qualified_named_transform import 
FullyQualifiedNamedTransform
 from apache_beam.yaml import yaml_provider
+from apache_beam.yaml.yaml_combine import normalize_combine
 
 __all__ = ["YamlTransform"]
 
@@ -885,7 +886,7 @@ def preprocess(spec, verbose=False, known_transforms=None):
     return spec
 
   def preprocess_langauges(spec):
-    if spec['type'] in ('Filter', 'MapToFields'):
+    if spec['type'] in ('Filter', 'MapToFields', 'Combine'):
       language = spec.get('config', {}).get('language', 'generic')
       new_type = spec['type'] + '-' + language
       if known_transforms and new_type not in known_transforms:
@@ -900,6 +901,7 @@ def preprocess(spec, verbose=False, known_transforms=None):
 
   for phase in [
       ensure_transforms_have_types,
+      normalize_combine,
       preprocess_langauges,
       ensure_transforms_have_providers,
       preprocess_source_sink,
@@ -951,14 +953,18 @@ class YamlTransform(beam.PTransform):
       root = next(iter(pcolls.values())).pipeline
       if not self._spec['input']:
         self._spec['input'] = {name: name for name in pcolls.keys()}
+    python_provider = yaml_provider.InlineProvider({})
     result = expand_transform(
         self._spec,
         Scope(
             root,
             pcolls,
-            transforms=[],
+            transforms=[self._spec],
             providers=self._providers,
-            input_providers={}))
+            input_providers={
+                pcoll: python_provider
+                for pcoll in pcolls.values()
+            }))
     if len(result) == 1:
       return only_element(result.values())
     else:

Reply via email to