This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch users/damccorm/kmsSecret in repository https://gitbox.apache.org/repos/asf/beam.git
commit 1d98933a0eea73cbb7f707fbd026d84db5c19b5c Author: Danny Mccormick <[email protected]> AuthorDate: Mon Nov 24 15:53:54 2025 -0500 Add new method of generating key for GBEK --- sdks/python/apache_beam/transforms/core_it_test.py | 67 ++++++++++++ sdks/python/apache_beam/transforms/util.py | 118 +++++++++++++++++++- sdks/python/apache_beam/transforms/util_test.py | 119 +++++++++++++++++++++ sdks/python/setup.py | 1 + 4 files changed, 304 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/transforms/core_it_test.py b/sdks/python/apache_beam/transforms/core_it_test.py index 18ae3f30f57..a4658d51a61 100644 --- a/sdks/python/apache_beam/transforms/core_it_test.py +++ b/sdks/python/apache_beam/transforms/core_it_test.py @@ -38,6 +38,11 @@ try: except ImportError: secretmanager = None # type: ignore[assignment] +try: + from google.cloud import kms +except ImportError: + kms = None # type: ignore[assignment] + class GbekIT(unittest.TestCase): @classmethod @@ -74,10 +79,56 @@ class GbekIT(unittest.TestCase): cls.gcp_secret = GcpSecret(version_name) cls.secret_option = f'type:GcpSecret;version_name:{version_name}' + if kms is not None: + cls.kms_client = kms.KeyManagementServiceClient() + cls.location_id = 'global' + py_version = f'_py{sys.version_info.major}{sys.version_info.minor}' + secret_postfix = datetime.now().strftime('%m%d_%H%M%S') + py_version + cls.key_ring_id = 'gbekit_key_ring_tests_' + secret_postfix + cls.key_ring_path = cls.kms_client.key_ring_path( + cls.project_id, cls.location_id, cls.key_ring_id) + try: + cls.kms_client.get_key_ring(request={'name': cls.key_ring_path}) + except Exception: + cls.kms_client.create_key_ring( + request={ + 'parent': + f'projects/{cls.project_id}/locations/{cls.location_id}', + 'key_ring_id': cls.key_ring_id, + }) + cls.key_id = 'gbekit_key_tests_' + secret_postfix + cls.key_path = cls.kms_client.crypto_key_path( + cls.project_id, cls.location_id, cls.key_ring_id, cls.key_id) + try: + cls.kms_client.get_crypto_key(request={'name': cls.key_path}) + except Exception: + cls.kms_client.create_crypto_key( + request={ + 'parent': cls.key_ring_path, + 'crypto_key_id': cls.key_id, + 'crypto_key': { + 'purpose': kms.CryptoKey.CryptoKeyPurpose.ENCRYPT_DECRYPT + } + }) + cls.hsm_secret_option = ( + f'type:GcpHsmGeneratedSecret;project_id:{cls.project_id};' + f'location_id:{cls.location_id};key_ring_id:{cls.key_ring_id};' + f'key_id:{cls.key_id};job_name:{secret_postfix}') + @classmethod def tearDownClass(cls): if secretmanager is not None: cls.client.delete_secret(request={'name': cls.secret_path}) + if kms is not None and hasattr(cls, 'kms_client') and hasattr( + cls, 'key_path'): + for version in cls.kms_client.list_crypto_key_versions( + request={'parent': cls.key_path}): + if version.state in [ + kms.CryptoKeyVersion.CryptoKeyVersionState.ENABLED, + kms.CryptoKeyVersion.CryptoKeyVersionState.DISABLED + ]: + cls.kms_client.destroy_crypto_key_version( + request={'name': version.name}) @pytest.mark.it_postcommit @unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed') @@ -94,6 +145,22 @@ class GbekIT(unittest.TestCase): pipeline.run().wait_until_finish() + @pytest.mark.it_postcommit + @unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed') + @unittest.skipIf(kms is None, 'GCP dependencies are not installed') + def test_gbk_with_gbek_hsm_it(self): + pipeline = TestPipeline(is_integration_test=True) + pipeline.options.view_as(SetupOptions).gbek = self.hsm_secret_option + + pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), ('b', 3), + ('c', 4)]) + result = (pcoll_1) | beam.GroupByKey() + sorted_result = result | beam.Map(lambda x: (x[0], sorted(x[1]))) + assert_that( + sorted_result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) + + pipeline.run().wait_until_finish() + @pytest.mark.it_postcommit @unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed') def test_combineValues_with_gbek_it(self): diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index ba79d4ddf31..7e96a1093c1 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -30,6 +30,7 @@ import re import threading import time import uuid +from datetime import datetime from collections.abc import Callable from collections.abc import Iterable from typing import TYPE_CHECKING @@ -366,10 +367,15 @@ class Secret(): if secret_type == 'gcpsecret': secret_class = GcpSecret secret_params = ['version_name'] + elif secret_type == 'gcphsmgeneratedsecret': + secret_class = GcpHsmGeneratedSecret + secret_params = [ + 'project_id', 'location_id', 'key_ring_id', 'key_id', 'job_name' + ] else: raise ValueError( f'Invalid secret type {secret_type}, currently only ' - 'GcpSecret is supported') + 'GcpSecret and GcpHsmGeneratedSecret are supported') for param_name in param_map.keys(): if param_name not in secret_params: @@ -413,6 +419,116 @@ class GcpSecret(Secret): return self._version_name == getattr(secret, '_version_name', None) +class GcpHsmGeneratedSecret(Secret): + def __init__( + self, project_id: str, location_id: str, key_ring_id: str, key_id: str, job_name: str): + self._project_id = project_id + self._location_id = location_id + self._key_ring_id = key_ring_id + self._key_id = key_id + self._secret_version_name = f'HsmGeneratedSecret_{job_name}' + + def get_secret_bytes(self) -> bytes: + try: + from google.cloud import secretmanager + from google.api_core import exceptions as api_exceptions + client = secretmanager.SecretManagerServiceClient() + + project_path = f"projects/{self._project_id}" + secret_path = f"{project_path}/secrets/{self._secret_version_name}" + # Since we may generate multiple versions when doing this on workers, + # just always take the first version added to maintain consistency. + secret_version_path = f"{secret_path}/versions/1" + + try: + response = client.access_secret_version( + request={"name": secret_version_path}) + return response.payload.data + except api_exceptions.NotFound: + # Don't bother logging yet, we'll only log if we actually add the + # secret version below + pass + + try: + client.create_secret( + request={ + "parent": project_path, + "secret_id": self._secret_version_name, + "secret": {"replication": {"automatic": {}}}, + }) + except api_exceptions.AlreadyExists: + # Don't bother logging yet, we'll only log if we actually add the + # secret version below + pass + + new_key = self.generate_dek() + try: + # Try one more time in case it was created while we were generating the + # DEK. + response = client.access_secret_version( + request={"name": secret_version_path}) + return response.payload.data + except api_exceptions.NotFound: + logging.info( + f"Secret version {secret_version_path} not found. " + "Creating new secret and version.") + client.add_secret_version( + request={"parent": secret_path, "payload": {"data": new_key}}) + response = client.access_secret_version( + request={"name": secret_version_path}) + return response.payload.data + + except Exception as e: + raise RuntimeError( + f'Failed to retrieve or create secret bytes for secret ' + f'{self._secret_version_name} with exception {e}') + + def generate_dek(self, dek_size: int = 32) -> bytes: + """Generates a new Data Encryption Key (DEK) using an HSM-backed key. + + This function follows a key derivation process that incorporates entropy + from the HSM-backed key into the nonce used for key derivation. + + Returns: + A new DEK of the specified size. + """ + try: + import base64 + import os + from cryptography.hazmat.primitives.kdf.hkdf import HKDF + from cryptography.hazmat.primitives import hashes + from google.cloud import kms + + # 1. Generate a random nonce (nonce_one) + nonce_one = os.urandom(dek_size) + + # 2. Use the HSM-backed key to encrypt nonce_one to create nonce_two + kms_client = kms.KeyManagementServiceClient() + key_path = kms_client.crypto_key_path( + self._project_id, + self._location_id, + self._key_ring_id, + self._key_id) + response = kms_client.encrypt( + request={'name': key_path, 'plaintext': nonce_one}) + nonce_two = response.ciphertext + + # 3. Generate a Derivation Key (DK) + dk = os.urandom(dek_size) + + # 4. Use a KDF to derive the DEK using DK and nonce_two + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=dek_size, + salt=nonce_two, + info=None, + ) + dek = hkdf.derive(dk) + return base64.urlsafe_b64encode(dek) + except Exception as e: + raise RuntimeError(f'Failed to generate DEK with exception {e}') + + class _EncryptMessage(DoFn): """A DoFn that encrypts the key and value of each element.""" def __init__( diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 34e251fad1c..a6270e9f6a0 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -73,6 +73,7 @@ from apache_beam.transforms.core import FlatMapTuple from apache_beam.transforms.trigger import AfterCount from apache_beam.transforms.trigger import Repeatedly from apache_beam.transforms.util import GcpSecret +from apache_beam.transforms.util import GcpHsmGeneratedSecret from apache_beam.transforms.util import Secret from apache_beam.transforms.window import FixedWindows from apache_beam.transforms.window import GlobalWindow @@ -439,6 +440,124 @@ class GroupByEncryptedKeyTest(unittest.TestCase): result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) [email protected](secretmanager is None, 'GCP dependencies are not installed') +class GcpHsmGeneratedSecretTest(unittest.TestCase): + def setUp(self): + self.mock_secret_manager_client = mock.MagicMock() + self.mock_kms_client = mock.MagicMock() + + # Patch the clients + self.secretmanager_patcher = mock.patch( + 'google.cloud.secretmanager.SecretManagerServiceClient', + return_value=self.mock_secret_manager_client) + self.kms_patcher = mock.patch( + 'google.cloud.kms.KeyManagementServiceClient', + return_value=self.mock_kms_client) + self.os_urandom_patcher = mock.patch('os.urandom', return_value=b'0' * 32) + self.hkdf_patcher = mock.patch( + 'cryptography.hazmat.primitives.kdf.hkdf.HKDF.derive', + return_value=b'derived_key') + + self.secretmanager_patcher.start() + self.kms_patcher.start() + self.os_urandom_patcher.start() + self.hkdf_patcher.start() + + def tearDown(self): + self.secretmanager_patcher.stop() + self.kms_patcher.stop() + self.os_urandom_patcher.stop() + self.hkdf_patcher.stop() + + def test_happy_path_secret_creation(self): + from google.api_core import exceptions as api_exceptions + + project_id = 'test-project' + location_id = 'global' + key_ring_id = 'test-key-ring' + key_id = 'test-key' + job_name = 'test-job' + + secret = GcpHsmGeneratedSecret( + project_id, location_id, key_ring_id, key_id, job_name) + + # Mock responses for secret creation path + self.mock_secret_manager_client.access_secret_version.side_effect = [ + api_exceptions.NotFound('not found'), # first check + api_exceptions.NotFound('not found'), # second check + mock.MagicMock(payload=mock.MagicMock(data=b'derived_key')) + ] + self.mock_kms_client.encrypt.return_value = mock.MagicMock( + ciphertext=b'encrypted_nonce') + + secret_bytes = secret.get_secret_bytes() + self.assertEqual(secret_bytes, b'derived_key') + + # Assertions on mocks + secret_version_path = ( + f'projects/{project_id}/secrets/{secret._secret_version_name}/versions/1' + ) + self.mock_secret_manager_client.access_secret_version.assert_any_call( + request={'name': secret_version_path}) + self.assertEqual( + self.mock_secret_manager_client.access_secret_version.call_count, 3) + self.mock_secret_manager_client.create_secret.assert_called_once() + self.mock_kms_client.encrypt.assert_called_once() + self.mock_secret_manager_client.add_secret_version.assert_called_once() + + def test_secret_already_exists(self): + from google.api_core import exceptions as api_exceptions + + project_id = 'test-project' + location_id = 'global' + key_ring_id = 'test-key-ring' + key_id = 'test-key' + job_name = 'test-job' + + secret = GcpHsmGeneratedSecret( + project_id, location_id, key_ring_id, key_id, job_name) + + # Mock responses for secret creation path + self.mock_secret_manager_client.access_secret_version.side_effect = [ + api_exceptions.NotFound('not found'), + api_exceptions.NotFound('not found'), + mock.MagicMock(payload=mock.MagicMock(data=b'derived_key')) + ] + self.mock_secret_manager_client.create_secret.side_effect = ( + api_exceptions.AlreadyExists('exists')) + self.mock_kms_client.encrypt.return_value = mock.MagicMock( + ciphertext=b'encrypted_nonce') + + secret_bytes = secret.get_secret_bytes() + self.assertEqual(secret_bytes, b'derived_key') + + # Assertions on mocks + self.mock_secret_manager_client.create_secret.assert_called_once() + self.mock_secret_manager_client.add_secret_version.assert_called_once() + + def test_secret_version_already_exists(self): + project_id = 'test-project' + location_id = 'global' + key_ring_id = 'test-key-ring' + key_id = 'test-key' + job_name = 'test-job' + + secret = GcpHsmGeneratedSecret( + project_id, location_id, key_ring_id, key_id, job_name) + + self.mock_secret_manager_client.access_secret_version.return_value = ( + mock.MagicMock(payload=mock.MagicMock(data=b'existing_dek'))) + + secret_bytes = secret.get_secret_bytes() + self.assertEqual(secret_bytes, b'existing_dek') + + # Assertions + self.mock_secret_manager_client.access_secret_version.assert_called_once() + self.mock_secret_manager_client.create_secret.assert_not_called() + self.mock_secret_manager_client.add_secret_version.assert_not_called() + self.mock_kms_client.encrypt.assert_not_called() + + class FakeClock(object): def __init__(self, now=time.time()): self._now = now diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 074d64ae892..80b121fd9e3 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -485,6 +485,7 @@ if __name__ == '__main__': 'google-cloud-spanner>=3.0.0,<4', # GCP Packages required by ML functionality 'google-cloud-dlp>=3.0.0,<4', + 'google-cloud-kms>=3.0.0,<4', 'google-cloud-language>=2.0,<3', 'google-cloud-secret-manager>=2.0,<3', 'google-cloud-videointelligence>=2.0,<3',
