This is an automated email from the ASF dual-hosted git repository.
robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new d5fc02479c2 [YAML] Add a basic aggregating transform to Beam Yaml.
(#29167)
d5fc02479c2 is described below
commit d5fc02479c25cec4b98322495c3cb4fd866cc746
Author: Robert Bradshaw <[email protected]>
AuthorDate: Mon Oct 30 17:18:47 2023 -0700
[YAML] Add a basic aggregating transform to Beam Yaml. (#29167)
This follows the spirit of the schema-aware GroupBy operations in
Java, Python, Typescript, etc. Syntactic sugar is provided for
simplified forms.
This is currently guarded as experimental until we've fully vetted
the API.
---
sdks/python/apache_beam/yaml/readme_test.py | 23 ++-
sdks/python/apache_beam/yaml/yaml_combine.md | 166 ++++++++++++++++++
sdks/python/apache_beam/yaml/yaml_combine.py | 205 ++++++++++++++++++++++
sdks/python/apache_beam/yaml/yaml_combine_test.py | 173 ++++++++++++++++++
sdks/python/apache_beam/yaml/yaml_mapping.py | 67 ++-----
sdks/python/apache_beam/yaml/yaml_provider.py | 41 +++++
sdks/python/apache_beam/yaml/yaml_transform.py | 12 +-
7 files changed, 619 insertions(+), 68 deletions(-)
diff --git a/sdks/python/apache_beam/yaml/readme_test.py
b/sdks/python/apache_beam/yaml/readme_test.py
index d918d18e11d..7f2d193bf35 100644
--- a/sdks/python/apache_beam/yaml/readme_test.py
+++ b/sdks/python/apache_beam/yaml/readme_test.py
@@ -26,13 +26,13 @@ import sys
import tempfile
import unittest
+import mock
import yaml
from yaml.loader import SafeLoader
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.typehints import trivial_inference
-from apache_beam.yaml import yaml_mapping
from apache_beam.yaml import yaml_provider
from apache_beam.yaml import yaml_transform
@@ -200,7 +200,10 @@ def create_test_method(test_type, test_name, test_yaml):
if write in test_yaml:
spec = replace_recursive(spec, write, 'path', env.output_file())
modified_yaml = yaml.dump(spec)
- options = {'pickle_library': 'cloudpickle'}
+ options = {
+ 'pickle_library': 'cloudpickle',
+ 'yaml_experimental_features': ['Combine']
+ }
if RENDER_DIR is not None:
options['runner'] = 'apache_beam.runners.render.RenderRunner'
options['render_output'] = [
@@ -208,13 +211,12 @@ def create_test_method(test_type, test_name, test_yaml):
]
options['render_leaf_composite_nodes'] = ['.*']
test_provider = TestProvider(TEST_TRANSFORMS)
- test_sql_mapping_provider =
yaml_mapping.SqlMappingProvider(test_provider)
- p = beam.Pipeline(options=PipelineOptions(**options))
- yaml_transform.expand_pipeline(
- p,
- modified_yaml,
- yaml_provider.merge_providers(
- [test_provider, test_sql_mapping_provider]))
+ with mock.patch(
+ 'apache_beam.yaml.yaml_provider.SqlBackedProvider.sql_provider',
+ lambda self: test_provider):
+ p = beam.Pipeline(options=PipelineOptions(**options))
+ yaml_transform.expand_pipeline(
+ p, modified_yaml, yaml_provider.merge_providers([test_provider]))
if test_type == 'BUILD':
return
p.run().wait_until_finish()
@@ -270,6 +272,9 @@ ErrorHandlingTest = createTestSuite(
'ErrorHandlingTest',
os.path.join(os.path.dirname(__file__), 'yaml_errors.md'))
+CombineTest = createTestSuite(
+ 'CombineTest', os.path.join(os.path.dirname(__file__), 'yaml_combine.md'))
+
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--render_dir', default=None)
diff --git a/sdks/python/apache_beam/yaml/yaml_combine.md
b/sdks/python/apache_beam/yaml/yaml_combine.md
new file mode 100644
index 00000000000..e2fef304fb0
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_combine.md
@@ -0,0 +1,166 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# Beam YAML Aggregations
+
+Beam YAML has EXPERIMENTAL ability to do aggregations to group and combine
+values across records. The is accomplished via the `Combine` transform type.
+Currently `Combine` needs to be in the `yaml_experimental_features`
+option to use this transform.
+
+For example, one can write
+
+```
+- type: Combine
+ config:
+ group_by: col1
+ combine:
+ total:
+ value: col2
+ fn:
+ type: sum
+```
+
+If the function has no configuration requirements, it can be provided directly
+as a string
+
+```
+- type: Combine
+ config:
+ group_by: col1
+ combine:
+ total:
+ value: col2
+ fn: sum
+```
+
+This can be simplified further if the output field name is the same as the
input
+field name
+
+```
+- type: Combine
+ config:
+ group_by: col1
+ combine:
+ col2: sum
+```
+
+One can aggregate over may fields at once
+
+```
+- type: Combine
+ config:
+ group_by: col1
+ combine:
+ col2: sum
+ col3: max
+```
+
+and/or group by more than one field
+
+```
+- type: Combine
+ config:
+ group_by: [col1, col2]
+ combine:
+ col3: sum
+```
+
+or none at all (which will result in a global combine with a single output)
+
+```
+- type: Combine
+ config:
+ group_by: []
+ combine:
+ col2: sum
+ col3: max
+```
+
+## Windowed aggregation
+
+As with all transforms, `Combine` can take a windowing parameter
+
+```
+- type: Combine
+ windowing:
+ type: fixed
+ size: 60
+ config:
+ group_by: col1
+ combine:
+ col2: sum
+ col3: max
+```
+
+If no windowing specification is provided, it inherits the windowing
+parameters from upstream, e.g.
+
+```
+- type: WindowInto
+ windowing:
+ type: fixed
+ size: 60
+- type: Combine
+ config:
+ group_by: col1
+ combine:
+ col2: sum
+ col3: max
+```
+
+is equivalent to the previous example.
+
+
+## Custom aggregation functions
+
+One can use aggregation functions defined in Python by setting the language
+parameter.
+
+```
+- type: Combine
+ config:
+ language: python
+ group_by: col1
+ combine:
+ biggest:
+ value: "col2 + col2"
+ fn:
+ type: 'apache_beam.transforms.combiners.TopCombineFn'
+ config:
+ n: 10
+```
+
+## SQL-style aggregations
+
+By setting the language to SQL, one can provide full SQL snippets as the
+combine fn.
+
+```
+- type: Combine
+ config:
+ language: sql
+ group_by: col1
+ combine:
+ num_values: "count(*)"
+ total: "sum(col2)"
+```
+
+One can of course also use the `Sql` transform type and provide a query
+directly.
diff --git a/sdks/python/apache_beam/yaml/yaml_combine.py
b/sdks/python/apache_beam/yaml/yaml_combine.py
new file mode 100644
index 00000000000..ef4974cff35
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_combine.py
@@ -0,0 +1,205 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""This module defines the basic Combine operation."""
+
+from typing import Any
+from typing import Iterable
+from typing import Mapping
+from typing import Optional
+
+import apache_beam as beam
+from apache_beam import typehints
+from apache_beam.typehints import row_type
+from apache_beam.typehints import trivial_inference
+from apache_beam.typehints.decorators import get_type_hints
+from apache_beam.typehints.schemas import named_fields_from_element_type
+from apache_beam.utils import python_callable
+from apache_beam.yaml import options
+from apache_beam.yaml import yaml_mapping
+from apache_beam.yaml import yaml_provider
+
+BUILTIN_COMBINE_FNS = {
+ 'sum': sum,
+ 'max': max,
+ 'min': min,
+ 'all': all,
+ 'any': any,
+ 'mean': beam.transforms.combiners.MeanCombineFn(),
+ 'count': beam.transforms.combiners.CountCombineFn(),
+}
+
+
+def normalize_combine(spec):
+ """Expands various shorthand specs for combine (which can otherwise be quite
+ verbose for simple cases.) We do this here so that it doesn't need to be
done
+ per language. The following are all equivalent::
+
+ dest: fn_type
+
+ dest:
+ value: dest
+ fn: fn_type
+
+ dest:
+ value: dest
+ fn:
+ type: fn_type
+ """
+ from apache_beam.yaml.yaml_transform import SafeLineLoader
+ if spec['type'] == 'Combine':
+ config = spec.get('config')
+ if isinstance(config.get('group_by'), str):
+ config['group_by'] = [config['group_by']]
+
+ def normalize_agg(dest, agg):
+ if isinstance(agg, str):
+ agg = {'fn': agg}
+ if 'value' not in agg and spec.get('language') != 'sql':
+ agg['value'] = dest
+ if isinstance(agg['fn'], str):
+ agg['fn'] = {'type': agg['fn']}
+ return agg
+
+ if 'combine' not in config:
+ raise ValueError('Missing combine parameter in Combine config.')
+ config['combine'] = {
+ dest: normalize_agg(dest, agg)
+ for (dest,
+ agg) in SafeLineLoader.strip_metadata(config['combine']).items()
+ }
+ return spec
+
+
+class PyJsYamlCombine(beam.PTransform):
+ def __init__(
+ self,
+ group_by: Iterable[str],
+ combine: Mapping[str, Mapping[str, Any]],
+ language: Optional[str] = None):
+ self._group_by = group_by
+ self._combine = combine
+ self._language = language
+
+ def expand(self, pcoll):
+ options.YamlOptions.check_enabled(pcoll.pipeline, 'Combine')
+ input_types = dict(named_fields_from_element_type(pcoll.element_type))
+ all_fields = list(input_types.keys())
+ unknown_keys = set(self._group_by) - set(all_fields)
+ if unknown_keys:
+ raise ValueError(f'Unknown grouping columns: {list(unknown_keys)}')
+
+ def create_combine_fn(fn_spec):
+ if 'type' not in fn_spec:
+ raise ValueError(f'CombineFn spec missing type: {fn_spec}')
+ elif fn_spec['type'] in BUILTIN_COMBINE_FNS:
+ return BUILTIN_COMBINE_FNS[fn_spec['type']]
+ elif self._language == 'python':
+ # TODO(yaml): Support output_type here as well.
+ fn = python_callable.PythonCallableWithSource.load_from_source(
+ fn_spec['type'])
+ if 'config' in fn_spec:
+ fn = fn(**fn_spec['config'])
+ return fn
+ else:
+ raise TypeError('Unknown CombineFn: {fn_spec}')
+
+ def extract_return_type(expr):
+ if isinstance(expr, str) and expr in input_types:
+ return input_types[expr]
+ expr_hints = get_type_hints(expr)
+ if (expr_hints and expr_hints.has_simple_output_type() and
+ expr_hints.simple_output_type(None) != typehints.Any):
+ return expr_hints.simple_output_type(None)
+ elif callable(expr):
+ return trivial_inference.infer_return_type(expr, [pcoll.element_type])
+ else:
+ return Any
+
+ # TODO(yaml): Support error handling.
+ transform = beam.GroupBy(*self._group_by)
+ output_types = [(k, input_types[k]) for k in self._group_by]
+
+ for output, agg in self._combine.items():
+ expr = yaml_mapping._as_callable(
+ all_fields, agg['value'], 'Combine', self._language)
+ fn = create_combine_fn(agg['fn'])
+ transform = transform.aggregate_field(expr, fn, output)
+
+ # TODO(yaml): See if this logic can be pushed into GroupBy itself.
+ expr_type = extract_return_type(expr)
+ print('expr', expr, 'expr_type', expr_type)
+ if isinstance(fn, beam.CombineFn):
+ # TODO(yaml): Better inference on CombineFns whose outputs types are
+ # functions of their input types
+ combined_type = extract_return_type(fn)
+ elif fn in (sum, min, max):
+ combined_type = expr_type
+ elif fn in (any, all):
+ combined_type = bool
+ else:
+ combined_type = Any
+ output_types.append((output, combined_type))
+
+ return pcoll | transform.with_output_types(
+ row_type.RowTypeConstraint.from_fields(output_types))
+
+
[email protected]_fn
+def _SqlCombineTransform(
+ pcoll, sql_transform_constructor, group_by, combine, language=None):
+ options.YamlOptions.check_enabled(pcoll.pipeline, 'Combine')
+ all_fields = [
+ x for x, _ in named_fields_from_element_type(pcoll.element_type)
+ ]
+ unknown_keys = set(group_by) - set(all_fields)
+ if unknown_keys:
+ raise ValueError(f'Unknown grouping columns: {list(unknown_keys)}')
+
+ def combine_col(dest, fn_spec):
+ if 'value' in fn_spec or 'config' in fn_spec['fn']:
+ expr = '%s(%s)' % (
+ fn_spec['fn']['type'],
+ ', '.join([fn_spec['value']] +
+ list(fn_spec['fn'].get('config', {}).values())))
+ else:
+ expr = fn_spec['fn']['type']
+ return f'{expr} as {dest}'
+
+ return pcoll | sql_transform_constructor(
+ 'SELECT %s FROM PCOLLECTION GROUP BY %s' % (
+ ', '.join(
+ list(group_by) +
+ [combine_col(dest, fn_spec)
+ for dest, fn_spec in combine.items()]),
+ ', '.join(group_by),
+ ))
+
+
+def create_combine_providers():
+ return [
+ yaml_provider.InlineProvider({
+ 'Combine-generic': PyJsYamlCombine,
+ 'Combine-python': PyJsYamlCombine,
+ 'Combine-javascript': PyJsYamlCombine,
+ }),
+ yaml_provider.SqlBackedProvider({
+ 'Combine-generic': _SqlCombineTransform,
+ 'Combine-sql': _SqlCombineTransform,
+ 'Combine-calcite': _SqlCombineTransform,
+ }),
+ ]
diff --git a/sdks/python/apache_beam/yaml/yaml_combine_test.py
b/sdks/python/apache_beam/yaml/yaml_combine_test.py
new file mode 100644
index 00000000000..ef696c89379
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_combine_test.py
@@ -0,0 +1,173 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+
+import apache_beam as beam
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.yaml.yaml_transform import YamlTransform
+
+DATA = [
+ beam.Row(a='x', b=1, c=101),
+ beam.Row(a='x', b=1, c=102),
+ beam.Row(a='y', b=1, c=103),
+ beam.Row(a='y', b=2, c=104),
+]
+
+
+class YamlCombineTest(unittest.TestCase):
+ def test_multiple_aggregations(self):
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle', yaml_experimental_features=['Combine'
+ ])) as p:
+ elements = p | beam.Create(DATA)
+ result = elements | YamlTransform(
+ '''
+ type: Combine
+ config:
+ group_by: a
+ combine:
+ b: sum
+ c: max
+ ''')
+ assert_that(
+ result | beam.Map(lambda x: beam.Row(**x._asdict())),
+ equal_to([
+ beam.Row(a='x', b=2, c=102),
+ beam.Row(a='y', b=3, c=104),
+ ]))
+
+ def test_multiple_keys(self):
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle', yaml_experimental_features=['Combine'
+ ])) as p:
+ elements = p | beam.Create(DATA)
+ result = elements | YamlTransform(
+ '''
+ type: Combine
+ config:
+ group_by: [a, b]
+ combine:
+ c: sum
+ ''')
+ assert_that(
+ result | beam.Map(lambda x: beam.Row(**x._asdict())),
+ equal_to([
+ beam.Row(a='x', b=1, c=203),
+ beam.Row(a='y', b=1, c=103),
+ beam.Row(a='y', b=2, c=104),
+ ]))
+
+ def test_no_keys(self):
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle', yaml_experimental_features=['Combine'
+ ])) as p:
+ elements = p | beam.Create(DATA)
+ result = elements | YamlTransform(
+ '''
+ type: Combine
+ config:
+ group_by: []
+ combine:
+ c: sum
+ ''')
+ assert_that(
+ result | beam.Map(lambda x: beam.Row(**x._asdict())),
+ equal_to([
+ beam.Row(c=410),
+ ]))
+
+ def test_multiple_combines(self):
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle', yaml_experimental_features=['Combine'
+ ])) as p:
+ elements = p | beam.Create(DATA)
+ result = elements | YamlTransform(
+ '''
+ type: Combine
+ config:
+ group_by: a
+ combine:
+ min_c:
+ fn: min
+ value: c
+ max_c:
+ fn: max
+ value: c
+ ''')
+ assert_that(
+ result | beam.Map(lambda x: beam.Row(**x._asdict())),
+ equal_to([
+ beam.Row(a='x', min_c=101, max_c=102),
+ beam.Row(a='y', min_c=103, max_c=104),
+ ]))
+
+ def test_expression(self):
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle', yaml_experimental_features=['Combine'
+ ])) as p:
+ elements = p | beam.Create(DATA)
+ result = elements | YamlTransform(
+ '''
+ type: Combine
+ config:
+ language: python
+ group_by: a
+ combine:
+ max:
+ fn: max
+ value: b + c
+ ''')
+ assert_that(
+ result | beam.Map(lambda x: beam.Row(**x._asdict())),
+ equal_to([
+ beam.Row(a='x', max=103),
+ beam.Row(a='y', max=106),
+ ]))
+
+ def test_config(self):
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle', yaml_experimental_features=['Combine'
+ ])) as p:
+ elements = p | beam.Create(DATA)
+ result = elements | YamlTransform(
+ '''
+ type: Combine
+ config:
+ language: python
+ group_by: b
+ combine:
+ biggest:
+ fn:
+ type: 'apache_beam.transforms.combiners.TopCombineFn'
+ config:
+ n: 2
+ value: c
+ ''')
+ assert_that(
+ result | beam.Map(lambda x: beam.Row(**x._asdict())),
+ equal_to([
+ beam.Row(b=1, biggest=[103, 102]),
+ beam.Row(b=2, biggest=[104]),
+ ]))
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py
b/sdks/python/apache_beam/yaml/yaml_mapping.py
index e217ab28584..501c7a5c57b 100644
--- a/sdks/python/apache_beam/yaml/yaml_mapping.py
+++ b/sdks/python/apache_beam/yaml/yaml_mapping.py
@@ -21,7 +21,6 @@ from typing import Any
from typing import Callable
from typing import Collection
from typing import Dict
-from typing import Iterable
from typing import Mapping
from typing import Optional
from typing import Union
@@ -399,62 +398,14 @@ def _PyJsMapToFields(pcoll, language='generic',
**mapping_args):
})
-class SqlMappingProvider(yaml_provider.Provider):
- def __init__(self, sql_provider=None):
- if sql_provider is None:
- sql_provider = yaml_provider.beam_jar(
- urns={'Sql': 'beam:external:java:sql:v1'},
- gradle_target='sdks:java:extensions:sql:expansion-service:shadowJar')
- self._sql_provider = sql_provider
-
- def available(self):
- return self._sql_provider.available()
-
- def cache_artifacts(self):
- return self._sql_provider.cache_artifacts()
-
- def provided_transforms(self) -> Iterable[str]:
- return [
- 'Filter-sql',
- 'Filter-calcite',
- 'MapToFields-sql',
- 'MapToFields-calcite'
- ]
-
- def create_transform(
- self,
- typ: str,
- args: Mapping[str, Any],
- yaml_create_transform: Callable[
- [Mapping[str, Any], Iterable[beam.PCollection]], beam.PTransform]
- ) -> beam.PTransform:
- if typ.startswith('Filter-'):
- return _SqlFilterTransform(
- self._sql_provider, yaml_create_transform, **args)
- if typ.startswith('MapToFields-'):
- return _SqlMapToFieldsTransform(
- self._sql_provider, yaml_create_transform, **args)
- else:
- raise NotImplementedError(typ)
-
- def underlying_provider(self):
- return self._sql_provider
-
- def to_json(self):
- return {'type': "SqlMappingProvider"}
-
-
@beam.ptransform.ptransform_fn
-def _SqlFilterTransform(
- pcoll, sql_provider, yaml_create_transform, keep, language):
- return pcoll | sql_provider.create_transform(
- 'Sql', {'query': f'SELECT * FROM PCOLLECTION WHERE {keep}'},
- yaml_create_transform)
+def _SqlFilterTransform(pcoll, sql_transform_constructor, keep, language):
+ return pcoll | sql_transform_constructor(
+ f'SELECT * FROM PCOLLECTION WHERE {keep}')
@beam.ptransform.ptransform_fn
-def _SqlMapToFieldsTransform(
- pcoll, sql_provider, yaml_create_transform, **mapping_args):
+def _SqlMapToFieldsTransform(pcoll, sql_transform_constructor, **mapping_args):
_, fields = normalize_fields(pcoll, **mapping_args)
def extract_expr(name, v):
@@ -470,8 +421,7 @@ def _SqlMapToFieldsTransform(
for (name, expr) in fields.items()
]
query = "SELECT " + ", ".join(selects) + " FROM PCOLLECTION"
- return pcoll | sql_provider.create_transform(
- 'Sql', {'query': query}, yaml_create_transform)
+ return pcoll | sql_transform_constructor(query)
def create_mapping_providers():
@@ -487,5 +437,10 @@ def create_mapping_providers():
'MapToFields-javascript': _PyJsMapToFields,
'MapToFields-generic': _PyJsMapToFields,
}),
- SqlMappingProvider(),
+ yaml_provider.SqlBackedProvider({
+ 'Filter-sql': _SqlFilterTransform,
+ 'Filter-calcite': _SqlFilterTransform,
+ 'MapToFields-sql': _SqlMapToFieldsTransform,
+ 'MapToFields-calcite': _SqlMapToFieldsTransform,
+ }),
]
diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py
b/sdks/python/apache_beam/yaml/yaml_provider.py
index 33c16380ece..a21db09b50c 100644
--- a/sdks/python/apache_beam/yaml/yaml_provider.py
+++ b/sdks/python/apache_beam/yaml/yaml_provider.py
@@ -453,6 +453,45 @@ class MetaInlineProvider(InlineProvider):
return self._transform_factories[type](yaml_create_transform, **args)
+class SqlBackedProvider(Provider):
+ def __init__(
+ self,
+ transforms: Mapping[str, Callable[..., beam.PTransform]],
+ sql_provider: Optional[Provider] = None):
+ self._transforms = transforms
+ if sql_provider is None:
+ sql_provider = beam_jar(
+ urns={'Sql': 'beam:external:java:sql:v1'},
+ gradle_target='sdks:java:extensions:sql:expansion-service:shadowJar')
+ self._sql_provider = sql_provider
+
+ def sql_provider(self):
+ return self._sql_provider
+
+ def provided_transforms(self):
+ return self._transforms.keys()
+
+ def available(self):
+ return self.sql_provider().available()
+
+ def cache_artifacts(self):
+ return self.sql_provider().cache_artifacts()
+
+ def underlying_provider(self):
+ return self.sql_provider()
+
+ def to_json(self):
+ return {'type': "SqlBackedProvider"}
+
+ def create_transform(
+ self, typ: str, args: Mapping[str, Any],
+ yaml_create_transform: Any) -> beam.PTransform:
+ return self._transforms[typ](
+ lambda query: self.sql_provider().create_transform(
+ 'Sql', {'query': query}, yaml_create_transform),
+ **args)
+
+
PRIMITIVE_NAMES_TO_ATOMIC_TYPE = {
py_type.__name__: schema_type
for (py_type, schema_type) in schemas.PRIMITIVE_TO_ATOMIC_TYPE.items()
@@ -781,6 +820,7 @@ def merge_providers(*provider_sets):
def standard_providers():
+ from apache_beam.yaml.yaml_combine import create_combine_providers
from apache_beam.yaml.yaml_mapping import create_mapping_providers
from apache_beam.yaml.yaml_io import io_providers
with open(os.path.join(os.path.dirname(__file__),
@@ -790,6 +830,7 @@ def standard_providers():
return merge_providers(
create_builtin_provider(),
create_mapping_providers(),
+ create_combine_providers(),
io_providers(),
parse_providers(standard_providers))
diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py
b/sdks/python/apache_beam/yaml/yaml_transform.py
index 7ab8da33f1a..ca63834e283 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform.py
@@ -34,6 +34,7 @@ from yaml.loader import SafeLoader
import apache_beam as beam
from apache_beam.transforms.fully_qualified_named_transform import
FullyQualifiedNamedTransform
from apache_beam.yaml import yaml_provider
+from apache_beam.yaml.yaml_combine import normalize_combine
__all__ = ["YamlTransform"]
@@ -885,7 +886,7 @@ def preprocess(spec, verbose=False, known_transforms=None):
return spec
def preprocess_langauges(spec):
- if spec['type'] in ('Filter', 'MapToFields'):
+ if spec['type'] in ('Filter', 'MapToFields', 'Combine'):
language = spec.get('config', {}).get('language', 'generic')
new_type = spec['type'] + '-' + language
if known_transforms and new_type not in known_transforms:
@@ -900,6 +901,7 @@ def preprocess(spec, verbose=False, known_transforms=None):
for phase in [
ensure_transforms_have_types,
+ normalize_combine,
preprocess_langauges,
ensure_transforms_have_providers,
preprocess_source_sink,
@@ -951,14 +953,18 @@ class YamlTransform(beam.PTransform):
root = next(iter(pcolls.values())).pipeline
if not self._spec['input']:
self._spec['input'] = {name: name for name in pcolls.keys()}
+ python_provider = yaml_provider.InlineProvider({})
result = expand_transform(
self._spec,
Scope(
root,
pcolls,
- transforms=[],
+ transforms=[self._spec],
providers=self._providers,
- input_providers={}))
+ input_providers={
+ pcoll: python_provider
+ for pcoll in pcolls.values()
+ }))
if len(result) == 1:
return only_element(result.values())
else: