This is an automated email from the ASF dual-hosted git repository.
robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 6b8f8c09f5e Initial start of a yaml-based declarative way of building
pipelines. (#24667)
6b8f8c09f5e is described below
commit 6b8f8c09f5e0f31ffa5617c4650297907810003a
Author: Robert Bradshaw <[email protected]>
AuthorDate: Tue Jan 24 09:49:58 2023 -0800
Initial start of a yaml-based declarative way of building pipelines.
(#24667)
---
sdks/python/apache_beam/transforms/external.py | 3 +
sdks/python/apache_beam/yaml/__init__.py | 18 +
sdks/python/apache_beam/yaml/main.py | 63 +++
sdks/python/apache_beam/yaml/pipeline.schema.yaml | 130 ++++++
.../apache_beam/yaml/standard_providers.yaml | 25 ++
sdks/python/apache_beam/yaml/yaml_provider.py | 437 ++++++++++++++++++++
sdks/python/apache_beam/yaml/yaml_transform.py | 450 +++++++++++++++++++++
.../python/apache_beam/yaml/yaml_transform_test.py | 90 +++++
8 files changed, 1216 insertions(+)
diff --git a/sdks/python/apache_beam/transforms/external.py
b/sdks/python/apache_beam/transforms/external.py
index 7a51379a0e3..1c4a6dd0519 100644
--- a/sdks/python/apache_beam/transforms/external.py
+++ b/sdks/python/apache_beam/transforms/external.py
@@ -777,6 +777,9 @@ class ExpansionAndArtifactRetrievalStub(
return beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceStub(
self._channel, **self._kwargs)
+ def ready(self, timeout_sec):
+ grpc.channel_ready_future(self._channel).result(timeout=timeout_sec)
+
class JavaJarExpansionService(object):
"""An expansion service based on an Java Jar file.
diff --git a/sdks/python/apache_beam/yaml/__init__.py
b/sdks/python/apache_beam/yaml/__init__.py
new file mode 100644
index 00000000000..73dd57dd7a6
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/__init__.py
@@ -0,0 +1,18 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from apache_beam.yaml.yaml_transform import *
diff --git a/sdks/python/apache_beam/yaml/main.py
b/sdks/python/apache_beam/yaml/main.py
new file mode 100644
index 00000000000..7bdeccbc2db
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/main.py
@@ -0,0 +1,63 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+
+import yaml
+
+import apache_beam as beam
+from apache_beam.yaml import yaml_transform
+
+
+def run(argv=None):
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--pipeline_spec',
+ description='A yaml description of the pipeline to run.')
+ parser.add_argument(
+ '--pipeline_spec_file',
+ description='A file containing a yaml description of the pipeline to
run.'
+ )
+ known_args, pipeline_args = parser.parse_known_args(argv)
+
+ if known_args.pipeline_spec_file and known_args.pipeline_spec:
+ raise ValueError(
+ "Exactly one of pipeline_spec or pipeline_spec_file must be set.")
+ elif known_args.pipeline_spec_file:
+ with open(known_args.pipeline_spec_file) as fin:
+ pipeline_yaml = fin.read()
+ elif known_args.pipeline_spec:
+ pipeline_yaml = known_args.pipeline_spec
+ else:
+ raise ValueError(
+ "Exactly one of pipeline_spec or pipeline_spec_file must be set.")
+
+ pipeline_spec = yaml.load(pipeline_yaml,
Loader=yaml_transform.SafeLineLoader)
+
+ yaml_transform._LOGGER.setLevel('INFO')
+
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pipeline_args,
+ pickle_library='cloudpickle',
+ **pipeline_spec.get('options', {}))) as p:
+ print("Building pipeline...")
+ yaml_transform.expand_pipeline(p, known_args.pipeline_spec)
+ print("Running pipeline...")
+
+
+if __name__ == '__main__':
+ run()
diff --git a/sdks/python/apache_beam/yaml/pipeline.schema.yaml
b/sdks/python/apache_beam/yaml/pipeline.schema.yaml
new file mode 100644
index 00000000000..36c35af49d8
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/pipeline.schema.yaml
@@ -0,0 +1,130 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+$schema: 'http://json-schema.org/schema#'
+$id:
https://github.com/apache/beam/tree/master/sdks/python/apache_beam/yaml/pipeline.schema.yaml
+
+$defs:
+
+ transformBase:
+ type: object
+ properties:
+ type: { type: string }
+ name: { type: string }
+ input:
+ oneOf:
+ - type: string
+ - type: object
+ additionalProperties:
+ type: string
+ output:
+ oneOf:
+ - type: string
+ - type: object
+ additionalProperties:
+ type: string
+ additionalProperties: true
+ required:
+ - type
+
+ leafTransform:
+ allOf:
+ - $ref: '#/$defs/transformBase'
+ - type: object
+ properties:
+ type:
+ not:
+ anyOf:
+ - const: chain
+ - const: composite
+ args:
+ type: object
+
+ chainTransform:
+ allOf:
+ - $ref: '#/$defs/transformBase'
+ - type: object
+ properties:
+ type:
+ const: chain
+ name: {}
+ input: {}
+ output: {}
+ transforms:
+ type: array
+ items:
+ allOf:
+ - $ref: '#/$defs/transform'
+ - type: object
+ properties:
+ # Must be implicit.
+ input: { not: {} }
+ output: { not: {} }
+ additionalProperties: false
+ required:
+ - transforms
+
+ compositeTransform:
+ allOf:
+ - $ref: '#/$defs/transformBase'
+ - type: object
+ properties:
+ type:
+ const: composite
+ name: {}
+ input: {}
+ output: {}
+ transforms:
+ type: array
+ items:
+ $ref: '#/$defs/transform'
+ additionalProperties: false
+ required:
+ - transforms
+
+ transform:
+ oneOf:
+ - $ref: '#/$defs/leafTransform'
+ - $ref: '#/$defs/chainTransform'
+ - $ref: '#/$defs/compositeTransform'
+
+ provider:
+ # TODO(robertwb): Consider enumerating the provider types along with
+ # the arguments they accept/expect (possibly in a separate schema file).
+ type: object
+ properties:
+ type: { type: string }
+ transforms:
+ type: object
+ additionalProperties:
+ type: string
+ required:
+ - type
+ - transforms
+
+type: object
+properties:
+ pipeline:
+ type: array
+ items:
+ $ref: '#/$defs/transform'
+ providers:
+ type: array
+ items:
+ $ref: '#/$defs/provider'
+required:
+ - pipeline
diff --git a/sdks/python/apache_beam/yaml/standard_providers.yaml
b/sdks/python/apache_beam/yaml/standard_providers.yaml
new file mode 100644
index 00000000000..04d33b53d64
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/standard_providers.yaml
@@ -0,0 +1,25 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# TODO(robertwb): Add more providers.
+# TODO(robertwb): Perhaps auto-generate this file?
+
+- type: 'beamJar'
+ gradle_target: 'sdks:java:extensions:sql:expansion-service:shadowJar'
+ version: BEAM_VERSION
+ transforms:
+ Sql: 'beam:external:java:sql:v1'
diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py
b/sdks/python/apache_beam/yaml/yaml_provider.py
new file mode 100644
index 00000000000..a558de3507a
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_provider.py
@@ -0,0 +1,437 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""This module defines Providers usable from yaml, which is a specification
+for where to find and how to invoke services that vend implementations of
+various PTransforms."""
+
+import collections
+import hashlib
+import json
+import os
+import subprocess
+import sys
+import uuid
+from typing import Any
+from typing import Iterable
+from typing import Mapping
+
+import yaml
+from yaml.loader import SafeLoader
+
+import apache_beam as beam
+import apache_beam.dataframe.io
+import apache_beam.io
+import apache_beam.transforms.util
+from apache_beam.portability.api import schema_pb2
+from apache_beam.transforms import external
+from apache_beam.transforms.fully_qualified_named_transform import
FullyQualifiedNamedTransform
+from apache_beam.typehints import schemas
+from apache_beam.typehints import trivial_inference
+from apache_beam.utils import python_callable
+from apache_beam.utils import subprocess_server
+from apache_beam.version import __version__ as beam_version
+
+
+class Provider:
+ """Maps transform types names and args to concrete PTransform instances."""
+ def available(self) -> bool:
+ """Returns whether this provider is available to use in this
environment."""
+ raise NotImplementedError(type(self))
+
+ def provided_transforms(self) -> Iterable[str]:
+ """Returns a list of transform type names this provider can handle."""
+ raise NotImplementedError(type(self))
+
+ def create_transform(
+ self, typ: str, args: Mapping[str, Any]) -> beam.PTransform:
+ """Creates a PTransform instance for the given transform type and
arguments.
+ """
+ raise NotImplementedError(type(self))
+
+
+class ExternalProvider(Provider):
+ """A Provider implemented via the cross language transform service."""
+ def __init__(self, urns, service):
+ self._urns = urns
+ self._service = service
+ self._schema_transforms = None
+
+ def provided_transforms(self):
+ return self._urns.keys()
+
+ def create_transform(self, type, args):
+ if callable(self._service):
+ self._service = self._service()
+ if self._schema_transforms is None:
+ try:
+ self._schema_transforms = [
+ config.identifier
+ for config in external.SchemaAwareExternalTransform.discover(
+ self._service)
+ ]
+ except Exception:
+ self._schema_transforms = []
+ urn = self._urns[type]
+ if urn in self._schema_transforms:
+ return external.SchemaAwareExternalTransform(urn, self._service, **args)
+ else:
+ return type >> self.create_external_transform(urn, args)
+
+ def create_external_transform(self, urn, args):
+ return external.ExternalTransform(
+ urn,
+ external.ImplicitSchemaPayloadBuilder(args).payload(),
+ self._service)
+
+ @staticmethod
+ def provider_from_spec(spec):
+ urns = spec['transforms']
+ type = spec['type']
+ if spec.get('version', None) == 'BEAM_VERSION':
+ spec['version'] = beam_version
+ if type == 'jar':
+ return ExternalJavaProvider(urns, lambda: spec['jar'])
+ elif type == 'mavenJar':
+ return ExternalJavaProvider(
+ urns,
+ lambda: subprocess_server.JavaJarServer.path_to_maven_jar(
+ **{
+ key: value
+ for (key, value) in spec.items() if key in [
+ 'artifact_id',
+ 'group_id',
+ 'version',
+ 'repository',
+ 'classifier',
+ 'appendix'
+ ]
+ }))
+ elif type == 'beamJar':
+ return ExternalJavaProvider(
+ urns,
+ lambda: subprocess_server.JavaJarServer.path_to_beam_jar(
+ **{
+ key: value
+ for (key, value) in spec.items() if key in
+ ['gradle_target', 'version', 'appendix', 'artifact_id']
+ }))
+ elif type == 'pypi':
+ return ExternalPythonProvider(urns, spec['packages'])
+ elif type == 'remote':
+ return RemoteProvider(spec['address'])
+ elif type == 'docker':
+ raise NotImplementedError()
+ else:
+ raise NotImplementedError(f'Unknown provider type: {type}')
+
+
+class RemoteProvider(ExternalProvider):
+ _is_available = None
+
+ def available(self):
+ if self._is_available is None:
+ try:
+ with external.ExternalTransform.service(self._service) as service:
+ service.ready(1)
+ self._is_available = True
+ except Exception:
+ self._is_available = False
+ return self._is_available
+
+
+class ExternalJavaProvider(ExternalProvider):
+ def __init__(self, urns, jar_provider):
+ super().__init__(
+ urns, lambda: external.JavaJarExpansionService(jar_provider()))
+
+ def available(self):
+ # pylint: disable=subprocess-run-check
+ return subprocess.run(['which', 'java'],
+ capture_output=True).returncode == 0
+
+
+class ExternalPythonProvider(ExternalProvider):
+ def __init__(self, urns, packages):
+ super().__init__(urns, PypiExpansionService(packages))
+
+ def available(self):
+ return True # If we're running this script, we have Python installed.
+
+ def create_external_transform(self, urn, args):
+ # Python transforms are "registered" by fully qualified name.
+ return external.ExternalTransform(
+ "beam:transforms:python:fully_qualified_named",
+ external.ImplicitSchemaPayloadBuilder({
+ 'constructor': urn,
+ 'kwargs': args,
+ }).payload(),
+ self._service)
+
+
+# This is needed because type inference can't handle *args, **kwargs fowarding.
+# TODO(BEAM-24755): Add support for type inference of through kwargs calls.
+def fix_pycallable():
+ from apache_beam.transforms.ptransform import label_from_callable
+
+ def default_label(self):
+ src = self._source.strip()
+ last_line = src.split('\n')[-1]
+ if last_line[0] != ' ' and len(last_line) < 72:
+ return last_line
+ return label_from_callable(self._callable)
+
+ def _argspec_fn(self):
+ return self._callable
+
+ python_callable.PythonCallableWithSource.default_label = default_label
+ python_callable.PythonCallableWithSource._argspec_fn = property(_argspec_fn)
+
+ original_infer_return_type = trivial_inference.infer_return_type
+
+ def infer_return_type(fn, *args, **kwargs):
+ if isinstance(fn, python_callable.PythonCallableWithSource):
+ fn = fn._callable
+ return original_infer_return_type(fn, *args, **kwargs)
+
+ trivial_inference.infer_return_type = infer_return_type
+
+ original_fn_takes_side_inputs = (
+ apache_beam.transforms.util.fn_takes_side_inputs)
+
+ def fn_takes_side_inputs(fn):
+ if isinstance(fn, python_callable.PythonCallableWithSource):
+ fn = fn._callable
+ return original_fn_takes_side_inputs(fn)
+
+ apache_beam.transforms.util.fn_takes_side_inputs = fn_takes_side_inputs
+
+
+class InlineProvider(Provider):
+ def __init__(self, transform_factories):
+ self._transform_factories = transform_factories
+
+ def available(self):
+ return True
+
+ def provided_transforms(self):
+ return self._transform_factories.keys()
+
+ def create_transform(self, type, args):
+ return self._transform_factories[type](**args)
+
+ def to_json(self):
+ return {'type': "InlineProvider"}
+
+
+PRIMITIVE_NAMES_TO_ATOMIC_TYPE = {
+ py_type.__name__: schema_type
+ for (py_type, schema_type) in schemas.PRIMITIVE_TO_ATOMIC_TYPE.items()
+ if py_type.__module__ != 'typing'
+}
+
+
+def create_builtin_provider():
+ def with_schema(**args):
+ # TODO: This is preliminary.
+ def parse_type(spec):
+ if spec in PRIMITIVE_NAMES_TO_ATOMIC_TYPE:
+ return schema_pb2.FieldType(
+ atomic_type=PRIMITIVE_NAMES_TO_ATOMIC_TYPE[spec])
+ elif isinstance(spec, list):
+ if len(spec) != 1:
+ raise ValueError("Use single-element lists to denote list types.")
+ else:
+ return schema_pb2.FieldType(
+ iterable_type=schema_pb2.IterableType(
+ element_type=parse_type(spec[0])))
+ elif isinstance(spec, dict):
+ return schema_pb2.FieldType(
+ iterable_type=schema_pb2.RowType(schema=parse_schema(spec[0])))
+ else:
+ raise ValueError("Unknown schema type: {spec}")
+
+ def parse_schema(spec):
+ return schema_pb2.Schema(
+ fields=[
+ schema_pb2.Field(name=key, type=parse_type(value), id=ix)
+ for (ix, (key, value)) in enumerate(spec.items())
+ ],
+ id=str(uuid.uuid4()))
+
+ named_tuple = schemas.named_tuple_from_schema(parse_schema(args))
+ names = list(args.keys())
+
+ def extract_field(x, name):
+ if isinstance(x, dict):
+ return x[name]
+ else:
+ return getattr(x, name)
+
+ return 'WithSchema(%s)' % ', '.join(names) >> beam.Map(
+ lambda x: named_tuple(*[extract_field(x, name) for name in names])
+ ).with_output_types(named_tuple)
+
+ # Or should this be posargs, args?
+ # pylint: disable=dangerous-default-value
+ def fully_qualified_named_transform(constructor, args=(), kwargs={}):
+ with FullyQualifiedNamedTransform.with_filter('*'):
+ return constructor >> FullyQualifiedNamedTransform(
+ constructor, args, kwargs)
+
+ # This intermediate is needed because there is no way to specify a tuple of
+ # exactly zero or one PCollection in yaml (as they would be interpreted as
+ # PBegin and the PCollection itself respectively).
+ class Flatten(beam.PTransform):
+ def expand(self, pcolls):
+ if isinstance(pcolls, beam.PCollection):
+ pipeline_arg = {}
+ pcolls = (pcolls, )
+ elif isinstance(pcolls, dict):
+ pipeline_arg = {}
+ pcolls = tuple(pcolls.values())
+ else:
+ pipeline_arg = {'pipeline': pcolls.pipeline}
+ pcolls = ()
+ return pcolls | beam.Flatten(**pipeline_arg)
+
+ ios = {
+ key: getattr(apache_beam.io, key)
+ for key in dir(apache_beam.io)
+ if key.startswith('ReadFrom') or key.startswith('WriteTo')
+ }
+ ios['ReadFromCsv'] = lambda **kwargs: apache_beam.dataframe.io.ReadViaPandas(
+ 'csv', **kwargs)
+ ios['WriteToCsv'] = lambda **kwargs: apache_beam.dataframe.io.WriteViaPandas(
+ 'csv', **kwargs)
+ ios['ReadFromJson'] = (
+ lambda *,
+ orient='records',
+ lines=True,
+ **kwargs: apache_beam.dataframe.io.ReadViaPandas(
+ 'json', orient=orient, lines=lines, **kwargs))
+ ios['WriteToJson'] = (
+ lambda *,
+ orient='records',
+ lines=True,
+ **kwargs: apache_beam.dataframe.io.WriteViaPandas(
+ 'json', orient=orient, lines=lines, **kwargs))
+
+ return InlineProvider(
+ dict({
+ 'Create': lambda elements,
+ reshuffle=True: beam.Create(elements, reshuffle),
+ 'PyMap': lambda fn: beam.Map(
+ python_callable.PythonCallableWithSource(fn)),
+ 'PyMapTuple': lambda fn: beam.MapTuple(
+ python_callable.PythonCallableWithSource(fn)),
+ 'PyFlatMap': lambda fn: beam.FlatMap(
+ python_callable.PythonCallableWithSource(fn)),
+ 'PyFlatMapTuple': lambda fn: beam.FlatMapTuple(
+ python_callable.PythonCallableWithSource(fn)),
+ 'PyFilter': lambda fn: beam.Filter(
+ python_callable.PythonCallableWithSource(fn)),
+ 'PyTransform': fully_qualified_named_transform,
+ 'PyToRow': lambda fields: beam.Select(
+ **{
+ name: python_callable.PythonCallableWithSource(fn)
+ for (name, fn) in fields.items()
+ }),
+ 'WithSchema': with_schema,
+ 'Flatten': Flatten,
+ 'GroupByKey': beam.GroupByKey,
+ },
+ **ios))
+
+
+class PypiExpansionService:
+ """Expands transforms by fully qualified name in a virtual environment
+ with the given dependencies.
+ """
+ VENV_CACHE = os.path.expanduser("~/.apache_beam/cache/venvs")
+
+ def __init__(self, packages, base_python=sys.executable):
+ self._packages = packages
+ self._base_python = base_python
+
+ def _key(self):
+ return json.dumps({'binary': self._base_python, 'packages':
self._packages})
+
+ def _venv(self):
+ venv = os.path.join(
+ self.VENV_CACHE,
+ hashlib.sha256(self._key().encode('utf-8')).hexdigest())
+ if not os.path.exists(venv):
+ python_binary = os.path.join(venv, 'bin', 'python')
+ subprocess.run([self._base_python, '-m', 'venv', venv], check=True)
+ subprocess.run([python_binary, '-m', 'ensurepip'], check=True)
+ subprocess.run([python_binary, '-m', 'pip', 'install'] + self._packages,
+ check=True)
+ with open(venv + '-requirements.txt', 'w') as fout:
+ fout.write('\n'.join(self._packages))
+ return venv
+
+ def __enter__(self):
+ venv = self._venv()
+ self._service_provider = subprocess_server.SubprocessServer(
+ external.ExpansionAndArtifactRetrievalStub,
+ [
+ os.path.join(venv, 'bin', 'python'),
+ '-m',
+ 'apache_beam.runners.portability.expansion_service_main',
+ '--port',
+ '{{PORT}}',
+ '--fully_qualified_name_glob=*',
+ '--pickle_library=cloudpickle',
+ '--requirements_file=' + os.path.join(venv + '-requirements.txt')
+ ])
+ self._service = self._service_provider.__enter__()
+ return self._service
+
+ def __exit__(self, *args):
+ self._service_provider.__exit__(*args)
+ self._service = None
+
+
+def parse_providers(provider_specs):
+ providers = collections.defaultdict(list)
+ for provider_spec in provider_specs:
+ provider = ExternalProvider.provider_from_spec(provider_spec)
+ for transform_type in provider.provided_transforms():
+ providers[transform_type].append(provider)
+ # TODO: Do this better.
+ provider.to_json = lambda result=provider_spec: result
+ return providers
+
+
+def merge_providers(*provider_sets):
+ result = collections.defaultdict(list)
+ for provider_set in provider_sets:
+ for transform_type, providers in provider_set.items():
+ result[transform_type].extend(providers)
+ return result
+
+
+def standard_providers():
+ builtin_providers = collections.defaultdict(list)
+ builtin_provider = create_builtin_provider()
+ for transform_type in builtin_provider.provided_transforms():
+ builtin_providers[transform_type].append(builtin_provider)
+ with open(os.path.join(os.path.dirname(__file__),
+ 'standard_providers.yaml')) as fin:
+ standard_providers = yaml.load(fin, Loader=SafeLoader)
+ return merge_providers(builtin_providers,
parse_providers(standard_providers))
diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py
b/sdks/python/apache_beam/yaml/yaml_transform.py
new file mode 100644
index 00000000000..1e8495e308a
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_transform.py
@@ -0,0 +1,450 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# This module is experimental. No backwards-compatibility guarantees.
+
+import collections
+import json
+import logging
+import re
+import uuid
+from typing import Iterable
+from typing import Mapping
+
+import yaml
+from yaml.loader import SafeLoader
+
+import apache_beam as beam
+from apache_beam.transforms.fully_qualified_named_transform import
FullyQualifiedNamedTransform
+from apache_beam.yaml import yaml_provider
+
+__all__ = ["YamlTransform"]
+
+_LOGGER = logging.getLogger(__name__)
+yaml_provider.fix_pycallable()
+
+
+def memoize_method(func):
+ def wrapper(self, *args):
+ if not hasattr(self, '_cache'):
+ self._cache = {}
+ key = func.__name__, args
+ if key not in self._cache:
+ self._cache[key] = func(self, *args)
+ return self._cache[key]
+
+ return wrapper
+
+
+def only_element(xs):
+ x, = xs
+ return x
+
+
+class SafeLineLoader(SafeLoader):
+ """A yaml loader that attaches line information to mappings and strings."""
+ class TaggedString(str):
+ """A string class to which we can attach metadata.
+
+ This is primarily used to trace a string's origin back to its place in a
+ yaml file.
+ """
+ def __reduce__(self):
+ # Pickle as an ordinary string.
+ return str, (str(self), )
+
+ def construct_scalar(self, node):
+ value = super().construct_scalar(node)
+ if isinstance(value, str):
+ value = SafeLineLoader.TaggedString(value)
+ value._line_ = node.start_mark.line + 1
+ return value
+
+ def construct_mapping(self, node, deep=False):
+ mapping = super().construct_mapping(node, deep=deep)
+ mapping['__line__'] = node.start_mark.line + 1
+ mapping['__uuid__'] = str(uuid.uuid4())
+ return mapping
+
+ @classmethod
+ def strip_metadata(cls, spec, tagged_str=True):
+ if isinstance(spec, Mapping):
+ return {
+ key: cls.strip_metadata(value, tagged_str)
+ for key,
+ value in spec.items() if key not in ('__line__', '__uuid__')
+ }
+ elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)):
+ return [cls.strip_metadata(value, tagged_str) for value in spec]
+ elif isinstance(spec, SafeLineLoader.TaggedString) and tagged_str:
+ return str(spec)
+ else:
+ return spec
+
+ @staticmethod
+ def get_line(obj):
+ if isinstance(obj, dict):
+ return obj.get('__line__', 'unknown')
+ else:
+ return getattr(obj, '_line_', 'unknown')
+
+
+class Scope(object):
+ """To look up PCollections (typically outputs of prior transforms) by
name."""
+ def __init__(self, root, inputs, transforms, providers):
+ self.root = root
+ self.providers = providers
+ self._inputs = inputs
+ self._transforms = transforms
+ self._transforms_by_uuid = {t['__uuid__']: t for t in self._transforms}
+ self._uuid_by_name = collections.defaultdict(list)
+ for spec in self._transforms:
+ if 'name' in spec:
+ self._uuid_by_name[spec['name']].append(spec['__uuid__'])
+ if 'type' in spec:
+ self._uuid_by_name[spec['type']].append(spec['__uuid__'])
+ self._seen_names = set()
+
+ def compute_all(self):
+ for transform_id in self._transforms_by_uuid.keys():
+ self.compute_outputs(transform_id)
+
+ def get_pcollection(self, name):
+ if name in self._inputs:
+ return self._inputs[name]
+ elif '.' in name:
+ transform, output = name.rsplit('.', 1)
+ outputs = self.get_outputs(transform)
+ if output in outputs:
+ return outputs[output]
+ else:
+ raise ValueError(
+ f'Unknown output {repr(output)} '
+ f'at line {SafeLineLoader.get_line(name)}: '
+ f'{transform} only has outputs {list(outputs.keys())}')
+ else:
+ outputs = self.get_outputs(name)
+ if len(outputs) == 1:
+ return only_element(outputs.values())
+ else:
+ raise ValueError(
+ f'Ambiguous output at line {SafeLineLoader.get_line(name)}: '
+ f'{name} has outputs {list(outputs.keys())}')
+
+ def get_outputs(self, transform_name):
+ if transform_name in self._transforms_by_uuid:
+ transform_id = transform_name
+ else:
+ candidates = self._uuid_by_name[transform_name]
+ if not candidates:
+ raise ValueError(
+ f'Unknown transform at line '
+ f'{SafeLineLoader.get_line(transform_name)}: {transform_name}')
+ elif len(candidates) > 1:
+ raise ValueError(
+ f'Ambiguous transform at line '
+ f'{SafeLineLoader.get_line(transform_name)}: {transform_name}')
+ else:
+ transform_id = only_element(candidates)
+ return self.compute_outputs(transform_id)
+
+ @memoize_method
+ def compute_outputs(self, transform_id):
+ return expand_transform(self._transforms_by_uuid[transform_id], self)
+
+ # A method on scope as providers may be scoped...
+ def create_ptransform(self, spec):
+ if 'type' not in spec:
+ raise ValueError(f'Missing transform type: {identify_object(spec)}')
+
+ if spec['type'] not in self.providers:
+ raise ValueError(
+ 'Unknown transform type %r at %s' %
+ (spec['type'], identify_object(spec)))
+
+ for provider in self.providers.get(spec['type']):
+ if provider.available():
+ break
+ else:
+ raise ValueError(
+ 'No available provider for type %r at %s' %
+ (spec['type'], identify_object(spec)))
+
+ if 'args' in spec:
+ args = spec['args']
+ if not isinstance(args, dict):
+ raise ValueError(
+ 'Arguments for transform at %s must be a mapping.' %
+ identify_object(spec))
+ else:
+ args = {
+ key: value
+ for (key, value) in spec.items()
+ if key not in ('type', 'name', 'input', 'output')
+ }
+ real_args = SafeLineLoader.strip_metadata(args)
+ try:
+ # pylint: disable=undefined-loop-variable
+ ptransform = provider.create_transform(spec['type'], real_args)
+ # TODO(robertwb): Should we have a better API for adding annotations
+ # than this?
+ annotations = dict(
+ yaml_type=spec['type'],
+ yaml_args=json.dumps(real_args),
+ yaml_provider=json.dumps(provider.to_json()),
+ **ptransform.annotations())
+ ptransform.annotations = lambda: annotations
+ return ptransform
+ except Exception as exn:
+ if isinstance(exn, TypeError):
+ # Create a slightly more generic error message for argument errors.
+ msg = str(exn).replace('positional', '').replace('keyword', '')
+ msg = re.sub(r'\S+lambda\S+', '', msg)
+ msg = re.sub(' +', ' ', msg).strip()
+ else:
+ msg = str(exn)
+ raise ValueError(
+ f'Invalid transform specification at {identify_object(spec)}: {msg}'
+ ) from exn
+
+ def unique_name(self, spec, ptransform, strictness=0):
+ if 'name' in spec:
+ name = spec['name']
+ strictness += 1
+ else:
+ name = ptransform.label
+ if name in self._seen_names:
+ if strictness >= 2:
+ raise ValueError(f'Duplicate name at {identify_object(spec)}: {name}')
+ else:
+ name = f'{name}@{SafeLineLoader.get_line(spec)}'
+ self._seen_names.add(name)
+ return name
+
+
+def expand_transform(spec, scope):
+ if 'type' not in spec:
+ raise TypeError(
+ f'Missing type parameter for transform at {identify_object(spec)}')
+ type = spec['type']
+ if type == 'composite':
+ return expand_composite_transform(spec, scope)
+ elif type == 'chain':
+ return expand_chain_transform(spec, scope)
+ else:
+ return expand_leaf_transform(spec, scope)
+
+
+def expand_leaf_transform(spec, scope):
+ spec = normalize_inputs_outputs(spec)
+ inputs_dict = {
+ key: scope.get_pcollection(value)
+ for (key, value) in spec['input'].items()
+ }
+ input_type = spec.get('input_type', 'default')
+ if input_type == 'list':
+ inputs = tuple(inputs_dict.values())
+ elif input_type == 'map':
+ inputs = inputs_dict
+ else:
+ if len(inputs_dict) == 0:
+ inputs = scope.root
+ elif len(inputs_dict) == 1:
+ inputs = next(iter(inputs_dict.values()))
+ else:
+ inputs = inputs_dict
+ _LOGGER.info("Expanding %s ", identify_object(spec))
+ ptransform = scope.create_ptransform(spec)
+ try:
+ # TODO: Move validation to construction?
+ with FullyQualifiedNamedTransform.with_filter('*'):
+ outputs = inputs | scope.unique_name(spec, ptransform) >> ptransform
+ except Exception as exn:
+ raise ValueError(
+ f"Errror apply transform {identify_object(spec)}: {exn}") from exn
+ if isinstance(outputs, dict):
+ # TODO: Handle (or at least reject) nested case.
+ return outputs
+ elif isinstance(outputs, (tuple, list)):
+ return {'out{ix}': pcoll for (ix, pcoll) in enumerate(outputs)}
+ elif isinstance(outputs, beam.PCollection):
+ return {'out': outputs}
+ else:
+ raise ValueError(
+ f'Transform {identify_object(spec)} returned an unexpected type '
+ f'{type(outputs)}')
+
+
+def expand_composite_transform(spec, scope):
+ spec = normalize_inputs_outputs(spec)
+
+ inner_scope = Scope(
+ scope.root, {
+ key: scope.get_pcollection(value)
+ for key,
+ value in spec['input'].items()
+ },
+ spec['transforms'],
+ yaml_provider.merge_providers(
+ yaml_provider.parse_providers(spec.get('providers', [])),
+ scope.providers))
+
+ class CompositePTransform(beam.PTransform):
+ @staticmethod
+ def expand(inputs):
+ inner_scope.compute_all()
+ return {
+ key: inner_scope.get_pcollection(value)
+ for (key, value) in spec['output'].items()
+ }
+
+ if 'name' not in spec:
+ spec['name'] = 'Composite'
+ if spec['name'] is None: # top-level pipeline, don't nest
+ return CompositePTransform.expand(None)
+ else:
+ _LOGGER.info("Expanding %s ", identify_object(spec))
+ return ({
+ key: scope.get_pcollection(value)
+ for key,
+ value in spec['input'].items()
+ } or scope.root) | scope.unique_name(spec, None) >> CompositePTransform()
+
+
+def expand_chain_transform(spec, scope):
+ return expand_composite_transform(chain_as_composite(spec), scope)
+
+
+def chain_as_composite(spec):
+ # A chain is simply a composite transform where all inputs and outputs
+ # are implicit.
+ if 'transforms' not in spec:
+ raise TypeError(
+ f"Chain at {identify_object(spec)} missing transforms property.")
+ has_explicit_outputs = 'output' in spec
+ composite_spec = normalize_inputs_outputs(spec)
+ new_transforms = []
+ for ix, transform in enumerate(composite_spec['transforms']):
+ if any(io in transform for io in ('input', 'output', 'input', 'output')):
+ raise ValueError(
+ f'Transform {identify_object(transform)} is part of a chain, '
+ 'must have implicit inputs and outputs.')
+ if ix == 0:
+ transform['input'] = {key: key for key in composite_spec['input'].keys()}
+ else:
+ transform['input'] = new_transforms[-1]['__uuid__']
+ new_transforms.append(transform)
+ composite_spec['transforms'] = new_transforms
+
+ last_transform = new_transforms[-1]['__uuid__']
+ if has_explicit_outputs:
+ composite_spec['output'] = {
+ key: f'{last_transform}.{value}'
+ for (key, value) in composite_spec['output'].items()
+ }
+ else:
+ composite_spec['output'] = last_transform
+ if 'name' not in composite_spec:
+ composite_spec['name'] = 'Chain'
+ composite_spec['type'] = 'composite'
+ return composite_spec
+
+
+def pipeline_as_composite(spec):
+ if isinstance(spec, list):
+ return {
+ 'type': 'composite',
+ 'name': None,
+ 'transforms': spec,
+ '__line__': spec[0]['__line__'],
+ '__uuid__': str(uuid.uuid4()),
+ }
+ else:
+ return dict(spec, name=None, type='composite')
+
+
+def normalize_inputs_outputs(spec):
+ spec = dict(spec)
+
+ def normalize_io(tag):
+ io = spec.get(tag, {})
+ if isinstance(io, str):
+ return {tag: io}
+ elif isinstance(io, list):
+ return {f'{tag}{ix}': value for ix, value in enumerate(io)}
+ else:
+ return SafeLineLoader.strip_metadata(io, tagged_str=False)
+
+ return dict(spec, input=normalize_io('input'), output=normalize_io('output'))
+
+
+def identify_object(spec):
+ line = SafeLineLoader.get_line(spec)
+ name = extract_name(spec)
+ if name:
+ return f'"{name}" at line {line}'
+ else:
+ return f'at line {line}'
+
+
+def extract_name(spec):
+ if 'name' in spec:
+ return spec['name']
+ elif 'id' in spec:
+ return spec['id']
+ elif 'type' in spec:
+ return spec['type']
+ elif len(spec) == 1:
+ return extract_name(next(iter(spec.values())))
+ else:
+ return ''
+
+
+class YamlTransform(beam.PTransform):
+ def __init__(self, spec, providers={}): # pylint:
disable=dangerous-default-value
+ if isinstance(spec, str):
+ spec = yaml.load(spec, Loader=SafeLineLoader)
+ self._spec = spec
+ self._providers = yaml_provider.merge_providers(
+ providers, yaml_provider.standard_providers())
+
+ def expand(self, pcolls):
+ if isinstance(pcolls, beam.pvalue.PBegin):
+ root = pcolls
+ pcolls = {}
+ elif isinstance(pcolls, beam.PCollection):
+ root = pcolls.pipeline
+ pcolls = {'input': pcolls}
+ else:
+ root = next(iter(pcolls.values())).pipeline
+ result = expand_transform(
+ self._spec,
+ Scope(root, pcolls, transforms=[], providers=self._providers))
+ if len(result) == 1:
+ return only_element(result.values())
+ else:
+ return result
+
+
+def expand_pipeline(pipeline, pipeline_spec):
+ if isinstance(pipeline_spec, str):
+ pipeline_spec = yaml.load(pipeline_spec, Loader=SafeLineLoader)
+ # Calling expand directly to avoid outer layer of nesting.
+ return YamlTransform(
+ pipeline_as_composite(pipeline_spec['pipeline']),
+ yaml_provider.parse_providers(pipeline_spec.get('providers',
[]))).expand(
+ beam.pvalue.PBegin(pipeline))
diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py
b/sdks/python/apache_beam/yaml/yaml_transform_test.py
new file mode 100644
index 00000000000..e3b7097df24
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+
+import apache_beam as beam
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.yaml.yaml_transform import YamlTransform
+
+
+class YamlTransformTest(unittest.TestCase):
+ def test_composite(self):
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle')) as p:
+ elements = p | beam.Create([1, 2, 3])
+ # TODO(robertwb): Consider making the input implicit (and below).
+ result = elements | YamlTransform(
+ '''
+ type: composite
+ input:
+ elements: input
+ transforms:
+ - type: PyMap
+ name: Square
+ input: elements
+ fn: "lambda x: x * x"
+ - type: PyMap
+ name: Cube
+ input: elements
+ fn: "lambda x: x * x * x"
+ - type: Flatten
+ input: [Square, Cube]
+ output:
+ Flatten
+ ''')
+ assert_that(result, equal_to([1, 4, 9, 1, 8, 27]))
+
+ def test_chain_with_input(self):
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle')) as p:
+ elements = p | beam.Create(range(10))
+ result = elements | YamlTransform(
+ '''
+ type: chain
+ input:
+ elements: input
+ transforms:
+ - type: PyMap
+ fn: "lambda x: x * x + x"
+ - type: PyMap
+ fn: "lambda x: x + 41"
+ ''')
+ assert_that(result, equal_to([41, 43, 47, 53, 61, 71, 83, 97, 113, 131]))
+
+ def test_chain_with_root(self):
+ with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+ pickle_library='cloudpickle')) as p:
+ result = p | YamlTransform(
+ '''
+ type: chain
+ transforms:
+ - type: Create
+ elements: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+ - type: PyMap
+ fn: "lambda x: x * x + x"
+ - type: PyMap
+ fn: "lambda x: x + 41"
+ ''')
+ assert_that(result, equal_to([41, 43, 47, 53, 61, 71, 83, 97, 113, 131]))
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()