gemini-code-assist[bot] commented on code in PR #38783:
URL: https://github.com/apache/beam/pull/38783#discussion_r3348625276


##########
sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java:
##########
@@ -457,6 +469,37 @@ private void addErrorCollections(
     }
   }
 
+  private class AddShardKeyFn<DestT, ElemT>
+      extends DoFn<
+          KV<DestT, StorageApiWritePayload>, KV<Integer, KV<DestT, 
StorageApiWritePayload>>> {
+
+    private final StorageApiDynamicDestinations<ElemT, DestT> 
dynamicDestinations;
+    private final int numShards;
+
+    public AddShardKeyFn(
+        StorageApiDynamicDestinations<ElemT, DestT> dynamicDestinations, int 
numShards) {
+      this.dynamicDestinations = dynamicDestinations;
+      this.numShards = numShards;
+    }
+
+    @ProcessElement
+    public void processElement(
+        ProcessContext c,
+        @Element KV<DestT, StorageApiWritePayload> element,
+        OutputReceiver<KV<Integer, KV<DestT, StorageApiWritePayload>>> 
outputReceiver) {
+      dynamicDestinations.setSideInputAccessorFromProcessContext(c);
+
+      String tableUrn = 
dynamicDestinations.getTable(element.getKey()).getShortTableUrn();
+
+      int hash = Hashing.murmur3_32_fixed().hashString(tableUrn, 
StandardCharsets.UTF_8).asInt();
+
+      int shardKey =
+          Math.floorMod(hash ^ ThreadLocalRandom.current().nextInt(numShards), 
numShards);
+
+      outputReceiver.output(KV.of(shardKey, element));
+    }
+  }

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   There are two important improvements for the `AddShardKeyFn` class:
   
   1. **Make the class `static`**: Currently, `AddShardKeyFn` is a non-static 
inner class. In Apache Beam, `DoFn`s should be declared as `static` classes to 
avoid implicitly capturing the outer class instance (`StorageApiLoads`), which 
can lead to `NotSerializableException` or unnecessary serialization overhead.
   2. **Fix the sharding logic**: Using 
`ThreadLocalRandom.current().nextInt(numShards)` per element completely 
randomizes the shard key for every single element, regardless of the `tableUrn` 
hash. Mathematically, `(hash ^ random) % numShards` results in a uniform random 
distribution over `[0, numShards - 1]` for each element. This means elements 
for the same destination are completely scattered across all shards, defeating 
the goal of keeping elements for the same destination close to each other to 
reduce connections.
   
   To achieve both goals—allowing parallelism for a single hot destination 
while grouping elements of the same destination together—you can use a random 
offset that is constant **per bundle** (initialized in `@StartBundle`) instead 
of per element. This ensures that all elements for a given destination within 
the same bundle are routed to the same shard, while different bundles/workers 
can still route to different shards.
   
   ```java
     private static class AddShardKeyFn<DestT, ElemT>
         extends DoFn<
             KV<DestT, StorageApiWritePayload>, KV<Integer, KV<DestT, 
StorageApiWritePayload>>> {
   
       private final StorageApiDynamicDestinations<ElemT, DestT> 
dynamicDestinations;
       private final int numShards;
       private transient int shardOffset;
   
       public AddShardKeyFn(
           StorageApiDynamicDestinations<ElemT, DestT> dynamicDestinations, int 
numShards) {
         this.dynamicDestinations = dynamicDestinations;
         this.numShards = numShards;
       }
   
       @DoFn.StartBundle
       public void startBundle() {
         shardOffset = ThreadLocalRandom.current().nextInt(numShards);
       }
   
       @ProcessElement
       public void processElement(
           ProcessContext c,
           @Element KV<DestT, StorageApiWritePayload> element,
           OutputReceiver<KV<Integer, KV<DestT, StorageApiWritePayload>>> 
outputReceiver) {
         dynamicDestinations.setSideInputAccessorFromProcessContext(c);
   
         String tableUrn = 
dynamicDestinations.getTable(element.getKey()).getShortTableUrn();
   
         int hash = Hashing.murmur3_32_fixed().hashString(tableUrn, 
StandardCharsets.UTF_8).asInt();
   
         int shardKey = Math.floorMod(hash + shardOffset, numShards);
   
         outputReceiver.output(KV.of(shardKey, element));
       }
     }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to