This is an automated email from the ASF dual-hosted git repository.
xqhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new bfa0c59ebcd The Bag Partition is now configurable. (#33805)
bfa0c59ebcd is described below
commit bfa0c59ebcd587dc19f218385b1f9f5aacbaa653
Author: Alex Merose <[email protected]>
AuthorDate: Sat Feb 1 12:00:10 2025 -0800
The Bag Partition is now configurable. (#33805)
* The Bag Partition is now configurable.
Configuring the number of partitions in the Dask runner is very important
to tune performance. This CL gives users control over this parameter.
* Apply formatter.
* Passing lint via the `run_pylint.sh` script.
* Implementing review feedback.
* Attempting to pass lint/fmt check.
* Fixing isort issues by reading CI output.
* More indentation.
* rm blank like for isort.
---
CHANGES.md | 1 +
.../python/apache_beam/runners/dask/dask_runner.py | 39 +++++++++++++++++++---
.../apache_beam/runners/dask/dask_runner_test.py | 19 +++++++++++
.../runners/dask/transform_evaluator.py | 30 +++++++++++++++--
4 files changed, 82 insertions(+), 7 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index 799d26dc05e..fde00b9da4c 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -86,6 +86,7 @@
* Support the Process Environment for execution in Prism
([#33651](https://github.com/apache/beam/pull/33651))
* Support the AnyOf Environment for execution in Prism
([#33705](https://github.com/apache/beam/pull/33705))
* This improves support for developing Xlang pipelines, when using a
compatible cross language service.
+* Partitions are now configurable for the DaskRunner in the Python SDK
([#33805](https://github.com/apache/beam/pull/33805)).
## Breaking Changes
diff --git a/sdks/python/apache_beam/runners/dask/dask_runner.py
b/sdks/python/apache_beam/runners/dask/dask_runner.py
index cc17d9919b8..8975fcf1e13 100644
--- a/sdks/python/apache_beam/runners/dask/dask_runner.py
+++ b/sdks/python/apache_beam/runners/dask/dask_runner.py
@@ -58,6 +58,18 @@ class DaskOptions(PipelineOptions):
import dask
return dask.config.no_default
+ @staticmethod
+ def _extract_bag_kwargs(dask_options: t.Dict) -> t.Dict:
+ """Parse keyword arguments for `dask.Bag`s; used in graph translation."""
+ out = {}
+
+ if npartitions := dask_options.pop('npartitions', None):
+ out['npartitions'] = npartitions
+ if partition_size := dask_options.pop('partition_size', None):
+ out['partition_size'] = partition_size
+
+ return out
+
@classmethod
def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
@@ -93,6 +105,21 @@ class DaskOptions(PipelineOptions):
default=512,
help='The number of open comms to maintain at once in the connection '
'pool.')
+ partitions_parser = parser.add_mutually_exclusive_group()
+ partitions_parser.add_argument(
+ '--dask_npartitions',
+ dest='npartitions',
+ type=int,
+ default=None,
+ help='The desired number of `dask.Bag` partitions. When unspecified, '
+ 'an educated guess is made.')
+ partitions_parser.add_argument(
+ '--dask_partition_size',
+ dest='partition_size',
+ type=int,
+ default=None,
+ help='The length of each `dask.Bag` partition. When unspecified, '
+ 'an educated guess is made.')
@dataclasses.dataclass
@@ -139,9 +166,12 @@ class DaskRunnerResult(PipelineResult):
class DaskRunner(BundleBasedDirectRunner):
"""Executes a pipeline on a Dask distributed client."""
@staticmethod
- def to_dask_bag_visitor() -> PipelineVisitor:
+ def to_dask_bag_visitor(bag_kwargs=None) -> PipelineVisitor:
from dask import bag as db
+ if bag_kwargs is None:
+ bag_kwargs = {}
+
@dataclasses.dataclass
class DaskBagVisitor(PipelineVisitor):
bags: t.Dict[AppliedPTransform, db.Bag] = dataclasses.field(
@@ -149,7 +179,7 @@ class DaskRunner(BundleBasedDirectRunner):
def visit_transform(self, transform_node: AppliedPTransform) -> None:
op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
- op = op_class(transform_node)
+ op = op_class(transform_node, bag_kwargs=bag_kwargs)
op_kws = {"input_bag": None, "side_inputs": None}
inputs = list(transform_node.inputs)
@@ -195,7 +225,7 @@ class DaskRunner(BundleBasedDirectRunner):
def run_pipeline(self, pipeline, options):
import dask
- # TODO(alxr): Create interactive notebook support.
+ # TODO(alxmrs): Create interactive notebook support.
if is_in_notebook():
raise NotImplementedError('interactive support will come later!')
@@ -207,11 +237,12 @@ class DaskRunner(BundleBasedDirectRunner):
dask_options = options.view_as(DaskOptions).get_all_options(
drop_default=True)
+ bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
client = ddist.Client(**dask_options)
pipeline.replace_all(dask_overrides())
- dask_visitor = self.to_dask_bag_visitor()
+ dask_visitor = self.to_dask_bag_visitor(bag_kwargs)
pipeline.visit(dask_visitor)
# The dictionary in this visitor keeps a mapping of every Beam
# PTransform to the equivalent Bag operation. This is highly
diff --git a/sdks/python/apache_beam/runners/dask/dask_runner_test.py
b/sdks/python/apache_beam/runners/dask/dask_runner_test.py
index 66dda4a984f..afe363ba3ee 100644
--- a/sdks/python/apache_beam/runners/dask/dask_runner_test.py
+++ b/sdks/python/apache_beam/runners/dask/dask_runner_test.py
@@ -66,6 +66,25 @@ class DaskOptionsTest(unittest.TestCase):
with self.subTest(f'{opt_name} in dask.distributed.Client constructor'):
self.assertIn(opt_name, client_args)
+ def test_parser_extract_bag_kwargs__deletes_dask_kwargs(self):
+ options = PipelineOptions('--dask_npartitions 8'.split())
+ dask_options = options.view_as(DaskOptions).get_all_options()
+
+ self.assertIn('npartitions', dask_options)
+ bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
+ self.assertNotIn('npartitions', dask_options)
+ self.assertEqual(bag_kwargs, {'npartitions': 8})
+
+ def test_parser_extract_bag_kwargs__unconfigured(self):
+ options = PipelineOptions()
+ dask_options = options.view_as(DaskOptions).get_all_options()
+
+ # It's present as a default option.
+ self.assertIn('npartitions', dask_options)
+ bag_kwargs = DaskOptions._extract_bag_kwargs(dask_options)
+ self.assertNotIn('npartitions', dask_options)
+ self.assertEqual(bag_kwargs, {})
+
class DaskRunnerRunPipelineTest(unittest.TestCase):
"""Test class used to introspect the dask runner via a debugger."""
diff --git a/sdks/python/apache_beam/runners/dask/transform_evaluator.py
b/sdks/python/apache_beam/runners/dask/transform_evaluator.py
index e72ebcce8b1..7cad1fe4045 100644
--- a/sdks/python/apache_beam/runners/dask/transform_evaluator.py
+++ b/sdks/python/apache_beam/runners/dask/transform_evaluator.py
@@ -22,6 +22,7 @@ to Dask Bag functions.
"""
import abc
import dataclasses
+import logging
import math
import typing as t
from dataclasses import field
@@ -52,6 +53,8 @@ OpSide = t.Optional[t.Sequence[SideInputMap]]
# Value types for PCollections (possibly Windowed Values).
PCollVal = t.Union[WindowedValue, t.Any]
+_LOGGER = logging.getLogger(__name__)
+
def get_windowed_value(item: t.Any, window_fn: WindowFn) -> WindowedValue:
"""Wraps a value (item) inside a Window."""
@@ -127,8 +130,11 @@ class DaskBagOp(abc.ABC):
Attributes
applied: The underlying `AppliedPTransform` which holds the code for the
target operation.
+ bag_kwargs: (optional) Keyword arguments applied to input bags, usually
+ from the pipeline's `DaskOptions`.
"""
applied: AppliedPTransform
+ bag_kwargs: t.Dict = dataclasses.field(default_factory=dict)
@property
def transform(self):
@@ -151,10 +157,28 @@ class Create(DaskBagOp):
assert input_bag is None, 'Create expects no input!'
original_transform = t.cast(_Create, self.transform)
items = original_transform.values
+
+ npartitions = self.bag_kwargs.get('npartitions')
+ partition_size = self.bag_kwargs.get('partition_size')
+ if npartitions and partition_size:
+ raise ValueError(
+ f'Please specify either `dask_npartitions` or '
+ f'`dask_parition_size` but not both: '
+ f'{npartitions=}, {partition_size=}.')
+ if not npartitions and not partition_size:
+ # partition_size is inversely related to `npartitions`.
+ # Ideal "chunk sizes" in dask are around 10-100 MBs.
+ # Let's hope ~128 items per partition is around this
+ # memory overhead.
+ default_size = 128
+ partition_size = max(default_size, math.ceil(math.sqrt(len(items)) / 10))
+ if partition_size == default_size:
+ _LOGGER.warning(
+ 'The new default partition size is %d, it used to be 1 '
+ 'in previous DaskRunner versions.' % default_size)
+
return db.from_sequence(
- items,
- partition_size=max(
- 1, math.ceil(math.sqrt(len(items)) / math.sqrt(100))))
+ items, npartitions=npartitions, partition_size=partition_size)
def apply_dofn_to_bundle(