This is an automated email from the ASF dual-hosted git repository.
yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new e939be38587 Windowing Support for the Dask Runner (#32941)
e939be38587 is described below
commit e939be3858730f6d64411c3871d99f17729459d3
Author: Alex Merose <[email protected]>
AuthorDate: Mon Nov 18 16:39:30 2024 -0300
Windowing Support for the Dask Runner (#32941)
Windowing Support for the Dask Runner
---------
Co-authored-by: Pablo E <[email protected]>
Co-authored-by: Pablo <[email protected]>
Co-authored-by: Charles Stern
<[email protected]>
---
.github/workflows/dask_runner_tests.yml | 2 +-
.../python/apache_beam/runners/dask/dask_runner.py | 59 +++-
.../apache_beam/runners/dask/dask_runner_test.py | 311 ++++++++++++++++++++-
sdks/python/apache_beam/runners/dask/overrides.py | 16 +-
.../runners/dask/transform_evaluator.py | 180 ++++++++++--
sdks/python/scripts/generate_pydoc.sh | 3 +-
sdks/python/setup.py | 11 +-
sdks/python/test-suites/tox/common.gradle | 1 -
sdks/python/tox.ini | 17 +-
9 files changed, 558 insertions(+), 42 deletions(-)
diff --git a/.github/workflows/dask_runner_tests.yml
b/.github/workflows/dask_runner_tests.yml
index f87c70d8b72..0f60c22b6aa 100644
--- a/.github/workflows/dask_runner_tests.yml
+++ b/.github/workflows/dask_runner_tests.yml
@@ -78,7 +78,7 @@ jobs:
run: pip install tox
- name: Install SDK with dask
working-directory: ./sdks/python
- run: pip install setuptools --upgrade && pip install -e
.[gcp,dask,test]
+ run: pip install setuptools --upgrade && pip install -e
.[dask,test,dataframes]
- name: Run tests basic unix
if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos')
working-directory: ./sdks/python
diff --git a/sdks/python/apache_beam/runners/dask/dask_runner.py
b/sdks/python/apache_beam/runners/dask/dask_runner.py
index 109c4379b45..0f2317074ce 100644
--- a/sdks/python/apache_beam/runners/dask/dask_runner.py
+++ b/sdks/python/apache_beam/runners/dask/dask_runner.py
@@ -31,12 +31,22 @@ from apache_beam.pipeline import AppliedPTransform
from apache_beam.pipeline import PipelineVisitor
from apache_beam.runners.dask.overrides import dask_overrides
from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS
+from apache_beam.runners.dask.transform_evaluator import
DaskBagWindowedIterator
+from apache_beam.runners.dask.transform_evaluator import Flatten
from apache_beam.runners.dask.transform_evaluator import NoOp
from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner
from apache_beam.runners.runner import PipelineResult
from apache_beam.runners.runner import PipelineState
+from apache_beam.transforms.sideinputs import SideInputMap
from apache_beam.utils.interactive_utils import is_in_notebook
+try:
+ # Added to try to prevent threading related issues, see
+ # https://github.com/pytest-dev/pytest/issues/3216#issuecomment-1502451456
+ import dask.distributed as ddist
+except ImportError:
+ ddist = {}
+
class DaskOptions(PipelineOptions):
@staticmethod
@@ -86,10 +96,9 @@ class DaskOptions(PipelineOptions):
@dataclasses.dataclass
class DaskRunnerResult(PipelineResult):
- from dask import distributed
- client: distributed.Client
- futures: t.Sequence[distributed.Future]
+ client: ddist.Client
+ futures: t.Sequence[ddist.Future]
def __post_init__(self):
super().__init__(PipelineState.RUNNING)
@@ -99,8 +108,16 @@ class DaskRunnerResult(PipelineResult):
if duration is not None:
# Convert milliseconds to seconds
duration /= 1000
- self.client.wait_for_workers(timeout=duration)
- self.client.gather(self.futures, errors='raise')
+ for _ in ddist.as_completed(self.futures,
+ timeout=duration,
+ with_results=True):
+ # without gathering results, worker errors are not raised on the
client:
+ #
https://distributed.dask.org/en/stable/resilience.html#user-code-failures
+ # so we want to gather results to raise errors client-side, but we do
+ # not actually need to use the results here, so we just pass. to
gather,
+ # we use the iterative `as_completed(..., with_results=True)`, instead
+ # of aggregate `client.gather`, to minimize memory footprint of
results.
+ pass
self._state = PipelineState.DONE
except: # pylint: disable=broad-except
self._state = PipelineState.FAILED
@@ -133,6 +150,7 @@ class DaskRunner(BundleBasedDirectRunner):
op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
op = op_class(transform_node)
+ op_kws = {"input_bag": None, "side_inputs": None}
inputs = list(transform_node.inputs)
if inputs:
bag_inputs = []
@@ -144,13 +162,28 @@ class DaskRunner(BundleBasedDirectRunner):
if prev_op in self.bags:
bag_inputs.append(self.bags[prev_op])
- if len(bag_inputs) == 1:
- self.bags[transform_node] = op.apply(bag_inputs[0])
+ # Input to `Flatten` could be of length 1, e.g. a single-element
+ # tuple: `(pcoll, ) | beam.Flatten()`. If so, we still pass it as
+ # an iterable, because `Flatten.apply` always takes an iterable.
+ if len(bag_inputs) == 1 and not isinstance(op, Flatten):
+ op_kws["input_bag"] = bag_inputs[0]
else:
- self.bags[transform_node] = op.apply(bag_inputs)
+ op_kws["input_bag"] = bag_inputs
+
+ side_inputs = list(transform_node.side_inputs)
+ if side_inputs:
+ bag_side_inputs = []
+ for si in side_inputs:
+ si_asbag = self.bags.get(si.pvalue.producer)
+ bag_side_inputs.append(
+ SideInputMap(
+ type(si),
+ si._view_options(),
+ DaskBagWindowedIterator(si_asbag, si._window_mapping_fn)))
+
+ op_kws["side_inputs"] = bag_side_inputs
- else:
- self.bags[transform_node] = op.apply(None)
+ self.bags[transform_node] = op.apply(**op_kws)
return DaskBagVisitor()
@@ -159,6 +192,8 @@ class DaskRunner(BundleBasedDirectRunner):
return False
def run_pipeline(self, pipeline, options):
+ import dask
+
# TODO(alxr): Create interactive notebook support.
if is_in_notebook():
raise NotImplementedError('interactive support will come later!')
@@ -177,6 +212,6 @@ class DaskRunner(BundleBasedDirectRunner):
dask_visitor = self.to_dask_bag_visitor()
pipeline.visit(dask_visitor)
-
- futures = client.compute(list(dask_visitor.bags.values()))
+ opt_graph = dask.optimize(*list(dask_visitor.bags.values()))
+ futures = client.compute(opt_graph)
return DaskRunnerResult(client, futures)
diff --git a/sdks/python/apache_beam/runners/dask/dask_runner_test.py
b/sdks/python/apache_beam/runners/dask/dask_runner_test.py
index d8b3e17d8a5..66dda4a984f 100644
--- a/sdks/python/apache_beam/runners/dask/dask_runner_test.py
+++ b/sdks/python/apache_beam/runners/dask/dask_runner_test.py
@@ -14,7 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import datetime
import inspect
+import typing as t
import unittest
import apache_beam as beam
@@ -22,12 +24,14 @@ from apache_beam.options.pipeline_options import
PipelineOptions
from apache_beam.testing import test_pipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
+from apache_beam.transforms import window
try:
- from apache_beam.runners.dask.dask_runner import DaskOptions
- from apache_beam.runners.dask.dask_runner import DaskRunner
import dask
import dask.distributed as ddist
+
+ from apache_beam.runners.dask.dask_runner import DaskOptions # pylint:
disable=ungrouped-imports
+ from apache_beam.runners.dask.dask_runner import DaskRunner # pylint:
disable=ungrouped-imports
except (ImportError, ModuleNotFoundError):
raise unittest.SkipTest('Dask must be installed to run tests.')
@@ -73,6 +77,11 @@ class DaskRunnerRunPipelineTest(unittest.TestCase):
pcoll = p | beam.Create([1])
assert_that(pcoll, equal_to([1]))
+ def test_create_multiple(self):
+ with self.pipeline as p:
+ pcoll = p | beam.Create([1, 2, 3, 4])
+ assert_that(pcoll, equal_to([1, 2, 3, 4]))
+
def test_create_and_map(self):
def double(x):
return x * 2
@@ -81,6 +90,22 @@ class DaskRunnerRunPipelineTest(unittest.TestCase):
pcoll = p | beam.Create([1]) | beam.Map(double)
assert_that(pcoll, equal_to([2]))
+ def test_create_and_map_multiple(self):
+ def double(x):
+ return x * 2
+
+ with self.pipeline as p:
+ pcoll = p | beam.Create([1, 2]) | beam.Map(double)
+ assert_that(pcoll, equal_to([2, 4]))
+
+ def test_create_and_map_many(self):
+ def double(x):
+ return x * 2
+
+ with self.pipeline as p:
+ pcoll = p | beam.Create(list(range(1, 11))) | beam.Map(double)
+ assert_that(pcoll, equal_to(list(range(2, 21, 2))))
+
def test_create_map_and_groupby(self):
def double(x):
return x * 2, x
@@ -89,6 +114,288 @@ class DaskRunnerRunPipelineTest(unittest.TestCase):
pcoll = p | beam.Create([1]) | beam.Map(double) | beam.GroupByKey()
assert_that(pcoll, equal_to([(2, [1])]))
+ def test_create_map_and_groupby_multiple(self):
+ def double(x):
+ return x * 2, x
+
+ with self.pipeline as p:
+ pcoll = (
+ p
+ | beam.Create([1, 2, 1, 2, 3])
+ | beam.Map(double)
+ | beam.GroupByKey())
+ assert_that(pcoll, equal_to([(2, [1, 1]), (4, [2, 2]), (6, [3])]))
+
+ def test_map_with_positional_side_input(self):
+ def mult_by(x, y):
+ return x * y
+
+ with self.pipeline as p:
+ side = p | "side" >> beam.Create([3])
+ pcoll = (
+ p
+ | "main" >> beam.Create([1])
+ | beam.Map(mult_by, beam.pvalue.AsSingleton(side)))
+ assert_that(pcoll, equal_to([3]))
+
+ def test_map_with_keyword_side_input(self):
+ def mult_by(x, y):
+ return x * y
+
+ with self.pipeline as p:
+ side = p | "side" >> beam.Create([3])
+ pcoll = (
+ p
+ | "main" >> beam.Create([1])
+ | beam.Map(mult_by, y=beam.pvalue.AsSingleton(side)))
+ assert_that(pcoll, equal_to([3]))
+
+ def test_pardo_side_inputs(self):
+ def cross_product(elem, sides):
+ for side in sides:
+ yield elem, side
+
+ with self.pipeline as p:
+ main = p | "main" >> beam.Create(["a", "b", "c"])
+ side = p | "side" >> beam.Create(["x", "y"])
+ assert_that(
+ main | beam.FlatMap(cross_product, beam.pvalue.AsList(side)),
+ equal_to([
+ ("a", "x"),
+ ("b", "x"),
+ ("c", "x"),
+ ("a", "y"),
+ ("b", "y"),
+ ("c", "y"),
+ ]),
+ )
+
+ def test_pardo_side_input_dependencies(self):
+ with self.pipeline as p:
+ inputs = [p | beam.Create([None])]
+ for k in range(1, 10):
+ inputs.append(
+ inputs[0]
+ | beam.ParDo(
+ ExpectingSideInputsFn(f"Do{k}"),
+ *[beam.pvalue.AsList(inputs[s]) for s in range(1, k)],
+ ))
+
+ def test_pardo_side_input_sparse_dependencies(self):
+ with self.pipeline as p:
+ inputs = []
+
+ def choose_input(s):
+ return inputs[(389 + s * 5077) % len(inputs)]
+
+ for k in range(20):
+ num_inputs = int((k * k % 16)**0.5)
+ if num_inputs == 0:
+ inputs.append(p | f"Create{k}" >> beam.Create([f"Create{k}"]))
+ else:
+ inputs.append(
+ choose_input(0)
+ | beam.ParDo(
+ ExpectingSideInputsFn(f"Do{k}"),
+ *[
+ beam.pvalue.AsList(choose_input(s))
+ for s in range(1, num_inputs)
+ ],
+ ))
+
+ @unittest.expectedFailure
+ def test_pardo_windowed_side_inputs(self):
+ with self.pipeline as p:
+ # Now with some windowing.
+ pcoll = (
+ p
+ | beam.Create(list(range(10)))
+ | beam.Map(lambda t: window.TimestampedValue(t, t)))
+ # Intentionally choosing non-aligned windows to highlight the transition.
+ main = pcoll | "WindowMain" >> beam.WindowInto(window.FixedWindows(5))
+ side = pcoll | "WindowSide" >> beam.WindowInto(window.FixedWindows(7))
+ res = main | beam.Map(
+ lambda x, s: (x, sorted(s)), beam.pvalue.AsList(side))
+ assert_that(
+ res,
+ equal_to([
+ # The window [0, 5) maps to the window [0, 7).
+ (0, list(range(7))),
+ (1, list(range(7))),
+ (2, list(range(7))),
+ (3, list(range(7))),
+ (4, list(range(7))),
+ # The window [5, 10) maps to the window [7, 14).
+ (5, list(range(7, 10))),
+ (6, list(range(7, 10))),
+ (7, list(range(7, 10))),
+ (8, list(range(7, 10))),
+ (9, list(range(7, 10))),
+ ]),
+ label="windowed",
+ )
+
+ def test_flattened_side_input(self, with_transcoding=True):
+ with self.pipeline as p:
+ main = p | "main" >> beam.Create([None])
+ side1 = p | "side1" >> beam.Create([("a", 1)])
+ side2 = p | "side2" >> beam.Create([("b", 2)])
+ if with_transcoding:
+ # Also test non-matching coder types (transcoding required)
+ third_element = [("another_type")]
+ else:
+ third_element = [("b", 3)]
+ side3 = p | "side3" >> beam.Create(third_element)
+ side = (side1, side2) | beam.Flatten()
+ assert_that(
+ main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)),
+ equal_to([(None, {
+ "a": 1, "b": 2
+ })]),
+ label="CheckFlattenAsSideInput",
+ )
+ assert_that(
+ (side, side3) | "FlattenAfter" >> beam.Flatten(),
+ equal_to([("a", 1), ("b", 2)] + third_element),
+ label="CheckFlattenOfSideInput",
+ )
+
+ def test_gbk_side_input(self):
+ with self.pipeline as p:
+ main = p | "main" >> beam.Create([None])
+ side = p | "side" >> beam.Create([("a", 1)]) | beam.GroupByKey()
+ assert_that(
+ main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)),
+ equal_to([(None, {
+ "a": [1]
+ })]),
+ )
+
+ def test_multimap_side_input(self):
+ with self.pipeline as p:
+ main = p | "main" >> beam.Create(["a", "b"])
+ side = p | "side" >> beam.Create([("a", 1), ("b", 2), ("a", 3)])
+ assert_that(
+ main
+ | beam.Map(
+ lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)),
+ equal_to([("a", [1, 3]), ("b", [2])]),
+ )
+
+ def test_multimap_multiside_input(self):
+ # A test where two transforms in the same stage consume the same
PCollection
+ # twice as side input.
+ with self.pipeline as p:
+ main = p | "main" >> beam.Create(["a", "b"])
+ side = p | "side" >> beam.Create([("a", 1), ("b", 2), ("a", 3)])
+ assert_that(
+ main
+ | "first map" >> beam.Map(
+ lambda k,
+ d,
+ l: (k, sorted(d[k]), sorted([e[1] for e in l])),
+ beam.pvalue.AsMultiMap(side),
+ beam.pvalue.AsList(side),
+ )
+ | "second map" >> beam.Map(
+ lambda k,
+ d,
+ l: (k[0], sorted(d[k[0]]), sorted([e[1] for e in l])),
+ beam.pvalue.AsMultiMap(side),
+ beam.pvalue.AsList(side),
+ ),
+ equal_to([("a", [1, 3], [1, 2, 3]), ("b", [2], [1, 2, 3])]),
+ )
+
+ def test_multimap_side_input_type_coercion(self):
+ with self.pipeline as p:
+ main = p | "main" >> beam.Create(["a", "b"])
+ # The type of this side-input is forced to Any (overriding type
+ # inference). Without type coercion to Tuple[Any, Any], the usage of this
+ # side-input in AsMultiMap() below should fail.
+ side = p | "side" >> beam.Create([("a", 1), ("b", 2),
+ ("a", 3)]).with_output_types(t.Any)
+ assert_that(
+ main
+ | beam.Map(
+ lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)),
+ equal_to([("a", [1, 3]), ("b", [2])]),
+ )
+
+ def test_pardo_unfusable_side_inputs__one(self):
+ def cross_product(elem, sides):
+ for side in sides:
+ yield elem, side
+
+ with self.pipeline as p:
+ pcoll = p | "Create1" >> beam.Create(["a", "b"])
+ assert_that(
+ pcoll |
+ "FlatMap1" >> beam.FlatMap(cross_product, beam.pvalue.AsList(pcoll)),
+ equal_to([("a", "a"), ("a", "b"), ("b", "a"), ("b", "b")]),
+ label="assert_that1",
+ )
+
+ def test_pardo_unfusable_side_inputs__two(self):
+ def cross_product(elem, sides):
+ for side in sides:
+ yield elem, side
+
+ with self.pipeline as p:
+ pcoll = p | "Create2" >> beam.Create(["a", "b"])
+
+ derived = ((pcoll, )
+ | beam.Flatten()
+ | beam.Map(lambda x: (x, x))
+ | beam.GroupByKey()
+ | "Unkey" >> beam.Map(lambda kv: kv[0]))
+ assert_that(
+ pcoll | "FlatMap2" >> beam.FlatMap(
+ cross_product, beam.pvalue.AsList(derived)),
+ equal_to([("a", "a"), ("a", "b"), ("b", "a"), ("b", "b")]),
+ label="assert_that2",
+ )
+
+ def test_groupby_with_fixed_windows(self):
+ def double(x):
+ return x * 2, x
+
+ def add_timestamp(pair):
+ delta = datetime.timedelta(seconds=pair[1] * 60)
+ now = (datetime.datetime.now() + delta).timestamp()
+ return window.TimestampedValue(pair, now)
+
+ with self.pipeline as p:
+ pcoll = (
+ p
+ | beam.Create([1, 2, 1, 2, 3])
+ | beam.Map(double)
+ | beam.WindowInto(window.FixedWindows(60))
+ | beam.Map(add_timestamp)
+ | beam.GroupByKey())
+ assert_that(pcoll, equal_to([(2, [1, 1]), (4, [2, 2]), (6, [3])]))
+
+ def test_groupby_string_keys(self):
+ with self.pipeline as p:
+ pcoll = (
+ p
+ | beam.Create([('a', 1), ('a', 2), ('b', 3), ('b', 4)])
+ | beam.GroupByKey())
+ assert_that(pcoll, equal_to([('a', [1, 2]), ('b', [3, 4])]))
+
+
+class ExpectingSideInputsFn(beam.DoFn):
+ def __init__(self, name):
+ self._name = name
+
+ def default_label(self):
+ return self._name
+
+ def process(self, element, *side_inputs):
+ if not all(list(s) for s in side_inputs):
+ raise ValueError(f"Missing data in side input {side_inputs}")
+ yield self._name
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/runners/dask/overrides.py
b/sdks/python/apache_beam/runners/dask/overrides.py
index d07c7cd518a..b952834f12d 100644
--- a/sdks/python/apache_beam/runners/dask/overrides.py
+++ b/sdks/python/apache_beam/runners/dask/overrides.py
@@ -73,7 +73,6 @@ class _GroupByKeyOnly(beam.PTransform):
@typehints.with_input_types(t.Tuple[K, t.Iterable[V]])
@typehints.with_output_types(t.Tuple[K, t.Iterable[V]])
class _GroupAlsoByWindow(beam.ParDo):
- """Not used yet..."""
def __init__(self, windowing):
super().__init__(_GroupAlsoByWindowDoFn(windowing))
self.windowing = windowing
@@ -86,12 +85,23 @@ class _GroupAlsoByWindow(beam.ParDo):
@typehints.with_output_types(t.Tuple[K, t.Iterable[V]])
class _GroupByKey(beam.PTransform):
def expand(self, input_or_inputs):
- return input_or_inputs | "GroupByKey" >> _GroupByKeyOnly()
+ return (
+ input_or_inputs
+ | "ReifyWindows" >> beam.ParDo(beam.GroupByKey.ReifyWindows())
+ | "GroupByKey" >> _GroupByKeyOnly()
+ | "GroupByWindow" >> _GroupAlsoByWindow(input_or_inputs.windowing))
class _Flatten(beam.PTransform):
def expand(self, input_or_inputs):
- is_bounded = all(pcoll.is_bounded for pcoll in input_or_inputs)
+ if isinstance(input_or_inputs, beam.PCollection):
+ # NOTE(cisaacstern): I needed this to avoid
+ # `TypeError: 'PCollection' object is not iterable`
+ # being raised by `all(...)` call below for single-element flattens,
i.e.,
+ # `(pcoll, ) | beam.Flatten() | ...`
+ is_bounded = input_or_inputs.is_bounded
+ else:
+ is_bounded = all(pcoll.is_bounded for pcoll in input_or_inputs)
return beam.pvalue.PCollection(self.pipeline, is_bounded=is_bounded)
diff --git a/sdks/python/apache_beam/runners/dask/transform_evaluator.py
b/sdks/python/apache_beam/runners/dask/transform_evaluator.py
index d4d58879b7f..e3bd5fd8776 100644
--- a/sdks/python/apache_beam/runners/dask/transform_evaluator.py
+++ b/sdks/python/apache_beam/runners/dask/transform_evaluator.py
@@ -26,19 +26,110 @@ import abc
import dataclasses
import math
import typing as t
+from dataclasses import field
import apache_beam
import dask.bag as db
+from apache_beam import DoFn
+from apache_beam import TaggedOutput
from apache_beam.pipeline import AppliedPTransform
+from apache_beam.runners.common import DoFnContext
+from apache_beam.runners.common import DoFnInvoker
+from apache_beam.runners.common import DoFnSignature
+from apache_beam.runners.common import Receiver
+from apache_beam.runners.common import _OutputHandler
from apache_beam.runners.dask.overrides import _Create
from apache_beam.runners.dask.overrides import _Flatten
from apache_beam.runners.dask.overrides import _GroupByKeyOnly
+from apache_beam.transforms.sideinputs import SideInputMap
+from apache_beam.transforms.window import GlobalWindow
+from apache_beam.transforms.window import TimestampedValue
+from apache_beam.transforms.window import WindowFn
+from apache_beam.utils.windowed_value import WindowedValue
+# Inputs to DaskOps.
OpInput = t.Union[db.Bag, t.Sequence[db.Bag], None]
+OpSide = t.Optional[t.Sequence[SideInputMap]]
+
+# Value types for PCollections (possibly Windowed Values).
+PCollVal = t.Union[WindowedValue, t.Any]
+
+
+def get_windowed_value(item: t.Any, window_fn: WindowFn) -> WindowedValue:
+ """Wraps a value (item) inside a Window."""
+ if isinstance(item, TaggedOutput):
+ item = item.value
+
+ if isinstance(item, WindowedValue):
+ windowed_value = item
+ elif isinstance(item, TimestampedValue):
+ assign_context = WindowFn.AssignContext(item.timestamp, item.value)
+ windowed_value = WindowedValue(
+ item.value, item.timestamp, tuple(window_fn.assign(assign_context)))
+ else:
+ windowed_value = WindowedValue(item, 0, (GlobalWindow(), ))
+
+ return windowed_value
+
+
+def defenestrate(x):
+ """Extracts the underlying item from a Window."""
+ if isinstance(x, WindowedValue):
+ return x.value
+ return x
+
+
[email protected]
+class DaskBagWindowedIterator:
+ """Iterator for `apache_beam.transforms.sideinputs.SideInputMap`"""
+
+ bag: db.Bag
+ window_fn: WindowFn
+
+ def __iter__(self):
+ # FIXME(cisaacstern): list() is likely inefficient, since it presumably
+ # materializes the full result before iterating over it. doing this for
+ # now as a proof-of-concept. can we can generate results incrementally?
+ for result in list(self.bag):
+ yield get_windowed_value(result, self.window_fn)
+
+
[email protected]
+class TaggingReceiver(Receiver):
+ """A Receiver that handles tagged `WindowValue`s."""
+ tag: str
+ values: t.List[PCollVal]
+
+ def receive(self, windowed_value: WindowedValue):
+ if self.tag:
+ output = TaggedOutput(self.tag, windowed_value)
+ else:
+ output = windowed_value
+ self.values.append(output)
+
+
[email protected]
+class OneReceiver(dict):
+ """A Receiver that tags value via dictionary lookup key."""
+ values: t.List[PCollVal] = field(default_factory=list)
+
+ def __missing__(self, key):
+ if key not in self:
+ self[key] = TaggingReceiver(key, self.values)
+ return self[key]
@dataclasses.dataclass
class DaskBagOp(abc.ABC):
+ """Abstract Base Class for all Dask-supported Operations.
+
+ All DaskBagOps must support an `apply()` operation, which invokes the dask
+ bag upon the previous op's input.
+
+ Attributes
+ applied: The underlying `AppliedPTransform` which holds the code for the
+ target operation.
+ """
applied: AppliedPTransform
@property
@@ -46,17 +137,19 @@ class DaskBagOp(abc.ABC):
return self.applied.transform
@abc.abstractmethod
- def apply(self, input_bag: OpInput) -> db.Bag:
+ def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag:
pass
class NoOp(DaskBagOp):
- def apply(self, input_bag: OpInput) -> db.Bag:
+ """An identity on a dask bag: returns the input as-is."""
+ def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag:
return input_bag
class Create(DaskBagOp):
- def apply(self, input_bag: OpInput) -> db.Bag:
+ """The beginning of a Beam pipeline; the input must be `None`."""
+ def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag:
assert input_bag is None, 'Create expects no input!'
original_transform = t.cast(_Create, self.transform)
items = original_transform.values
@@ -66,42 +159,95 @@ class Create(DaskBagOp):
1, math.ceil(math.sqrt(len(items)) / math.sqrt(100))))
+def apply_dofn_to_bundle(
+ items, do_fn_invoker_args, do_fn_invoker_kwargs, tagged_receivers):
+ """Invokes a DoFn within a bundle, implemented as a Dask partition."""
+
+ do_fn_invoker = DoFnInvoker.create_invoker(
+ *do_fn_invoker_args, **do_fn_invoker_kwargs)
+
+ do_fn_invoker.invoke_setup()
+ do_fn_invoker.invoke_start_bundle()
+
+ for it in items:
+ do_fn_invoker.invoke_process(it)
+
+ results = [v.value for v in tagged_receivers.values]
+
+ do_fn_invoker.invoke_finish_bundle()
+ do_fn_invoker.invoke_teardown()
+
+ return results
+
+
class ParDo(DaskBagOp):
- def apply(self, input_bag: db.Bag) -> db.Bag:
- transform = t.cast(apache_beam.ParDo, self.transform)
- return input_bag.map(
- transform.fn.process, *transform.args, **transform.kwargs).flatten()
+ """Apply a pure function in an embarrassingly-parallel way.
+ This consumes a sequence of items and returns a sequence of items.
+ """
+ def apply(self, input_bag: db.Bag, side_inputs: OpSide = None) -> db.Bag:
+ transform = t.cast(apache_beam.ParDo, self.transform)
-class Map(DaskBagOp):
- def apply(self, input_bag: db.Bag) -> db.Bag:
- transform = t.cast(apache_beam.Map, self.transform)
- return input_bag.map(
- transform.fn.process, *transform.args, **transform.kwargs)
+ args, kwargs = transform.raw_side_inputs
+ args = list(args)
+ main_input = next(iter(self.applied.main_inputs.values()))
+ window_fn = main_input.windowing.windowfn if hasattr(
+ main_input, "windowing") else None
+
+ tagged_receivers = OneReceiver()
+
+ do_fn_invoker_args = [
+ DoFnSignature(transform.fn),
+ _OutputHandler(
+ window_fn=window_fn,
+ main_receivers=tagged_receivers[None],
+ tagged_receivers=tagged_receivers,
+ per_element_output_counter=None,
+ output_batch_converter=None,
+ process_yields_batches=False,
+ process_batch_yields_elements=False),
+ ]
+ do_fn_invoker_kwargs = dict(
+ context=DoFnContext(transform.label, state=None),
+ side_inputs=side_inputs,
+ input_args=args,
+ input_kwargs=kwargs,
+ user_state_context=None,
+ bundle_finalizer_param=DoFn.BundleFinalizerParam(),
+ )
+
+ return input_bag.map(get_windowed_value, window_fn).map_partitions(
+ apply_dofn_to_bundle,
+ do_fn_invoker_args,
+ do_fn_invoker_kwargs,
+ tagged_receivers,
+ )
class GroupByKey(DaskBagOp):
- def apply(self, input_bag: db.Bag) -> db.Bag:
+ """Group a PCollection into a mapping of keys to elements."""
+ def apply(self, input_bag: db.Bag, side_inputs: OpSide = None) -> db.Bag:
def key(item):
return item[0]
def value(item):
k, v = item
- return k, [elm[1] for elm in v]
+ return k, [defenestrate(elm[1]) for elm in v]
return input_bag.groupby(key).map(value)
class Flatten(DaskBagOp):
- def apply(self, input_bag: OpInput) -> db.Bag:
- assert type(input_bag) is list, 'Must take a sequence of bags!'
+ """Produces a flattened bag from a collection of bags."""
+ def apply(
+ self, input_bag: t.List[db.Bag], side_inputs: OpSide = None) -> db.Bag:
+ assert isinstance(input_bag, list), 'Must take a sequence of bags!'
return db.concat(input_bag)
TRANSLATIONS = {
_Create: Create,
apache_beam.ParDo: ParDo,
- apache_beam.Map: Map,
_GroupByKeyOnly: GroupByKey,
_Flatten: Flatten,
}
diff --git a/sdks/python/scripts/generate_pydoc.sh
b/sdks/python/scripts/generate_pydoc.sh
index 21561e1bf6a..4922b61169d 100755
--- a/sdks/python/scripts/generate_pydoc.sh
+++ b/sdks/python/scripts/generate_pydoc.sh
@@ -64,6 +64,7 @@ excluded_patterns=(
'apache_beam/runners/portability/'
'apache_beam/runners/test/'
'apache_beam/runners/worker/'
+ 'apache_beam/runners/dask/transform_evaluator.*'
'apache_beam/testing/benchmarks/chicago_taxi/'
'apache_beam/testing/benchmarks/cloudml/'
'apache_beam/testing/benchmarks/inference/'
@@ -134,7 +135,7 @@ autodoc_member_order = 'bysource'
autodoc_mock_imports = ["tensorrt", "cuda", "torch",
"onnxruntime", "onnx", "tensorflow", "tensorflow_hub",
"tensorflow_transform", "tensorflow_metadata", "transformers", "xgboost",
"datatable", "transformers",
- "sentence_transformers", "redis", "tensorflow_text", "feast",
+ "sentence_transformers", "redis", "tensorflow_text", "feast", "dask",
]
# Allow a special section for documenting DataFrame API
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index 9ae5d3153f5..3b45cbf82fc 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -512,8 +512,15 @@ if __name__ == '__main__':
],
'dataframe': dataframe_dependency,
'dask': [
- 'dask >= 2022.6',
- 'distributed >= 2022.6',
+ 'distributed >= 2024.4.2',
+ 'dask >= 2024.4.2',
+ # For development, 'distributed >= 2023.12.1' should work with
+ # the above dask PR, however it can't be installed as part of
+ # a single `pip` call, since distributed releases are pinned to
+ # specific dask releases. As a workaround, distributed can be
+ # installed first, and then `.[dask]` installed second, with the
+ # `--update` / `-U` flag to replace the dask release brought in
+ # by distributed.
],
'yaml': [
'docstring-parser>=0.15,<1.0',
diff --git a/sdks/python/test-suites/tox/common.gradle
b/sdks/python/test-suites/tox/common.gradle
index df42a2c384c..01265a6eeff 100644
--- a/sdks/python/test-suites/tox/common.gradle
+++ b/sdks/python/test-suites/tox/common.gradle
@@ -31,7 +31,6 @@ test.dependsOn "testPy${pythonVersionSuffix}ML"
// toxTask "testPy${pythonVersionSuffix}Dask",
"py${pythonVersionSuffix}-dask", "${posargs}"
// test.dependsOn "testPy${pythonVersionSuffix}Dask"
-
project.tasks.register("preCommitPy${pythonVersionSuffix}") {
// Since codecoverage reports will always be generated for py38,
// all tests will be exercised.
diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini
index c7713498d87..ad5d7ec5505 100644
--- a/sdks/python/tox.ini
+++ b/sdks/python/tox.ini
@@ -109,9 +109,20 @@ commands =
bash {toxinidir}/scripts/run_pytest.sh {envname} "{posargs}"
[testenv:py{39,310,311,312}-dask]
-extras = test,dask
+extras = test,dask,dataframes
+commands_pre =
+ pip install 'distributed>=2024.4.2' 'dask>=2024.4.2'
commands =
- bash {toxinidir}/scripts/run_pytest.sh {envname} "{posargs}"
+ bash {toxinidir}/scripts/run_pytest.sh {envname}
{toxinidir}/apache_beam/runners/dask/
+
+[testenv:py{39,310,311,312}-win-dask]
+commands_pre =
+ pip install 'distributed>=2024.4.2' 'dask>=2024.4.2'
+commands =
+ python apache_beam/examples/complete/autocomplete_test.py
+ bash {toxinidir}/scripts/run_pytest.sh {envname}
{toxinidir}/apache_beam/runners/dask/
+install_command = {envbindir}/python.exe {envbindir}/pip.exe install --retries
10 {opts} {packages}
+list_dependencies_command = {envbindir}/python.exe {envbindir}/pip.exe freeze
[testenv:py39-cloudcoverage]
deps =
@@ -394,7 +405,7 @@ commands =
[testenv:py39-tensorflow-212]
deps =
- 212:
+ 212:
tensorflow>=2.12rc1,<2.13
# Help pip resolve conflict with typing-extensions for old version of TF
https://github.com/apache/beam/issues/30852
pydantic<2.7