This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch users/damccorm/kmsSecret
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 1d98933a0eea73cbb7f707fbd026d84db5c19b5c
Author: Danny Mccormick <[email protected]>
AuthorDate: Mon Nov 24 15:53:54 2025 -0500

    Add new method of generating key for GBEK
---
 sdks/python/apache_beam/transforms/core_it_test.py |  67 ++++++++++++
 sdks/python/apache_beam/transforms/util.py         | 118 +++++++++++++++++++-
 sdks/python/apache_beam/transforms/util_test.py    | 119 +++++++++++++++++++++
 sdks/python/setup.py                               |   1 +
 4 files changed, 304 insertions(+), 1 deletion(-)

diff --git a/sdks/python/apache_beam/transforms/core_it_test.py 
b/sdks/python/apache_beam/transforms/core_it_test.py
index 18ae3f30f57..a4658d51a61 100644
--- a/sdks/python/apache_beam/transforms/core_it_test.py
+++ b/sdks/python/apache_beam/transforms/core_it_test.py
@@ -38,6 +38,11 @@ try:
 except ImportError:
   secretmanager = None  # type: ignore[assignment]
 
+try:
+  from google.cloud import kms
+except ImportError:
+  kms = None  # type: ignore[assignment]
+
 
 class GbekIT(unittest.TestCase):
   @classmethod
@@ -74,10 +79,56 @@ class GbekIT(unittest.TestCase):
       cls.gcp_secret = GcpSecret(version_name)
       cls.secret_option = f'type:GcpSecret;version_name:{version_name}'
 
+    if kms is not None:
+      cls.kms_client = kms.KeyManagementServiceClient()
+      cls.location_id = 'global'
+      py_version = f'_py{sys.version_info.major}{sys.version_info.minor}'
+      secret_postfix = datetime.now().strftime('%m%d_%H%M%S') + py_version
+      cls.key_ring_id = 'gbekit_key_ring_tests_' + secret_postfix
+      cls.key_ring_path = cls.kms_client.key_ring_path(
+          cls.project_id, cls.location_id, cls.key_ring_id)
+      try:
+        cls.kms_client.get_key_ring(request={'name': cls.key_ring_path})
+      except Exception:
+        cls.kms_client.create_key_ring(
+            request={
+                'parent':
+                    f'projects/{cls.project_id}/locations/{cls.location_id}',
+                'key_ring_id': cls.key_ring_id,
+            })
+      cls.key_id = 'gbekit_key_tests_' + secret_postfix
+      cls.key_path = cls.kms_client.crypto_key_path(
+          cls.project_id, cls.location_id, cls.key_ring_id, cls.key_id)
+      try:
+        cls.kms_client.get_crypto_key(request={'name': cls.key_path})
+      except Exception:
+        cls.kms_client.create_crypto_key(
+            request={
+                'parent': cls.key_ring_path,
+                'crypto_key_id': cls.key_id,
+                'crypto_key': {
+                    'purpose': kms.CryptoKey.CryptoKeyPurpose.ENCRYPT_DECRYPT
+                }
+            })
+      cls.hsm_secret_option = (
+          f'type:GcpHsmGeneratedSecret;project_id:{cls.project_id};'
+          f'location_id:{cls.location_id};key_ring_id:{cls.key_ring_id};'
+          f'key_id:{cls.key_id};job_name:{secret_postfix}')
+
   @classmethod
   def tearDownClass(cls):
     if secretmanager is not None:
       cls.client.delete_secret(request={'name': cls.secret_path})
+    if kms is not None and hasattr(cls, 'kms_client') and hasattr(
+        cls, 'key_path'):
+      for version in cls.kms_client.list_crypto_key_versions(
+          request={'parent': cls.key_path}):
+        if version.state in [
+            kms.CryptoKeyVersion.CryptoKeyVersionState.ENABLED,
+            kms.CryptoKeyVersion.CryptoKeyVersionState.DISABLED
+        ]:
+          cls.kms_client.destroy_crypto_key_version(
+              request={'name': version.name})
 
   @pytest.mark.it_postcommit
   @unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed')
@@ -94,6 +145,22 @@ class GbekIT(unittest.TestCase):
 
     pipeline.run().wait_until_finish()
 
+  @pytest.mark.it_postcommit
+  @unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed')
+  @unittest.skipIf(kms is None, 'GCP dependencies are not installed')
+  def test_gbk_with_gbek_hsm_it(self):
+    pipeline = TestPipeline(is_integration_test=True)
+    pipeline.options.view_as(SetupOptions).gbek = self.hsm_secret_option
+
+    pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), ('b', 
3),
+                                                   ('c', 4)])
+    result = (pcoll_1) | beam.GroupByKey()
+    sorted_result = result | beam.Map(lambda x: (x[0], sorted(x[1])))
+    assert_that(
+        sorted_result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))
+
+    pipeline.run().wait_until_finish()
+
   @pytest.mark.it_postcommit
   @unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed')
   def test_combineValues_with_gbek_it(self):
diff --git a/sdks/python/apache_beam/transforms/util.py 
b/sdks/python/apache_beam/transforms/util.py
index ba79d4ddf31..7e96a1093c1 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -30,6 +30,7 @@ import re
 import threading
 import time
 import uuid
+from datetime import datetime
 from collections.abc import Callable
 from collections.abc import Iterable
 from typing import TYPE_CHECKING
@@ -366,10 +367,15 @@ class Secret():
     if secret_type == 'gcpsecret':
       secret_class = GcpSecret
       secret_params = ['version_name']
+    elif secret_type == 'gcphsmgeneratedsecret':
+      secret_class = GcpHsmGeneratedSecret
+      secret_params = [
+          'project_id', 'location_id', 'key_ring_id', 'key_id', 'job_name'
+      ]
     else:
       raise ValueError(
           f'Invalid secret type {secret_type}, currently only '
-          'GcpSecret is supported')
+          'GcpSecret and GcpHsmGeneratedSecret are supported')
 
     for param_name in param_map.keys():
       if param_name not in secret_params:
@@ -413,6 +419,116 @@ class GcpSecret(Secret):
     return self._version_name == getattr(secret, '_version_name', None)
 
 
+class GcpHsmGeneratedSecret(Secret):
+  def __init__(
+      self, project_id: str, location_id: str, key_ring_id: str, key_id: str, 
job_name: str):
+    self._project_id = project_id
+    self._location_id = location_id
+    self._key_ring_id = key_ring_id
+    self._key_id = key_id
+    self._secret_version_name = f'HsmGeneratedSecret_{job_name}'
+
+  def get_secret_bytes(self) -> bytes:
+    try:
+      from google.cloud import secretmanager
+      from google.api_core import exceptions as api_exceptions
+      client = secretmanager.SecretManagerServiceClient()
+
+      project_path = f"projects/{self._project_id}"
+      secret_path = f"{project_path}/secrets/{self._secret_version_name}"
+      # Since we may generate multiple versions when doing this on workers,
+      # just always take the first version added to maintain consistency.
+      secret_version_path = f"{secret_path}/versions/1"
+
+      try:
+        response = client.access_secret_version(
+            request={"name": secret_version_path})
+        return response.payload.data
+      except api_exceptions.NotFound:
+        # Don't bother logging yet, we'll only log if we actually add the
+        # secret version below
+        pass
+
+      try:
+        client.create_secret(
+            request={
+                "parent": project_path,
+                "secret_id": self._secret_version_name,
+                "secret": {"replication": {"automatic": {}}},
+            })
+      except api_exceptions.AlreadyExists:
+        # Don't bother logging yet, we'll only log if we actually add the
+        # secret version below
+        pass
+
+      new_key = self.generate_dek()
+      try:
+        # Try one more time in case it was created while we were generating the
+        # DEK.
+        response = client.access_secret_version(
+            request={"name": secret_version_path})
+        return response.payload.data
+      except api_exceptions.NotFound:
+        logging.info(
+            f"Secret version {secret_version_path} not found. "
+            "Creating new secret and version.")
+      client.add_secret_version(
+          request={"parent": secret_path, "payload": {"data": new_key}})
+      response = client.access_secret_version(
+          request={"name": secret_version_path})
+      return response.payload.data
+
+    except Exception as e:
+      raise RuntimeError(
+          f'Failed to retrieve or create secret bytes for secret '
+          f'{self._secret_version_name} with exception {e}')
+    
+  def generate_dek(self, dek_size: int = 32) -> bytes:
+    """Generates a new Data Encryption Key (DEK) using an HSM-backed key.
+
+    This function follows a key derivation process that incorporates entropy
+    from the HSM-backed key into the nonce used for key derivation.
+
+    Returns:
+        A new DEK of the specified size.
+    """
+    try:
+      import base64
+      import os
+      from cryptography.hazmat.primitives.kdf.hkdf import HKDF
+      from cryptography.hazmat.primitives import hashes
+      from google.cloud import kms
+
+      # 1. Generate a random nonce (nonce_one)
+      nonce_one = os.urandom(dek_size)
+
+      # 2. Use the HSM-backed key to encrypt nonce_one to create nonce_two
+      kms_client = kms.KeyManagementServiceClient()
+      key_path = kms_client.crypto_key_path(
+          self._project_id,
+          self._location_id,
+          self._key_ring_id,
+          self._key_id)
+      response = kms_client.encrypt(
+          request={'name': key_path, 'plaintext': nonce_one})
+      nonce_two = response.ciphertext
+
+      # 3. Generate a Derivation Key (DK)
+      dk = os.urandom(dek_size)
+
+      # 4. Use a KDF to derive the DEK using DK and nonce_two
+      hkdf = HKDF(
+          algorithm=hashes.SHA256(),
+          length=dek_size,
+          salt=nonce_two,
+          info=None,
+      )
+      dek = hkdf.derive(dk)
+      return base64.urlsafe_b64encode(dek)
+    except Exception as e:
+      raise RuntimeError(f'Failed to generate DEK with exception {e}')
+
+
 class _EncryptMessage(DoFn):
   """A DoFn that encrypts the key and value of each element."""
   def __init__(
diff --git a/sdks/python/apache_beam/transforms/util_test.py 
b/sdks/python/apache_beam/transforms/util_test.py
index 34e251fad1c..a6270e9f6a0 100644
--- a/sdks/python/apache_beam/transforms/util_test.py
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -73,6 +73,7 @@ from apache_beam.transforms.core import FlatMapTuple
 from apache_beam.transforms.trigger import AfterCount
 from apache_beam.transforms.trigger import Repeatedly
 from apache_beam.transforms.util import GcpSecret
+from apache_beam.transforms.util import GcpHsmGeneratedSecret
 from apache_beam.transforms.util import Secret
 from apache_beam.transforms.window import FixedWindows
 from apache_beam.transforms.window import GlobalWindow
@@ -439,6 +440,124 @@ class GroupByEncryptedKeyTest(unittest.TestCase):
             result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))
 
 
[email protected](secretmanager is None, 'GCP dependencies are not installed')
+class GcpHsmGeneratedSecretTest(unittest.TestCase):
+  def setUp(self):
+    self.mock_secret_manager_client = mock.MagicMock()
+    self.mock_kms_client = mock.MagicMock()
+
+    # Patch the clients
+    self.secretmanager_patcher = mock.patch(
+        'google.cloud.secretmanager.SecretManagerServiceClient',
+        return_value=self.mock_secret_manager_client)
+    self.kms_patcher = mock.patch(
+        'google.cloud.kms.KeyManagementServiceClient',
+        return_value=self.mock_kms_client)
+    self.os_urandom_patcher = mock.patch('os.urandom', return_value=b'0' * 32)
+    self.hkdf_patcher = mock.patch(
+        'cryptography.hazmat.primitives.kdf.hkdf.HKDF.derive',
+        return_value=b'derived_key')
+
+    self.secretmanager_patcher.start()
+    self.kms_patcher.start()
+    self.os_urandom_patcher.start()
+    self.hkdf_patcher.start()
+
+  def tearDown(self):
+    self.secretmanager_patcher.stop()
+    self.kms_patcher.stop()
+    self.os_urandom_patcher.stop()
+    self.hkdf_patcher.stop()
+
+  def test_happy_path_secret_creation(self):
+    from google.api_core import exceptions as api_exceptions
+
+    project_id = 'test-project'
+    location_id = 'global'
+    key_ring_id = 'test-key-ring'
+    key_id = 'test-key'
+    job_name = 'test-job'
+
+    secret = GcpHsmGeneratedSecret(
+        project_id, location_id, key_ring_id, key_id, job_name)
+
+    # Mock responses for secret creation path
+    self.mock_secret_manager_client.access_secret_version.side_effect = [
+        api_exceptions.NotFound('not found'),  # first check
+        api_exceptions.NotFound('not found'),  # second check
+        mock.MagicMock(payload=mock.MagicMock(data=b'derived_key'))
+    ]
+    self.mock_kms_client.encrypt.return_value = mock.MagicMock(
+        ciphertext=b'encrypted_nonce')
+
+    secret_bytes = secret.get_secret_bytes()
+    self.assertEqual(secret_bytes, b'derived_key')
+
+    # Assertions on mocks
+    secret_version_path = (
+        
f'projects/{project_id}/secrets/{secret._secret_version_name}/versions/1'
+    )
+    self.mock_secret_manager_client.access_secret_version.assert_any_call(
+        request={'name': secret_version_path})
+    self.assertEqual(
+        self.mock_secret_manager_client.access_secret_version.call_count, 3)
+    self.mock_secret_manager_client.create_secret.assert_called_once()
+    self.mock_kms_client.encrypt.assert_called_once()
+    self.mock_secret_manager_client.add_secret_version.assert_called_once()
+
+  def test_secret_already_exists(self):
+    from google.api_core import exceptions as api_exceptions
+
+    project_id = 'test-project'
+    location_id = 'global'
+    key_ring_id = 'test-key-ring'
+    key_id = 'test-key'
+    job_name = 'test-job'
+
+    secret = GcpHsmGeneratedSecret(
+        project_id, location_id, key_ring_id, key_id, job_name)
+
+    # Mock responses for secret creation path
+    self.mock_secret_manager_client.access_secret_version.side_effect = [
+        api_exceptions.NotFound('not found'),
+        api_exceptions.NotFound('not found'),
+        mock.MagicMock(payload=mock.MagicMock(data=b'derived_key'))
+    ]
+    self.mock_secret_manager_client.create_secret.side_effect = (
+        api_exceptions.AlreadyExists('exists'))
+    self.mock_kms_client.encrypt.return_value = mock.MagicMock(
+        ciphertext=b'encrypted_nonce')
+
+    secret_bytes = secret.get_secret_bytes()
+    self.assertEqual(secret_bytes, b'derived_key')
+
+    # Assertions on mocks
+    self.mock_secret_manager_client.create_secret.assert_called_once()
+    self.mock_secret_manager_client.add_secret_version.assert_called_once()
+
+  def test_secret_version_already_exists(self):
+    project_id = 'test-project'
+    location_id = 'global'
+    key_ring_id = 'test-key-ring'
+    key_id = 'test-key'
+    job_name = 'test-job'
+
+    secret = GcpHsmGeneratedSecret(
+        project_id, location_id, key_ring_id, key_id, job_name)
+
+    self.mock_secret_manager_client.access_secret_version.return_value = (
+        mock.MagicMock(payload=mock.MagicMock(data=b'existing_dek')))
+
+    secret_bytes = secret.get_secret_bytes()
+    self.assertEqual(secret_bytes, b'existing_dek')
+
+    # Assertions
+    self.mock_secret_manager_client.access_secret_version.assert_called_once()
+    self.mock_secret_manager_client.create_secret.assert_not_called()
+    self.mock_secret_manager_client.add_secret_version.assert_not_called()
+    self.mock_kms_client.encrypt.assert_not_called()
+
+
 class FakeClock(object):
   def __init__(self, now=time.time()):
     self._now = now
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index 074d64ae892..80b121fd9e3 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -485,6 +485,7 @@ if __name__ == '__main__':
               'google-cloud-spanner>=3.0.0,<4',
               # GCP Packages required by ML functionality
               'google-cloud-dlp>=3.0.0,<4',
+              'google-cloud-kms>=3.0.0,<4',
               'google-cloud-language>=2.0,<3',
               'google-cloud-secret-manager>=2.0,<3',
               'google-cloud-videointelligence>=2.0,<3',

Reply via email to