This is an automated email from the ASF dual-hosted git repository.

chamikara pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new f3e0f6d4962 Updates YAML SDK to replace Kafka read/write transforms 
with equivalent managed transforms (#34755)
f3e0f6d4962 is described below

commit f3e0f6d496218a2025380cdef583f259685b82cf
Author: Chamikara Jayalath <[email protected]>
AuthorDate: Wed Apr 30 12:03:57 2025 -0700

    Updates YAML SDK to replace Kafka read/write transforms with equivalent 
managed transforms (#34755)
    
    * Updates YAML SDK to replace Kafka read/write transforms with equivalent 
managed transforms
    
    * Addressing reviewer comments and adding unit tests
    
    * Resolves conflict
    
    * Fixes a test failure
    
    * Fix lint
    
    * Fixes a test
---
 sdks/python/apache_beam/transforms/external.py     | 79 +++++++++++++++++++++-
 .../python/apache_beam/transforms/external_test.py | 66 +++++++++++++++++-
 sdks/python/apache_beam/transforms/managed.py      | 25 +++----
 sdks/python/apache_beam/yaml/standard_io.yaml      |  6 ++
 sdks/python/apache_beam/yaml/yaml_provider.py      | 37 ++++++++--
 sdks/python/gen_managed_doc.py                     |  8 ++-
 6 files changed, 194 insertions(+), 27 deletions(-)

diff --git a/sdks/python/apache_beam/transforms/external.py 
b/sdks/python/apache_beam/transforms/external.py
index 3fc58f04a78..9b6b4060cb7 100644
--- a/sdks/python/apache_beam/transforms/external.py
+++ b/sdks/python/apache_beam/transforms/external.py
@@ -31,6 +31,7 @@ from collections import OrderedDict
 from collections import namedtuple
 
 import grpc
+import yaml
 
 from apache_beam import pvalue
 from apache_beam.coders import RowCoder
@@ -42,10 +43,12 @@ from apache_beam.portability.api import 
beam_expansion_api_pb2_grpc
 from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.portability.api import external_transforms_pb2
 from apache_beam.portability.api import schema_pb2
+from apache_beam.portability.common_urns import ManagedTransforms
 from apache_beam.runners import pipeline_context
 from apache_beam.runners.portability import artifact_service
 from apache_beam.transforms import environments
 from apache_beam.transforms import ptransform
+from apache_beam.transforms.util import is_compat_version_prior_to
 from apache_beam.typehints import WithTypeHints
 from apache_beam.typehints import native_type_compatibility
 from apache_beam.typehints import row_type
@@ -61,6 +64,25 @@ from apache_beam.utils import transform_service_launcher
 
 DEFAULT_EXPANSION_SERVICE = 'localhost:8097'
 
+MANAGED_SCHEMA_TRANSFORM_IDENTIFIER = "beam:transform:managed:v1"
+
+_IO_EXPANSION_SERVICE_JAR_TARGET = "sdks:java:io:expansion-service:shadowJar"
+
+_GCP_EXPANSION_SERVICE_JAR_TARGET = (
+    "sdks:java:io:google-cloud-platform:expansion-service:shadowJar")
+
+# A mapping from supported managed transforms URNs to expansion service jars
+# that include the corresponding transforms.
+MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING = {
+    ManagedTransforms.Urns.ICEBERG_READ.urn: _IO_EXPANSION_SERVICE_JAR_TARGET,
+    ManagedTransforms.Urns.ICEBERG_WRITE.urn: _IO_EXPANSION_SERVICE_JAR_TARGET,
+    ManagedTransforms.Urns.ICEBERG_CDC_READ.urn: 
_IO_EXPANSION_SERVICE_JAR_TARGET,  # pylint: disable=line-too-long
+    ManagedTransforms.Urns.KAFKA_READ.urn: _IO_EXPANSION_SERVICE_JAR_TARGET,
+    ManagedTransforms.Urns.KAFKA_WRITE.urn: _IO_EXPANSION_SERVICE_JAR_TARGET,
+    ManagedTransforms.Urns.BIGQUERY_READ.urn: 
_GCP_EXPANSION_SERVICE_JAR_TARGET,
+    ManagedTransforms.Urns.BIGQUERY_WRITE.urn: 
_GCP_EXPANSION_SERVICE_JAR_TARGET
+}
+
 
 def convert_to_typing_type(type_):
   if isinstance(type_, row_type.RowTypeConstraint):
@@ -378,6 +400,10 @@ SchemaTransformsConfig = namedtuple(
     'SchemaTransformsConfig',
     ['identifier', 'configuration_schema', 'inputs', 'outputs', 'description'])
 
+ManagedReplacement = namedtuple(
+    'ManagedReplacement',
+    ['underlying_transform_identifier', 'update_compatibility_version'])
+
 
 class SchemaAwareExternalTransform(ptransform.PTransform):
   """A proxy transform for SchemaTransforms implemented in external SDKs.
@@ -396,6 +422,12 @@ class SchemaAwareExternalTransform(ptransform.PTransform):
       the configuration.
   :param classpath: (Optional) A list paths to additional jars to place on the
       expansion service classpath.
+  :param managed_replacement: (Optional) a 'ManagedReplacement' namedtuple that
+      defines information needed to replace the transform with an equivalent
+      managed transform during the expansion. If an
+      'updateCompatibilityBeamVersion' pipeline option is provided, we will
+      only replace if the managed transform is update compatible with the
+      provided version.
   :kwargs: field name to value mapping for configuring the schema transform.
       keys map to the field names of the schema of the SchemaTransform
       (in-order).
@@ -406,10 +438,14 @@ class SchemaAwareExternalTransform(ptransform.PTransform):
       expansion_service,
       rearrange_based_on_discovery=False,
       classpath=None,
+      managed_replacement=None,
       **kwargs):
     self._expansion_service = expansion_service
     self._kwargs = kwargs
     self._classpath = classpath
+    if managed_replacement:
+      assert isinstance(managed_replacement, ManagedReplacement)
+    self._managed_replacement = managed_replacement
 
     _kwargs = kwargs
     if rearrange_based_on_discovery:
@@ -420,16 +456,55 @@ class SchemaAwareExternalTransform(ptransform.PTransform):
           named_tuple_to_schema(config.configuration_schema),
           **_kwargs)
 
+      if self._managed_replacement:
+        # We have to do the replacement at the expansion instead of at
+        # construction
+        # since we don't have access to the PipelineOptions object at the
+        # construction.
+        underlying_transform_id = (
+            self._managed_replacement.underlying_transform_identifier)
+        if not (underlying_transform_id in
+                MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING):
+          raise ValueError(
+              'Could not find an expansion service jar for the managed ' +
+              'transform ' + underlying_transform_id)
+        managed_expansion_service_jar = (
+            MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING
+        )[underlying_transform_id]
+        self._managed_expansion_service = BeamJarExpansionService(
+            managed_expansion_service_jar)
+        managed_config = SchemaAwareExternalTransform.discover_config(
+            self._managed_expansion_service,
+            MANAGED_SCHEMA_TRANSFORM_IDENTIFIER)
+
+        yaml_config = yaml.dump(kwargs)
+        self._managed_payload_builder = (
+            ExplicitSchemaTransformPayloadBuilder(
+                MANAGED_SCHEMA_TRANSFORM_IDENTIFIER,
+                named_tuple_to_schema(managed_config.configuration_schema),
+                transform_identifier=underlying_transform_id,
+                config=yaml_config))
     else:
       self._payload_builder = SchemaTransformPayloadBuilder(
           identifier, **_kwargs)
 
   def expand(self, pcolls):
     # Expand the transform using the expansion service.
+    payload_builder = self._payload_builder
+    expansion_service = self._expansion_service
+
+    if self._managed_replacement:
+      compat_version_prior_to_current = is_compat_version_prior_to(
+          pcolls.pipeline._options,
+          self._managed_replacement.update_compatibility_version)
+      if not compat_version_prior_to_current:
+        payload_builder = self._managed_payload_builder
+        expansion_service = self._managed_expansion_service
+
     return pcolls | self._payload_builder.identifier() >> ExternalTransform(
         common_urns.schematransform_based_expand.urn,
-        self._payload_builder,
-        self._expansion_service)
+        payload_builder,
+        expansion_service)
 
   @classmethod
   @functools.lru_cache
diff --git a/sdks/python/apache_beam/transforms/external_test.py 
b/sdks/python/apache_beam/transforms/external_test.py
index adf44d2286c..84a7025c0a5 100644
--- a/sdks/python/apache_beam/transforms/external_test.py
+++ b/sdks/python/apache_beam/transforms/external_test.py
@@ -29,18 +29,21 @@ import unittest
 import mock
 
 import apache_beam as beam
+from apache_beam import ManagedReplacement
 from apache_beam import Pipeline
 from apache_beam.coders import RowCoder
 from apache_beam.options.pipeline_options import PipelineOptions
 from apache_beam.portability.api import beam_expansion_api_pb2
 from apache_beam.portability.api import external_transforms_pb2
 from apache_beam.portability.api import schema_pb2
+from apache_beam.portability.common_urns import ManagedTransforms
 from apache_beam.runners import pipeline_context
 from apache_beam.runners.portability import expansion_service
 from apache_beam.runners.portability.expansion_service_test import FibTransform
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
 from apache_beam.transforms import external
+from apache_beam.transforms.external import MANAGED_SCHEMA_TRANSFORM_IDENTIFIER
 from apache_beam.transforms.external import AnnotationBasedPayloadBuilder
 from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder
 from apache_beam.transforms.external import JavaClassLookupPayloadBuilder
@@ -530,8 +533,28 @@ class SchemaAwareExternalTransformTest(unittest.TestCase):
               id="test-id"),
           input_pcollection_names=["input"],
           output_pcollection_names=["output"])
+
+      test_managed_config = beam_expansion_api_pb2.SchemaTransformConfig(
+          config_schema=schema_pb2.Schema(
+              fields=[
+                  schema_pb2.Field(
+                      name="transform_identifier",
+                      type=schema_pb2.FieldType(atomic_type="STRING")),
+                  schema_pb2.Field(
+                      name="config_url",
+                      type=schema_pb2.FieldType(atomic_type="STRING")),
+                  schema_pb2.Field(
+                      name="config",
+                      type=schema_pb2.FieldType(atomic_type="STRING"))
+              ],
+              id="test-id1"),
+          input_pcollection_names=["input"],
+          output_pcollection_names=["output"])
       return beam_expansion_api_pb2.DiscoverSchemaTransformResponse(
-          schema_transform_configs={"test_schematransform": test_config})
+          schema_transform_configs={
+              "test_schematransform": test_config,
+              MANAGED_SCHEMA_TRANSFORM_IDENTIFIER: test_managed_config
+          })
 
   @mock.patch("apache_beam.transforms.external.ExternalTransform.service")
   def test_discover_one_config(self, mock_service):
@@ -573,6 +596,47 @@ class SchemaAwareExternalTransformTest(unittest.TestCase):
     self.assertNotEqual(tuple(kwargs.keys()), external_config_fields)
     self.assertEqual(tuple(ordered_fields), external_config_fields)
 
+  @mock.patch("apache_beam.transforms.external.ExternalTransform.service")
+  def test_managed_replacement_unknown_id(self, mock_service):
+    mock_service.return_value = self.MockDiscoveryService()
+
+    identifier = "test_schematransform"
+    kwargs = {"int_field": 0, "str_field": "str"}
+
+    managed_replacement = ManagedReplacement(
+        underlying_transform_identifier="unknown_id",
+        update_compatibility_version="2.50.0")
+
+    with self.assertRaises(ValueError):
+      beam.SchemaAwareExternalTransform(
+          identifier=identifier,
+          expansion_service=expansion_service,
+          rearrange_based_on_discovery=True,
+          managed_replacement=managed_replacement,
+          **kwargs)
+
+  @mock.patch("apache_beam.transforms.external.ExternalTransform.service")
+  @mock.patch("apache_beam.transforms.external.BeamJarExpansionService")
+  def test_managed_replacement_known_id(
+      self, mock_service, mock_beam_jar_service):
+    mock_service.return_value = self.MockDiscoveryService()
+    mock_beam_jar_service.return_value = self.MockDiscoveryService()
+
+    identifier = "test_schematransform"
+    kwargs = {"int_field": 0, "str_field": "str"}
+
+    managed_replacement = ManagedReplacement(
+        
underlying_transform_identifier=ManagedTransforms.Urns.ICEBERG_READ.urn,
+        update_compatibility_version="2.50.0")
+
+    external_transform = beam.SchemaAwareExternalTransform(
+        identifier=identifier,
+        expansion_service=expansion_service,
+        rearrange_based_on_discovery=True,
+        managed_replacement=managed_replacement,
+        **kwargs)
+    self.assertIsNotNone(external_transform._managed_payload_builder)
+
 
 class JavaClassLookupPayloadBuilderTest(unittest.TestCase):
   def _verify_row(self, schema, row_payload, expected_values):
diff --git a/sdks/python/apache_beam/transforms/managed.py 
b/sdks/python/apache_beam/transforms/managed.py
index 6113f953aed..609a27b3713 100644
--- a/sdks/python/apache_beam/transforms/managed.py
+++ b/sdks/python/apache_beam/transforms/managed.py
@@ -77,6 +77,8 @@ from typing import Optional
 import yaml
 
 from apache_beam.portability.common_urns import ManagedTransforms
+from apache_beam.transforms.external import MANAGED_SCHEMA_TRANSFORM_IDENTIFIER
+from apache_beam.transforms.external import 
MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING
 from apache_beam.transforms.external import BeamJarExpansionService
 from apache_beam.transforms.external import SchemaAwareExternalTransform
 from apache_beam.transforms.ptransform import PTransform
@@ -87,13 +89,6 @@ ICEBERG = "iceberg"
 _ICEBERG_CDC = "iceberg_cdc"
 KAFKA = "kafka"
 BIGQUERY = "bigquery"
-_MANAGED_IDENTIFIER = "beam:transform:managed:v1"
-_EXPANSION_SERVICE_JAR_TARGETS = {
-    "sdks:java:io:expansion-service:shadowJar": [KAFKA, ICEBERG, _ICEBERG_CDC],
-    "sdks:java:io:google-cloud-platform:expansion-service:shadowJar": [
-        BIGQUERY
-    ]
-}
 
 __all__ = ["ICEBERG", "KAFKA", "BIGQUERY", "Read", "Write"]
 
@@ -131,7 +126,7 @@ class Read(PTransform):
 
   def expand(self, input):
     return input | SchemaAwareExternalTransform(
-        identifier=_MANAGED_IDENTIFIER,
+        identifier=MANAGED_SCHEMA_TRANSFORM_IDENTIFIER,
         expansion_service=self._expansion_service,
         rearrange_based_on_discovery=True,
         transform_identifier=self._underlying_identifier,
@@ -175,7 +170,7 @@ class Write(PTransform):
 
   def expand(self, input):
     return input | SchemaAwareExternalTransform(
-        identifier=_MANAGED_IDENTIFIER,
+        identifier=MANAGED_SCHEMA_TRANSFORM_IDENTIFIER,
         expansion_service=self._expansion_service,
         rearrange_based_on_discovery=True,
         transform_identifier=self._underlying_identifier,
@@ -192,13 +187,11 @@ def _resolve_expansion_service(
   if expansion_service:
     return expansion_service
 
-  default_target = None
-  for gradle_target, transforms in _EXPANSION_SERVICE_JAR_TARGETS.items():
-    if transform_name.lower() in transforms:
-      default_target = gradle_target
-      break
-  if not default_target:
+  gradle_target = None
+  if identifier in MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING:
+    gradle_target = MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING.get(identifier)
+  if not gradle_target:
     raise ValueError(
         "No expansion service was specified and could not find a "
         f"default expansion service for {transform_name}: '{identifier}'.")
-  return BeamJarExpansionService(default_target)
+  return BeamJarExpansionService(gradle_target)
diff --git a/sdks/python/apache_beam/yaml/standard_io.yaml 
b/sdks/python/apache_beam/yaml/standard_io.yaml
index be652a300ef..c6713725025 100644
--- a/sdks/python/apache_beam/yaml/standard_io.yaml
+++ b/sdks/python/apache_beam/yaml/standard_io.yaml
@@ -84,6 +84,12 @@
         'WriteToKafka': 'beam:schematransform:org.apache.beam:kafka_write:v1'
       config:
         gradle_target: 'sdks:java:io:expansion-service:shadowJar'
+        managed_replacement:
+          # Following transforms may be replaced with equivalent managed 
transforms,
+          # if the pipelines 'updateCompatibilityBeamVersion' match the 
provided
+          # version.
+          'ReadFromKafka': '2.66.0'
+          'WriteToKafka': '2.66.0'
 
 # PubSub
 - type: renaming
diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py 
b/sdks/python/apache_beam/yaml/yaml_provider.py
index 171f229746a..7537e80164c 100755
--- a/sdks/python/apache_beam/yaml/yaml_provider.py
+++ b/sdks/python/apache_beam/yaml/yaml_provider.py
@@ -48,6 +48,7 @@ import apache_beam as beam
 import apache_beam.dataframe.io
 import apache_beam.io
 import apache_beam.transforms.util
+from apache_beam import ManagedReplacement
 from apache_beam.io.filesystems import FileSystems
 from apache_beam.portability.api import schema_pb2
 from apache_beam.runners import pipeline_context
@@ -181,10 +182,20 @@ class ExternalProvider(Provider):
   """A Provider implemented via the cross language transform service."""
   _provider_types: dict[str, Callable[..., Provider]] = {}
 
-  def __init__(self, urns, service):
+  def __init__(self, urns, service, managed_replacement=None):
+    """Initializes the ExternalProvider.
+
+    Args:
+      urns: a set of URNs that uniquely identify the transforms supported.
+      service: the gradle target that identified the expansion service jar.
+      managed_replacement (Optional): a map that defines the transform for
+        which the SDK may replace the transform with an available managed
+        transform.
+    """
     self._urns = urns
     self._service = service
     self._schema_transforms = None
+    self._managed_replacement = managed_replacement
 
   def provided_transforms(self):
     return self._urns.keys()
@@ -224,8 +235,18 @@ class ExternalProvider(Provider):
       self._service = self._service()
     urn = self._urns[type]
     if urn in self.schema_transforms():
+      managed_replacement = None
+      if self._managed_replacement and type in self._managed_replacement:
+        managed_replacement = ManagedReplacement(
+            underlying_transform_identifier=urn,
+            update_compatibility_version=self._managed_replacement[type])
+
       return external.SchemaAwareExternalTransform(
-          urn, self._service, rearrange_based_on_discovery=True, **args)
+          urn,
+          self._service,
+          rearrange_based_on_discovery=True,
+          managed_replacement=managed_replacement,
+          **args)
     else:
       return type >> self.create_external_transform(urn, args)
 
@@ -318,14 +339,16 @@ def beam_jar(
     urns,
     *,
     gradle_target,
+    managed_replacement=None,
     appendix=None,
     version=beam_version,
     artifact_id=None):
   return ExternalJavaProvider(
       urns,
       lambda: subprocess_server.JavaJarServer.path_to_beam_jar(
-          gradle_target=gradle_target, version=version, 
artifact_id=artifact_id)
-  )
+          gradle_target=gradle_target, version=version, artifact_id=artifact_id
+      ),
+      managed_replacement=managed_replacement)
 
 
 @ExternalProvider.register_provider_type('docker')
@@ -357,11 +380,13 @@ class RemoteProvider(ExternalProvider):
 
 
 class ExternalJavaProvider(ExternalProvider):
-  def __init__(self, urns, jar_provider, classpath=None):
+  def __init__(
+      self, urns, jar_provider, managed_replacement=None, classpath=None):
     super().__init__(
         urns,
         lambda: external.JavaJarExpansionService(
-            jar_provider(), classpath=classpath))
+            jar_provider(), classpath=classpath),
+        managed_replacement)
     self._jar_provider = jar_provider
     self._classpath = classpath
 
diff --git a/sdks/python/gen_managed_doc.py b/sdks/python/gen_managed_doc.py
index d3d3f373d1e..85a7c73679a 100644
--- a/sdks/python/gen_managed_doc.py
+++ b/sdks/python/gen_managed_doc.py
@@ -89,6 +89,7 @@ _DOCUMENTATION_DESTINATION = os.path.join(
 
 
 def generate_managed_doc(output_location):
+  from apache_beam.transforms.external import 
MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING
   from apache_beam.transforms.external import BeamJarExpansionService
   from apache_beam.transforms.external_transform_provider import 
ExternalTransform
   from apache_beam.transforms.external_transform_provider import 
ExternalTransformProvider
@@ -99,13 +100,16 @@ def generate_managed_doc(output_location):
   with open(_MANAGED_CONFIG_ALIASES) as f:
     all_config_aliases: dict = yaml.safe_load(f)
 
-  services_and_names = managed._EXPANSION_SERVICE_JAR_TARGETS
+  # Creating a unique list of expansion service jars.
+  expansion_service_jar_targets = list(
+    dict.fromkeys(MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING.values()))
+
   read_names_and_identifiers = managed.Read._READ_TRANSFORMS
   write_names_and_identifiers = managed.Write._WRITE_TRANSFORMS
 
   all_transforms = {}
 
-  for gradle_target in services_and_names.keys():
+  for gradle_target in expansion_service_jar_targets:
     provider = 
ExternalTransformProvider(BeamJarExpansionService(gradle_target))
     discovered: Dict[str, ExternalTransform] = provider.get_all()
 

Reply via email to