This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 49ad65c1222 [YAML] Add PubSub reading and writing transforms. (#28595)
49ad65c1222 is described below

commit 49ad65c1222eb85cc5a698aa3232423613a076ac
Author: Robert Bradshaw <[email protected]>
AuthorDate: Mon Sep 25 16:04:47 2023 -0700

    [YAML] Add PubSub reading and writing transforms. (#28595)
---
 sdks/python/apache_beam/testing/util.py        |   7 +
 sdks/python/apache_beam/transforms/core.py     |  37 ++--
 sdks/python/apache_beam/typehints/schemas.py   |   9 +
 sdks/python/apache_beam/yaml/standard_io.yaml  |   2 +
 sdks/python/apache_beam/yaml/yaml_io.py        | 225 +++++++++++++++++++++-
 sdks/python/apache_beam/yaml/yaml_io_test.py   | 257 +++++++++++++++++++++++++
 sdks/python/apache_beam/yaml/yaml_transform.py |   2 +
 7 files changed, 526 insertions(+), 13 deletions(-)

diff --git a/sdks/python/apache_beam/testing/util.py 
b/sdks/python/apache_beam/testing/util.py
index 8c918128959..10a7a8e86f9 100644
--- a/sdks/python/apache_beam/testing/util.py
+++ b/sdks/python/apache_beam/testing/util.py
@@ -33,6 +33,7 @@ from apache_beam.transforms.core import Map
 from apache_beam.transforms.core import ParDo
 from apache_beam.transforms.core import WindowInto
 from apache_beam.transforms.ptransform import PTransform
+from apache_beam.transforms.ptransform import ptransform_fn
 from apache_beam.transforms.util import CoGroupByKey
 
 __all__ = [
@@ -308,6 +309,12 @@ def assert_that(
   actual | AssertThat()  # pylint: disable=expression-not-assigned
 
 
+@ptransform_fn
+def AssertThat(pcoll, *args, **kwargs):
+  """Like assert_that, but as an applicable PTransform."""
+  return assert_that(pcoll, *args, **kwargs)
+
+
 def open_shards(glob_pattern, mode='rt', encoding='utf-8'):
   """Returns a composite file of all shards matching the given glob pattern.
 
diff --git a/sdks/python/apache_beam/transforms/core.py 
b/sdks/python/apache_beam/transforms/core.py
index 671af54e47b..e980dccea74 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -2262,6 +2262,10 @@ class _PValueWithErrors(object):
   def element_type(self):
     return self._pcoll.element_type
 
+  @element_type.setter
+  def element_type(self, value):
+    self._pcoll.element_type = value
+
   def main_output_tag(self):
     return self._exception_handling_args.get('main_tag', 'good')
 
@@ -2272,17 +2276,24 @@ class _PValueWithErrors(object):
     return self.apply(transform)
 
   def apply(self, transform):
-    result = self._pcoll | transform.with_exception_handling(
-        **self._exception_handling_args)
-    if result[self.main_output_tag()].element_type == typehints.Any:
-      result[self.main_output_tag()].element_type = 
transform.infer_output_type(
-          self._pcoll.element_type)
-    # TODO(BEAM-18957): Add support for tagged type hints.
-    result[self.error_output_tag()].element_type = typehints.Any
-    return _PValueWithErrors(
-        result[self.main_output_tag()],
-        self._exception_handling_args,
-        self._upstream_errors + (result[self.error_output_tag()], ))
+    if hasattr(transform, 'with_exception_handling'):
+      result = self._pcoll | transform.with_exception_handling(
+          **self._exception_handling_args)
+      if result[self.main_output_tag()].element_type == typehints.Any:
+        result[
+            self.main_output_tag()].element_type = transform.infer_output_type(
+                self._pcoll.element_type)
+      # TODO(BEAM-18957): Add support for tagged type hints.
+      result[self.error_output_tag()].element_type = typehints.Any
+      return _PValueWithErrors(
+          result[self.main_output_tag()],
+          self._exception_handling_args,
+          self._upstream_errors + (result[self.error_output_tag()], ))
+    else:
+      return _PValueWithErrors(
+          self._pcoll | transform,
+          self._exception_handling_args,
+          self._upstream_errors)
 
   def accumulated_errors(self):
     if len(self._upstream_errors) == 1:
@@ -2317,6 +2328,10 @@ class _MaybePValueWithErrors(object):
   def element_type(self):
     return self._pvalue.element_type
 
+  @element_type.setter
+  def element_type(self, value):
+    self._pvalue.element_type = value
+
   def __or__(self, transform):
     return self.apply(transform)
 
diff --git a/sdks/python/apache_beam/typehints/schemas.py 
b/sdks/python/apache_beam/typehints/schemas.py
index 5b900f29668..229a8af20bb 100644
--- a/sdks/python/apache_beam/typehints/schemas.py
+++ b/sdks/python/apache_beam/typehints/schemas.py
@@ -227,6 +227,15 @@ def option_from_runner_api(
       schema_registry=schema_registry).option_from_runner_api(option_proto)
 
 
+def schema_field(
+    name: str, field_type: Union[schema_pb2.FieldType,
+                                 type]) -> schema_pb2.Field:
+  return schema_pb2.Field(
+      name=name,
+      type=field_type if isinstance(field_type, schema_pb2.FieldType) else
+      typing_to_runner_api(field_type))
+
+
 class SchemaTranslation(object):
   def __init__(self, schema_registry: SchemaTypeRegistry = SCHEMA_REGISTRY):
     self.schema_registry = schema_registry
diff --git a/sdks/python/apache_beam/yaml/standard_io.yaml 
b/sdks/python/apache_beam/yaml/standard_io.yaml
index 1738110539c..9ad4f53ba1f 100644
--- a/sdks/python/apache_beam/yaml/standard_io.yaml
+++ b/sdks/python/apache_beam/yaml/standard_io.yaml
@@ -53,6 +53,8 @@
     # 'WriteToBigQuery': 'apache_beam.yaml.yaml_io.write_to_bigquery'
     'ReadFromText': 'apache_beam.yaml.yaml_io.read_from_text'
     'WriteToText': 'apache_beam.yaml.yaml_io.write_to_text'
+    'ReadFromPubSub': 'apache_beam.yaml.yaml_io.read_from_pubsub'
+    'WriteToPubSub': 'apache_beam.yaml.yaml_io.write_to_pubsub'
 
 # Declared as a renaming transform to avoid exposing all
 # (implementation-specific) pandas arguments and aligning with possible Java
diff --git a/sdks/python/apache_beam/yaml/yaml_io.py 
b/sdks/python/apache_beam/yaml/yaml_io.py
index 297c07e9abb..4a1d1249005 100644
--- a/sdks/python/apache_beam/yaml/yaml_io.py
+++ b/sdks/python/apache_beam/yaml/yaml_io.py
@@ -24,6 +24,11 @@ implementations of the same transforms, the configs must be 
kept in sync.
 """
 
 import os
+from typing import Any
+from typing import Iterable
+from typing import List
+from typing import Mapping
+from typing import Optional
 
 import yaml
 
@@ -32,7 +37,9 @@ import apache_beam.io as beam_io
 from apache_beam.io import ReadFromBigQuery
 from apache_beam.io import WriteToBigQuery
 from apache_beam.io.gcp.bigquery import BigQueryDisposition
-from apache_beam.typehints.schemas import named_fields_from_element_type
+from apache_beam.portability.api import schema_pb2
+from apache_beam.typehints import schemas
+from apache_beam.yaml import yaml_mapping
 from apache_beam.yaml import yaml_provider
 
 
@@ -46,7 +53,8 @@ def read_from_text(path: str):
 def write_to_text(pcoll, path: str):
   try:
     field_names = [
-        name for name, _ in named_fields_from_element_type(pcoll.element_type)
+        name for name,
+        _ in schemas.named_fields_from_element_type(pcoll.element_type)
     ]
   except Exception as exn:
     raise ValueError(
@@ -123,6 +131,219 @@ def write_to_bigquery(
   return WriteToBigQueryHandlingErrors()
 
 
+def _create_parser(format, schema):
+  if format == 'raw':
+    if schema:
+      raise ValueError('raw format does not take a schema')
+    return (
+        schema_pb2.Schema(fields=[schemas.schema_field('payload', bytes)]),
+        lambda payload: beam.Row(payload=payload))
+  else:
+    raise ValueError(f'Unknown format: {format}')
+
+
+def _create_formatter(format, schema, beam_schema):
+  if format == 'raw':
+    if schema:
+      raise ValueError('raw format does not take a schema')
+    field_names = [field.name for field in beam_schema.fields]
+    if len(field_names) != 1:
+      raise ValueError(f'Expecting exactly one field, found {field_names}')
+    return lambda row: getattr(row, field_names[0])
+  else:
+    raise ValueError(f'Unknown format: {format}')
+
+
[email protected]_fn
+@yaml_mapping.maybe_with_exception_handling_transform_fn
+def read_from_pubsub(
+    root,
+    *,
+    topic: Optional[str] = None,
+    subscription: Optional[str] = None,
+    format: str,
+    schema: Optional[Any] = None,
+    attributes: Optional[Iterable[str]] = None,
+    attributes_map: Optional[str] = None,
+    id_attribute: Optional[str] = None,
+    timestamp_attribute: Optional[str] = None):
+  """Reads messages from Cloud Pub/Sub.
+
+  Args:
+    topic: Cloud Pub/Sub topic in the form
+      "projects/<project>/topics/<topic>". If provided, subscription must be
+      None.
+    subscription: Existing Cloud Pub/Sub subscription to use in the
+      form "projects/<project>/subscriptions/<subscription>". If not
+      specified, a temporary subscription will be created from the specified
+      topic. If provided, topic must be None.
+    format: The expected format of the message payload.  Currently suported
+      formats are
+
+        - raw: Produces records with a single `payload` field whose contents
+            are the raw bytes of the pubsub message.
+
+    schema: Schema specification for the given format.
+    attributes: List of attribute keys whose values will be flattened into the
+      output message as additional fields.  For example, if the format is `raw`
+      and attributes is `["a", "b"]` then this read will produce elements of
+      the form `Row(payload=..., a=..., b=...)`.
+    attribute_map: Name of a field in which to store the full set of attributes
+      associated with this message.  For example, if the format is `raw` and
+      `attribute_map` is set to `"attrs"` then this read will produce elements
+      of the form `Row(payload=..., attrs=...)` where `attrs` is a Map type
+      of string to string.
+      If both `attributes` and `attribute_map` are set, the overlapping
+      attribute values will be present in both the flattened structure and the
+      attribute map.
+    id_attribute: The attribute on incoming Pub/Sub messages to use as a unique
+      record identifier. When specified, the value of this attribute (which
+      can be any string that uniquely identifies the record) will be used for
+      deduplication of messages. If not provided, we cannot guarantee
+      that no duplicate data will be delivered on the Pub/Sub stream. In this
+      case, deduplication of the stream will be strictly best effort.
+    timestamp_attribute: Message value to use as element timestamp. If None,
+      uses message publishing time as the timestamp.
+
+      Timestamp values should be in one of two formats:
+
+      - A numerical value representing the number of milliseconds since the
+        Unix epoch.
+      - A string in RFC 3339 format, UTC timezone. Example:
+        ``2015-10-29T23:41:41.123Z``. The sub-second component of the
+        timestamp is optional, and digits beyond the first three (i.e., time
+        units smaller than milliseconds) may be ignored.
+  """
+  if topic and subscription:
+    raise TypeError('Only one of topic and subscription may be specified.')
+  elif not topic and not subscription:
+    raise TypeError('One of topic or subscription may be specified.')
+  payload_schema, parser = _create_parser(format, schema)
+  extra_fields: List[schema_pb2.Field] = []
+  if not attributes and not attributes_map:
+    mapper = lambda msg: parser(msg)
+  else:
+    if isinstance(attributes, str):
+      attributes = [attributes]
+    if attributes:
+      extra_fields.extend(
+          [schemas.schema_field(attr, str) for attr in attributes])
+    if attributes_map:
+      extra_fields.append(
+          schemas.schema_field(attributes_map, Mapping[str, str]))
+
+    def mapper(msg):
+      values = parser(msg.data).as_dict()
+      if attributes:
+        # Should missing attributes be optional or parse errors?
+        for attr in attributes:
+          values[attr] = msg.attributes[attr]
+      if attributes_map:
+        values[attributes_map] = msg.attributes
+      return beam.Row(**values)
+
+  output = (
+      root
+      | beam.io.ReadFromPubSub(
+          topic=topic,
+          subscription=subscription,
+          with_attributes=bool(attributes or attributes_map),
+          id_label=id_attribute,
+          timestamp_attribute=timestamp_attribute)
+      | 'ParseMessage' >> beam.Map(mapper))
+  output.element_type = schemas.named_tuple_from_schema(
+      schema_pb2.Schema(fields=list(payload_schema.fields) + extra_fields))
+  return output
+
+
[email protected]_fn
+@yaml_mapping.maybe_with_exception_handling_transform_fn
+def write_to_pubsub(
+    pcoll,
+    *,
+    topic: str,
+    format: str,
+    schema: Optional[Any] = None,
+    attributes: Optional[Iterable[str]] = None,
+    attributes_map: Optional[str] = None,
+    id_attribute: Optional[str] = None,
+    timestamp_attribute: Optional[str] = None):
+  """Writes messages from Cloud Pub/Sub.
+
+  Args:
+    topic: Cloud Pub/Sub topic in the form "/topics/<project>/<topic>".
+    format: How to format the message payload.  Currently suported
+      formats are
+
+        - raw: Expects a message with a single field (excluding
+            attribute-related fields )whose contents are used as the raw bytes
+            of the pubsub message.
+
+    schema: Schema specification for the given format.
+    attributes: List of attribute keys whose values will be pulled out as
+      PubSub message attributes.  For example, if the format is `raw`
+      and attributes is `["a", "b"]` then elements of the form
+      `Row(any_field=..., a=..., b=...)` will result in PubSub messages whose
+      payload has the contents of any_field and whose attribute will be
+      populated with the values of `a` and `b`.
+    attribute_map: Name of a string-to-string map field in which to pull a set
+      of attributes associated with this message.  For example, if the format
+      is `raw` and `attribute_map` is set to `"attrs"` then elements of the 
form
+      `Row(any_field=..., attrs=...)` will result in PubSub messages whose
+      payload has the contents of any_field and whose attribute will be
+      populated with the values from attrs.
+      If both `attributes` and `attribute_map` are set, the union of attributes
+      from these two sources will be used to populate the PubSub message
+      attributes.
+    id_attribute: If set, will set an attribute for each Cloud Pub/Sub message
+      with the given name and a unique value. This attribute can then be used
+      in a ReadFromPubSub PTransform to deduplicate messages.
+    timestamp_attribute: If set, will set an attribute for each Cloud Pub/Sub
+      message with the given name and the message's publish time as the value.
+  """
+  input_schema = schemas.schema_from_element_type(pcoll.element_type)
+
+  extra_fields: List[str] = []
+  if isinstance(attributes, str):
+    attributes = [attributes]
+  if attributes:
+    extra_fields.extend(attributes)
+  if attributes_map:
+    extra_fields.append(attributes_map)
+
+  def attributes_extractor(row):
+    if attributes_map:
+      attribute_values = dict(getattr(row, attributes_map))
+    else:
+      attribute_values = {}
+    if attributes:
+      attribute_values.update({attr: getattr(row, attr) for attr in 
attributes})
+    return attribute_values
+
+  schema_names = set(f.name for f in input_schema.fields)
+  missing_attribute_names = set(extra_fields) - schema_names
+  if missing_attribute_names:
+    raise ValueError(
+        f'Attribute fields {missing_attribute_names} '
+        f'not found in schema fields {schema_names}')
+
+  payload_schema = schema_pb2.Schema(
+      fields=[
+          field for field in input_schema.fields
+          if field.name not in extra_fields
+      ])
+  formatter = _create_formatter(format, schema, payload_schema)
+  return (
+      pcoll | beam.Map(
+          lambda row: beam.io.gcp.pubsub.PubsubMessage(
+              formatter(row), attributes_extractor(row)))
+      | beam.io.WriteToPubSub(
+          topic,
+          with_attributes=True,
+          id_label=id_attribute,
+          timestamp_attribute=timestamp_attribute))
+
+
 def io_providers():
   with open(os.path.join(os.path.dirname(__file__), 'standard_io.yaml')) as 
fin:
     return yaml_provider.parse_providers(yaml.load(fin, 
Loader=yaml.SafeLoader))
diff --git a/sdks/python/apache_beam/yaml/yaml_io_test.py 
b/sdks/python/apache_beam/yaml/yaml_io_test.py
new file mode 100644
index 00000000000..ab6298661c1
--- /dev/null
+++ b/sdks/python/apache_beam/yaml/yaml_io_test.py
@@ -0,0 +1,257 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+
+import mock
+
+import apache_beam as beam
+from apache_beam.io.gcp.pubsub import PubsubMessage
+from apache_beam.testing.util import AssertThat
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+from apache_beam.yaml.yaml_transform import YamlTransform
+
+
+class FakeReadFromPubSub:
+  def __init__(
+      self,
+      topic,
+      messages,
+      subscription=None,
+      id_attribute=None,
+      timestamp_attribute=None):
+    self._topic = topic
+    self._subscription = subscription
+    self._messages = messages
+    self._id_attribute = id_attribute
+    self._timestamp_attribute = timestamp_attribute
+
+  def __call__(
+      self,
+      *,
+      topic,
+      subscription,
+      with_attributes,
+      id_label,
+      timestamp_attribute):
+    assert topic == self._topic
+    assert id_label == self._id_attribute
+    assert timestamp_attribute == self._timestamp_attribute
+    assert subscription == self._subscription
+    if with_attributes:
+      data = self._messages
+    else:
+      data = [x.data for x in self._messages]
+    return beam.Create(data)
+
+
+class FakeWriteToPubSub:
+  def __init__(
+      self, topic, messages, id_attribute=None, timestamp_attribute=None):
+    self._topic = topic
+    self._messages = messages
+    self._id_attribute = id_attribute
+    self._timestamp_attribute = timestamp_attribute
+
+  def __call__(self, topic, *, with_attributes, id_label, timestamp_attribute):
+    assert topic == self._topic
+    assert with_attributes is True
+    assert id_label == self._id_attribute
+    assert timestamp_attribute == self._timestamp_attribute
+    return AssertThat(equal_to(self._messages))
+
+
+class YamlPubSubTest(unittest.TestCase):
+  def test_simple_read(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      with mock.patch('apache_beam.io.ReadFromPubSub',
+                      FakeReadFromPubSub(
+                          topic='my_topic',
+                          messages=[PubsubMessage(b'msg1', {'attr': 'value1'}),
+                                    PubsubMessage(b'msg2',
+                                                  {'attr': 'value2'})])):
+        result = p | YamlTransform(
+            '''
+            type: ReadFromPubSub
+            config:
+              topic: my_topic
+              format: raw
+            ''')
+        assert_that(
+            result,
+            equal_to([beam.Row(payload=b'msg1'), beam.Row(payload=b'msg2')]))
+
+  def test_read_with_attribute(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      with mock.patch('apache_beam.io.ReadFromPubSub',
+                      FakeReadFromPubSub(
+                          topic='my_topic',
+                          messages=[PubsubMessage(b'msg1', {'attr': 'value1'}),
+                                    PubsubMessage(b'msg2',
+                                                  {'attr': 'value2'})])):
+        result = p | YamlTransform(
+            '''
+            type: ReadFromPubSub
+            config:
+              topic: my_topic
+              format: raw
+              attributes: [attr]
+            ''')
+        assert_that(
+            result,
+            equal_to([
+                beam.Row(payload=b'msg1', attr='value1'),
+                beam.Row(payload=b'msg2', attr='value2')
+            ]))
+
+  def test_read_with_attribute_map(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      with mock.patch('apache_beam.io.ReadFromPubSub',
+                      FakeReadFromPubSub(
+                          topic='my_topic',
+                          messages=[PubsubMessage(b'msg1', {'attr': 'value1'}),
+                                    PubsubMessage(b'msg2',
+                                                  {'attr': 'value2'})])):
+        result = p | YamlTransform(
+            '''
+            type: ReadFromPubSub
+            config:
+              topic: my_topic
+              format: raw
+              attributes_map: attrMap
+            ''')
+        assert_that(
+            result,
+            equal_to([
+                beam.Row(payload=b'msg1', attrMap={'attr': 'value1'}),
+                beam.Row(payload=b'msg2', attrMap={'attr': 'value2'})
+            ]))
+
+  def test_read_with_id_attribute(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      with mock.patch('apache_beam.io.ReadFromPubSub',
+                      FakeReadFromPubSub(
+                          topic='my_topic',
+                          messages=[PubsubMessage(b'msg1', {'attr': 'value1'}),
+                                    PubsubMessage(b'msg2', {'attr': 
'value2'})],
+                          id_attribute='some_attr')):
+        result = p | YamlTransform(
+            '''
+            type: ReadFromPubSub
+            config:
+              topic: my_topic
+              format: raw
+              id_attribute: some_attr
+            ''')
+        assert_that(
+            result,
+            equal_to([beam.Row(payload=b'msg1'), beam.Row(payload=b'msg2')]))
+
+  def test_simple_write(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      with mock.patch('apache_beam.io.WriteToPubSub',
+                      FakeWriteToPubSub(topic='my_topic',
+                                        messages=[PubsubMessage(b'msg1', {}),
+                                                  PubsubMessage(b'msg2', 
{})])):
+        _ = (
+            p | beam.Create([beam.Row(a=b'msg1'), beam.Row(a=b'msg2')])
+            | YamlTransform(
+                '''
+            type: WriteToPubSub
+            input: input
+            config:
+              topic: my_topic
+              format: raw
+            '''))
+
+  def test_write_with_attribute(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      with mock.patch('apache_beam.io.WriteToPubSub',
+                      FakeWriteToPubSub(
+                          topic='my_topic',
+                          messages=[PubsubMessage(b'msg1', {'attr': 'value1'}),
+                                    PubsubMessage(b'msg2',
+                                                  {'attr': 'value2'})])):
+        _ = (
+            p | beam.Create([
+                beam.Row(a=b'msg1', attr='value1'),
+                beam.Row(a=b'msg2', attr='value2')
+            ]) | YamlTransform(
+                '''
+            type: WriteToPubSub
+            input: input
+            config:
+              topic: my_topic
+              format: raw
+              attributes: [attr]
+            '''))
+
+  def test_write_with_attribute_map(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      with mock.patch('apache_beam.io.WriteToPubSub',
+                      FakeWriteToPubSub(topic='my_topic',
+                                        messages=[PubsubMessage(b'msg1',
+                                                                {'a': 'b'}),
+                                                  PubsubMessage(b'msg2',
+                                                                {'c': 'd'})])):
+        _ = (
+            p | beam.Create([
+                beam.Row(a=b'msg1', attrMap={'a': 'b'}),
+                beam.Row(a=b'msg2', attrMap={'c': 'd'})
+            ]) | YamlTransform(
+                '''
+            type: WriteToPubSub
+            input: input
+            config:
+              topic: my_topic
+              format: raw
+              attributes_map: attrMap
+            '''))
+
+  def test_write_with_id_attribute(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      with mock.patch('apache_beam.io.WriteToPubSub',
+                      FakeWriteToPubSub(topic='my_topic',
+                                        messages=[PubsubMessage(b'msg1', {}),
+                                                  PubsubMessage(b'msg2', {})],
+                                        id_attribute='some_attr')):
+        _ = (
+            p | beam.Create([beam.Row(a=b'msg1'), beam.Row(a=b'msg2')])
+            | YamlTransform(
+                '''
+            type: WriteToPubSub
+            input: input
+            config:
+              topic: my_topic
+              format: raw
+              id_attribute: some_attr
+            '''))
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()
diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py 
b/sdks/python/apache_beam/yaml/yaml_transform.py
index 01f8b485343..fa30c183080 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform.py
@@ -460,6 +460,8 @@ def expand_leaf_transform(spec, scope):
     return {f'out{ix}': pcoll for (ix, pcoll) in enumerate(outputs)}
   elif isinstance(outputs, beam.PCollection):
     return {'out': outputs}
+  elif outputs is None:
+    return {}
   else:
     raise ValueError(
         f'Transform {identify_object(spec)} returned an unexpected type '

Reply via email to