This is an automated email from the ASF dual-hosted git repository.
robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new ada507dc3b3 Implement basic mapping capabilities for YAML. (#27096)
ada507dc3b3 is described below
commit ada507dc3b389121a58ca1b3765ef2853c786f4e
Author: Robert Bradshaw <[email protected]>
AuthorDate: Thu Jun 15 10:40:26 2023 -0700
Implement basic mapping capabilities for YAML. (#27096)
---
sdks/python/apache_beam/typehints/schemas.py | 8 +-
sdks/python/apache_beam/yaml/readme_test.py | 43 ++--
sdks/python/apache_beam/yaml/yaml_mapping.md | 196 +++++++++++++++++
sdks/python/apache_beam/yaml/yaml_mapping.py | 255 ++++++++++++++++++++++
sdks/python/apache_beam/yaml/yaml_mapping_test.py | 138 ++++++++++++
sdks/python/apache_beam/yaml/yaml_provider.py | 32 ++-
sdks/python/apache_beam/yaml/yaml_transform.py | 3 +-
7 files changed, 652 insertions(+), 23 deletions(-)
diff --git a/sdks/python/apache_beam/typehints/schemas.py
b/sdks/python/apache_beam/typehints/schemas.py
index 2f5c366b4cb..156b877d07e 100644
--- a/sdks/python/apache_beam/typehints/schemas.py
+++ b/sdks/python/apache_beam/typehints/schemas.py
@@ -291,7 +291,13 @@ class SchemaTranslation(object):
result.nullable = True
return result
- elif _safe_issubclass(type_, Sequence):
+ elif type_ == range:
+ return schema_pb2.FieldType(
+ array_type=schema_pb2.ArrayType(
+ element_type=schema_pb2.FieldType(
+ atomic_type=PRIMITIVE_TO_ATOMIC_TYPE[int])))
+
+ elif _safe_issubclass(type_, Sequence) and not _safe_issubclass(type_,
str):
element_type = self.typing_to_runner_api(_get_args(type_)[0])
return schema_pb2.FieldType(
array_type=schema_pb2.ArrayType(element_type=element_type))
diff --git a/sdks/python/apache_beam/yaml/readme_test.py
b/sdks/python/apache_beam/yaml/readme_test.py
index 1c014a4280f..26df760f285 100644
--- a/sdks/python/apache_beam/yaml/readme_test.py
+++ b/sdks/python/apache_beam/yaml/readme_test.py
@@ -68,6 +68,8 @@ class FakeSql(beam.PTransform):
typ = float
else:
typ = str
+ elif '+' in expr:
+ typ = float
else:
part = parts[0]
if '.' in part:
@@ -172,6 +174,8 @@ def replace_recursive(spec, transform_type, arg_name,
arg_value):
def create_test_method(test_type, test_name, test_yaml):
+ test_yaml = test_yaml.replace('pkg.module.fn', 'str')
+
def test(self):
with TestEnvironment() as env:
spec = yaml.load(test_yaml, Loader=SafeLoader)
@@ -202,6 +206,7 @@ def create_test_method(test_type, test_name, test_yaml):
def parse_test_methods(markdown_lines):
+ # pylint: disable=too-many-nested-blocks
code_lines = None
for ix, line in enumerate(markdown_lines):
line = line.rstrip()
@@ -211,26 +216,38 @@ def parse_test_methods(markdown_lines):
test_type = 'RUN'
test_name = f'test_line_{ix + 2}'
else:
- if code_lines and code_lines[0] == 'pipeline:':
- yaml_pipeline = '\n'.join(code_lines)
- if 'providers:' in yaml_pipeline:
- test_type = 'PARSE'
- yield test_name, create_test_method(
- test_type,
- test_name,
- yaml_pipeline)
+ if code_lines:
+ if code_lines[0].startswith('- type:'):
+ # Treat this as a fragment of a larger pipeline.
+ code_lines = [
+ 'pipeline:',
+ ' type: chain',
+ ' transforms:',
+ ' - type: ReadFromCsv',
+ ' path: whatever',
+ ] + [
+ ' ' + line for line in code_lines
+ ] # pylint: disable=not-an-iterable
+ if code_lines[0] == 'pipeline:':
+ yaml_pipeline = '\n'.join(code_lines)
+ if 'providers:' in yaml_pipeline:
+ test_type = 'PARSE'
+ yield test_name, create_test_method(
+ test_type,
+ test_name,
+ yaml_pipeline)
code_lines = None
elif code_lines is not None:
code_lines.append(line)
-def createTestSuite():
- with open(os.path.join(os.path.dirname(__file__), 'README.md')) as readme:
- return type(
- 'ReadMeTest', (unittest.TestCase, ), dict(parse_test_methods(readme)))
+def createTestSuite(name, path):
+ with open(path) as readme:
+ return type(name, (unittest.TestCase, ), dict(parse_test_methods(readme)))
-ReadMeTest = createTestSuite()
+ReadMeTest = createTestSuite(
+ 'ReadMeTest', os.path.join(os.path.dirname(__file__), 'README.md'))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.md
b/sdks/python/apache_beam/yaml/yaml_mapping.md
new file mode 100644
index 00000000000..193abac610e
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_mapping.md
@@ -0,0 +1,196 @@
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# Beam YAML mappings
+
+Beam YAML has the ability to do simple transformations which can be used to
+get data into the correct shape. The simplest of these is `MaptoFields`
+which creates records with new fields defined in terms of the input fields.
+
+## Field renames
+
+To rename fields one can write
+
+```
+- type: MapToFields
+ fields:
+ new_col1: col1
+ new_col2: col2
+```
+
+will result in an output where each record has two fields,
+`new_col1` and `new_col2`, whose values are those of `col1` and `col2`
+respectively.
+
+One can specify the append parameter which indicates the original fields should
+be retained similar to the use of `*` in an SQL select statement. For example
+
+```
+- type: MapToFields
+ append: true
+ fields:
+ new_col1: col1
+ new_col2: col2
+```
+
+will output records that have `new_col1` and `new_col2` as *additional*
+fields. When the append field is specified, one can drop fields as well, e.g.
+
+```
+- type: MapToFields
+ append: true
+ drop:
+ - col3
+ fields:
+ new_col1: col1
+ new_col2: col2
+```
+
+which includes all original fiels *except* col3 in addition to outputting the
+two new ones.
+
+
+## Mapping functions
+
+Of course one may want to do transformations beyond just dropping and renaming
+fields. Beam YAML has the ability to inline simple UDFs.
+This requires a language specification. For example
+
+```
+- type: MapToFields
+ language: python
+ fields:
+ new_col: "col1.upper()"
+ another_col: "col2 + col3"
+```
+
+In addition, one can provide a full Python callable that takes the row as an
+argument to do more complex mappings
+(see
[PythonCallableSource](https://beam.apache.org/releases/pydoc/current/apache_beam.utils.python_callable.html#apache_beam.utils.python_callable.PythonCallableWithSource)
+for acceptable formats). Thus one can write
+
+```
+- type: MapToFields
+ language: python
+ fields:
+ new_col:
+ callable: |
+ import re
+ def my_mapping(row):
+ if re.match("[0-9]+", row.col1) and row.col2 > 0:
+ return "good"
+ else:
+ return "bad"
+```
+
+Once one reaches a certain level of complexity, it may be preferable to package
+this up as a dependency and simply refer to it by fully qualified name, e.g.
+
+```
+- type: MapToFields
+ language: python
+ fields:
+ new_col:
+ callable: pkg.module.fn
+```
+
+Currently, in addition to Python, SQL expressions are supported as well
+
+```
+- type: MapToFields
+ language: sql
+ fields:
+ new_col: "UPPER(col1)"
+ another_col: "col2 + col3"
+```
+
+## FlatMap
+
+Sometimes it may be desirable to emit more (or less) than one record for each
+input record. This can be accomplished by mapping to an iterable type and
+noting that the specific field should be exploded, e.g.
+
+```
+- type: MapToFields
+ language: python
+ fields:
+ new_col: "[col1.upper(), col1.lower(), col1.title()]"
+ another_col: "col2 + col3"
+ explode: new_col
+```
+
+will result in three output records for every input record.
+
+If more than one record is to be exploded, one must specify whether the cross
+product over all fields should be taken. For example
+
+```
+- type: MapToFields
+ language: python
+ fields:
+ new_col: "[col1.upper(), col1.lower(), col1.title()]"
+ another_col: "[col2 - 1, col2, col2 + 1]"
+ explode: [new_col, another_col]
+ cross_product: true
+```
+
+will emit nine records whereas
+
+```
+- type: MapToFields
+ language: python
+ fields:
+ new_col: "[col1.upper(), col1.lower(), col1.title()]"
+ another_col: "[col2 - 1, col2, col2 + 1]"
+ explode: [new_col, another_col]
+ cross_product: false
+```
+
+will only emit three.
+
+If one is only exploding existing fields, a simpler `Explode` transform may be
+used instead
+
+```
+- type: Explode
+ explode: [col1]
+```
+
+## Filtering
+
+Sometimes it can be desirable to only keep records that satisfy a certain
+criteria. This can be accomplished by specifying a keep parameter, e.g.
+
+```
+- type: MapToFields
+ language: python
+ fields:
+ new_col: "col1.upper()"
+ another_col: "col2 + col3"
+ keep: "col2 > 0"
+```
+
+Like explode, there is a simpler `Filter` transform useful when no mapping is
+being done
+
+```
+- type: Filter
+ language: sql
+ keep: "col2 > 0"
+```
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py
b/sdks/python/apache_beam/yaml/yaml_mapping.py
new file mode 100644
index 00000000000..7f959773320
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_mapping.py
@@ -0,0 +1,255 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""This module defines the basic MapToFields operation."""
+
+import itertools
+
+import apache_beam as beam
+from apache_beam.typehints import row_type
+from apache_beam.typehints import trivial_inference
+from apache_beam.typehints.schemas import named_fields_from_element_type
+from apache_beam.utils import python_callable
+from apache_beam.yaml import yaml_provider
+
+
+def _as_callable(original_fields, expr):
+ if expr in original_fields:
+ return expr
+ else:
+ # TODO(yaml): support a type parameter
+ # TODO(yaml): support an imports parameter
+ # TODO(yaml): support a requirements parameter (possibly at a higher level)
+ if isinstance(expr, str):
+ expr = {'expression': expr}
+ if not isinstance(expr, dict):
+ raise ValueError(
+ f"Ambiguous expression type (perhaps missing quoting?): {expr}")
+ elif len(expr) != 1:
+ raise ValueError(f"Ambiguous expression type: {list(expr.keys())}")
+ if 'expression' in expr:
+ # TODO(robertwb): Consider constructing a single callable that takes
+ # the row and returns the new row, rather than invoking (and unpacking)
+ # for each field individually.
+ source = '\n'.join(['def fn(__row__):'] + [
+ f' {name} = __row__.{name}'
+ for name in original_fields if name in expr['expression']
+ ] + [' return (' + expr['expression'] + ')'])
+ elif 'callable' in expr:
+ source = expr['callable']
+ else:
+ raise ValueError(f"Unknown expression type: {list(expr.keys())}")
+ return python_callable.PythonCallableWithSource(source)
+
+
+# TODO(yaml): This should be available in all environments, in which case
+# we choose the one that matches best.
+class _Explode(beam.PTransform):
+ def __init__(self, fields, cross_product):
+ self._fields = fields
+ self._cross_product = cross_product
+
+ def expand(self, pcoll):
+ all_fields = [
+ x for x, _ in named_fields_from_element_type(pcoll.element_type)
+ ]
+ to_explode = self._fields
+
+ def explode_cross_product(base, fields):
+ if fields:
+ copy = dict(base)
+ for value in base[fields[0]]:
+ copy[fields[0]] = value
+ yield from explode_cross_product(copy, fields[1:])
+ else:
+ yield beam.Row(**base)
+
+ def explode_zip(base, fields):
+ to_zip = [base[field] for field in fields]
+ copy = dict(base)
+ for values in itertools.zip_longest(*to_zip, fillvalue=None):
+ for ix, field in enumerate(fields):
+ copy[field] = values[ix]
+ yield beam.Row(**copy)
+
+ return pcoll | beam.FlatMap(
+ lambda row: (
+ explode_cross_product if self._cross_product else explode_zip)(
+ {name: getattr(row, name) for name in all_fields}, # yapf break
+ to_explode))
+
+ def infer_output_type(self, input_type):
+ return row_type.RowTypeConstraint.from_fields([(
+ name,
+ trivial_inference.element_type(typ) if name in self._fields else
+ typ) for (name, typ) in named_fields_from_element_type(input_type)])
+
+
+# TODO(yaml): Should Filter and Explode be distinct operations from Project?
+# We'll want these per-language.
[email protected]_fn
+def _PythonProjectionTransform(
+ pcoll, *, fields, keep=None, explode=(), cross_product=True):
+ original_fields = [
+ name for (name, _) in named_fields_from_element_type(pcoll.element_type)
+ ]
+
+ if keep:
+ if isinstance(keep, str) and keep in original_fields:
+ keep_fn = lambda row: getattr(row, keep)
+ else:
+ keep_fn = _as_callable(original_fields, keep)
+ filtered = pcoll | beam.Filter(keep_fn)
+ else:
+ filtered = pcoll
+
+ if list(fields.items()) == [(name, name) for name in original_fields]:
+ projected = filtered
+ else:
+ projected = filtered | beam.Select(
+ **{
+ name: _as_callable(original_fields, expr)
+ for (name, expr) in fields.items()
+ })
+
+ if explode:
+ result = projected | _Explode(explode, cross_product=cross_product)
+ else:
+ result = projected
+
+ return result
+
+
[email protected]_fn
+def MapToFields(
+ pcoll,
+ yaml_create_transform,
+ *,
+ fields,
+ keep=None,
+ explode=(),
+ cross_product=None,
+ append=False,
+ drop=(),
+ language=None,
+ **language_keywords):
+
+ if isinstance(explode, str):
+ explode = [explode]
+ if cross_product is None:
+ if len(explode) > 1:
+ # TODO(robertwb): Consider if true is an OK default.
+ raise ValueError(
+ 'cross_product must be specified true or false '
+ 'when exploding multiple fields')
+ else:
+ # Doesn't matter.
+ cross_product = True
+
+ input_schema = dict(named_fields_from_element_type(pcoll.element_type))
+ if drop and not append:
+ raise ValueError("Can only drop fields if append is true.")
+ for name in drop:
+ if name not in input_schema:
+ raise ValueError(f'Dropping unknown field "{name}"')
+ for name in explode:
+ if not (name in fields or (append and name in input_schema)):
+ raise ValueError(f'Exploding unknown field "{name}"')
+ if append:
+ for name in fields:
+ if name in input_schema and name not in drop:
+ raise ValueError(f'Redefinition of field "{name}"')
+
+ if append:
+ fields = {
+ **{name: name
+ for name in input_schema.keys() if name not in drop},
+ **fields
+ }
+
+ if language is None:
+ for name, expr in fields.items():
+ if not isinstance(expr, str) or expr not in input_schema:
+ # TODO(robertw): Could consider defaulting to SQL, or another
+ # lowest-common-denominator expression language.
+ raise ValueError("Missing language specification.")
+
+ # We should support this for all languages.
+ language = "python"
+
+ if language in ("sql", "calcite"):
+ selects = [f'{expr} AS {name}' for (name, expr) in fields.items()]
+ query = "SELECT " + ", ".join(selects) + " FROM PCOLLECTION"
+ if keep:
+ query += " WHERE " + keep
+
+ result = pcoll | yaml_create_transform({
+ 'type': 'Sql', 'query': query, **language_keywords
+ })
+ if explode:
+ # TODO(yaml): Implement via unnest.
+ result = result | _Explode(explode, cross_product)
+
+ return result
+
+ elif language == 'python':
+ return pcoll | yaml_create_transform({
+ 'type': 'PyTransform',
+ 'constructor': __name__ + '._PythonProjectionTransform',
+ 'kwargs': {
+ 'fields': fields,
+ 'keep': keep,
+ 'explode': explode,
+ 'cross_product': cross_product,
+ },
+ **language_keywords
+ })
+
+ else:
+ # TODO(yaml): Support javascript expressions and UDFs.
+ # TODO(yaml): Support java by fully qualified name.
+ # TODO(yaml): Maybe support java lambdas?
+ raise ValueError(
+ f'Unknown language: {language}. '
+ 'Supported languages are "sql" (alias calcite) and "python."')
+
+
+def create_mapping_provider():
+ # These are MetaInlineProviders because their expansion is in terms of other
+ # YamlTransforms, but in a way that needs to be deferred until the input
+ # schema is known.
+ return yaml_provider.MetaInlineProvider({
+ 'MapToFields': MapToFields,
+ 'Filter': (
+ lambda yaml_create_transform,
+ keep,
+ **kwargs: MapToFields(
+ yaml_create_transform,
+ keep=keep,
+ fields={},
+ append=True,
+ **kwargs)),
+ 'Explode': (
+ lambda yaml_create_transform,
+ explode,
+ **kwargs: MapToFields(
+ yaml_create_transform,
+ explode=explode,
+ fields={},
+ append=True,
+ **kwargs)),
+ })
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping_test.py
b/sdks/python/apache_beam/yaml/yaml_mapping_test.py
new file mode 100644
index 00000000000..3305ff8b92a
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_mapping_test.py
@@ -0,0 +1,138 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import os
+import unittest
+
+import apache_beam as beam
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.yaml.readme_test import createTestSuite
+from apache_beam.yaml.yaml_transform import YamlTransform
+
+DATA = [
+ beam.Row(label='11a', conductor=11, rank=0),
+ beam.Row(label='37a', conductor=37, rank=1),
+ beam.Row(label='389a', conductor=389, rank=2),
+]
+
+
+class YamlMappingTest(unittest.TestCase):
+ def test_basic(self):
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle')) as p:
+ elements = p | beam.Create(DATA)
+ result = elements | YamlTransform(
+ '''
+ type: MapToFields
+ input: input
+ language: python
+ fields:
+ label: label
+ isogeny: "label[-1]"
+ ''')
+ assert_that(
+ result,
+ equal_to([
+ beam.Row(label='11a', isogeny='a'),
+ beam.Row(label='37a', isogeny='a'),
+ beam.Row(label='389a', isogeny='a'),
+ ]))
+
+ def test_drop(self):
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle')) as p:
+ elements = p | beam.Create(DATA)
+ result = elements | YamlTransform(
+ '''
+ type: MapToFields
+ input: input
+ fields: {}
+ append: true
+ drop: [conductor]
+ ''')
+ assert_that(
+ result,
+ equal_to([
+ beam.Row(label='11a', rank=0),
+ beam.Row(label='37a', rank=1),
+ beam.Row(label='389a', rank=2),
+ ]))
+
+ def test_filter(self):
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle')) as p:
+ elements = p | beam.Create(DATA)
+ result = elements | YamlTransform(
+ '''
+ type: MapToFields
+ input: input
+ language: python
+ fields:
+ label: label
+ keep: "rank > 0"
+ ''')
+ assert_that(
+ result, equal_to([
+ beam.Row(label='37a'),
+ beam.Row(label='389a'),
+ ]))
+
+ def test_explode(self):
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle')) as p:
+ elements = p | beam.Create([
+ beam.Row(a=2, b='abc', c=.25),
+ beam.Row(a=3, b='xy', c=.125),
+ ])
+ result = elements | YamlTransform(
+ '''
+ type: MapToFields
+ input: input
+ language: python
+ append: true
+ fields:
+ range: "range(a)"
+ explode: [range, b]
+ cross_product: true
+ ''')
+ assert_that(
+ result,
+ equal_to([
+ beam.Row(a=2, b='a', c=.25, range=0),
+ beam.Row(a=2, b='a', c=.25, range=1),
+ beam.Row(a=2, b='b', c=.25, range=0),
+ beam.Row(a=2, b='b', c=.25, range=1),
+ beam.Row(a=2, b='c', c=.25, range=0),
+ beam.Row(a=2, b='c', c=.25, range=1),
+ beam.Row(a=3, b='x', c=.125, range=0),
+ beam.Row(a=3, b='x', c=.125, range=1),
+ beam.Row(a=3, b='x', c=.125, range=2),
+ beam.Row(a=3, b='y', c=.125, range=0),
+ beam.Row(a=3, b='y', c=.125, range=1),
+ beam.Row(a=3, b='y', c=.125, range=2),
+ ]))
+
+
+YamlMappingDocTest = createTestSuite(
+ 'YamlMappingDocTest',
+ os.path.join(os.path.dirname(__file__), 'yaml_mapping.md'))
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py
b/sdks/python/apache_beam/yaml/yaml_provider.py
index e3d544781c3..209e2178b8e 100644
--- a/sdks/python/apache_beam/yaml/yaml_provider.py
+++ b/sdks/python/apache_beam/yaml/yaml_provider.py
@@ -27,6 +27,7 @@ import subprocess
import sys
import uuid
from typing import Any
+from typing import Callable
from typing import Iterable
from typing import Mapping
@@ -59,7 +60,11 @@ class Provider:
raise NotImplementedError(type(self))
def create_transform(
- self, typ: str, args: Mapping[str, Any]) -> beam.PTransform:
+ self,
+ typ: str,
+ args: Mapping[str, Any],
+ yaml_create_transform: Callable[[Mapping[str, Any]], beam.PTransform]
+ ) -> beam.PTransform:
"""Creates a PTransform instance for the given transform type and
arguments.
"""
raise NotImplementedError(type(self))
@@ -88,7 +93,7 @@ class ExternalProvider(Provider):
def provided_transforms(self):
return self._urns.keys()
- def create_transform(self, type, args):
+ def create_transform(self, type, args, yaml_create_transform):
if callable(self._service):
self._service = self._service()
if self._schema_transforms is None:
@@ -245,13 +250,18 @@ class InlineProvider(Provider):
def provided_transforms(self):
return self._transform_factories.keys()
- def create_transform(self, type, args):
+ def create_transform(self, type, args, yaml_create_transform):
return self._transform_factories[type](**args)
def to_json(self):
return {'type': "InlineProvider"}
+class MetaInlineProvider(InlineProvider):
+ def create_transform(self, type, args, yaml_create_transform):
+ return self._transform_factories[type](yaml_create_transform, **args)
+
+
PRIMITIVE_NAMES_TO_ATOMIC_TYPE = {
py_type.__name__: schema_type
for (py_type, schema_type) in schemas.PRIMITIVE_TO_ATOMIC_TYPE.items()
@@ -446,17 +456,23 @@ def parse_providers(provider_specs):
def merge_providers(*provider_sets):
result = collections.defaultdict(list)
for provider_set in provider_sets:
+ if isinstance(provider_set, Provider):
+ provider = provider_set
+ provider_set = {
+ transform_type: [provider]
+ for transform_type in provider.provided_transforms()
+ }
for transform_type, providers in provider_set.items():
result[transform_type].extend(providers)
return result
def standard_providers():
- builtin_providers = collections.defaultdict(list)
- builtin_provider = create_builtin_provider()
- for transform_type in builtin_provider.provided_transforms():
- builtin_providers[transform_type].append(builtin_provider)
+ from apache_beam.yaml.yaml_mapping import create_mapping_provider
with open(os.path.join(os.path.dirname(__file__),
'standard_providers.yaml')) as fin:
standard_providers = yaml.load(fin, Loader=SafeLoader)
- return merge_providers(builtin_providers,
parse_providers(standard_providers))
+ return merge_providers(
+ create_builtin_provider(),
+ create_mapping_provider(),
+ parse_providers(standard_providers))
diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py
b/sdks/python/apache_beam/yaml/yaml_transform.py
index f3e605adf03..925aa0d85b9 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform.py
@@ -235,7 +235,8 @@ class Scope(LightweightScope):
real_args = SafeLineLoader.strip_metadata(args)
try:
# pylint: disable=undefined-loop-variable
- ptransform = provider.create_transform(spec['type'], real_args)
+ ptransform = provider.create_transform(
+ spec['type'], real_args, self.create_ptransform)
# TODO(robertwb): Should we have a better API for adding annotations
# than this?
annotations = dict(