This is an automated email from the ASF dual-hosted git repository.
robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new c5e6c7962e6 Refactor and cleanup yaml MapToFields. (#28462)
c5e6c7962e6 is described below
commit c5e6c7962e60ee8366bfc92edf0812341e940020
Author: Robert Bradshaw <[email protected]>
AuthorDate: Wed Sep 20 16:37:17 2023 -0700
Refactor and cleanup yaml MapToFields. (#28462)
* Avoid the use of MetaProviders, which was always kind of hacky.
We may want to remove this infrastructure altogether as it
does not play nicely with provider inference.
* Split MapToFields into separate mapping, filtering, and exploding
operations.
* Allow MapToFields to act on non-schema'd PCollections.
The various langauge flavors of these UDFs are now handled by a
preprocessing
step. This will make it easier to extend to other langauges, including
in particular possible multiple (equivalent) implementations of javascript
to
minimize cross-langauge boundary crossings.
---------
Co-authored-by: Danny McCormick <[email protected]>
---
sdks/python/apache_beam/transforms/core.py | 8 +
sdks/python/apache_beam/yaml/readme_test.py | 23 +-
sdks/python/apache_beam/yaml/yaml_mapping.md | 35 +-
sdks/python/apache_beam/yaml/yaml_mapping.py | 367 ++++++++++++---------
sdks/python/apache_beam/yaml/yaml_mapping_test.py | 32 +-
sdks/python/apache_beam/yaml/yaml_provider.py | 7 +-
sdks/python/apache_beam/yaml/yaml_transform.py | 15 +
.../python/apache_beam/yaml/yaml_transform_test.py | 14 +-
sdks/python/apache_beam/yaml/yaml_udf_test.py | 38 ++-
9 files changed, 311 insertions(+), 228 deletions(-)
diff --git a/sdks/python/apache_beam/transforms/core.py
b/sdks/python/apache_beam/transforms/core.py
index 66ac8fbad96..671af54e47b 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -2258,6 +2258,10 @@ class _PValueWithErrors(object):
self._exception_handling_args = exception_handling_args
self._upstream_errors = upstream_errors
+ @property
+ def element_type(self):
+ return self._pcoll.element_type
+
def main_output_tag(self):
return self._exception_handling_args.get('main_tag', 'good')
@@ -2309,6 +2313,10 @@ class _MaybePValueWithErrors(object):
else:
self._pvalue = _PValueWithErrors(pvalue, exception_handling_args)
+ @property
+ def element_type(self):
+ return self._pvalue.element_type
+
def __or__(self, transform):
return self.apply(transform)
diff --git a/sdks/python/apache_beam/yaml/readme_test.py
b/sdks/python/apache_beam/yaml/readme_test.py
index 958d9cb5783..d918d18e11d 100644
--- a/sdks/python/apache_beam/yaml/readme_test.py
+++ b/sdks/python/apache_beam/yaml/readme_test.py
@@ -32,6 +32,7 @@ from yaml.loader import SafeLoader
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.typehints import trivial_inference
+from apache_beam.yaml import yaml_mapping
from apache_beam.yaml import yaml_provider
from apache_beam.yaml import yaml_transform
@@ -85,13 +86,16 @@ class FakeSql(beam.PTransform):
typ, = [t for t in typ.__args__ if t is not type(None)]
return name, typ
- output_schema = [
- guess_name_and_type(expr) for expr in m.group(1).split(',')
- ]
- output_element = beam.Row(**{name: typ() for name, typ in output_schema})
- return next(iter(inputs.values())) | beam.Map(
- lambda _: output_element).with_output_types(
- trivial_inference.instance_to_type(output_element))
+ if m.group(1) == '*':
+ return inputs['PCOLLECTION'] | beam.Filter(lambda _: True)
+ else:
+ output_schema = [
+ guess_name_and_type(expr) for expr in m.group(1).split(',')
+ ]
+ output_element = beam.Row(**{name: typ() for name, typ in output_schema})
+ return next(iter(inputs.values())) | beam.Map(
+ lambda _: output_element).with_output_types(
+ trivial_inference.instance_to_type(output_element))
class FakeReadFromPubSub(beam.PTransform):
@@ -204,12 +208,13 @@ def create_test_method(test_type, test_name, test_yaml):
]
options['render_leaf_composite_nodes'] = ['.*']
test_provider = TestProvider(TEST_TRANSFORMS)
+ test_sql_mapping_provider =
yaml_mapping.SqlMappingProvider(test_provider)
p = beam.Pipeline(options=PipelineOptions(**options))
yaml_transform.expand_pipeline(
p,
modified_yaml,
- {t: test_provider
- for t in test_provider.provided_transforms()})
+ yaml_provider.merge_providers(
+ [test_provider, test_sql_mapping_provider]))
if test_type == 'BUILD':
return
p.run().wait_until_finish()
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.md
b/sdks/python/apache_beam/yaml/yaml_mapping.md
index b5e84e1a805..653b4abe8b8 100644
--- a/sdks/python/apache_beam/yaml/yaml_mapping.md
+++ b/sdks/python/apache_beam/yaml/yaml_mapping.md
@@ -131,7 +131,7 @@ Currently, in addition to Python, SQL expressions are
supported as well
Sometimes it may be desirable to emit more (or less) than one record for each
input record. This can be accomplished by mapping to an iterable type and
-noting that the specific field should be exploded, e.g.
+following the mapping with an Explode operation, e.g.
```
- type: MapToFields
@@ -140,7 +140,9 @@ noting that the specific field should be exploded, e.g.
fields:
new_col: "[col1.upper(), col1.lower(), col1.title()]"
another_col: "col2 + col3"
- explode: new_col
+- type: Explode
+ config:
+ fields: new_col
```
will result in three output records for every input record.
@@ -155,7 +157,9 @@ product over all fields should be taken. For example
fields:
new_col: "[col1.upper(), col1.lower(), col1.title()]"
another_col: "[col2 - 1, col2, col2 + 1]"
- explode: [new_col, another_col]
+- type: Explode
+ config:
+ fields: [new_col, another_col]
cross_product: true
```
@@ -168,38 +172,27 @@ will emit nine records whereas
fields:
new_col: "[col1.upper(), col1.lower(), col1.title()]"
another_col: "[col2 - 1, col2, col2 + 1]"
- explode: [new_col, another_col]
+- type: Explode
+ config:
+ fields: [new_col, another_col]
cross_product: false
```
will only emit three.
-If one is only exploding existing fields, a simpler `Explode` transform may be
-used instead
+The `Explode` operation can be used on its own if the field in question is
+already an iterable type.
```
- type: Explode
config:
- explode: [col1]
+ fields: [col1]
```
## Filtering
Sometimes it can be desirable to only keep records that satisfy a certain
-criteria. This can be accomplished by specifying a keep parameter, e.g.
-
-```
-- type: MapToFields
- config:
- language: python
- fields:
- new_col: "col1.upper()"
- another_col: "col2 + col3"
- keep: "col2 > 0"
-```
-
-Like explode, there is a simpler `Filter` transform useful when no mapping is
-being done
+criteria. This can be accomplished with a `Filter` transform, e.g.
```
- type: Filter
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py
b/sdks/python/apache_beam/yaml/yaml_mapping.py
index b6dea894b3e..221c6f018d6 100644
--- a/sdks/python/apache_beam/yaml/yaml_mapping.py
+++ b/sdks/python/apache_beam/yaml/yaml_mapping.py
@@ -17,6 +17,14 @@
"""This module defines the basic MapToFields operation."""
import itertools
+from typing import Any
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import Iterable
+from typing import Mapping
+from typing import Optional
+from typing import Union
import js2py
@@ -139,18 +147,73 @@ def _as_callable(original_fields, expr, transform_name,
language):
'Supported languages are "javascript" and "python."')
+def exception_handling_args(error_handling_spec):
+ if error_handling_spec:
+ return {
+ 'dead_letter_tag' if k == 'output' else k: v
+ for (k, v) in error_handling_spec.items()
+ }
+ else:
+ return None
+
+
+def _map_errors_to_standard_format():
+ # TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple.
+ return beam.Map(
+ lambda x: beam.Row(element=x[0], msg=str(x[1][1]), stack=str(x[1][2])))
+
+
+def maybe_with_exception_handling(inner_expand):
+ def expand(self, pcoll):
+ wrapped_pcoll = beam.core._MaybePValueWithErrors(
+ pcoll, self._exception_handling_args)
+ return inner_expand(self, wrapped_pcoll).as_result(
+ _map_errors_to_standard_format())
+
+ return expand
+
+
+def maybe_with_exception_handling_transform_fn(transform_fn):
+ def expand(pcoll, error_handling=None, **kwargs):
+ wrapped_pcoll = beam.core._MaybePValueWithErrors(
+ pcoll, exception_handling_args(error_handling))
+ return transform_fn(wrapped_pcoll,
+ **kwargs).as_result(_map_errors_to_standard_format())
+
+ return expand
+
+
# TODO(yaml): This should be available in all environments, in which case
# we choose the one that matches best.
class _Explode(beam.PTransform):
- def __init__(self, fields, cross_product):
+ def __init__(
+ self,
+ fields: Union[str, Collection[str]],
+ cross_product: Optional[bool] = None,
+ error_handling: Optional[Mapping[str, Any]] = None):
+ if isinstance(fields, str):
+ fields = [fields]
+ if cross_product is None:
+ if len(fields) > 1:
+ raise ValueError(
+ 'cross_product must be specified true or false '
+ 'when exploding multiple fields')
+ else:
+ # Doesn't matter.
+ cross_product = True
self._fields = fields
self._cross_product = cross_product
- self._exception_handling_args = None
+ # TODO(yaml): Support standard error handling argument.
+ self._exception_handling_args = exception_handling_args(error_handling)
+ @maybe_with_exception_handling
def expand(self, pcoll):
all_fields = [
x for x, _ in named_fields_from_element_type(pcoll.element_type)
]
+ for field in self._fields:
+ if field not in all_fields:
+ raise ValueError(f'Exploding unknown field "{field}"')
to_explode = self._fields
def explode_cross_product(base, fields):
@@ -171,12 +234,12 @@ class _Explode(beam.PTransform):
yield beam.Row(**copy)
return (
- beam.core._MaybePValueWithErrors(pcoll, self._exception_handling_args)
+ pcoll
| beam.FlatMap(
lambda row:
(explode_cross_product if self._cross_product else explode_zip)
({name: getattr(row, name)
- for name in all_fields}, to_explode))).as_result()
+ for name in all_fields}, to_explode)))
def infer_output_type(self, input_type):
return row_type.RowTypeConstraint.from_fields([(
@@ -190,189 +253,171 @@ class _Explode(beam.PTransform):
return self
-# TODO(yaml): Should Filter and Explode be distinct operations from Project?
-# We'll want these per-language.
@beam.ptransform.ptransform_fn
-def _PythonProjectionTransform(
- pcoll,
- *,
- fields,
- transform_name,
- language,
- keep=None,
- explode=(),
- cross_product=True,
- error_handling=None):
- original_fields = [
- name for (name, _) in named_fields_from_element_type(pcoll.element_type)
- ]
+@maybe_with_exception_handling_transform_fn
+def _PyJsFilter(
+ pcoll, keep: Union[str, Dict[str, str]], language: Optional[str] = None):
- if error_handling is None:
- error_handling_args = None
+ input_schema = dict(named_fields_from_element_type(pcoll.element_type))
+ if isinstance(keep, str) and keep in input_schema:
+ keep_fn = lambda row: getattr(row, keep)
else:
- error_handling_args = {
- 'dead_letter_tag' if k == 'output' else k: v
- for (k, v) in error_handling.items()
- }
+ keep_fn = _as_callable(list(input_schema.keys()), keep, "keep", language)
+ return pcoll | beam.Filter(keep_fn)
- pcoll = beam.core._MaybePValueWithErrors(pcoll, error_handling_args)
- if keep:
- if isinstance(keep, str) and keep in original_fields:
- keep_fn = lambda row: getattr(row, keep)
- else:
- keep_fn = _as_callable(original_fields, keep, transform_name, language)
- filtered = pcoll | beam.Filter(keep_fn)
- else:
- filtered = pcoll
+def is_expr(v):
+ return isinstance(v, str) or (isinstance(v, dict) and 'expression' in v)
- projected = filtered | beam.Select(
- **{
- name: _as_callable(original_fields, expr, transform_name, language)
- for (name, expr) in fields.items()
- })
- if explode:
- result = projected | _Explode(explode, cross_product=cross_product)
- else:
- result = projected
-
- return result.as_result(
- # TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple.
- beam.Map(
- lambda x: beam.Row(
- element=x[0], msg=str(x[1][1]), stack=str(x[1][2]))))
-
-
[email protected]_fn
-def MapToFields(
- pcoll,
- yaml_create_transform,
- *,
- fields,
- keep=None,
- explode=(),
- cross_product=None,
- append=False,
- drop=(),
- language=None,
- error_handling=None,
- transform_name="MapToFields",
- **language_keywords):
- if isinstance(explode, str):
- explode = [explode]
- if cross_product is None:
- if len(explode) > 1:
- # TODO(robertwb): Consider if true is an OK default.
- raise ValueError(
- 'cross_product must be specified true or false '
- 'when exploding multiple fields')
- else:
- # Doesn't matter.
- cross_product = True
+def normalize_fields(pcoll, fields, drop=(), append=False, language='generic'):
+ try:
+ input_schema = dict(named_fields_from_element_type(pcoll.element_type))
+ except ValueError as exn:
+ if drop:
+ raise ValueError("Can only drop fields on a schema'd input.") from exn
+ if append:
+ raise ValueError("Can only append fields on a schema'd input.") from exn
+ elif any(is_expr(x) for x in fields.values()):
+ raise ValueError("Can only use expressions on a schema'd input.") from
exn
+ input_schema = {}
- input_schema = dict(named_fields_from_element_type(pcoll.element_type))
+ if isinstance(drop, str):
+ drop = [drop]
if drop and not append:
raise ValueError("Can only drop fields if append is true.")
for name in drop:
if name not in input_schema:
raise ValueError(f'Dropping unknown field "{name}"')
- for name in explode:
- if not (name in fields or (append and name in input_schema)):
- raise ValueError(f'Exploding unknown field "{name}"')
if append:
for name in fields:
if name in input_schema and name not in drop:
- raise ValueError(f'Redefinition of field "{name}"')
+ raise ValueError(
+ f'Redefinition of field "{name}". '
+ 'Cannot append a field that already exists in original input.')
+
+ if language == 'generic':
+ for expr in fields.values():
+ if not isinstance(expr, str):
+ raise ValueError(
+ "Missing language specification. "
+ "Must specify a language when using a map with custom logic.")
+ missing = set(fields.values()) - set(input_schema.keys())
+ if missing:
+ raise ValueError(
+ f"Missing language specification or unknown input fields: {missing}")
if append:
- fields = {
+ return input_schema, {
**{name: name
for name in input_schema.keys() if name not in drop},
**fields
}
+ else:
+ return input_schema, fields
- if language is None:
- for name, expr in fields.items():
- if not isinstance(expr, str) or expr not in input_schema:
- # TODO(robertw): Could consider defaulting to SQL, or another
- # lowest-common-denominator expression language.
- raise ValueError("Missing language specification.")
-
- # We should support this for all languages.
- language = "python"
-
- if language in ("sql", "calcite"):
- if error_handling:
- raise ValueError('Error handling unsupported for sql.')
- selects = [f'{expr} AS {name}' for (name, expr) in fields.items()]
- query = "SELECT " + ", ".join(selects) + " FROM PCOLLECTION"
- if keep:
- query += " WHERE " + keep
-
- result = pcoll | yaml_create_transform({
- 'type': 'Sql',
- 'config': {
- 'query': query, **language_keywords
- },
- }, [pcoll])
- if explode:
- # TODO(yaml): Implement via unnest.
- result = result | _Explode(explode, cross_product)
-
- return result
-
- elif language == 'python' or language == 'javascript':
- return pcoll | yaml_create_transform({
- 'type': 'PyTransform',
- 'config': {
- 'constructor': __name__ + '._PythonProjectionTransform',
- 'kwargs': {
- 'fields': fields,
- 'transform_name': transform_name,
- 'language': language,
- 'keep': keep,
- 'explode': explode,
- 'cross_product': cross_product,
- 'error_handling': error_handling,
- },
- **language_keywords
- },
- }, [pcoll])
- else:
- # TODO(yaml): Support javascript expressions and UDFs.
- # TODO(yaml): Support java by fully qualified name.
- # TODO(yaml): Maybe support java lambdas?
- raise ValueError(
- f'Unknown language: {language}. '
- 'Supported languages are "sql" (alias calcite) and "python."')
[email protected]_fn
+@maybe_with_exception_handling_transform_fn
+def _PyJsMapToFields(pcoll, language='generic', **mapping_args):
+ input_schema, fields = normalize_fields(
+ pcoll, language=language, **mapping_args)
+ original_fields = list(input_schema.keys())
+
+ return pcoll | beam.Select(
+ **{
+ name: _as_callable(original_fields, expr, name, language)
+ for (name, expr) in fields.items()
+ })
+
+
+class SqlMappingProvider(yaml_provider.Provider):
+ def __init__(self, sql_provider=None):
+ if sql_provider is None:
+ sql_provider = yaml_provider.beam_jar(
+ urns={'Sql': 'beam:external:java:sql:v1'},
+ gradle_target='sdks:java:extensions:sql:expansion-service:shadowJar')
+ self._sql_provider = sql_provider
+
+ def available(self):
+ return self._sql_provider.available()
+
+ def cache_artifacts(self):
+ return self._sql_provider.cache_artifacts()
+
+ def provided_transforms(self) -> Iterable[str]:
+ return [
+ 'Filter-sql',
+ 'Filter-calcite',
+ 'MapToFields-sql',
+ 'MapToFields-calcite'
+ ]
+
+ def create_transform(
+ self,
+ typ: str,
+ args: Mapping[str, Any],
+ yaml_create_transform: Callable[
+ [Mapping[str, Any], Iterable[beam.PCollection]], beam.PTransform]
+ ) -> beam.PTransform:
+ if typ.startswith('Filter-'):
+ return _SqlFilterTransform(
+ self._sql_provider, yaml_create_transform, **args)
+ if typ.startswith('MapToFields-'):
+ return _SqlMapToFieldsTransform(
+ self._sql_provider, yaml_create_transform, **args)
+ else:
+ raise NotImplementedError(typ)
+
+ def underlying_provider(self):
+ return self._sql_provider
+
+ def to_json(self):
+ return {'type': "SqlMappingProvider"}
+
+
[email protected]_fn
+def _SqlFilterTransform(
+ pcoll, sql_provider, yaml_create_transform, keep, language):
+ return pcoll | sql_provider.create_transform(
+ 'Sql', {'query': f'SELECT * FROM PCOLLECTION WHERE {keep}'},
+ yaml_create_transform)
-def create_mapping_provider():
[email protected]_fn
+def _SqlMapToFieldsTransform(
+ pcoll, sql_provider, yaml_create_transform, **mapping_args):
+ _, fields = normalize_fields(pcoll, **mapping_args)
+
+ def extract_expr(name, v):
+ if isinstance(v, str):
+ return v
+ elif 'expression' in v:
+ return v['expression']
+ else:
+ raise ValueError("Only expressions allowed in SQL at {name}.")
+
+ selects = [
+ f'({extract_expr(name, expr)}) AS {name}'
+ for (name, expr) in fields.items()
+ ]
+ query = "SELECT " + ", ".join(selects) + " FROM PCOLLECTION"
+ return pcoll | sql_provider.create_transform(
+ 'Sql', {'query': query}, yaml_create_transform)
+
+
+def create_mapping_providers():
# These are MetaInlineProviders because their expansion is in terms of other
# YamlTransforms, but in a way that needs to be deferred until the input
# schema is known.
- return yaml_provider.MetaInlineProvider({
- 'MapToFields': MapToFields,
- 'Filter': (
- lambda yaml_create_transform,
- keep,
- **kwargs: MapToFields(
- yaml_create_transform,
- keep=keep,
- fields={},
- append=True,
- transform_name='Filter',
- **kwargs)),
- 'Explode': (
- lambda yaml_create_transform,
- explode,
- **kwargs: MapToFields(
- yaml_create_transform,
- explode=explode,
- fields={},
- append=True,
- transform_name='Explode',
- **kwargs)),
- })
+ return [
+ yaml_provider.InlineProvider({
+ 'Explode': _Explode,
+ 'Filter-python': _PyJsFilter,
+ 'Filter-javascript': _PyJsFilter,
+ 'MapToFields-python': _PyJsMapToFields,
+ 'MapToFields-javascript': _PyJsMapToFields,
+ 'MapToFields-generic': _PyJsMapToFields,
+ }),
+ SqlMappingProvider(),
+ ]
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping_test.py
b/sdks/python/apache_beam/yaml/yaml_mapping_test.py
index 728476b1fd5..55032aeae52 100644
--- a/sdks/python/apache_beam/yaml/yaml_mapping_test.py
+++ b/sdks/python/apache_beam/yaml/yaml_mapping_test.py
@@ -82,18 +82,18 @@ class YamlMappingTest(unittest.TestCase):
elements = p | beam.Create(DATA)
result = elements | YamlTransform(
'''
- type: MapToFields
+ type: Filter
input: input
config:
language: python
- fields:
- label: label
keep: "rank > 0"
''')
assert_that(
- result, equal_to([
- beam.Row(label='37a'),
- beam.Row(label='389a'),
+ result
+ | beam.Map(lambda named_tuple: beam.Row(**named_tuple._asdict())),
+ equal_to([
+ beam.Row(label='37a', conductor=37, rank=1),
+ beam.Row(label='389a', conductor=389, rank=2),
]))
def test_explode(self):
@@ -105,15 +105,19 @@ class YamlMappingTest(unittest.TestCase):
])
result = elements | YamlTransform(
'''
- type: MapToFields
+ type: chain
input: input
- config:
- language: python
- append: true
- fields:
- range: "range(a)"
- explode: [range, b]
- cross_product: true
+ transforms:
+ - type: MapToFields
+ config:
+ language: python
+ append: true
+ fields:
+ range: "range(a)"
+ - type: Explode
+ config:
+ fields: [range, b]
+ cross_product: true
''')
assert_that(
result,
diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py
b/sdks/python/apache_beam/yaml/yaml_provider.py
index d01852a69c3..0cd9bdcadcc 100644
--- a/sdks/python/apache_beam/yaml/yaml_provider.py
+++ b/sdks/python/apache_beam/yaml/yaml_provider.py
@@ -209,6 +209,7 @@ class ExternalProvider(Provider):
def register_provider_type(cls, type_name):
def apply(constructor):
cls._provider_types[type_name] = constructor
+ return constructor
return apply
@@ -709,19 +710,21 @@ def merge_providers(*provider_sets):
transform_type: [provider]
for transform_type in provider.provided_transforms()
}
+ elif isinstance(provider_set, list):
+ provider_set = merge_providers(*provider_set)
for transform_type, providers in provider_set.items():
result[transform_type].extend(providers)
return result
def standard_providers():
- from apache_beam.yaml.yaml_mapping import create_mapping_provider
+ from apache_beam.yaml.yaml_mapping import create_mapping_providers
from apache_beam.yaml.yaml_io import io_providers
with open(os.path.join(os.path.dirname(__file__),
'standard_providers.yaml')) as fin:
standard_providers = yaml.load(fin, Loader=SafeLoader)
return merge_providers(
create_builtin_provider(),
- create_mapping_provider(),
+ create_mapping_providers(),
io_providers(),
parse_providers(standard_providers))
diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py
b/sdks/python/apache_beam/yaml/yaml_transform.py
index da9bf526cd5..78546aa28cb 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform.py
@@ -879,8 +879,23 @@ def preprocess(spec, verbose=False, known_transforms=None):
f'Unknown type or missing provider for {identify_object(spec)}')
return spec
+ def preprocess_langauges(spec):
+ if spec['type'] in ('Filter', 'MapToFields'):
+ language = spec.get('config', {}).get('language', 'generic')
+ new_type = spec['type'] + '-' + language
+ if known_transforms and new_type not in known_transforms:
+ if language == 'generic':
+ raise ValueError(f'Missing language for {identify_object(spec)}')
+ else:
+ raise ValueError(
+ f'Unknown language {language} for {identify_object(spec)}')
+ return dict(spec, type=new_type, name=spec.get('name', spec['type']))
+ else:
+ return spec
+
for phase in [
ensure_transforms_have_types,
+ preprocess_langauges,
ensure_transforms_have_providers,
preprocess_source_sink,
preprocess_chain,
diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py
b/sdks/python/apache_beam/yaml/yaml_transform_test.py
index 993f9ea6639..ebf12710d3f 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform_test.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py
@@ -419,21 +419,27 @@ class ErrorHandlingTest(unittest.TestCase):
input: Create
config:
fn: "lambda x: beam.Row(num=x, str='a' * x or 'bbb')"
+ - type: Filter
+ input: ToRow
+ config:
+ language: python
+ keep:
+ str[1] >= 'a'
+ error_handling:
+ output: errors
- type: MapToFields
name: MapWithErrorHandling
- input: ToRow
+ input: Filter
config:
language: python
fields:
num: num
inverse: float(1 / num)
- keep:
- str[1] >= 'a'
error_handling:
output: errors
- type: PyMap
name: TrimErrors
- input: MapWithErrorHandling.errors
+ input: [MapWithErrorHandling.errors, Filter.errors]
config:
fn: "lambda x: x.msg"
- type: MapToFields
diff --git a/sdks/python/apache_beam/yaml/yaml_udf_test.py
b/sdks/python/apache_beam/yaml/yaml_udf_test.py
index bb15cd49475..5e9faa08253 100644
--- a/sdks/python/apache_beam/yaml/yaml_udf_test.py
+++ b/sdks/python/apache_beam/yaml/yaml_udf_test.py
@@ -28,6 +28,10 @@ from apache_beam.testing.util import equal_to
from apache_beam.yaml.yaml_transform import YamlTransform
+def AsRows():
+ return beam.Map(lambda named_tuple: beam.Row(**named_tuple._asdict()))
+
+
class YamlUDFMappingTest(unittest.TestCase):
def __init__(self, method_name='runYamlMappingTest'):
super().__init__(method_name)
@@ -59,12 +63,11 @@ class YamlUDFMappingTest(unittest.TestCase):
callable: "function label_map(x) {return x.label + 'x'}"
conductor:
callable: "function conductor_map(x) {return x.conductor + 1}"
- keep:
- callable: "function filter(x) {return x.rank > 0}"
''')
assert_that(
result,
equal_to([
+ beam.Row(label='11ax', conductor=12),
beam.Row(label='37ax', conductor=38),
beam.Row(label='389ax', conductor=390),
]))
@@ -84,12 +87,11 @@ class YamlUDFMappingTest(unittest.TestCase):
callable: "lambda x: x.label + 'x'"
conductor:
callable: "lambda x: x.conductor + 1"
- keep:
- callable: "lambda x: x.rank > 0"
''')
assert_that(
result,
equal_to([
+ beam.Row(label='11ax', conductor=12),
beam.Row(label='37ax', conductor=38),
beam.Row(label='389ax', conductor=390),
]))
@@ -104,11 +106,11 @@ class YamlUDFMappingTest(unittest.TestCase):
input: input
config:
language: javascript
- keep:
+ keep:
callable: "function filter(x) {return x.rank > 0}"
''')
assert_that(
- result,
+ result | AsRows(),
equal_to([
beam.Row(label='37a', conductor=37, rank=1),
beam.Row(label='389a', conductor=389, rank=2),
@@ -124,11 +126,11 @@ class YamlUDFMappingTest(unittest.TestCase):
input: input
config:
language: python
- keep:
+ keep:
callable: "lambda x: x.rank > 0"
''')
assert_that(
- result,
+ result | AsRows(),
equal_to([
beam.Row(label='37a', conductor=37, rank=1),
beam.Row(label='389a', conductor=389, rank=2),
@@ -144,11 +146,12 @@ class YamlUDFMappingTest(unittest.TestCase):
input: input
config:
language: javascript
- keep:
+ keep:
expression: "label.toUpperCase().indexOf('3') == -1 && conductor"
''')
assert_that(
- result, equal_to([
+ result | AsRows(),
+ equal_to([
beam.Row(label='11a', conductor=11, rank=0),
]))
@@ -162,11 +165,12 @@ class YamlUDFMappingTest(unittest.TestCase):
input: input
config:
language: python
- keep:
+ keep:
expression: "'3' not in label"
''')
assert_that(
- result, equal_to([
+ result | AsRows(),
+ equal_to([
beam.Row(label='11a', conductor=11, rank=0),
]))
@@ -175,7 +179,7 @@ class YamlUDFMappingTest(unittest.TestCase):
function f(x) {
return x.rank > 0
}
-
+
function g(x) {
return x.rank > 1
}
@@ -193,12 +197,12 @@ class YamlUDFMappingTest(unittest.TestCase):
input: input
config:
language: javascript
- keep:
+ keep:
path: {path}
name: "f"
''')
assert_that(
- result,
+ result | AsRows(),
equal_to([
beam.Row(label='37a', conductor=37, rank=1),
beam.Row(label='389a', conductor=389, rank=2),
@@ -225,12 +229,12 @@ class YamlUDFMappingTest(unittest.TestCase):
input: input
config:
language: python
- keep:
+ keep:
path: {path}
name: "f"
''')
assert_that(
- result,
+ result | AsRows(),
equal_to([
beam.Row(label='37a', conductor=37, rank=1),
beam.Row(label='389a', conductor=389, rank=2),