nehsyc commented on a change in pull request #13292:
URL: https://github.com/apache/beam/pull/13292#discussion_r526334463



##########
File path: sdks/python/apache_beam/transforms/util.py
##########
@@ -780,6 +783,48 @@ def expand(self, pcoll):
             self.max_buffering_duration_secs,
             self.clock))
 
+  @typehints.with_input_types(Tuple[K, V])
+  @typehints.with_output_types(Tuple[K, Iterable[V]])
+  class WithShardedKey(PTransform):
+    """A GroupIntoBatches transform that outputs batched elements associated
+    with sharded input keys.
+
+    The sharding is determined by the runner to balance the load during the

Review comment:
       Yes that is totally correct. I updated the documentation. How does it 
sound now?

##########
File path: sdks/python/apache_beam/transforms/util.py
##########
@@ -780,6 +783,48 @@ def expand(self, pcoll):
             self.max_buffering_duration_secs,
             self.clock))
 
+  @typehints.with_input_types(Tuple[K, V])
+  @typehints.with_output_types(Tuple[K, Iterable[V]])
+  class WithShardedKey(PTransform):
+    """A GroupIntoBatches transform that outputs batched elements associated
+    with sharded input keys.
+
+    The sharding is determined by the runner to balance the load during the
+    execution time. By default, it spreads the input elements with the same key
+    to all available threads executing the transform.
+    """
+    def __init__(self, batch_size, max_buffering_duration_secs=None):
+      """Create a new GroupIntoBatches.WithShardedKey.
+
+      Arguments:
+        batch_size: (required) How many elements should be in a batch
+        max_buffering_duration_secs: (optional) How long in seconds at most an
+          incomplete batch of elements is allowed to be buffered in the states.
+          The duration must be a positive second duration and should be given 
as
+          an int or float.
+      """
+      self.batch_size = batch_size
+
+      if max_buffering_duration_secs is not None:
+        assert max_buffering_duration_secs > 0, (
+            'max buffering duration should be a positive value')
+      self.max_buffering_duration_secs = max_buffering_duration_secs
+
+    _pid = os.getpid()

Review comment:
       The parallelism of the`GroupIntoBatchesDoFn` is tied to the number of 
keys (due to the per-key state semantics and the implementation of keyed state 
management). So choosing one shard per key (i.e., without key sharding) 
effectively means that we can not have more parallelism than the number of 
input keys. We are trying to by default spread the input elements to all 
available threads across workers, which is definitely not ideal but slightly 
better than no sharding.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to