This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch users/damccorm/enforce_gbek
in repository https://gitbox.apache.org/repos/asf/beam.git

commit dfb53d078aa4a8e0779a965329bc4807ad15c757
Author: Danny Mccormick <[email protected]>
AuthorDate: Mon Sep 29 13:28:14 2025 -0400

    Add pipeline option to enforce gbek
---
 .../python/apache_beam/options/pipeline_options.py | 10 ++++
 sdks/python/apache_beam/transforms/core.py         | 19 +++++++
 sdks/python/apache_beam/transforms/util.py         | 54 +++++++++++++++++--
 sdks/python/apache_beam/transforms/util_test.py    | 62 +++++++++++++++++++++-
 4 files changed, 141 insertions(+), 4 deletions(-)

diff --git a/sdks/python/apache_beam/options/pipeline_options.py 
b/sdks/python/apache_beam/options/pipeline_options.py
index 6595d683911..034073a41fe 100644
--- a/sdks/python/apache_beam/options/pipeline_options.py
+++ b/sdks/python/apache_beam/options/pipeline_options.py
@@ -1716,6 +1716,16 @@ class SetupOptions(PipelineOptions):
         help=(
             'Docker registry url to use for tagging and pushing the prebuilt '
             'sdk worker container image.'))
+    parser.add_argument(
+        '--gbek',
+        default=None,
+        help=(
+            'When set, will replace all GroupByKey transforms in the pipeline '
+            'with EncryptedGroupByKey transforms using the secret passed in '
+            'the option. Beam will infer the secret type and value based on '
+            'secret itself. The option should be structured like: '
+            '--encrypt=type:<secret_type>;<secret_param>:<value>, for example '
+            '--encrypt=type:GcpSecret;version_name:my_secret/versions/latest'))
 
   def validate(self, validator):
     errors = []
diff --git a/sdks/python/apache_beam/transforms/core.py 
b/sdks/python/apache_beam/transforms/core.py
index 2304faf478f..1f514b2f989 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -39,6 +39,7 @@ from apache_beam import typehints
 from apache_beam.coders import typecoders
 from apache_beam.internal import pickler
 from apache_beam.internal import util
+from apache_beam.options.pipeline_options import SetupOptions
 from apache_beam.options.pipeline_options import TypeOptions
 from apache_beam.portability import common_urns
 from apache_beam.portability import python_urns
@@ -3324,6 +3325,10 @@ class GroupByKey(PTransform):
 
   The implementation here is used only when run on the local direct runner.
   """
+  def __init__(self):
+    self._replaced_by_gbek = False
+    self._inside_gbek = False
+
   class ReifyWindows(DoFn):
     def process(
         self, element, window=DoFn.WindowParam, timestamp=DoFn.TimestampParam):
@@ -3342,6 +3347,16 @@ class GroupByKey(PTransform):
           key_type, typehints.WindowedValue[value_type]]  # type: ignore[misc]
 
   def expand(self, pcoll):
+    replace_with_gbek_secret = (
+        pcoll.pipeline._options.view_as(SetupOptions).gbek)
+    if replace_with_gbek_secret is not None and not self._inside_gbek:
+      self._replaced_by_gbek = True
+      from apache_beam.transforms.util import GroupByEncryptedKey
+      from apache_beam.transforms.util import Secret
+
+      secret = Secret.parse_secret_option(replace_with_gbek_secret)
+      return (pcoll | "Group by encrypted key" >> GroupByEncryptedKey(secret))
+
     from apache_beam.transforms.trigger import DataLossReason
     from apache_beam.transforms.trigger import DefaultTrigger
     windowing = pcoll.windowing
@@ -3389,6 +3404,10 @@ class GroupByKey(PTransform):
 
   def to_runner_api_parameter(self, unused_context):
     # type: (PipelineContext) -> typing.Tuple[str, None]
+    # if we're containing a GroupByEncryptedKey, don't allow runners to
+    # recognize this transform as a GBEK so that it doesn't get replaced.
+    if self._replaced_by_gbek:
+      return super().to_runner_api_parameter(unused_context)
     return common_urns.primitives.GROUP_BY_KEY.urn, None
 
   @staticmethod
diff --git a/sdks/python/apache_beam/transforms/util.py 
b/sdks/python/apache_beam/transforms/util.py
index c63478dc0cf..8bce3dfc55f 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -341,6 +341,44 @@ class Secret():
     """Generates a new secret key."""
     return Fernet.generate_key()
 
+  @staticmethod
+  def parse_secret_option(secret) -> 'Secret':
+    """Parses a secret string and returns the appropriate secret type.
+
+    The secret string should be formatted like:
+    'type:<secret_type>;<secret_param>:<value>'
+
+    For example, 'type:GcpSecret;version_name:my_secret/versions/latest'
+    would return a GcpSecret initialized with 'my_secret/versions/latest'.
+    """
+    param_map = {}
+    for param in secret.split(';'):
+      parts = param.split(':')
+      param_map[parts[0]] = parts[1]
+
+    if 'type' not in param_map:
+      raise RuntimeError('Secret string must contain a valid type parameter')
+
+    secret_type = param_map['type'].lower()
+    del param_map['type']
+    secret_class = None
+    secret_params = None
+    if secret_type == 'gcpsecret':
+      secret_class = GcpSecret
+      secret_params = ['version_name']
+    else:
+      raise RuntimeError(
+          f'Invalid secret type {secret_type}, currently only '
+          'GcpSecret is supported')
+
+    for param_name in param_map.keys():
+      if param_name not in secret_params:
+        raise RuntimeError(
+            f'Invalid secret parameter {param_name}, '
+            f'{secret_type} only supports the following '
+            f'parameters: {secret_params}')
+    return secret_class(**param_map)
+
 
 class GcpSecret(Secret):
   """A secret manager implementation that retrieves secrets from Google Cloud
@@ -367,7 +405,12 @@ class GcpSecret(Secret):
       secret = response.payload.data
       return secret
     except Exception as e:
-      raise RuntimeError(f'Failed to retrieve secret bytes with excetion {e}')
+      raise RuntimeError(
+          'Failed to retrieve secret bytes for secret '
+          f'{self._version_name} with exception {e}')
+
+  def __eq__(self, secret):
+    return self._version_name == getattr(secret, '_version_name', None)
 
 
 class _EncryptMessage(DoFn):
@@ -499,7 +542,9 @@ class GroupByEncryptedKey(PTransform):
     self._hmac_key = hmac_key
 
   def expand(self, pcoll):
-    kv_type_hint = pcoll.element_type
+    key_type, value_type = (typehints.typehints.coerce_to_kv_type(
+        pcoll.element_type).tuple_types)
+    kv_type_hint = typehints.KV[key_type, value_type]
     if kv_type_hint and kv_type_hint != typehints.Any:
       coder = coders.registry.get_coder(kv_type_hint).as_deterministic_coder(
           f'GroupByEncryptedKey {self.label}'
@@ -518,10 +563,13 @@ class GroupByEncryptedKey(PTransform):
       key_coder = coders.registry.get_coder(typehints.Any)
       value_coder = key_coder
 
+    gbk = beam.GroupByKey()
+    gbk._inside_gbek = True
+
     return (
         pcoll
         | beam.ParDo(_EncryptMessage(self._hmac_key, key_coder, value_coder))
-        | beam.GroupByKey()
+        | gbk
         | beam.ParDo(_DecryptMessage(self._hmac_key, key_coder, value_coder)))
 
 
diff --git a/sdks/python/apache_beam/transforms/util_test.py 
b/sdks/python/apache_beam/transforms/util_test.py
index 6cd8d5fcba7..5f7a383b395 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -50,6 +50,7 @@ from apache_beam import WindowInto
 from apache_beam.coders import coders
 from apache_beam.metrics import MetricsFilter
 from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
 from apache_beam.options.pipeline_options import StandardOptions
 from apache_beam.options.pipeline_options import TypeOptions
 from apache_beam.portability import common_urns
@@ -289,6 +290,37 @@ class 
MockNoOpDecrypt(beam.transforms.util._DecryptMessage):
     return super().process(element)
 
 
+class SecretTest(unittest.TestCase):
+  @parameterized.expand([
+      param(
+          
secret_string='type:GcpSecret;version_name:my_secret/versions/latest',
+          secret=GcpSecret('my_secret/versions/latest')),
+      param(
+          secret_string='type:GcpSecret;version_name:foo',
+          secret=GcpSecret('foo')),
+      param(
+          
secret_string='type:gcpsecreT;version_name:my_secret/versions/latest',
+          secret=GcpSecret('my_secret/versions/latest')),
+  ])
+  def test_secret_manager_parses_correctly(self, secret_string, secret):
+    self.assertEqual(secret, Secret.parse_secret_option(secret_string))
+
+  @parameterized.expand([
+      param(
+          secret_string='version_name:foo',
+          exception_str='must contain a valid type parameter'),
+      param(
+          secret_string='type:gcpsecreT',
+          exception_str='missing 1 required positional argument'),
+      param(
+          secret_string='type:gcpsecreT;version_name:foo;extra:val',
+          exception_str='Invalid secret parameter extra'),
+  ])
+  def test_secret_manager_throws_on_invalid(self, secret_string, 
exception_str):
+    with self.assertRaisesRegex(Exception, exception_str):
+      Secret.parse_secret_option(secret_string)
+
+
 class GroupByEncryptedKeyTest(unittest.TestCase):
   def setUp(self):
     if secretmanager is not None:
@@ -318,7 +350,9 @@ class GroupByEncryptedKeyTest(unittest.TestCase):
                     'data': Secret.generate_secret_bytes()
                 }
             })
-      self.gcp_secret = GcpSecret(f'{self.secret_path}/versions/latest')
+      version_name = f'{self.secret_path}/versions/latest'
+      self.gcp_secret = GcpSecret(version_name)
+      self.secret_option = f'type:GcpSecret;version_name:{version_name}'
 
   def tearDown(self):
     if secretmanager is not None:
@@ -334,6 +368,19 @@ class GroupByEncryptedKeyTest(unittest.TestCase):
       assert_that(
           result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))
 
+  def test_gbk_with_gbek_option_fake_secret_manager_roundtrips(self):
+    options = PipelineOptions()
+    options.view_as(SetupOptions).gbek = self.secret_option
+
+    with beam.Pipeline(options=options) as pipeline:
+      pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
+                                                     ('b', 3), ('c', 4)])
+      result = (pcoll_1) | beam.GroupByKey()
+      sorted_result = result | beam.Map(lambda x: (x[0], sorted(x[1])))
+      assert_that(
+          sorted_result,
+          equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))
+
   @mock.patch('apache_beam.transforms.util._DecryptMessage', MockNoOpDecrypt)
   def test_gbek_fake_secret_manager_actually_does_encryption(self):
     fakeSecret = FakeSecret()
@@ -345,6 +392,19 @@ class GroupByEncryptedKeyTest(unittest.TestCase):
       assert_that(
           result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))
 
+  @mock.patch('apache_beam.transforms.util._DecryptMessage', MockNoOpDecrypt)
+  def test_gbk_actually_does_encryption(self):
+    options = PipelineOptions()
+    options.view_as(SetupOptions).gbek = self.secret_option
+
+    with TestPipeline(options=options) as pipeline:
+      pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
+                                                     ('b', 3), ('c', 4)],
+                                                    reshuffle=False)
+      result = pcoll_1 | beam.GroupByKey()
+      # assert_that(
+      #     result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))
+
   def test_gbek_fake_secret_manager_throws(self):
     fakeSecret = FakeSecret(True)
 

Reply via email to