This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch users/damccorm/enforce_gbek in repository https://gitbox.apache.org/repos/asf/beam.git
commit dfb53d078aa4a8e0779a965329bc4807ad15c757 Author: Danny Mccormick <[email protected]> AuthorDate: Mon Sep 29 13:28:14 2025 -0400 Add pipeline option to enforce gbek --- .../python/apache_beam/options/pipeline_options.py | 10 ++++ sdks/python/apache_beam/transforms/core.py | 19 +++++++ sdks/python/apache_beam/transforms/util.py | 54 +++++++++++++++++-- sdks/python/apache_beam/transforms/util_test.py | 62 +++++++++++++++++++++- 4 files changed, 141 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 6595d683911..034073a41fe 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -1716,6 +1716,16 @@ class SetupOptions(PipelineOptions): help=( 'Docker registry url to use for tagging and pushing the prebuilt ' 'sdk worker container image.')) + parser.add_argument( + '--gbek', + default=None, + help=( + 'When set, will replace all GroupByKey transforms in the pipeline ' + 'with EncryptedGroupByKey transforms using the secret passed in ' + 'the option. Beam will infer the secret type and value based on ' + 'secret itself. The option should be structured like: ' + '--encrypt=type:<secret_type>;<secret_param>:<value>, for example ' + '--encrypt=type:GcpSecret;version_name:my_secret/versions/latest')) def validate(self, validator): errors = [] diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 2304faf478f..1f514b2f989 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -39,6 +39,7 @@ from apache_beam import typehints from apache_beam.coders import typecoders from apache_beam.internal import pickler from apache_beam.internal import util +from apache_beam.options.pipeline_options import SetupOptions from apache_beam.options.pipeline_options import TypeOptions from apache_beam.portability import common_urns from apache_beam.portability import python_urns @@ -3324,6 +3325,10 @@ class GroupByKey(PTransform): The implementation here is used only when run on the local direct runner. """ + def __init__(self): + self._replaced_by_gbek = False + self._inside_gbek = False + class ReifyWindows(DoFn): def process( self, element, window=DoFn.WindowParam, timestamp=DoFn.TimestampParam): @@ -3342,6 +3347,16 @@ class GroupByKey(PTransform): key_type, typehints.WindowedValue[value_type]] # type: ignore[misc] def expand(self, pcoll): + replace_with_gbek_secret = ( + pcoll.pipeline._options.view_as(SetupOptions).gbek) + if replace_with_gbek_secret is not None and not self._inside_gbek: + self._replaced_by_gbek = True + from apache_beam.transforms.util import GroupByEncryptedKey + from apache_beam.transforms.util import Secret + + secret = Secret.parse_secret_option(replace_with_gbek_secret) + return (pcoll | "Group by encrypted key" >> GroupByEncryptedKey(secret)) + from apache_beam.transforms.trigger import DataLossReason from apache_beam.transforms.trigger import DefaultTrigger windowing = pcoll.windowing @@ -3389,6 +3404,10 @@ class GroupByKey(PTransform): def to_runner_api_parameter(self, unused_context): # type: (PipelineContext) -> typing.Tuple[str, None] + # if we're containing a GroupByEncryptedKey, don't allow runners to + # recognize this transform as a GBEK so that it doesn't get replaced. + if self._replaced_by_gbek: + return super().to_runner_api_parameter(unused_context) return common_urns.primitives.GROUP_BY_KEY.urn, None @staticmethod diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index c63478dc0cf..8bce3dfc55f 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -341,6 +341,44 @@ class Secret(): """Generates a new secret key.""" return Fernet.generate_key() + @staticmethod + def parse_secret_option(secret) -> 'Secret': + """Parses a secret string and returns the appropriate secret type. + + The secret string should be formatted like: + 'type:<secret_type>;<secret_param>:<value>' + + For example, 'type:GcpSecret;version_name:my_secret/versions/latest' + would return a GcpSecret initialized with 'my_secret/versions/latest'. + """ + param_map = {} + for param in secret.split(';'): + parts = param.split(':') + param_map[parts[0]] = parts[1] + + if 'type' not in param_map: + raise RuntimeError('Secret string must contain a valid type parameter') + + secret_type = param_map['type'].lower() + del param_map['type'] + secret_class = None + secret_params = None + if secret_type == 'gcpsecret': + secret_class = GcpSecret + secret_params = ['version_name'] + else: + raise RuntimeError( + f'Invalid secret type {secret_type}, currently only ' + 'GcpSecret is supported') + + for param_name in param_map.keys(): + if param_name not in secret_params: + raise RuntimeError( + f'Invalid secret parameter {param_name}, ' + f'{secret_type} only supports the following ' + f'parameters: {secret_params}') + return secret_class(**param_map) + class GcpSecret(Secret): """A secret manager implementation that retrieves secrets from Google Cloud @@ -367,7 +405,12 @@ class GcpSecret(Secret): secret = response.payload.data return secret except Exception as e: - raise RuntimeError(f'Failed to retrieve secret bytes with excetion {e}') + raise RuntimeError( + 'Failed to retrieve secret bytes for secret ' + f'{self._version_name} with exception {e}') + + def __eq__(self, secret): + return self._version_name == getattr(secret, '_version_name', None) class _EncryptMessage(DoFn): @@ -499,7 +542,9 @@ class GroupByEncryptedKey(PTransform): self._hmac_key = hmac_key def expand(self, pcoll): - kv_type_hint = pcoll.element_type + key_type, value_type = (typehints.typehints.coerce_to_kv_type( + pcoll.element_type).tuple_types) + kv_type_hint = typehints.KV[key_type, value_type] if kv_type_hint and kv_type_hint != typehints.Any: coder = coders.registry.get_coder(kv_type_hint).as_deterministic_coder( f'GroupByEncryptedKey {self.label}' @@ -518,10 +563,13 @@ class GroupByEncryptedKey(PTransform): key_coder = coders.registry.get_coder(typehints.Any) value_coder = key_coder + gbk = beam.GroupByKey() + gbk._inside_gbek = True + return ( pcoll | beam.ParDo(_EncryptMessage(self._hmac_key, key_coder, value_coder)) - | beam.GroupByKey() + | gbk | beam.ParDo(_DecryptMessage(self._hmac_key, key_coder, value_coder))) diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 6cd8d5fcba7..5f7a383b395 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -50,6 +50,7 @@ from apache_beam import WindowInto from apache_beam.coders import coders from apache_beam.metrics import MetricsFilter from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.pipeline_options import TypeOptions from apache_beam.portability import common_urns @@ -289,6 +290,37 @@ class MockNoOpDecrypt(beam.transforms.util._DecryptMessage): return super().process(element) +class SecretTest(unittest.TestCase): + @parameterized.expand([ + param( + secret_string='type:GcpSecret;version_name:my_secret/versions/latest', + secret=GcpSecret('my_secret/versions/latest')), + param( + secret_string='type:GcpSecret;version_name:foo', + secret=GcpSecret('foo')), + param( + secret_string='type:gcpsecreT;version_name:my_secret/versions/latest', + secret=GcpSecret('my_secret/versions/latest')), + ]) + def test_secret_manager_parses_correctly(self, secret_string, secret): + self.assertEqual(secret, Secret.parse_secret_option(secret_string)) + + @parameterized.expand([ + param( + secret_string='version_name:foo', + exception_str='must contain a valid type parameter'), + param( + secret_string='type:gcpsecreT', + exception_str='missing 1 required positional argument'), + param( + secret_string='type:gcpsecreT;version_name:foo;extra:val', + exception_str='Invalid secret parameter extra'), + ]) + def test_secret_manager_throws_on_invalid(self, secret_string, exception_str): + with self.assertRaisesRegex(Exception, exception_str): + Secret.parse_secret_option(secret_string) + + class GroupByEncryptedKeyTest(unittest.TestCase): def setUp(self): if secretmanager is not None: @@ -318,7 +350,9 @@ class GroupByEncryptedKeyTest(unittest.TestCase): 'data': Secret.generate_secret_bytes() } }) - self.gcp_secret = GcpSecret(f'{self.secret_path}/versions/latest') + version_name = f'{self.secret_path}/versions/latest' + self.gcp_secret = GcpSecret(version_name) + self.secret_option = f'type:GcpSecret;version_name:{version_name}' def tearDown(self): if secretmanager is not None: @@ -334,6 +368,19 @@ class GroupByEncryptedKeyTest(unittest.TestCase): assert_that( result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) + def test_gbk_with_gbek_option_fake_secret_manager_roundtrips(self): + options = PipelineOptions() + options.view_as(SetupOptions).gbek = self.secret_option + + with beam.Pipeline(options=options) as pipeline: + pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), + ('b', 3), ('c', 4)]) + result = (pcoll_1) | beam.GroupByKey() + sorted_result = result | beam.Map(lambda x: (x[0], sorted(x[1]))) + assert_that( + sorted_result, + equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) + @mock.patch('apache_beam.transforms.util._DecryptMessage', MockNoOpDecrypt) def test_gbek_fake_secret_manager_actually_does_encryption(self): fakeSecret = FakeSecret() @@ -345,6 +392,19 @@ class GroupByEncryptedKeyTest(unittest.TestCase): assert_that( result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) + @mock.patch('apache_beam.transforms.util._DecryptMessage', MockNoOpDecrypt) + def test_gbk_actually_does_encryption(self): + options = PipelineOptions() + options.view_as(SetupOptions).gbek = self.secret_option + + with TestPipeline(options=options) as pipeline: + pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), + ('b', 3), ('c', 4)], + reshuffle=False) + result = pcoll_1 | beam.GroupByKey() + # assert_that( + # result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) + def test_gbek_fake_secret_manager_throws(self): fakeSecret = FakeSecret(True)
