This is an automated email from the ASF dual-hosted git repository.

xqhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new d26dbacc8e4 Add support for PROTO format in YAML Pub/Sub transform 
(#36185)
d26dbacc8e4 is described below

commit d26dbacc8e463bac1274f890bd8b740d2cfb09ec
Author: liferoad <[email protected]>
AuthorDate: Mon Sep 22 11:25:31 2025 -0400

    Add support for PROTO format in YAML Pub/Sub transform (#36185)
    
    * Add support for PROTO format in YAML Pub/Sub transform
    
    * Remove unused import of schema_utils in yaml_io.py and update 
YamlPubSubTest to use named_fields_to_schema for RowCoder.
    
    * Rename test_rw_proto to test_write_proto and add test_read_proto for 
PROTO format handling in YamlPubSubTest.
    
    * lints
---
 sdks/python/apache_beam/yaml/yaml_io.py      | 13 ++++++--
 sdks/python/apache_beam/yaml/yaml_io_test.py | 45 ++++++++++++++++++++++++++++
 2 files changed, 56 insertions(+), 2 deletions(-)

diff --git a/sdks/python/apache_beam/yaml/yaml_io.py 
b/sdks/python/apache_beam/yaml/yaml_io.py
index ffbc2b8db6b..ddf39935ebd 100644
--- a/sdks/python/apache_beam/yaml/yaml_io.py
+++ b/sdks/python/apache_beam/yaml/yaml_io.py
@@ -35,6 +35,7 @@ import fastavro
 import apache_beam as beam
 import apache_beam.io as beam_io
 from apache_beam import coders
+from apache_beam.coders.row_coder import RowCoder
 from apache_beam.io import ReadFromBigQuery
 from apache_beam.io import ReadFromTFRecord
 from apache_beam.io import WriteToBigQuery
@@ -247,6 +248,10 @@ def _create_parser(
         beam_schema,
         lambda record: covert_to_row(
             fastavro.schemaless_reader(io.BytesIO(record), schema)))  # type: 
ignore[call-arg]
+  elif format == 'PROTO':
+    _validate_schema()
+    beam_schema = json_utils.json_schema_to_beam_schema(schema)
+    return beam_schema, RowCoder(beam_schema).decode
   else:
     raise ValueError(f'Unknown format: {format}')
 
@@ -291,6 +296,8 @@ def _create_formatter(
       return buffer.read()
 
     return formatter
+  elif format == 'PROTO':
+    return RowCoder(beam_schema).encode
   else:
     raise ValueError(f'Unknown format: {format}')
 
@@ -416,7 +423,7 @@ def write_to_pubsub(
 
   Args:
     topic: Cloud Pub/Sub topic in the form "/topics/<project>/<topic>".
-    format: How to format the message payload.  Currently suported
+    format: How to format the message payload.  Currently supported
       formats are
 
         - RAW: Expects a message with a single field (excluding
@@ -426,6 +433,8 @@ def write_to_pubsub(
             from the input PCollection schema.
         - JSON: Formats records with a given JSON schema, which may be inferred
             from the input PCollection schema.
+        - PROTO: Encodes records with a given Protobuf schema, which may be
+            inferred from the input PCollection schema.
 
     schema: Schema specification for the given format.
     attributes: List of attribute keys whose values will be pulled out as
@@ -633,7 +642,7 @@ def read_from_tfrecord(
     compression_type (CompressionTypes): Used to handle compressed input files.
       Default value is CompressionTypes.AUTO, in which case the file_path's
       extension will be used to detect the compression.
-    validate (bool): Boolean flag to verify that the files exist during the 
+    validate (bool): Boolean flag to verify that the files exist during the
       pipeline creation time.
   """
   return ReadFromTFRecord(
diff --git a/sdks/python/apache_beam/yaml/yaml_io_test.py 
b/sdks/python/apache_beam/yaml/yaml_io_test.py
index 3ae9f19b9b8..a19dfd694a8 100644
--- a/sdks/python/apache_beam/yaml/yaml_io_test.py
+++ b/sdks/python/apache_beam/yaml/yaml_io_test.py
@@ -24,10 +24,12 @@ import fastavro
 import mock
 
 import apache_beam as beam
+from apache_beam.coders.row_coder import RowCoder
 from apache_beam.io.gcp.pubsub import PubsubMessage
 from apache_beam.testing.util import AssertThat
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
+from apache_beam.typehints import schemas as schema_utils
 from apache_beam.yaml.yaml_transform import YamlTransform
 
 
@@ -491,6 +493,49 @@ class YamlPubSubTest(unittest.TestCase):
               attributes_map: other
             '''))
 
+  def test_write_proto(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      data = [beam.Row(label='37a', rank=1), beam.Row(label='389a', rank=2)]
+      coder = RowCoder(
+          schema_utils.named_fields_to_schema([('label', str), ('rank', int)]))
+      expected_messages = [PubsubMessage(coder.encode(r), {}) for r in data]
+      with mock.patch('apache_beam.io.WriteToPubSub',
+                      FakeWriteToPubSub(topic='my_topic',
+                                        messages=expected_messages)):
+        _ = (
+            p | beam.Create(data) | YamlTransform(
+                '''
+            type: WriteToPubSub
+            config:
+              topic: my_topic
+              format: PROTO
+            '''))
+
+  def test_read_proto(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      data = [beam.Row(label='37a', rank=1), beam.Row(label='389a', rank=2)]
+      coder = RowCoder(
+          schema_utils.named_fields_to_schema([('label', str), ('rank', int)]))
+      expected_messages = [PubsubMessage(coder.encode(r), {}) for r in data]
+      with mock.patch('apache_beam.io.ReadFromPubSub',
+                      FakeReadFromPubSub(topic='my_topic',
+                                         messages=expected_messages)):
+        result = p | YamlTransform(
+            '''
+            type: ReadFromPubSub
+            config:
+              topic: my_topic
+              format: PROTO
+              schema:
+                type: object
+                properties:
+                  label: {type: string}
+                  rank: {type: integer}
+            ''')
+        assert_that(result, equal_to(data))
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)

Reply via email to