This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 9ed06d081ec Handle null keys in gbek (#36505)
9ed06d081ec is described below
commit 9ed06d081ec72355d6794d8d76df6dc3569c9a3c
Author: Danny McCormick <[email protected]>
AuthorDate: Wed Oct 15 06:34:46 2025 -0400
Handle null keys in gbek (#36505)
* Handle null keys in gbek
* Allow null values with hashmap
* add a test
* Test + remove check entirely
---
.../beam/sdk/transforms/GroupByEncryptedKey.java | 30 ++++++++++------------
.../sdk/transforms/GroupByEncryptedKeyTest.java | 9 +++++--
2 files changed, 20 insertions(+), 19 deletions(-)
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java
index 1f4b7535d89..85483fd517a 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByEncryptedKey.java
@@ -239,8 +239,9 @@ public class GroupByEncryptedKey<K, V>
}
@ProcessElement
+ @SuppressWarnings("nullness")
public void processElement(ProcessContext c) throws Exception {
- java.util.Map<K, java.util.List<V>> decryptedKvs = new
java.util.HashMap<>();
+ java.util.HashMap<K, java.util.List<V>> decryptedKvs = new
java.util.HashMap<>();
for (KV<byte[], byte[]> encryptedKv : c.element().getValue()) {
byte[] iv = Arrays.copyOfRange(encryptedKv.getKey(), 0, 12);
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(128, iv);
@@ -251,24 +252,19 @@ public class GroupByEncryptedKey<K, V>
byte[] decryptedKeyBytes = this.cipher.doFinal(encryptedKey);
K key = decode(this.keyCoder, decryptedKeyBytes);
- if (key != null) {
- if (!decryptedKvs.containsKey(key)) {
- decryptedKvs.put(key, new java.util.ArrayList<>());
- }
+ if (!decryptedKvs.containsKey(key)) {
+ decryptedKvs.put(key, new java.util.ArrayList<>());
+ }
- iv = Arrays.copyOfRange(encryptedKv.getValue(), 0, 12);
- gcmParameterSpec = new GCMParameterSpec(128, iv);
- this.cipher.init(Cipher.DECRYPT_MODE, this.secretKeySpec,
gcmParameterSpec);
+ iv = Arrays.copyOfRange(encryptedKv.getValue(), 0, 12);
+ gcmParameterSpec = new GCMParameterSpec(128, iv);
+ this.cipher.init(Cipher.DECRYPT_MODE, this.secretKeySpec,
gcmParameterSpec);
- byte[] encryptedValue =
- Arrays.copyOfRange(encryptedKv.getValue(), 12,
encryptedKv.getValue().length);
- byte[] decryptedValueBytes = this.cipher.doFinal(encryptedValue);
- V value = decode(this.valueCoder, decryptedValueBytes);
- decryptedKvs.get(key).add(value);
- } else {
- throw new RuntimeException(
- "Found null key when decoding " +
Arrays.toString(decryptedKeyBytes));
- }
+ byte[] encryptedValue =
+ Arrays.copyOfRange(encryptedKv.getValue(), 12,
encryptedKv.getValue().length);
+ byte[] decryptedValueBytes = this.cipher.doFinal(encryptedValue);
+ V value = decode(this.valueCoder, decryptedValueBytes);
+ decryptedKvs.get(key).add(value);
}
for (java.util.Map.Entry<K, java.util.List<V>> entry :
decryptedKvs.entrySet()) {
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java
index 3a2fc2f08c0..31064470bd3 100644
---
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java
+++
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java
@@ -33,6 +33,7 @@ import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.NullableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.testing.NeedsRunner;
@@ -42,6 +43,7 @@ import org.apache.beam.sdk.util.GcpSecret;
import org.apache.beam.sdk.util.Secret;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
+import org.checkerframework.checker.nullness.qual.Nullable;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Rule;
@@ -141,20 +143,22 @@ public class GroupByEncryptedKeyTest implements
Serializable {
@Test
@Category(NeedsRunner.class)
public void testGroupByKeyGcpSecret() {
- List<KV<String, Integer>> ungroupedPairs =
+ List<KV<@Nullable String, Integer>> ungroupedPairs =
Arrays.asList(
+ KV.of(null, 3),
KV.of("k1", 3),
KV.of("k5", Integer.MAX_VALUE),
KV.of("k5", Integer.MIN_VALUE),
KV.of("k2", 66),
KV.of("k1", 4),
+ KV.of(null, 5),
KV.of("k2", -33),
KV.of("k3", 0));
PCollection<KV<String, Integer>> input =
p.apply(
Create.of(ungroupedPairs)
- .withCoder(KvCoder.of(StringUtf8Coder.of(),
VarIntCoder.of())));
+ .withCoder(KvCoder.of(NullableCoder.of(StringUtf8Coder.of()),
VarIntCoder.of())));
PCollection<KV<String, Iterable<Integer>>> output =
input.apply(GroupByEncryptedKey.<String, Integer>create(gcpSecret));
@@ -162,6 +166,7 @@ public class GroupByEncryptedKeyTest implements
Serializable {
PAssert.that(output.apply("Sort", MapElements.via(new SortValues())))
.containsInAnyOrder(
KV.of("k1", Arrays.asList(3, 4)),
+ KV.of(null, Arrays.asList(3, 5)),
KV.of("k5", Arrays.asList(Integer.MIN_VALUE, Integer.MAX_VALUE)),
KV.of("k2", Arrays.asList(-33, 66)),
KV.of("k3", Arrays.asList(0)));