This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 6b8f8c09f5e Initial start of a yaml-based declarative way of building 
pipelines. (#24667)
6b8f8c09f5e is described below

commit 6b8f8c09f5e0f31ffa5617c4650297907810003a
Author: Robert Bradshaw <[email protected]>
AuthorDate: Tue Jan 24 09:49:58 2023 -0800

    Initial start of a yaml-based declarative way of building pipelines. 
(#24667)
---
 sdks/python/apache_beam/transforms/external.py     |   3 +
 sdks/python/apache_beam/yaml/__init__.py           |  18 +
 sdks/python/apache_beam/yaml/main.py               |  63 +++
 sdks/python/apache_beam/yaml/pipeline.schema.yaml  | 130 ++++++
 .../apache_beam/yaml/standard_providers.yaml       |  25 ++
 sdks/python/apache_beam/yaml/yaml_provider.py      | 437 ++++++++++++++++++++
 sdks/python/apache_beam/yaml/yaml_transform.py     | 450 +++++++++++++++++++++
 .../python/apache_beam/yaml/yaml_transform_test.py |  90 +++++
 8 files changed, 1216 insertions(+)

diff --git a/sdks/python/apache_beam/transforms/external.py 
b/sdks/python/apache_beam/transforms/external.py
index 7a51379a0e3..1c4a6dd0519 100644
--- a/sdks/python/apache_beam/transforms/external.py
+++ b/sdks/python/apache_beam/transforms/external.py
@@ -777,6 +777,9 @@ class ExpansionAndArtifactRetrievalStub(
     return beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceStub(
         self._channel, **self._kwargs)
 
+  def ready(self, timeout_sec):
+    grpc.channel_ready_future(self._channel).result(timeout=timeout_sec)
+
 
 class JavaJarExpansionService(object):
   """An expansion service based on an Java Jar file.
diff --git a/sdks/python/apache_beam/yaml/__init__.py 
b/sdks/python/apache_beam/yaml/__init__.py
new file mode 100644
index 00000000000..73dd57dd7a6
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/__init__.py
@@ -0,0 +1,18 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from apache_beam.yaml.yaml_transform import *
diff --git a/sdks/python/apache_beam/yaml/main.py 
b/sdks/python/apache_beam/yaml/main.py
new file mode 100644
index 00000000000..7bdeccbc2db
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/main.py
@@ -0,0 +1,63 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import argparse
+
+import yaml
+
+import apache_beam as beam
+from apache_beam.yaml import yaml_transform
+
+
+def run(argv=None):
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--pipeline_spec',
+      description='A yaml description of the pipeline to run.')
+  parser.add_argument(
+      '--pipeline_spec_file',
+      description='A file containing a yaml description of the pipeline to 
run.'
+  )
+  known_args, pipeline_args = parser.parse_known_args(argv)
+
+  if known_args.pipeline_spec_file and known_args.pipeline_spec:
+    raise ValueError(
+        "Exactly one of pipeline_spec or pipeline_spec_file must be set.")
+  elif known_args.pipeline_spec_file:
+    with open(known_args.pipeline_spec_file) as fin:
+      pipeline_yaml = fin.read()
+  elif known_args.pipeline_spec:
+    pipeline_yaml = known_args.pipeline_spec
+  else:
+    raise ValueError(
+        "Exactly one of pipeline_spec or pipeline_spec_file must be set.")
+
+  pipeline_spec = yaml.load(pipeline_yaml, 
Loader=yaml_transform.SafeLineLoader)
+
+  yaml_transform._LOGGER.setLevel('INFO')
+
+  with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+      pipeline_args,
+      pickle_library='cloudpickle',
+      **pipeline_spec.get('options', {}))) as p:
+    print("Building pipeline...")
+    yaml_transform.expand_pipeline(p, known_args.pipeline_spec)
+    print("Running pipeline...")
+
+
+if __name__ == '__main__':
+  run()
diff --git a/sdks/python/apache_beam/yaml/pipeline.schema.yaml 
b/sdks/python/apache_beam/yaml/pipeline.schema.yaml
new file mode 100644
index 00000000000..36c35af49d8
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/pipeline.schema.yaml
@@ -0,0 +1,130 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+$schema: 'http://json-schema.org/schema#'
+$id: 
https://github.com/apache/beam/tree/master/sdks/python/apache_beam/yaml/pipeline.schema.yaml
+
+$defs:
+
+  transformBase:
+    type: object
+    properties:
+      type: { type: string }
+      name: { type: string }
+      input:
+        oneOf:
+          - type: string
+          - type: object
+            additionalProperties:
+              type: string
+      output:
+        oneOf:
+          - type: string
+          - type: object
+            additionalProperties:
+              type: string
+    additionalProperties: true
+    required:
+      - type
+
+  leafTransform:
+    allOf:
+      - $ref: '#/$defs/transformBase'
+      - type: object
+        properties:
+          type:
+            not:
+              anyOf:
+                - const: chain
+                - const: composite
+          args:
+            type: object
+
+  chainTransform:
+    allOf:
+      - $ref: '#/$defs/transformBase'
+      - type: object
+        properties:
+          type:
+            const: chain
+          name: {}
+          input: {}
+          output: {}
+          transforms:
+            type: array
+            items:
+              allOf:
+                - $ref: '#/$defs/transform'
+                - type: object
+                  properties:
+                    # Must be implicit.
+                    input: { not: {} }
+                    output: { not: {} }
+        additionalProperties: false
+        required:
+          - transforms
+
+  compositeTransform:
+    allOf:
+      - $ref: '#/$defs/transformBase'
+      - type: object
+        properties:
+          type:
+            const: composite
+          name: {}
+          input: {}
+          output: {}
+          transforms:
+            type: array
+            items:
+              $ref: '#/$defs/transform'
+        additionalProperties: false
+        required:
+          - transforms
+
+  transform:
+    oneOf:
+      - $ref: '#/$defs/leafTransform'
+      - $ref: '#/$defs/chainTransform'
+      - $ref: '#/$defs/compositeTransform'
+
+  provider:
+    # TODO(robertwb): Consider enumerating the provider types along with
+    # the arguments they accept/expect (possibly in a separate schema file).
+    type: object
+    properties:
+      type: { type: string }
+      transforms:
+        type: object
+        additionalProperties:
+          type: string
+    required:
+      - type
+      - transforms
+
+type: object
+properties:
+  pipeline:
+    type: array
+    items:
+      $ref: '#/$defs/transform'
+  providers:
+    type: array
+    items:
+      $ref: '#/$defs/provider'
+required:
+  - pipeline
diff --git a/sdks/python/apache_beam/yaml/standard_providers.yaml 
b/sdks/python/apache_beam/yaml/standard_providers.yaml
new file mode 100644
index 00000000000..04d33b53d64
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/standard_providers.yaml
@@ -0,0 +1,25 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# TODO(robertwb): Add more providers.
+# TODO(robertwb): Perhaps auto-generate this file?
+
+- type: 'beamJar'
+  gradle_target: 'sdks:java:extensions:sql:expansion-service:shadowJar'
+  version: BEAM_VERSION
+  transforms:
+     Sql: 'beam:external:java:sql:v1'
diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py 
b/sdks/python/apache_beam/yaml/yaml_provider.py
new file mode 100644
index 00000000000..a558de3507a
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_provider.py
@@ -0,0 +1,437 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""This module defines Providers usable from yaml, which is a specification
+for where to find and how to invoke services that vend implementations of
+various PTransforms."""
+
+import collections
+import hashlib
+import json
+import os
+import subprocess
+import sys
+import uuid
+from typing import Any
+from typing import Iterable
+from typing import Mapping
+
+import yaml
+from yaml.loader import SafeLoader
+
+import apache_beam as beam
+import apache_beam.dataframe.io
+import apache_beam.io
+import apache_beam.transforms.util
+from apache_beam.portability.api import schema_pb2
+from apache_beam.transforms import external
+from apache_beam.transforms.fully_qualified_named_transform import 
FullyQualifiedNamedTransform
+from apache_beam.typehints import schemas
+from apache_beam.typehints import trivial_inference
+from apache_beam.utils import python_callable
+from apache_beam.utils import subprocess_server
+from apache_beam.version import __version__ as beam_version
+
+
+class Provider:
+  """Maps transform types names and args to concrete PTransform instances."""
+  def available(self) -> bool:
+    """Returns whether this provider is available to use in this 
environment."""
+    raise NotImplementedError(type(self))
+
+  def provided_transforms(self) -> Iterable[str]:
+    """Returns a list of transform type names this provider can handle."""
+    raise NotImplementedError(type(self))
+
+  def create_transform(
+      self, typ: str, args: Mapping[str, Any]) -> beam.PTransform:
+    """Creates a PTransform instance for the given transform type and 
arguments.
+    """
+    raise NotImplementedError(type(self))
+
+
+class ExternalProvider(Provider):
+  """A Provider implemented via the cross language transform service."""
+  def __init__(self, urns, service):
+    self._urns = urns
+    self._service = service
+    self._schema_transforms = None
+
+  def provided_transforms(self):
+    return self._urns.keys()
+
+  def create_transform(self, type, args):
+    if callable(self._service):
+      self._service = self._service()
+    if self._schema_transforms is None:
+      try:
+        self._schema_transforms = [
+            config.identifier
+            for config in external.SchemaAwareExternalTransform.discover(
+                self._service)
+        ]
+      except Exception:
+        self._schema_transforms = []
+    urn = self._urns[type]
+    if urn in self._schema_transforms:
+      return external.SchemaAwareExternalTransform(urn, self._service, **args)
+    else:
+      return type >> self.create_external_transform(urn, args)
+
+  def create_external_transform(self, urn, args):
+    return external.ExternalTransform(
+        urn,
+        external.ImplicitSchemaPayloadBuilder(args).payload(),
+        self._service)
+
+  @staticmethod
+  def provider_from_spec(spec):
+    urns = spec['transforms']
+    type = spec['type']
+    if spec.get('version', None) == 'BEAM_VERSION':
+      spec['version'] = beam_version
+    if type == 'jar':
+      return ExternalJavaProvider(urns, lambda: spec['jar'])
+    elif type == 'mavenJar':
+      return ExternalJavaProvider(
+          urns,
+          lambda: subprocess_server.JavaJarServer.path_to_maven_jar(
+              **{
+                  key: value
+                  for (key, value) in spec.items() if key in [
+                      'artifact_id',
+                      'group_id',
+                      'version',
+                      'repository',
+                      'classifier',
+                      'appendix'
+                  ]
+              }))
+    elif type == 'beamJar':
+      return ExternalJavaProvider(
+          urns,
+          lambda: subprocess_server.JavaJarServer.path_to_beam_jar(
+              **{
+                  key: value
+                  for (key, value) in spec.items() if key in
+                  ['gradle_target', 'version', 'appendix', 'artifact_id']
+              }))
+    elif type == 'pypi':
+      return ExternalPythonProvider(urns, spec['packages'])
+    elif type == 'remote':
+      return RemoteProvider(spec['address'])
+    elif type == 'docker':
+      raise NotImplementedError()
+    else:
+      raise NotImplementedError(f'Unknown provider type: {type}')
+
+
+class RemoteProvider(ExternalProvider):
+  _is_available = None
+
+  def available(self):
+    if self._is_available is None:
+      try:
+        with external.ExternalTransform.service(self._service) as service:
+          service.ready(1)
+          self._is_available = True
+      except Exception:
+        self._is_available = False
+    return self._is_available
+
+
+class ExternalJavaProvider(ExternalProvider):
+  def __init__(self, urns, jar_provider):
+    super().__init__(
+        urns, lambda: external.JavaJarExpansionService(jar_provider()))
+
+  def available(self):
+    # pylint: disable=subprocess-run-check
+    return subprocess.run(['which', 'java'],
+                          capture_output=True).returncode == 0
+
+
+class ExternalPythonProvider(ExternalProvider):
+  def __init__(self, urns, packages):
+    super().__init__(urns, PypiExpansionService(packages))
+
+  def available(self):
+    return True  # If we're running this script, we have Python installed.
+
+  def create_external_transform(self, urn, args):
+    # Python transforms are "registered" by fully qualified name.
+    return external.ExternalTransform(
+        "beam:transforms:python:fully_qualified_named",
+        external.ImplicitSchemaPayloadBuilder({
+            'constructor': urn,
+            'kwargs': args,
+        }).payload(),
+        self._service)
+
+
+# This is needed because type inference can't handle *args, **kwargs fowarding.
+# TODO(BEAM-24755): Add support for type inference of through kwargs calls.
+def fix_pycallable():
+  from apache_beam.transforms.ptransform import label_from_callable
+
+  def default_label(self):
+    src = self._source.strip()
+    last_line = src.split('\n')[-1]
+    if last_line[0] != ' ' and len(last_line) < 72:
+      return last_line
+    return label_from_callable(self._callable)
+
+  def _argspec_fn(self):
+    return self._callable
+
+  python_callable.PythonCallableWithSource.default_label = default_label
+  python_callable.PythonCallableWithSource._argspec_fn = property(_argspec_fn)
+
+  original_infer_return_type = trivial_inference.infer_return_type
+
+  def infer_return_type(fn, *args, **kwargs):
+    if isinstance(fn, python_callable.PythonCallableWithSource):
+      fn = fn._callable
+    return original_infer_return_type(fn, *args, **kwargs)
+
+  trivial_inference.infer_return_type = infer_return_type
+
+  original_fn_takes_side_inputs = (
+      apache_beam.transforms.util.fn_takes_side_inputs)
+
+  def fn_takes_side_inputs(fn):
+    if isinstance(fn, python_callable.PythonCallableWithSource):
+      fn = fn._callable
+    return original_fn_takes_side_inputs(fn)
+
+  apache_beam.transforms.util.fn_takes_side_inputs = fn_takes_side_inputs
+
+
+class InlineProvider(Provider):
+  def __init__(self, transform_factories):
+    self._transform_factories = transform_factories
+
+  def available(self):
+    return True
+
+  def provided_transforms(self):
+    return self._transform_factories.keys()
+
+  def create_transform(self, type, args):
+    return self._transform_factories[type](**args)
+
+  def to_json(self):
+    return {'type': "InlineProvider"}
+
+
+PRIMITIVE_NAMES_TO_ATOMIC_TYPE = {
+    py_type.__name__: schema_type
+    for (py_type, schema_type) in schemas.PRIMITIVE_TO_ATOMIC_TYPE.items()
+    if py_type.__module__ != 'typing'
+}
+
+
+def create_builtin_provider():
+  def with_schema(**args):
+    # TODO: This is preliminary.
+    def parse_type(spec):
+      if spec in PRIMITIVE_NAMES_TO_ATOMIC_TYPE:
+        return schema_pb2.FieldType(
+            atomic_type=PRIMITIVE_NAMES_TO_ATOMIC_TYPE[spec])
+      elif isinstance(spec, list):
+        if len(spec) != 1:
+          raise ValueError("Use single-element lists to denote list types.")
+        else:
+          return schema_pb2.FieldType(
+              iterable_type=schema_pb2.IterableType(
+                  element_type=parse_type(spec[0])))
+      elif isinstance(spec, dict):
+        return schema_pb2.FieldType(
+            iterable_type=schema_pb2.RowType(schema=parse_schema(spec[0])))
+      else:
+        raise ValueError("Unknown schema type: {spec}")
+
+    def parse_schema(spec):
+      return schema_pb2.Schema(
+          fields=[
+              schema_pb2.Field(name=key, type=parse_type(value), id=ix)
+              for (ix, (key, value)) in enumerate(spec.items())
+          ],
+          id=str(uuid.uuid4()))
+
+    named_tuple = schemas.named_tuple_from_schema(parse_schema(args))
+    names = list(args.keys())
+
+    def extract_field(x, name):
+      if isinstance(x, dict):
+        return x[name]
+      else:
+        return getattr(x, name)
+
+    return 'WithSchema(%s)' % ', '.join(names) >> beam.Map(
+        lambda x: named_tuple(*[extract_field(x, name) for name in names])
+    ).with_output_types(named_tuple)
+
+  # Or should this be posargs, args?
+  # pylint: disable=dangerous-default-value
+  def fully_qualified_named_transform(constructor, args=(), kwargs={}):
+    with FullyQualifiedNamedTransform.with_filter('*'):
+      return constructor >> FullyQualifiedNamedTransform(
+          constructor, args, kwargs)
+
+  # This intermediate is needed because there is no way to specify a tuple of
+  # exactly zero or one PCollection in yaml (as they would be interpreted as
+  # PBegin and the PCollection itself respectively).
+  class Flatten(beam.PTransform):
+    def expand(self, pcolls):
+      if isinstance(pcolls, beam.PCollection):
+        pipeline_arg = {}
+        pcolls = (pcolls, )
+      elif isinstance(pcolls, dict):
+        pipeline_arg = {}
+        pcolls = tuple(pcolls.values())
+      else:
+        pipeline_arg = {'pipeline': pcolls.pipeline}
+        pcolls = ()
+      return pcolls | beam.Flatten(**pipeline_arg)
+
+  ios = {
+      key: getattr(apache_beam.io, key)
+      for key in dir(apache_beam.io)
+      if key.startswith('ReadFrom') or key.startswith('WriteTo')
+  }
+  ios['ReadFromCsv'] = lambda **kwargs: apache_beam.dataframe.io.ReadViaPandas(
+      'csv', **kwargs)
+  ios['WriteToCsv'] = lambda **kwargs: apache_beam.dataframe.io.WriteViaPandas(
+      'csv', **kwargs)
+  ios['ReadFromJson'] = (
+      lambda *,
+      orient='records',
+      lines=True,
+      **kwargs: apache_beam.dataframe.io.ReadViaPandas(
+          'json', orient=orient, lines=lines, **kwargs))
+  ios['WriteToJson'] = (
+      lambda *,
+      orient='records',
+      lines=True,
+      **kwargs: apache_beam.dataframe.io.WriteViaPandas(
+          'json', orient=orient, lines=lines, **kwargs))
+
+  return InlineProvider(
+      dict({
+          'Create': lambda elements,
+          reshuffle=True: beam.Create(elements, reshuffle),
+          'PyMap': lambda fn: beam.Map(
+              python_callable.PythonCallableWithSource(fn)),
+          'PyMapTuple': lambda fn: beam.MapTuple(
+              python_callable.PythonCallableWithSource(fn)),
+          'PyFlatMap': lambda fn: beam.FlatMap(
+              python_callable.PythonCallableWithSource(fn)),
+          'PyFlatMapTuple': lambda fn: beam.FlatMapTuple(
+              python_callable.PythonCallableWithSource(fn)),
+          'PyFilter': lambda fn: beam.Filter(
+              python_callable.PythonCallableWithSource(fn)),
+          'PyTransform': fully_qualified_named_transform,
+          'PyToRow': lambda fields: beam.Select(
+              **{
+                  name: python_callable.PythonCallableWithSource(fn)
+                  for (name, fn) in fields.items()
+              }),
+          'WithSchema': with_schema,
+          'Flatten': Flatten,
+          'GroupByKey': beam.GroupByKey,
+      },
+           **ios))
+
+
+class PypiExpansionService:
+  """Expands transforms by fully qualified name in a virtual environment
+  with the given dependencies.
+  """
+  VENV_CACHE = os.path.expanduser("~/.apache_beam/cache/venvs")
+
+  def __init__(self, packages, base_python=sys.executable):
+    self._packages = packages
+    self._base_python = base_python
+
+  def _key(self):
+    return json.dumps({'binary': self._base_python, 'packages': 
self._packages})
+
+  def _venv(self):
+    venv = os.path.join(
+        self.VENV_CACHE,
+        hashlib.sha256(self._key().encode('utf-8')).hexdigest())
+    if not os.path.exists(venv):
+      python_binary = os.path.join(venv, 'bin', 'python')
+      subprocess.run([self._base_python, '-m', 'venv', venv], check=True)
+      subprocess.run([python_binary, '-m', 'ensurepip'], check=True)
+      subprocess.run([python_binary, '-m', 'pip', 'install'] + self._packages,
+                     check=True)
+      with open(venv + '-requirements.txt', 'w') as fout:
+        fout.write('\n'.join(self._packages))
+    return venv
+
+  def __enter__(self):
+    venv = self._venv()
+    self._service_provider = subprocess_server.SubprocessServer(
+        external.ExpansionAndArtifactRetrievalStub,
+        [
+            os.path.join(venv, 'bin', 'python'),
+            '-m',
+            'apache_beam.runners.portability.expansion_service_main',
+            '--port',
+            '{{PORT}}',
+            '--fully_qualified_name_glob=*',
+            '--pickle_library=cloudpickle',
+            '--requirements_file=' + os.path.join(venv + '-requirements.txt')
+        ])
+    self._service = self._service_provider.__enter__()
+    return self._service
+
+  def __exit__(self, *args):
+    self._service_provider.__exit__(*args)
+    self._service = None
+
+
+def parse_providers(provider_specs):
+  providers = collections.defaultdict(list)
+  for provider_spec in provider_specs:
+    provider = ExternalProvider.provider_from_spec(provider_spec)
+    for transform_type in provider.provided_transforms():
+      providers[transform_type].append(provider)
+      # TODO: Do this better.
+      provider.to_json = lambda result=provider_spec: result
+  return providers
+
+
+def merge_providers(*provider_sets):
+  result = collections.defaultdict(list)
+  for provider_set in provider_sets:
+    for transform_type, providers in provider_set.items():
+      result[transform_type].extend(providers)
+  return result
+
+
+def standard_providers():
+  builtin_providers = collections.defaultdict(list)
+  builtin_provider = create_builtin_provider()
+  for transform_type in builtin_provider.provided_transforms():
+    builtin_providers[transform_type].append(builtin_provider)
+  with open(os.path.join(os.path.dirname(__file__),
+                         'standard_providers.yaml')) as fin:
+    standard_providers = yaml.load(fin, Loader=SafeLoader)
+  return merge_providers(builtin_providers, 
parse_providers(standard_providers))
diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py 
b/sdks/python/apache_beam/yaml/yaml_transform.py
new file mode 100644
index 00000000000..1e8495e308a
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_transform.py
@@ -0,0 +1,450 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# This module is experimental. No backwards-compatibility guarantees.
+
+import collections
+import json
+import logging
+import re
+import uuid
+from typing import Iterable
+from typing import Mapping
+
+import yaml
+from yaml.loader import SafeLoader
+
+import apache_beam as beam
+from apache_beam.transforms.fully_qualified_named_transform import 
FullyQualifiedNamedTransform
+from apache_beam.yaml import yaml_provider
+
+__all__ = ["YamlTransform"]
+
+_LOGGER = logging.getLogger(__name__)
+yaml_provider.fix_pycallable()
+
+
+def memoize_method(func):
+  def wrapper(self, *args):
+    if not hasattr(self, '_cache'):
+      self._cache = {}
+    key = func.__name__, args
+    if key not in self._cache:
+      self._cache[key] = func(self, *args)
+    return self._cache[key]
+
+  return wrapper
+
+
+def only_element(xs):
+  x, = xs
+  return x
+
+
+class SafeLineLoader(SafeLoader):
+  """A yaml loader that attaches line information to mappings and strings."""
+  class TaggedString(str):
+    """A string class to which we can attach metadata.
+
+    This is primarily used to trace a string's origin back to its place in a
+    yaml file.
+    """
+    def __reduce__(self):
+      # Pickle as an ordinary string.
+      return str, (str(self), )
+
+  def construct_scalar(self, node):
+    value = super().construct_scalar(node)
+    if isinstance(value, str):
+      value = SafeLineLoader.TaggedString(value)
+      value._line_ = node.start_mark.line + 1
+    return value
+
+  def construct_mapping(self, node, deep=False):
+    mapping = super().construct_mapping(node, deep=deep)
+    mapping['__line__'] = node.start_mark.line + 1
+    mapping['__uuid__'] = str(uuid.uuid4())
+    return mapping
+
+  @classmethod
+  def strip_metadata(cls, spec, tagged_str=True):
+    if isinstance(spec, Mapping):
+      return {
+          key: cls.strip_metadata(value, tagged_str)
+          for key,
+          value in spec.items() if key not in ('__line__', '__uuid__')
+      }
+    elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)):
+      return [cls.strip_metadata(value, tagged_str) for value in spec]
+    elif isinstance(spec, SafeLineLoader.TaggedString) and tagged_str:
+      return str(spec)
+    else:
+      return spec
+
+  @staticmethod
+  def get_line(obj):
+    if isinstance(obj, dict):
+      return obj.get('__line__', 'unknown')
+    else:
+      return getattr(obj, '_line_', 'unknown')
+
+
+class Scope(object):
+  """To look up PCollections (typically outputs of prior transforms) by 
name."""
+  def __init__(self, root, inputs, transforms, providers):
+    self.root = root
+    self.providers = providers
+    self._inputs = inputs
+    self._transforms = transforms
+    self._transforms_by_uuid = {t['__uuid__']: t for t in self._transforms}
+    self._uuid_by_name = collections.defaultdict(list)
+    for spec in self._transforms:
+      if 'name' in spec:
+        self._uuid_by_name[spec['name']].append(spec['__uuid__'])
+      if 'type' in spec:
+        self._uuid_by_name[spec['type']].append(spec['__uuid__'])
+    self._seen_names = set()
+
+  def compute_all(self):
+    for transform_id in self._transforms_by_uuid.keys():
+      self.compute_outputs(transform_id)
+
+  def get_pcollection(self, name):
+    if name in self._inputs:
+      return self._inputs[name]
+    elif '.' in name:
+      transform, output = name.rsplit('.', 1)
+      outputs = self.get_outputs(transform)
+      if output in outputs:
+        return outputs[output]
+      else:
+        raise ValueError(
+            f'Unknown output {repr(output)} '
+            f'at line {SafeLineLoader.get_line(name)}: '
+            f'{transform} only has outputs {list(outputs.keys())}')
+    else:
+      outputs = self.get_outputs(name)
+      if len(outputs) == 1:
+        return only_element(outputs.values())
+      else:
+        raise ValueError(
+            f'Ambiguous output at line {SafeLineLoader.get_line(name)}: '
+            f'{name} has outputs {list(outputs.keys())}')
+
+  def get_outputs(self, transform_name):
+    if transform_name in self._transforms_by_uuid:
+      transform_id = transform_name
+    else:
+      candidates = self._uuid_by_name[transform_name]
+      if not candidates:
+        raise ValueError(
+            f'Unknown transform at line '
+            f'{SafeLineLoader.get_line(transform_name)}: {transform_name}')
+      elif len(candidates) > 1:
+        raise ValueError(
+            f'Ambiguous transform at line '
+            f'{SafeLineLoader.get_line(transform_name)}: {transform_name}')
+      else:
+        transform_id = only_element(candidates)
+    return self.compute_outputs(transform_id)
+
+  @memoize_method
+  def compute_outputs(self, transform_id):
+    return expand_transform(self._transforms_by_uuid[transform_id], self)
+
+  # A method on scope as providers may be scoped...
+  def create_ptransform(self, spec):
+    if 'type' not in spec:
+      raise ValueError(f'Missing transform type: {identify_object(spec)}')
+
+    if spec['type'] not in self.providers:
+      raise ValueError(
+          'Unknown transform type %r at %s' %
+          (spec['type'], identify_object(spec)))
+
+    for provider in self.providers.get(spec['type']):
+      if provider.available():
+        break
+    else:
+      raise ValueError(
+          'No available provider for type %r at %s' %
+          (spec['type'], identify_object(spec)))
+
+    if 'args' in spec:
+      args = spec['args']
+      if not isinstance(args, dict):
+        raise ValueError(
+            'Arguments for transform at %s must be a mapping.' %
+            identify_object(spec))
+    else:
+      args = {
+          key: value
+          for (key, value) in spec.items()
+          if key not in ('type', 'name', 'input', 'output')
+      }
+    real_args = SafeLineLoader.strip_metadata(args)
+    try:
+      # pylint: disable=undefined-loop-variable
+      ptransform = provider.create_transform(spec['type'], real_args)
+      # TODO(robertwb): Should we have a better API for adding annotations
+      # than this?
+      annotations = dict(
+          yaml_type=spec['type'],
+          yaml_args=json.dumps(real_args),
+          yaml_provider=json.dumps(provider.to_json()),
+          **ptransform.annotations())
+      ptransform.annotations = lambda: annotations
+      return ptransform
+    except Exception as exn:
+      if isinstance(exn, TypeError):
+        # Create a slightly more generic error message for argument errors.
+        msg = str(exn).replace('positional', '').replace('keyword', '')
+        msg = re.sub(r'\S+lambda\S+', '', msg)
+        msg = re.sub('  +', ' ', msg).strip()
+      else:
+        msg = str(exn)
+      raise ValueError(
+          f'Invalid transform specification at {identify_object(spec)}: {msg}'
+      ) from exn
+
+  def unique_name(self, spec, ptransform, strictness=0):
+    if 'name' in spec:
+      name = spec['name']
+      strictness += 1
+    else:
+      name = ptransform.label
+    if name in self._seen_names:
+      if strictness >= 2:
+        raise ValueError(f'Duplicate name at {identify_object(spec)}: {name}')
+      else:
+        name = f'{name}@{SafeLineLoader.get_line(spec)}'
+    self._seen_names.add(name)
+    return name
+
+
+def expand_transform(spec, scope):
+  if 'type' not in spec:
+    raise TypeError(
+        f'Missing type parameter for transform at {identify_object(spec)}')
+  type = spec['type']
+  if type == 'composite':
+    return expand_composite_transform(spec, scope)
+  elif type == 'chain':
+    return expand_chain_transform(spec, scope)
+  else:
+    return expand_leaf_transform(spec, scope)
+
+
+def expand_leaf_transform(spec, scope):
+  spec = normalize_inputs_outputs(spec)
+  inputs_dict = {
+      key: scope.get_pcollection(value)
+      for (key, value) in spec['input'].items()
+  }
+  input_type = spec.get('input_type', 'default')
+  if input_type == 'list':
+    inputs = tuple(inputs_dict.values())
+  elif input_type == 'map':
+    inputs = inputs_dict
+  else:
+    if len(inputs_dict) == 0:
+      inputs = scope.root
+    elif len(inputs_dict) == 1:
+      inputs = next(iter(inputs_dict.values()))
+    else:
+      inputs = inputs_dict
+  _LOGGER.info("Expanding %s ", identify_object(spec))
+  ptransform = scope.create_ptransform(spec)
+  try:
+    # TODO: Move validation to construction?
+    with FullyQualifiedNamedTransform.with_filter('*'):
+      outputs = inputs | scope.unique_name(spec, ptransform) >> ptransform
+  except Exception as exn:
+    raise ValueError(
+        f"Errror apply transform {identify_object(spec)}: {exn}") from exn
+  if isinstance(outputs, dict):
+    # TODO: Handle (or at least reject) nested case.
+    return outputs
+  elif isinstance(outputs, (tuple, list)):
+    return {'out{ix}': pcoll for (ix, pcoll) in enumerate(outputs)}
+  elif isinstance(outputs, beam.PCollection):
+    return {'out': outputs}
+  else:
+    raise ValueError(
+        f'Transform {identify_object(spec)} returned an unexpected type '
+        f'{type(outputs)}')
+
+
+def expand_composite_transform(spec, scope):
+  spec = normalize_inputs_outputs(spec)
+
+  inner_scope = Scope(
+      scope.root, {
+          key: scope.get_pcollection(value)
+          for key,
+          value in spec['input'].items()
+      },
+      spec['transforms'],
+      yaml_provider.merge_providers(
+          yaml_provider.parse_providers(spec.get('providers', [])),
+          scope.providers))
+
+  class CompositePTransform(beam.PTransform):
+    @staticmethod
+    def expand(inputs):
+      inner_scope.compute_all()
+      return {
+          key: inner_scope.get_pcollection(value)
+          for (key, value) in spec['output'].items()
+      }
+
+  if 'name' not in spec:
+    spec['name'] = 'Composite'
+  if spec['name'] is None:  # top-level pipeline, don't nest
+    return CompositePTransform.expand(None)
+  else:
+    _LOGGER.info("Expanding %s ", identify_object(spec))
+    return ({
+        key: scope.get_pcollection(value)
+        for key,
+        value in spec['input'].items()
+    } or scope.root) | scope.unique_name(spec, None) >> CompositePTransform()
+
+
+def expand_chain_transform(spec, scope):
+  return expand_composite_transform(chain_as_composite(spec), scope)
+
+
+def chain_as_composite(spec):
+  # A chain is simply a composite transform where all inputs and outputs
+  # are implicit.
+  if 'transforms' not in spec:
+    raise TypeError(
+        f"Chain at {identify_object(spec)} missing transforms property.")
+  has_explicit_outputs = 'output' in spec
+  composite_spec = normalize_inputs_outputs(spec)
+  new_transforms = []
+  for ix, transform in enumerate(composite_spec['transforms']):
+    if any(io in transform for io in ('input', 'output', 'input', 'output')):
+      raise ValueError(
+          f'Transform {identify_object(transform)} is part of a chain, '
+          'must have implicit inputs and outputs.')
+    if ix == 0:
+      transform['input'] = {key: key for key in composite_spec['input'].keys()}
+    else:
+      transform['input'] = new_transforms[-1]['__uuid__']
+    new_transforms.append(transform)
+  composite_spec['transforms'] = new_transforms
+
+  last_transform = new_transforms[-1]['__uuid__']
+  if has_explicit_outputs:
+    composite_spec['output'] = {
+        key: f'{last_transform}.{value}'
+        for (key, value) in composite_spec['output'].items()
+    }
+  else:
+    composite_spec['output'] = last_transform
+  if 'name' not in composite_spec:
+    composite_spec['name'] = 'Chain'
+  composite_spec['type'] = 'composite'
+  return composite_spec
+
+
+def pipeline_as_composite(spec):
+  if isinstance(spec, list):
+    return {
+        'type': 'composite',
+        'name': None,
+        'transforms': spec,
+        '__line__': spec[0]['__line__'],
+        '__uuid__': str(uuid.uuid4()),
+    }
+  else:
+    return dict(spec, name=None, type='composite')
+
+
+def normalize_inputs_outputs(spec):
+  spec = dict(spec)
+
+  def normalize_io(tag):
+    io = spec.get(tag, {})
+    if isinstance(io, str):
+      return {tag: io}
+    elif isinstance(io, list):
+      return {f'{tag}{ix}': value for ix, value in enumerate(io)}
+    else:
+      return SafeLineLoader.strip_metadata(io, tagged_str=False)
+
+  return dict(spec, input=normalize_io('input'), output=normalize_io('output'))
+
+
+def identify_object(spec):
+  line = SafeLineLoader.get_line(spec)
+  name = extract_name(spec)
+  if name:
+    return f'"{name}" at line {line}'
+  else:
+    return f'at line {line}'
+
+
+def extract_name(spec):
+  if 'name' in spec:
+    return spec['name']
+  elif 'id' in spec:
+    return spec['id']
+  elif 'type' in spec:
+    return spec['type']
+  elif len(spec) == 1:
+    return extract_name(next(iter(spec.values())))
+  else:
+    return ''
+
+
+class YamlTransform(beam.PTransform):
+  def __init__(self, spec, providers={}):  # pylint: 
disable=dangerous-default-value
+    if isinstance(spec, str):
+      spec = yaml.load(spec, Loader=SafeLineLoader)
+    self._spec = spec
+    self._providers = yaml_provider.merge_providers(
+        providers, yaml_provider.standard_providers())
+
+  def expand(self, pcolls):
+    if isinstance(pcolls, beam.pvalue.PBegin):
+      root = pcolls
+      pcolls = {}
+    elif isinstance(pcolls, beam.PCollection):
+      root = pcolls.pipeline
+      pcolls = {'input': pcolls}
+    else:
+      root = next(iter(pcolls.values())).pipeline
+    result = expand_transform(
+        self._spec,
+        Scope(root, pcolls, transforms=[], providers=self._providers))
+    if len(result) == 1:
+      return only_element(result.values())
+    else:
+      return result
+
+
+def expand_pipeline(pipeline, pipeline_spec):
+  if isinstance(pipeline_spec, str):
+    pipeline_spec = yaml.load(pipeline_spec, Loader=SafeLineLoader)
+  # Calling expand directly to avoid outer layer of nesting.
+  return YamlTransform(
+      pipeline_as_composite(pipeline_spec['pipeline']),
+      yaml_provider.parse_providers(pipeline_spec.get('providers', 
[]))).expand(
+          beam.pvalue.PBegin(pipeline))
diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py 
b/sdks/python/apache_beam/yaml/yaml_transform_test.py
new file mode 100644
index 00000000000..e3b7097df24
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+
+import apache_beam as beam
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.yaml.yaml_transform import YamlTransform
+
+
+class YamlTransformTest(unittest.TestCase):
+  def test_composite(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create([1, 2, 3])
+      # TODO(robertwb): Consider making the input implicit (and below).
+      result = elements | YamlTransform(
+          '''
+          type: composite
+          input:
+              elements: input
+          transforms:
+            - type: PyMap
+              name: Square
+              input: elements
+              fn: "lambda x: x * x"
+            - type: PyMap
+              name: Cube
+              input: elements
+              fn: "lambda x: x * x * x"
+            - type: Flatten
+              input: [Square, Cube]
+          output:
+              Flatten
+          ''')
+      assert_that(result, equal_to([1, 4, 9, 1, 8, 27]))
+
+  def test_chain_with_input(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create(range(10))
+      result = elements | YamlTransform(
+          '''
+          type: chain
+          input:
+              elements: input
+          transforms:
+            - type: PyMap
+              fn: "lambda x: x * x + x"
+            - type: PyMap
+              fn: "lambda x: x + 41"
+          ''')
+      assert_that(result, equal_to([41, 43, 47, 53, 61, 71, 83, 97, 113, 131]))
+
+  def test_chain_with_root(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      result = p | YamlTransform(
+          '''
+          type: chain
+          transforms:
+            - type: Create
+              elements: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+            - type: PyMap
+              fn: "lambda x: x * x + x"
+            - type: PyMap
+              fn: "lambda x: x + 41"
+          ''')
+      assert_that(result, equal_to([41, 43, 47, 53, 61, 71, 83, 97, 113, 131]))
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()


Reply via email to