This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 9ed06d081ec Handle null keys in gbek (#36505)
9ed06d081ec is described below

commit 9ed06d081ec72355d6794d8d76df6dc3569c9a3c
Author: Danny McCormick <[email protected]>
AuthorDate: Wed Oct 15 06:34:46 2025 -0400

    Handle null keys in gbek (#36505)
    
    * Handle null keys in gbek
    
    * Allow null values with hashmap
    
    * add a test
    
    * Test + remove check entirely
---
 .../beam/sdk/transforms/GroupByEncryptedKey.java   | 30 ++++++++++------------
 .../sdk/transforms/GroupByEncryptedKeyTest.java    |  9 +++++--
 2 files changed, 20 insertions(+), 19 deletions(-)

diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java
index 1f4b7535d89..85483fd517a 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java
@@ -239,8 +239,9 @@ public class GroupByEncryptedKey<K, V>
     }
 
     @ProcessElement
+    @SuppressWarnings("nullness")
     public void processElement(ProcessContext c) throws Exception {
-      java.util.Map<K, java.util.List<V>> decryptedKvs = new 
java.util.HashMap<>();
+      java.util.HashMap<K, java.util.List<V>> decryptedKvs = new 
java.util.HashMap<>();
       for (KV<byte[], byte[]> encryptedKv : c.element().getValue()) {
         byte[] iv = Arrays.copyOfRange(encryptedKv.getKey(), 0, 12);
         GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(128, iv);
@@ -251,24 +252,19 @@ public class GroupByEncryptedKey<K, V>
         byte[] decryptedKeyBytes = this.cipher.doFinal(encryptedKey);
         K key = decode(this.keyCoder, decryptedKeyBytes);
 
-        if (key != null) {
-          if (!decryptedKvs.containsKey(key)) {
-            decryptedKvs.put(key, new java.util.ArrayList<>());
-          }
+        if (!decryptedKvs.containsKey(key)) {
+          decryptedKvs.put(key, new java.util.ArrayList<>());
+        }
 
-          iv = Arrays.copyOfRange(encryptedKv.getValue(), 0, 12);
-          gcmParameterSpec = new GCMParameterSpec(128, iv);
-          this.cipher.init(Cipher.DECRYPT_MODE, this.secretKeySpec, 
gcmParameterSpec);
+        iv = Arrays.copyOfRange(encryptedKv.getValue(), 0, 12);
+        gcmParameterSpec = new GCMParameterSpec(128, iv);
+        this.cipher.init(Cipher.DECRYPT_MODE, this.secretKeySpec, 
gcmParameterSpec);
 
-          byte[] encryptedValue =
-              Arrays.copyOfRange(encryptedKv.getValue(), 12, 
encryptedKv.getValue().length);
-          byte[] decryptedValueBytes = this.cipher.doFinal(encryptedValue);
-          V value = decode(this.valueCoder, decryptedValueBytes);
-          decryptedKvs.get(key).add(value);
-        } else {
-          throw new RuntimeException(
-              "Found null key when decoding " + 
Arrays.toString(decryptedKeyBytes));
-        }
+        byte[] encryptedValue =
+            Arrays.copyOfRange(encryptedKv.getValue(), 12, 
encryptedKv.getValue().length);
+        byte[] decryptedValueBytes = this.cipher.doFinal(encryptedValue);
+        V value = decode(this.valueCoder, decryptedValueBytes);
+        decryptedKvs.get(key).add(value);
       }
 
       for (java.util.Map.Entry<K, java.util.List<V>> entry : 
decryptedKvs.entrySet()) {
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java
index 3a2fc2f08c0..31064470bd3 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java
@@ -33,6 +33,7 @@ import java.util.List;
 import java.util.stream.Collectors;
 import java.util.stream.StreamSupport;
 import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.NullableCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.testing.NeedsRunner;
@@ -42,6 +43,7 @@ import org.apache.beam.sdk.util.GcpSecret;
 import org.apache.beam.sdk.util.Secret;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.checkerframework.checker.nullness.qual.Nullable;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Rule;
@@ -141,20 +143,22 @@ public class GroupByEncryptedKeyTest implements 
Serializable {
   @Test
   @Category(NeedsRunner.class)
   public void testGroupByKeyGcpSecret() {
-    List<KV<String, Integer>> ungroupedPairs =
+    List<KV<@Nullable String, Integer>> ungroupedPairs =
         Arrays.asList(
+            KV.of(null, 3),
             KV.of("k1", 3),
             KV.of("k5", Integer.MAX_VALUE),
             KV.of("k5", Integer.MIN_VALUE),
             KV.of("k2", 66),
             KV.of("k1", 4),
+            KV.of(null, 5),
             KV.of("k2", -33),
             KV.of("k3", 0));
 
     PCollection<KV<String, Integer>> input =
         p.apply(
             Create.of(ungroupedPairs)
-                .withCoder(KvCoder.of(StringUtf8Coder.of(), 
VarIntCoder.of())));
+                .withCoder(KvCoder.of(NullableCoder.of(StringUtf8Coder.of()), 
VarIntCoder.of())));
 
     PCollection<KV<String, Iterable<Integer>>> output =
         input.apply(GroupByEncryptedKey.<String, Integer>create(gcpSecret));
@@ -162,6 +166,7 @@ public class GroupByEncryptedKeyTest implements 
Serializable {
     PAssert.that(output.apply("Sort", MapElements.via(new SortValues())))
         .containsInAnyOrder(
             KV.of("k1", Arrays.asList(3, 4)),
+            KV.of(null, Arrays.asList(3, 5)),
             KV.of("k5", Arrays.asList(Integer.MIN_VALUE, Integer.MAX_VALUE)),
             KV.of("k2", Arrays.asList(-33, 66)),
             KV.of("k3", Arrays.asList(0)));

Reply via email to