This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 7f04c4f07f2 Deduplicate common environments. (#30681)
7f04c4f07f2 is described below

commit 7f04c4f07f2698f823953902bfb79fc7cb6e1584
Author: Robert Bradshaw <rober...@gmail.com>
AuthorDate: Tue Mar 26 16:54:00 2024 -0700

    Deduplicate common environments. (#30681)
    
    We deduplicate both on proto construction (as before, but fixed) and again 
after more environments have been resolved.
---
 sdks/python/apache_beam/pipeline.py                | 31 +---------
 sdks/python/apache_beam/runners/common.py          | 67 ++++++++++++++++++++++
 sdks/python/apache_beam/runners/common_test.py     | 59 +++++++++++++++++++
 .../runners/dataflow/dataflow_runner.py            |  2 +
 .../runners/portability/fn_api_runner/fn_runner.py |  4 +-
 5 files changed, 133 insertions(+), 30 deletions(-)

diff --git a/sdks/python/apache_beam/pipeline.py 
b/sdks/python/apache_beam/pipeline.py
index 53044982a06..11bc74d27ec 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -86,6 +86,7 @@ from apache_beam.options.pipeline_options_validator import 
PipelineOptionsValida
 from apache_beam.portability import common_urns
 from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.runners import PipelineRunner
+from apache_beam.runners import common
 from apache_beam.runners import create_runner
 from apache_beam.transforms import ParDo
 from apache_beam.transforms import ptransform
@@ -967,35 +968,7 @@ class Pipeline(HasDisplayData):
 
     Mutates proto as contexts may have references to proto.components.
     """
-    env_map = {}
-    canonical_env = {}
-    files_by_hash = {}
-    for env_id, env in proto.components.environments.items():
-      # First deduplicate any file dependencies by their hash.
-      for dep in env.dependencies:
-        if dep.type_urn == common_urns.artifact_types.FILE.urn:
-          file_payload = beam_runner_api_pb2.ArtifactFilePayload.FromString(
-              dep.type_payload)
-          if file_payload.sha256:
-            if file_payload.sha256 in files_by_hash:
-              file_payload.path = files_by_hash[file_payload.sha256]
-              dep.type_payload = file_payload.SerializeToString()
-            else:
-              files_by_hash[file_payload.sha256] = file_payload.path
-      # Next check if we've ever seen this environment before.
-      normalized = env.SerializeToString(deterministic=True)
-      if normalized in canonical_env:
-        env_map[env_id] = canonical_env[normalized]
-      else:
-        canonical_env[normalized] = env_id
-    for old_env, new_env in env_map.items():
-      for transform in proto.components.transforms.values():
-        if transform.environment_id == old_env:
-          transform.environment_id = new_env
-      for windowing_strategy in proto.components.windowing_strategies.values():
-        if windowing_strategy.environment_id == old_env:
-          windowing_strategy.environment_id = new_env
-      del proto.components.environments[old_env]
+    common.merge_common_environments(proto, inplace=True)
 
   @staticmethod
   def from_runner_api(
diff --git a/sdks/python/apache_beam/runners/common.py 
b/sdks/python/apache_beam/runners/common.py
index 1cd0a304466..630ed7910c8 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -24,6 +24,8 @@ For internal use only; no backwards-compatibility guarantees.
 
 # pytype: skip-file
 
+import collections
+import copy
 import logging
 import sys
 import threading
@@ -43,6 +45,7 @@ from apache_beam.coders import coders
 from apache_beam.internal import util
 from apache_beam.options.value_provider import RuntimeValueProvider
 from apache_beam.portability import common_urns
+from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.pvalue import TaggedOutput
 from apache_beam.runners.sdf_utils import NoOpWatermarkEstimatorProvider
 from apache_beam.runners.sdf_utils import RestrictionTrackerView
@@ -52,6 +55,7 @@ from apache_beam.runners.sdf_utils import 
ThreadsafeRestrictionTracker
 from apache_beam.runners.sdf_utils import ThreadsafeWatermarkEstimator
 from apache_beam.transforms import DoFn
 from apache_beam.transforms import core
+from apache_beam.transforms import environments
 from apache_beam.transforms import userstate
 from apache_beam.transforms.core import RestrictionProvider
 from apache_beam.transforms.core import WatermarkEstimatorProvider
@@ -1941,3 +1945,66 @@ def validate_pipeline_graph(pipeline_proto):
 
   for t in pipeline_proto.root_transform_ids:
     validate_transform(t)
+
+
+def merge_common_environments(pipeline_proto, inplace=False):
+  def dep_key(dep):
+    if dep.type_urn == common_urns.artifact_types.FILE.urn:
+      payload = beam_runner_api_pb2.ArtifactFilePayload.FromString(
+          dep.type_payload)
+      if payload.sha256:
+        type_info = 'sha256', payload.sha256
+      else:
+        type_info = 'path', payload.path
+    elif dep.type_urn == common_urns.artifact_types.URL.urn:
+      payload = beam_runner_api_pb2.ArtifactUrlPayload.FromString(
+          dep.type_payload)
+      if payload.sha256:
+        type_info = 'sha256', payload.sha256
+      else:
+        type_info = 'url', payload.url
+    else:
+      type_info = dep.type_urn, dep.type_payload
+    return type_info, dep.role_urn, dep.role_payload
+
+  def base_env_key(env):
+    return (
+        env.urn,
+        env.payload,
+        tuple(sorted(env.capabilities)),
+        tuple(sorted(env.resource_hints.items())),
+        tuple(sorted(dep_key(dep) for dep in env.dependencies)))
+
+  def env_key(env):
+    return tuple(
+        sorted(
+            base_env_key(e)
+            for e in environments.expand_anyof_environments(env)))
+
+  cannonical_enviornments = collections.defaultdict(list)
+  for env_id, env in pipeline_proto.components.environments.items():
+    cannonical_enviornments[env_key(env)].append(env_id)
+
+  if len(cannonical_enviornments) == len(
+      pipeline_proto.components.environments):
+    # All environments are already sufficiently distinct.
+    return pipeline_proto
+
+  environment_remappings = {
+      e: es[0]
+      for es in cannonical_enviornments.values() for e in es
+  }
+
+  if not inplace:
+    pipeline_proto = copy.copy(pipeline_proto)
+
+  for t in pipeline_proto.components.transforms.values():
+    if t.environment_id:
+      t.environment_id = environment_remappings[t.environment_id]
+  for w in pipeline_proto.components.windowing_strategies.values():
+    if w.environment_id:
+      w.environment_id = environment_remappings[w.environment_id]
+  for e in set(pipeline_proto.components.environments.keys()) - set(
+      environment_remappings.values()):
+    del pipeline_proto.components.environments[e]
+  return pipeline_proto
diff --git a/sdks/python/apache_beam/runners/common_test.py 
b/sdks/python/apache_beam/runners/common_test.py
index 00645948c3e..ca2cd2539a8 100644
--- a/sdks/python/apache_beam/runners/common_test.py
+++ b/sdks/python/apache_beam/runners/common_test.py
@@ -26,8 +26,11 @@ from apache_beam.io.restriction_trackers import OffsetRange
 from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
 from apache_beam.io.watermark_estimators import ManualWatermarkEstimator
 from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.runners.common import DoFnSignature
 from apache_beam.runners.common import PerWindowInvoker
+from apache_beam.runners.common import merge_common_environments
+from apache_beam.runners.portability.expansion_service_test import FibTransform
 from apache_beam.runners.sdf_utils import SplitResultPrimary
 from apache_beam.runners.sdf_utils import SplitResultResidual
 from apache_beam.testing.test_pipeline import TestPipeline
@@ -584,5 +587,61 @@ class PerWindowInvokerSplitTest(unittest.TestCase):
     self.assertEqual(stop_index, 2)
 
 
+class UtilitiesTest(unittest.TestCase):
+  def test_equal_environments_merged(self):
+    pipeline_proto = merge_common_environments(
+        beam_runner_api_pb2.Pipeline(
+            components=beam_runner_api_pb2.Components(
+                environments={
+                    'a1': beam_runner_api_pb2.Environment(urn='A'),
+                    'a2': beam_runner_api_pb2.Environment(urn='A'),
+                    'b1': beam_runner_api_pb2.Environment(
+                        urn='B', payload=b'x'),
+                    'b2': beam_runner_api_pb2.Environment(
+                        urn='B', payload=b'x'),
+                    'b3': beam_runner_api_pb2.Environment(
+                        urn='B', payload=b'y'),
+                },
+                transforms={
+                    't1': beam_runner_api_pb2.PTransform(
+                        unique_name='t1', environment_id='a1'),
+                    't2': beam_runner_api_pb2.PTransform(
+                        unique_name='t2', environment_id='a2'),
+                },
+                windowing_strategies={
+                    'w1': beam_runner_api_pb2.WindowingStrategy(
+                        environment_id='b1'),
+                    'w2': beam_runner_api_pb2.WindowingStrategy(
+                        environment_id='b2'),
+                })))
+    self.assertEqual(len(pipeline_proto.components.environments), 3)
+    self.assertTrue(('a1' in pipeline_proto.components.environments)
+                    ^ ('a2' in pipeline_proto.components.environments))
+    self.assertTrue(('b1' in pipeline_proto.components.environments)
+                    ^ ('b2' in pipeline_proto.components.environments))
+    self.assertEqual(
+        len(
+            set(
+                t.environment_id
+                for t in pipeline_proto.components.transforms.values())),
+        1)
+    self.assertEqual(
+        len(
+            set(
+                w.environment_id for w in
+                pipeline_proto.components.windowing_strategies.values())),
+        1)
+
+  def test_external_merged(self):
+    p = beam.Pipeline()
+    # This transform recursively creates several external environments.
+    _ = p | FibTransform(4)
+    pipeline_proto = p.to_runner_api()
+    # All our external environments are equal and consolidated.
+    # We also have a placeholder "default" environment that has not been
+    # resolved do anything concrete yet.
+    self.assertEqual(len(pipeline_proto.components.environments), 2)
+
+
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py 
b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
index db6a5235ac9..e428551ef02 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
@@ -43,6 +43,7 @@ from apache_beam.options.pipeline_options import TypeOptions
 from apache_beam.options.pipeline_options import WorkerOptions
 from apache_beam.portability import common_urns
 from apache_beam.runners.common import group_by_key_input_visitor
+from apache_beam.runners.common import merge_common_environments
 from apache_beam.runners.dataflow.internal.clients import dataflow as 
dataflow_api
 from apache_beam.runners.runner import PipelineResult
 from apache_beam.runners.runner import PipelineRunner
@@ -419,6 +420,7 @@ class DataflowRunner(PipelineRunner):
       self.proto_pipeline.components.environments[env_id].CopyFrom(
           environments.resolve_anyof_environment(
               env, common_urns.environments.DOCKER.urn))
+    self.proto_pipeline = merge_common_environments(self.proto_pipeline)
 
     # Optimize the pipeline if it not streaming and the pre_optimize
     # experiment is set.
diff --git 
a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
index b3dd124216b..07569fe328d 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
@@ -62,6 +62,7 @@ from apache_beam.portability.api import beam_provision_api_pb2
 from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.runners import runner
 from apache_beam.runners.common import group_by_key_input_visitor
+from apache_beam.runners.common import merge_common_environments
 from apache_beam.runners.common import validate_pipeline_graph
 from apache_beam.runners.portability import portable_metrics
 from apache_beam.runners.portability.fn_api_runner import execution
@@ -221,7 +222,8 @@ class FnApiRunner(runner.PipelineRunner):
     ]
     if direct_options.direct_embed_docker_python:
       pipeline_proto = self.embed_default_docker_image(pipeline_proto)
-    pipeline_proto = self.resolve_any_environments(pipeline_proto)
+    pipeline_proto = merge_common_environments(
+        self.resolve_any_environments(pipeline_proto))
     stage_context, stages = self.create_stages(pipeline_proto)
     return self.run_stages(stage_context, stages)
 

Reply via email to