This is an automated email from the ASF dual-hosted git repository.

xqhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new bfa0c59ebcd The Bag Partition is now configurable. (#33805)
bfa0c59ebcd is described below

commit bfa0c59ebcd587dc19f218385b1f9f5aacbaa653
Author: Alex Merose <[email protected]>
AuthorDate: Sat Feb 1 12:00:10 2025 -0800

    The Bag Partition is now configurable. (#33805)
    
    * The Bag Partition is now configurable.
    
    Configuring the number of partitions in the Dask runner is very important 
to tune performance. This CL gives users control over this parameter.
    
    * Apply formatter.
    
    * Passing lint via the `run_pylint.sh` script.
    
    * Implementing review feedback.
    
    * Attempting to pass lint/fmt check.
    
    * Fixing isort issues by reading CI output.
    
    * More indentation.
    
    * rm blank like for isort.
---
 CHANGES.md                                         |  1 +
 .../python/apache_beam/runners/dask/dask_runner.py | 39 +++++++++++++++++++---
 .../apache_beam/runners/dask/dask_runner_test.py   | 19 +++++++++++
 .../runners/dask/transform_evaluator.py            | 30 +++++++++++++++--
 4 files changed, 82 insertions(+), 7 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 799d26dc05e..fde00b9da4c 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -86,6 +86,7 @@
   * Support the Process Environment for execution in Prism 
([#33651](https://github.com/apache/beam/pull/33651))
   * Support the AnyOf Environment for execution in Prism 
([#33705](https://github.com/apache/beam/pull/33705))
      * This improves support for developing Xlang pipelines, when using a 
compatible cross language service.
+* Partitions are now configurable for the DaskRunner in the Python SDK 
([#33805](https://github.com/apache/beam/pull/33805)).
 
 ## Breaking Changes
 
diff --git a/sdks/python/apache_beam/runners/dask/dask_runner.py 
b/sdks/python/apache_beam/runners/dask/dask_runner.py
index cc17d9919b8..8975fcf1e13 100644
--- a/sdks/python/apache_beam/runners/dask/dask_runner.py
+++ b/sdks/python/apache_beam/runners/dask/dask_runner.py
@@ -58,6 +58,18 @@ class DaskOptions(PipelineOptions):
       import dask
       return dask.config.no_default
 
+  @staticmethod
+  def _extract_bag_kwargs(dask_options: t.Dict) -> t.Dict:
+    """Parse keyword arguments for `dask.Bag`s; used in graph translation."""
+    out = {}
+
+    if npartitions := dask_options.pop('npartitions', None):
+      out['npartitions'] = npartitions
+    if partition_size := dask_options.pop('partition_size', None):
+      out['partition_size'] = partition_size
+
+    return out
+
   @classmethod
   def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
     parser.add_argument(
@@ -93,6 +105,21 @@ class DaskOptions(PipelineOptions):
         default=512,
         help='The number of open comms to maintain at once in the connection '
         'pool.')
+    partitions_parser = parser.add_mutually_exclusive_group()
+    partitions_parser.add_argument(
+        '--dask_npartitions',
+        dest='npartitions',
+        type=int,
+        default=None,
+        help='The desired number of `dask.Bag` partitions. When unspecified, '
+        'an educated guess is made.')
+    partitions_parser.add_argument(
+        '--dask_partition_size',
+        dest='partition_size',
+        type=int,
+        default=None,
+        help='The length of each `dask.Bag` partition. When unspecified, '
+        'an educated guess is made.')
 
 
 @dataclasses.dataclass
@@ -139,9 +166,12 @@ class DaskRunnerResult(PipelineResult):
 class DaskRunner(BundleBasedDirectRunner):
   """Executes a pipeline on a Dask distributed client."""
   @staticmethod
-  def to_dask_bag_visitor() -> PipelineVisitor:
+  def to_dask_bag_visitor(bag_kwargs=None) -> PipelineVisitor:
     from dask import bag as db
 
+    if bag_kwargs is None:
+      bag_kwargs = {}
+
     @dataclasses.dataclass
     class DaskBagVisitor(PipelineVisitor):
       bags: t.Dict[AppliedPTransform, db.Bag] = dataclasses.field(
@@ -149,7 +179,7 @@ class DaskRunner(BundleBasedDirectRunner):
 
       def visit_transform(self, transform_node: AppliedPTransform) -> None:
         op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
-        op = op_class(transform_node)
+        op = op_class(transform_node, bag_kwargs=bag_kwargs)
 
         op_kws = {"input_bag": None, "side_inputs": None}
         inputs = list(transform_node.inputs)
@@ -195,7 +225,7 @@ class DaskRunner(BundleBasedDirectRunner):
   def run_pipeline(self, pipeline, options):
     import dask
 
-    # TODO(alxr): Create interactive notebook support.
+    # TODO(alxmrs): Create interactive notebook support.
     if is_in_notebook():
       raise NotImplementedError('interactive support will come later!')
 
@@ -207,11 +237,12 @@ class DaskRunner(BundleBasedDirectRunner):
 
     dask_options = options.view_as(DaskOptions).get_all_options(
         drop_default=True)
+    bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
     client = ddist.Client(**dask_options)
 
     pipeline.replace_all(dask_overrides())
 
-    dask_visitor = self.to_dask_bag_visitor()
+    dask_visitor = self.to_dask_bag_visitor(bag_kwargs)
     pipeline.visit(dask_visitor)
     # The dictionary in this visitor keeps a mapping of every Beam
     # PTransform to the equivalent Bag operation. This is highly
diff --git a/sdks/python/apache_beam/runners/dask/dask_runner_test.py 
b/sdks/python/apache_beam/runners/dask/dask_runner_test.py
index 66dda4a984f..afe363ba3ee 100644
--- a/sdks/python/apache_beam/runners/dask/dask_runner_test.py
+++ b/sdks/python/apache_beam/runners/dask/dask_runner_test.py
@@ -66,6 +66,25 @@ class DaskOptionsTest(unittest.TestCase):
       with self.subTest(f'{opt_name} in dask.distributed.Client constructor'):
         self.assertIn(opt_name, client_args)
 
+  def test_parser_extract_bag_kwargs__deletes_dask_kwargs(self):
+    options = PipelineOptions('--dask_npartitions 8'.split())
+    dask_options = options.view_as(DaskOptions).get_all_options()
+
+    self.assertIn('npartitions', dask_options)
+    bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
+    self.assertNotIn('npartitions', dask_options)
+    self.assertEqual(bag_kwargs, {'npartitions': 8})
+
+  def test_parser_extract_bag_kwargs__unconfigured(self):
+    options = PipelineOptions()
+    dask_options = options.view_as(DaskOptions).get_all_options()
+
+    # It's present as a default option.
+    self.assertIn('npartitions', dask_options)
+    bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
+    self.assertNotIn('npartitions', dask_options)
+    self.assertEqual(bag_kwargs, {})
+
 
 class DaskRunnerRunPipelineTest(unittest.TestCase):
   """Test class used to introspect the dask runner via a debugger."""
diff --git a/sdks/python/apache_beam/runners/dask/transform_evaluator.py 
b/sdks/python/apache_beam/runners/dask/transform_evaluator.py
index e72ebcce8b1..7cad1fe4045 100644
--- a/sdks/python/apache_beam/runners/dask/transform_evaluator.py
+++ b/sdks/python/apache_beam/runners/dask/transform_evaluator.py
@@ -22,6 +22,7 @@ to Dask Bag functions.
 """
 import abc
 import dataclasses
+import logging
 import math
 import typing as t
 from dataclasses import field
@@ -52,6 +53,8 @@ OpSide = t.Optional[t.Sequence[SideInputMap]]
 # Value types for PCollections (possibly Windowed Values).
 PCollVal = t.Union[WindowedValue, t.Any]
 
+_LOGGER = logging.getLogger(__name__)
+
 
 def get_windowed_value(item: t.Any, window_fn: WindowFn) -> WindowedValue:
   """Wraps a value (item) inside a Window."""
@@ -127,8 +130,11 @@ class DaskBagOp(abc.ABC):
   Attributes
     applied: The underlying `AppliedPTransform` which holds the code for the
       target operation.
+    bag_kwargs: (optional) Keyword arguments applied to input bags, usually
+      from the pipeline's `DaskOptions`.
   """
   applied: AppliedPTransform
+  bag_kwargs: t.Dict = dataclasses.field(default_factory=dict)
 
   @property
   def transform(self):
@@ -151,10 +157,28 @@ class Create(DaskBagOp):
     assert input_bag is None, 'Create expects no input!'
     original_transform = t.cast(_Create, self.transform)
     items = original_transform.values
+
+    npartitions = self.bag_kwargs.get('npartitions')
+    partition_size = self.bag_kwargs.get('partition_size')
+    if npartitions and partition_size:
+      raise ValueError(
+          f'Please specify either `dask_npartitions` or '
+          f'`dask_parition_size` but not both: '
+          f'{npartitions=}, {partition_size=}.')
+    if not npartitions and not partition_size:
+      # partition_size is inversely related to `npartitions`.
+      # Ideal "chunk sizes" in dask are around 10-100 MBs.
+      # Let's hope ~128 items per partition is around this
+      # memory overhead.
+      default_size = 128
+      partition_size = max(default_size, math.ceil(math.sqrt(len(items)) / 10))
+      if partition_size == default_size:
+        _LOGGER.warning(
+            'The new default partition size is %d, it used to be 1 '
+            'in previous DaskRunner versions.' % default_size)
+
     return db.from_sequence(
-        items,
-        partition_size=max(
-            1, math.ceil(math.sqrt(len(items)) / math.sqrt(100))))
+        items, npartitions=npartitions, partition_size=partition_size)
 
 
 def apply_dofn_to_bundle(

Reply via email to