This is an automated email from the ASF dual-hosted git repository.

yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new e939be38587 Windowing Support for the Dask Runner (#32941)
e939be38587 is described below

commit e939be3858730f6d64411c3871d99f17729459d3
Author: Alex Merose <[email protected]>
AuthorDate: Mon Nov 18 16:39:30 2024 -0300

    Windowing Support for the Dask Runner (#32941)
    
    Windowing Support for the Dask Runner
    
    ---------
    
    Co-authored-by: Pablo E <[email protected]>
    Co-authored-by: Pablo <[email protected]>
    Co-authored-by: Charles Stern 
<[email protected]>
---
 .github/workflows/dask_runner_tests.yml            |   2 +-
 .../python/apache_beam/runners/dask/dask_runner.py |  59 +++-
 .../apache_beam/runners/dask/dask_runner_test.py   | 311 ++++++++++++++++++++-
 sdks/python/apache_beam/runners/dask/overrides.py  |  16 +-
 .../runners/dask/transform_evaluator.py            | 180 ++++++++++--
 sdks/python/scripts/generate_pydoc.sh              |   3 +-
 sdks/python/setup.py                               |  11 +-
 sdks/python/test-suites/tox/common.gradle          |   1 -
 sdks/python/tox.ini                                |  17 +-
 9 files changed, 558 insertions(+), 42 deletions(-)

diff --git a/.github/workflows/dask_runner_tests.yml 
b/.github/workflows/dask_runner_tests.yml
index f87c70d8b72..0f60c22b6aa 100644
--- a/.github/workflows/dask_runner_tests.yml
+++ b/.github/workflows/dask_runner_tests.yml
@@ -78,7 +78,7 @@ jobs:
         run: pip install tox
       - name: Install SDK with dask
         working-directory: ./sdks/python
-        run: pip install setuptools --upgrade && pip install -e 
.[gcp,dask,test]
+        run: pip install setuptools --upgrade && pip install -e 
.[dask,test,dataframes]
       - name: Run tests basic unix
         if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos')
         working-directory: ./sdks/python
diff --git a/sdks/python/apache_beam/runners/dask/dask_runner.py 
b/sdks/python/apache_beam/runners/dask/dask_runner.py
index 109c4379b45..0f2317074ce 100644
--- a/sdks/python/apache_beam/runners/dask/dask_runner.py
+++ b/sdks/python/apache_beam/runners/dask/dask_runner.py
@@ -31,12 +31,22 @@ from apache_beam.pipeline import AppliedPTransform
 from apache_beam.pipeline import PipelineVisitor
 from apache_beam.runners.dask.overrides import dask_overrides
 from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS
+from apache_beam.runners.dask.transform_evaluator import 
DaskBagWindowedIterator
+from apache_beam.runners.dask.transform_evaluator import Flatten
 from apache_beam.runners.dask.transform_evaluator import NoOp
 from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner
 from apache_beam.runners.runner import PipelineResult
 from apache_beam.runners.runner import PipelineState
+from apache_beam.transforms.sideinputs import SideInputMap
 from apache_beam.utils.interactive_utils import is_in_notebook
 
+try:
+  # Added to try to prevent threading related issues, see
+  # https://github.com/pytest-dev/pytest/issues/3216#issuecomment-1502451456
+  import dask.distributed as ddist
+except ImportError:
+  ddist = {}
+
 
 class DaskOptions(PipelineOptions):
   @staticmethod
@@ -86,10 +96,9 @@ class DaskOptions(PipelineOptions):
 
 @dataclasses.dataclass
 class DaskRunnerResult(PipelineResult):
-  from dask import distributed
 
-  client: distributed.Client
-  futures: t.Sequence[distributed.Future]
+  client: ddist.Client
+  futures: t.Sequence[ddist.Future]
 
   def __post_init__(self):
     super().__init__(PipelineState.RUNNING)
@@ -99,8 +108,16 @@ class DaskRunnerResult(PipelineResult):
       if duration is not None:
         # Convert milliseconds to seconds
         duration /= 1000
-      self.client.wait_for_workers(timeout=duration)
-      self.client.gather(self.futures, errors='raise')
+      for _ in ddist.as_completed(self.futures,
+                                  timeout=duration,
+                                  with_results=True):
+        # without gathering results, worker errors are not raised on the 
client:
+        # 
https://distributed.dask.org/en/stable/resilience.html#user-code-failures
+        # so we want to gather results to raise errors client-side, but we do
+        # not actually need to use the results here, so we just pass. to 
gather,
+        # we use the iterative `as_completed(..., with_results=True)`, instead
+        # of aggregate `client.gather`, to minimize memory footprint of 
results.
+        pass
       self._state = PipelineState.DONE
     except:  # pylint: disable=broad-except
       self._state = PipelineState.FAILED
@@ -133,6 +150,7 @@ class DaskRunner(BundleBasedDirectRunner):
         op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
         op = op_class(transform_node)
 
+        op_kws = {"input_bag": None, "side_inputs": None}
         inputs = list(transform_node.inputs)
         if inputs:
           bag_inputs = []
@@ -144,13 +162,28 @@ class DaskRunner(BundleBasedDirectRunner):
             if prev_op in self.bags:
               bag_inputs.append(self.bags[prev_op])
 
-          if len(bag_inputs) == 1:
-            self.bags[transform_node] = op.apply(bag_inputs[0])
+          # Input to `Flatten` could be of length 1, e.g. a single-element
+          # tuple: `(pcoll, ) | beam.Flatten()`. If so, we still pass it as
+          # an iterable, because `Flatten.apply` always takes an iterable.
+          if len(bag_inputs) == 1 and not isinstance(op, Flatten):
+            op_kws["input_bag"] = bag_inputs[0]
           else:
-            self.bags[transform_node] = op.apply(bag_inputs)
+            op_kws["input_bag"] = bag_inputs
+
+        side_inputs = list(transform_node.side_inputs)
+        if side_inputs:
+          bag_side_inputs = []
+          for si in side_inputs:
+            si_asbag = self.bags.get(si.pvalue.producer)
+            bag_side_inputs.append(
+                SideInputMap(
+                    type(si),
+                    si._view_options(),
+                    DaskBagWindowedIterator(si_asbag, si._window_mapping_fn)))
+
+          op_kws["side_inputs"] = bag_side_inputs
 
-        else:
-          self.bags[transform_node] = op.apply(None)
+        self.bags[transform_node] = op.apply(**op_kws)
 
     return DaskBagVisitor()
 
@@ -159,6 +192,8 @@ class DaskRunner(BundleBasedDirectRunner):
     return False
 
   def run_pipeline(self, pipeline, options):
+    import dask
+
     # TODO(alxr): Create interactive notebook support.
     if is_in_notebook():
       raise NotImplementedError('interactive support will come later!')
@@ -177,6 +212,6 @@ class DaskRunner(BundleBasedDirectRunner):
 
     dask_visitor = self.to_dask_bag_visitor()
     pipeline.visit(dask_visitor)
-
-    futures = client.compute(list(dask_visitor.bags.values()))
+    opt_graph = dask.optimize(*list(dask_visitor.bags.values()))
+    futures = client.compute(opt_graph)
     return DaskRunnerResult(client, futures)
diff --git a/sdks/python/apache_beam/runners/dask/dask_runner_test.py 
b/sdks/python/apache_beam/runners/dask/dask_runner_test.py
index d8b3e17d8a5..66dda4a984f 100644
--- a/sdks/python/apache_beam/runners/dask/dask_runner_test.py
+++ b/sdks/python/apache_beam/runners/dask/dask_runner_test.py
@@ -14,7 +14,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import datetime
 import inspect
+import typing as t
 import unittest
 
 import apache_beam as beam
@@ -22,12 +24,14 @@ from apache_beam.options.pipeline_options import 
PipelineOptions
 from apache_beam.testing import test_pipeline
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
+from apache_beam.transforms import window
 
 try:
-  from apache_beam.runners.dask.dask_runner import DaskOptions
-  from apache_beam.runners.dask.dask_runner import DaskRunner
   import dask
   import dask.distributed as ddist
+
+  from apache_beam.runners.dask.dask_runner import DaskOptions  # pylint: 
disable=ungrouped-imports
+  from apache_beam.runners.dask.dask_runner import DaskRunner  # pylint: 
disable=ungrouped-imports
 except (ImportError, ModuleNotFoundError):
   raise unittest.SkipTest('Dask must be installed to run tests.')
 
@@ -73,6 +77,11 @@ class DaskRunnerRunPipelineTest(unittest.TestCase):
       pcoll = p | beam.Create([1])
       assert_that(pcoll, equal_to([1]))
 
+  def test_create_multiple(self):
+    with self.pipeline as p:
+      pcoll = p | beam.Create([1, 2, 3, 4])
+      assert_that(pcoll, equal_to([1, 2, 3, 4]))
+
   def test_create_and_map(self):
     def double(x):
       return x * 2
@@ -81,6 +90,22 @@ class DaskRunnerRunPipelineTest(unittest.TestCase):
       pcoll = p | beam.Create([1]) | beam.Map(double)
       assert_that(pcoll, equal_to([2]))
 
+  def test_create_and_map_multiple(self):
+    def double(x):
+      return x * 2
+
+    with self.pipeline as p:
+      pcoll = p | beam.Create([1, 2]) | beam.Map(double)
+      assert_that(pcoll, equal_to([2, 4]))
+
+  def test_create_and_map_many(self):
+    def double(x):
+      return x * 2
+
+    with self.pipeline as p:
+      pcoll = p | beam.Create(list(range(1, 11))) | beam.Map(double)
+      assert_that(pcoll, equal_to(list(range(2, 21, 2))))
+
   def test_create_map_and_groupby(self):
     def double(x):
       return x * 2, x
@@ -89,6 +114,288 @@ class DaskRunnerRunPipelineTest(unittest.TestCase):
       pcoll = p | beam.Create([1]) | beam.Map(double) | beam.GroupByKey()
       assert_that(pcoll, equal_to([(2, [1])]))
 
+  def test_create_map_and_groupby_multiple(self):
+    def double(x):
+      return x * 2, x
+
+    with self.pipeline as p:
+      pcoll = (
+          p
+          | beam.Create([1, 2, 1, 2, 3])
+          | beam.Map(double)
+          | beam.GroupByKey())
+      assert_that(pcoll, equal_to([(2, [1, 1]), (4, [2, 2]), (6, [3])]))
+
+  def test_map_with_positional_side_input(self):
+    def mult_by(x, y):
+      return x * y
+
+    with self.pipeline as p:
+      side = p | "side" >> beam.Create([3])
+      pcoll = (
+          p
+          | "main" >> beam.Create([1])
+          | beam.Map(mult_by, beam.pvalue.AsSingleton(side)))
+      assert_that(pcoll, equal_to([3]))
+
+  def test_map_with_keyword_side_input(self):
+    def mult_by(x, y):
+      return x * y
+
+    with self.pipeline as p:
+      side = p | "side" >> beam.Create([3])
+      pcoll = (
+          p
+          | "main" >> beam.Create([1])
+          | beam.Map(mult_by, y=beam.pvalue.AsSingleton(side)))
+      assert_that(pcoll, equal_to([3]))
+
+  def test_pardo_side_inputs(self):
+    def cross_product(elem, sides):
+      for side in sides:
+        yield elem, side
+
+    with self.pipeline as p:
+      main = p | "main" >> beam.Create(["a", "b", "c"])
+      side = p | "side" >> beam.Create(["x", "y"])
+      assert_that(
+          main | beam.FlatMap(cross_product, beam.pvalue.AsList(side)),
+          equal_to([
+              ("a", "x"),
+              ("b", "x"),
+              ("c", "x"),
+              ("a", "y"),
+              ("b", "y"),
+              ("c", "y"),
+          ]),
+      )
+
+  def test_pardo_side_input_dependencies(self):
+    with self.pipeline as p:
+      inputs = [p | beam.Create([None])]
+      for k in range(1, 10):
+        inputs.append(
+            inputs[0]
+            | beam.ParDo(
+                ExpectingSideInputsFn(f"Do{k}"),
+                *[beam.pvalue.AsList(inputs[s]) for s in range(1, k)],
+            ))
+
+  def test_pardo_side_input_sparse_dependencies(self):
+    with self.pipeline as p:
+      inputs = []
+
+      def choose_input(s):
+        return inputs[(389 + s * 5077) % len(inputs)]
+
+      for k in range(20):
+        num_inputs = int((k * k % 16)**0.5)
+        if num_inputs == 0:
+          inputs.append(p | f"Create{k}" >> beam.Create([f"Create{k}"]))
+        else:
+          inputs.append(
+              choose_input(0)
+              | beam.ParDo(
+                  ExpectingSideInputsFn(f"Do{k}"),
+                  *[
+                      beam.pvalue.AsList(choose_input(s))
+                      for s in range(1, num_inputs)
+                  ],
+              ))
+
+  @unittest.expectedFailure
+  def test_pardo_windowed_side_inputs(self):
+    with self.pipeline as p:
+      # Now with some windowing.
+      pcoll = (
+          p
+          | beam.Create(list(range(10)))
+          | beam.Map(lambda t: window.TimestampedValue(t, t)))
+      # Intentionally choosing non-aligned windows to highlight the transition.
+      main = pcoll | "WindowMain" >> beam.WindowInto(window.FixedWindows(5))
+      side = pcoll | "WindowSide" >> beam.WindowInto(window.FixedWindows(7))
+      res = main | beam.Map(
+          lambda x, s: (x, sorted(s)), beam.pvalue.AsList(side))
+      assert_that(
+          res,
+          equal_to([
+              # The window [0, 5) maps to the window [0, 7).
+              (0, list(range(7))),
+              (1, list(range(7))),
+              (2, list(range(7))),
+              (3, list(range(7))),
+              (4, list(range(7))),
+              # The window [5, 10) maps to the window [7, 14).
+              (5, list(range(7, 10))),
+              (6, list(range(7, 10))),
+              (7, list(range(7, 10))),
+              (8, list(range(7, 10))),
+              (9, list(range(7, 10))),
+          ]),
+          label="windowed",
+      )
+
+  def test_flattened_side_input(self, with_transcoding=True):
+    with self.pipeline as p:
+      main = p | "main" >> beam.Create([None])
+      side1 = p | "side1" >> beam.Create([("a", 1)])
+      side2 = p | "side2" >> beam.Create([("b", 2)])
+      if with_transcoding:
+        # Also test non-matching coder types (transcoding required)
+        third_element = [("another_type")]
+      else:
+        third_element = [("b", 3)]
+      side3 = p | "side3" >> beam.Create(third_element)
+      side = (side1, side2) | beam.Flatten()
+      assert_that(
+          main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)),
+          equal_to([(None, {
+              "a": 1, "b": 2
+          })]),
+          label="CheckFlattenAsSideInput",
+      )
+      assert_that(
+          (side, side3) | "FlattenAfter" >> beam.Flatten(),
+          equal_to([("a", 1), ("b", 2)] + third_element),
+          label="CheckFlattenOfSideInput",
+      )
+
+  def test_gbk_side_input(self):
+    with self.pipeline as p:
+      main = p | "main" >> beam.Create([None])
+      side = p | "side" >> beam.Create([("a", 1)]) | beam.GroupByKey()
+      assert_that(
+          main | beam.Map(lambda a, b: (a, b), beam.pvalue.AsDict(side)),
+          equal_to([(None, {
+              "a": [1]
+          })]),
+      )
+
+  def test_multimap_side_input(self):
+    with self.pipeline as p:
+      main = p | "main" >> beam.Create(["a", "b"])
+      side = p | "side" >> beam.Create([("a", 1), ("b", 2), ("a", 3)])
+      assert_that(
+          main
+          | beam.Map(
+              lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)),
+          equal_to([("a", [1, 3]), ("b", [2])]),
+      )
+
+  def test_multimap_multiside_input(self):
+    # A test where two transforms in the same stage consume the same 
PCollection
+    # twice as side input.
+    with self.pipeline as p:
+      main = p | "main" >> beam.Create(["a", "b"])
+      side = p | "side" >> beam.Create([("a", 1), ("b", 2), ("a", 3)])
+      assert_that(
+          main
+          | "first map" >> beam.Map(
+              lambda k,
+              d,
+              l: (k, sorted(d[k]), sorted([e[1] for e in l])),
+              beam.pvalue.AsMultiMap(side),
+              beam.pvalue.AsList(side),
+          )
+          | "second map" >> beam.Map(
+              lambda k,
+              d,
+              l: (k[0], sorted(d[k[0]]), sorted([e[1] for e in l])),
+              beam.pvalue.AsMultiMap(side),
+              beam.pvalue.AsList(side),
+          ),
+          equal_to([("a", [1, 3], [1, 2, 3]), ("b", [2], [1, 2, 3])]),
+      )
+
+  def test_multimap_side_input_type_coercion(self):
+    with self.pipeline as p:
+      main = p | "main" >> beam.Create(["a", "b"])
+      # The type of this side-input is forced to Any (overriding type
+      # inference). Without type coercion to Tuple[Any, Any], the usage of this
+      # side-input in AsMultiMap() below should fail.
+      side = p | "side" >> beam.Create([("a", 1), ("b", 2),
+                                        ("a", 3)]).with_output_types(t.Any)
+      assert_that(
+          main
+          | beam.Map(
+              lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)),
+          equal_to([("a", [1, 3]), ("b", [2])]),
+      )
+
+  def test_pardo_unfusable_side_inputs__one(self):
+    def cross_product(elem, sides):
+      for side in sides:
+        yield elem, side
+
+    with self.pipeline as p:
+      pcoll = p | "Create1" >> beam.Create(["a", "b"])
+      assert_that(
+          pcoll |
+          "FlatMap1" >> beam.FlatMap(cross_product, beam.pvalue.AsList(pcoll)),
+          equal_to([("a", "a"), ("a", "b"), ("b", "a"), ("b", "b")]),
+          label="assert_that1",
+      )
+
+  def test_pardo_unfusable_side_inputs__two(self):
+    def cross_product(elem, sides):
+      for side in sides:
+        yield elem, side
+
+    with self.pipeline as p:
+      pcoll = p | "Create2" >> beam.Create(["a", "b"])
+
+      derived = ((pcoll, )
+                 | beam.Flatten()
+                 | beam.Map(lambda x: (x, x))
+                 | beam.GroupByKey()
+                 | "Unkey" >> beam.Map(lambda kv: kv[0]))
+      assert_that(
+          pcoll | "FlatMap2" >> beam.FlatMap(
+              cross_product, beam.pvalue.AsList(derived)),
+          equal_to([("a", "a"), ("a", "b"), ("b", "a"), ("b", "b")]),
+          label="assert_that2",
+      )
+
+  def test_groupby_with_fixed_windows(self):
+    def double(x):
+      return x * 2, x
+
+    def add_timestamp(pair):
+      delta = datetime.timedelta(seconds=pair[1] * 60)
+      now = (datetime.datetime.now() + delta).timestamp()
+      return window.TimestampedValue(pair, now)
+
+    with self.pipeline as p:
+      pcoll = (
+          p
+          | beam.Create([1, 2, 1, 2, 3])
+          | beam.Map(double)
+          | beam.WindowInto(window.FixedWindows(60))
+          | beam.Map(add_timestamp)
+          | beam.GroupByKey())
+      assert_that(pcoll, equal_to([(2, [1, 1]), (4, [2, 2]), (6, [3])]))
+
+  def test_groupby_string_keys(self):
+    with self.pipeline as p:
+      pcoll = (
+          p
+          | beam.Create([('a', 1), ('a', 2), ('b', 3), ('b', 4)])
+          | beam.GroupByKey())
+      assert_that(pcoll, equal_to([('a', [1, 2]), ('b', [3, 4])]))
+
+
+class ExpectingSideInputsFn(beam.DoFn):
+  def __init__(self, name):
+    self._name = name
+
+  def default_label(self):
+    return self._name
+
+  def process(self, element, *side_inputs):
+    if not all(list(s) for s in side_inputs):
+      raise ValueError(f"Missing data in side input {side_inputs}")
+    yield self._name
+
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/runners/dask/overrides.py 
b/sdks/python/apache_beam/runners/dask/overrides.py
index d07c7cd518a..b952834f12d 100644
--- a/sdks/python/apache_beam/runners/dask/overrides.py
+++ b/sdks/python/apache_beam/runners/dask/overrides.py
@@ -73,7 +73,6 @@ class _GroupByKeyOnly(beam.PTransform):
 @typehints.with_input_types(t.Tuple[K, t.Iterable[V]])
 @typehints.with_output_types(t.Tuple[K, t.Iterable[V]])
 class _GroupAlsoByWindow(beam.ParDo):
-  """Not used yet..."""
   def __init__(self, windowing):
     super().__init__(_GroupAlsoByWindowDoFn(windowing))
     self.windowing = windowing
@@ -86,12 +85,23 @@ class _GroupAlsoByWindow(beam.ParDo):
 @typehints.with_output_types(t.Tuple[K, t.Iterable[V]])
 class _GroupByKey(beam.PTransform):
   def expand(self, input_or_inputs):
-    return input_or_inputs | "GroupByKey" >> _GroupByKeyOnly()
+    return (
+        input_or_inputs
+        | "ReifyWindows" >> beam.ParDo(beam.GroupByKey.ReifyWindows())
+        | "GroupByKey" >> _GroupByKeyOnly()
+        | "GroupByWindow" >> _GroupAlsoByWindow(input_or_inputs.windowing))
 
 
 class _Flatten(beam.PTransform):
   def expand(self, input_or_inputs):
-    is_bounded = all(pcoll.is_bounded for pcoll in input_or_inputs)
+    if isinstance(input_or_inputs, beam.PCollection):
+      # NOTE(cisaacstern): I needed this to avoid
+      #   `TypeError: 'PCollection' object is not iterable`
+      # being raised by `all(...)` call below for single-element flattens, 
i.e.,
+      #   `(pcoll, ) | beam.Flatten() | ...`
+      is_bounded = input_or_inputs.is_bounded
+    else:
+      is_bounded = all(pcoll.is_bounded for pcoll in input_or_inputs)
     return beam.pvalue.PCollection(self.pipeline, is_bounded=is_bounded)
 
 
diff --git a/sdks/python/apache_beam/runners/dask/transform_evaluator.py 
b/sdks/python/apache_beam/runners/dask/transform_evaluator.py
index d4d58879b7f..e3bd5fd8776 100644
--- a/sdks/python/apache_beam/runners/dask/transform_evaluator.py
+++ b/sdks/python/apache_beam/runners/dask/transform_evaluator.py
@@ -26,19 +26,110 @@ import abc
 import dataclasses
 import math
 import typing as t
+from dataclasses import field
 
 import apache_beam
 import dask.bag as db
+from apache_beam import DoFn
+from apache_beam import TaggedOutput
 from apache_beam.pipeline import AppliedPTransform
+from apache_beam.runners.common import DoFnContext
+from apache_beam.runners.common import DoFnInvoker
+from apache_beam.runners.common import DoFnSignature
+from apache_beam.runners.common import Receiver
+from apache_beam.runners.common import _OutputHandler
 from apache_beam.runners.dask.overrides import _Create
 from apache_beam.runners.dask.overrides import _Flatten
 from apache_beam.runners.dask.overrides import _GroupByKeyOnly
+from apache_beam.transforms.sideinputs import SideInputMap
+from apache_beam.transforms.window import GlobalWindow
+from apache_beam.transforms.window import TimestampedValue
+from apache_beam.transforms.window import WindowFn
+from apache_beam.utils.windowed_value import WindowedValue
 
+# Inputs to DaskOps.
 OpInput = t.Union[db.Bag, t.Sequence[db.Bag], None]
+OpSide = t.Optional[t.Sequence[SideInputMap]]
+
+# Value types for PCollections (possibly Windowed Values).
+PCollVal = t.Union[WindowedValue, t.Any]
+
+
+def get_windowed_value(item: t.Any, window_fn: WindowFn) -> WindowedValue:
+  """Wraps a value (item) inside a Window."""
+  if isinstance(item, TaggedOutput):
+    item = item.value
+
+  if isinstance(item, WindowedValue):
+    windowed_value = item
+  elif isinstance(item, TimestampedValue):
+    assign_context = WindowFn.AssignContext(item.timestamp, item.value)
+    windowed_value = WindowedValue(
+        item.value, item.timestamp, tuple(window_fn.assign(assign_context)))
+  else:
+    windowed_value = WindowedValue(item, 0, (GlobalWindow(), ))
+
+  return windowed_value
+
+
+def defenestrate(x):
+  """Extracts the underlying item from a Window."""
+  if isinstance(x, WindowedValue):
+    return x.value
+  return x
+
+
[email protected]
+class DaskBagWindowedIterator:
+  """Iterator for `apache_beam.transforms.sideinputs.SideInputMap`"""
+
+  bag: db.Bag
+  window_fn: WindowFn
+
+  def __iter__(self):
+    # FIXME(cisaacstern): list() is likely inefficient, since it presumably
+    # materializes the full result before iterating over it. doing this for
+    # now as a proof-of-concept. can we can generate results incrementally?
+    for result in list(self.bag):
+      yield get_windowed_value(result, self.window_fn)
+
+
[email protected]
+class TaggingReceiver(Receiver):
+  """A Receiver that handles tagged `WindowValue`s."""
+  tag: str
+  values: t.List[PCollVal]
+
+  def receive(self, windowed_value: WindowedValue):
+    if self.tag:
+      output = TaggedOutput(self.tag, windowed_value)
+    else:
+      output = windowed_value
+    self.values.append(output)
+
+
[email protected]
+class OneReceiver(dict):
+  """A Receiver that tags value via dictionary lookup key."""
+  values: t.List[PCollVal] = field(default_factory=list)
+
+  def __missing__(self, key):
+    if key not in self:
+      self[key] = TaggingReceiver(key, self.values)
+    return self[key]
 
 
 @dataclasses.dataclass
 class DaskBagOp(abc.ABC):
+  """Abstract Base Class for all Dask-supported Operations.
+
+  All DaskBagOps must support an `apply()` operation, which invokes the dask
+  bag upon the previous op's input.
+
+  Attributes
+    applied: The underlying `AppliedPTransform` which holds the code for the
+      target operation.
+  """
   applied: AppliedPTransform
 
   @property
@@ -46,17 +137,19 @@ class DaskBagOp(abc.ABC):
     return self.applied.transform
 
   @abc.abstractmethod
-  def apply(self, input_bag: OpInput) -> db.Bag:
+  def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag:
     pass
 
 
 class NoOp(DaskBagOp):
-  def apply(self, input_bag: OpInput) -> db.Bag:
+  """An identity on a dask bag: returns the input as-is."""
+  def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag:
     return input_bag
 
 
 class Create(DaskBagOp):
-  def apply(self, input_bag: OpInput) -> db.Bag:
+  """The beginning of a Beam pipeline; the input must be `None`."""
+  def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag:
     assert input_bag is None, 'Create expects no input!'
     original_transform = t.cast(_Create, self.transform)
     items = original_transform.values
@@ -66,42 +159,95 @@ class Create(DaskBagOp):
             1, math.ceil(math.sqrt(len(items)) / math.sqrt(100))))
 
 
+def apply_dofn_to_bundle(
+    items, do_fn_invoker_args, do_fn_invoker_kwargs, tagged_receivers):
+  """Invokes a DoFn within a bundle, implemented as a Dask partition."""
+
+  do_fn_invoker = DoFnInvoker.create_invoker(
+      *do_fn_invoker_args, **do_fn_invoker_kwargs)
+
+  do_fn_invoker.invoke_setup()
+  do_fn_invoker.invoke_start_bundle()
+
+  for it in items:
+    do_fn_invoker.invoke_process(it)
+
+  results = [v.value for v in tagged_receivers.values]
+
+  do_fn_invoker.invoke_finish_bundle()
+  do_fn_invoker.invoke_teardown()
+
+  return results
+
+
 class ParDo(DaskBagOp):
-  def apply(self, input_bag: db.Bag) -> db.Bag:
-    transform = t.cast(apache_beam.ParDo, self.transform)
-    return input_bag.map(
-        transform.fn.process, *transform.args, **transform.kwargs).flatten()
+  """Apply a pure function in an embarrassingly-parallel way.
 
+  This consumes a sequence of items and returns a sequence of items.
+  """
+  def apply(self, input_bag: db.Bag, side_inputs: OpSide = None) -> db.Bag:
+    transform = t.cast(apache_beam.ParDo, self.transform)
 
-class Map(DaskBagOp):
-  def apply(self, input_bag: db.Bag) -> db.Bag:
-    transform = t.cast(apache_beam.Map, self.transform)
-    return input_bag.map(
-        transform.fn.process, *transform.args, **transform.kwargs)
+    args, kwargs = transform.raw_side_inputs
+    args = list(args)
+    main_input = next(iter(self.applied.main_inputs.values()))
+    window_fn = main_input.windowing.windowfn if hasattr(
+        main_input, "windowing") else None
+
+    tagged_receivers = OneReceiver()
+
+    do_fn_invoker_args = [
+        DoFnSignature(transform.fn),
+        _OutputHandler(
+            window_fn=window_fn,
+            main_receivers=tagged_receivers[None],
+            tagged_receivers=tagged_receivers,
+            per_element_output_counter=None,
+            output_batch_converter=None,
+            process_yields_batches=False,
+            process_batch_yields_elements=False),
+    ]
+    do_fn_invoker_kwargs = dict(
+        context=DoFnContext(transform.label, state=None),
+        side_inputs=side_inputs,
+        input_args=args,
+        input_kwargs=kwargs,
+        user_state_context=None,
+        bundle_finalizer_param=DoFn.BundleFinalizerParam(),
+    )
+
+    return input_bag.map(get_windowed_value, window_fn).map_partitions(
+        apply_dofn_to_bundle,
+        do_fn_invoker_args,
+        do_fn_invoker_kwargs,
+        tagged_receivers,
+    )
 
 
 class GroupByKey(DaskBagOp):
-  def apply(self, input_bag: db.Bag) -> db.Bag:
+  """Group a PCollection into a mapping of keys to elements."""
+  def apply(self, input_bag: db.Bag, side_inputs: OpSide = None) -> db.Bag:
     def key(item):
       return item[0]
 
     def value(item):
       k, v = item
-      return k, [elm[1] for elm in v]
+      return k, [defenestrate(elm[1]) for elm in v]
 
     return input_bag.groupby(key).map(value)
 
 
 class Flatten(DaskBagOp):
-  def apply(self, input_bag: OpInput) -> db.Bag:
-    assert type(input_bag) is list, 'Must take a sequence of bags!'
+  """Produces a flattened bag from a collection of bags."""
+  def apply(
+      self, input_bag: t.List[db.Bag], side_inputs: OpSide = None) -> db.Bag:
+    assert isinstance(input_bag, list), 'Must take a sequence of bags!'
     return db.concat(input_bag)
 
 
 TRANSLATIONS = {
     _Create: Create,
     apache_beam.ParDo: ParDo,
-    apache_beam.Map: Map,
     _GroupByKeyOnly: GroupByKey,
     _Flatten: Flatten,
 }
diff --git a/sdks/python/scripts/generate_pydoc.sh 
b/sdks/python/scripts/generate_pydoc.sh
index 21561e1bf6a..4922b61169d 100755
--- a/sdks/python/scripts/generate_pydoc.sh
+++ b/sdks/python/scripts/generate_pydoc.sh
@@ -64,6 +64,7 @@ excluded_patterns=(
     'apache_beam/runners/portability/'
     'apache_beam/runners/test/'
     'apache_beam/runners/worker/'
+    'apache_beam/runners/dask/transform_evaluator.*'
     'apache_beam/testing/benchmarks/chicago_taxi/'
     'apache_beam/testing/benchmarks/cloudml/'
     'apache_beam/testing/benchmarks/inference/'
@@ -134,7 +135,7 @@ autodoc_member_order = 'bysource'
 autodoc_mock_imports = ["tensorrt", "cuda", "torch",
     "onnxruntime", "onnx", "tensorflow", "tensorflow_hub",
     "tensorflow_transform", "tensorflow_metadata", "transformers", "xgboost", 
"datatable", "transformers",
-    "sentence_transformers", "redis", "tensorflow_text", "feast",
+    "sentence_transformers", "redis", "tensorflow_text", "feast", "dask",
     ]
 
 # Allow a special section for documenting DataFrame API
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index 9ae5d3153f5..3b45cbf82fc 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -512,8 +512,15 @@ if __name__ == '__main__':
           ],
           'dataframe': dataframe_dependency,
           'dask': [
-              'dask >= 2022.6',
-              'distributed >= 2022.6',
+              'distributed >= 2024.4.2',
+              'dask >= 2024.4.2',
+              # For development, 'distributed >= 2023.12.1' should work with
+              # the above dask PR, however it can't be installed as part of
+              # a single `pip` call, since distributed releases are pinned to
+              # specific dask releases. As a workaround, distributed can be
+              # installed first, and then `.[dask]` installed second, with the
+              # `--update` / `-U` flag to replace the dask release brought in
+              # by distributed.
           ],
           'yaml': [
               'docstring-parser>=0.15,<1.0',
diff --git a/sdks/python/test-suites/tox/common.gradle 
b/sdks/python/test-suites/tox/common.gradle
index df42a2c384c..01265a6eeff 100644
--- a/sdks/python/test-suites/tox/common.gradle
+++ b/sdks/python/test-suites/tox/common.gradle
@@ -31,7 +31,6 @@ test.dependsOn "testPy${pythonVersionSuffix}ML"
 
 // toxTask "testPy${pythonVersionSuffix}Dask", 
"py${pythonVersionSuffix}-dask", "${posargs}"
 // test.dependsOn "testPy${pythonVersionSuffix}Dask"
-
 project.tasks.register("preCommitPy${pythonVersionSuffix}") {
                // Since codecoverage reports will always be generated for py38,
                // all tests will be exercised.
diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini
index c7713498d87..ad5d7ec5505 100644
--- a/sdks/python/tox.ini
+++ b/sdks/python/tox.ini
@@ -109,9 +109,20 @@ commands =
   bash {toxinidir}/scripts/run_pytest.sh {envname} "{posargs}"
 
 [testenv:py{39,310,311,312}-dask]
-extras = test,dask
+extras = test,dask,dataframes
+commands_pre =
+  pip install 'distributed>=2024.4.2' 'dask>=2024.4.2'
 commands =
-  bash {toxinidir}/scripts/run_pytest.sh {envname} "{posargs}"
+  bash {toxinidir}/scripts/run_pytest.sh {envname} 
{toxinidir}/apache_beam/runners/dask/
+
+[testenv:py{39,310,311,312}-win-dask]
+commands_pre =
+  pip install 'distributed>=2024.4.2' 'dask>=2024.4.2'
+commands =
+  python apache_beam/examples/complete/autocomplete_test.py
+  bash {toxinidir}/scripts/run_pytest.sh {envname} 
{toxinidir}/apache_beam/runners/dask/
+install_command = {envbindir}/python.exe {envbindir}/pip.exe install --retries 
10 {opts} {packages}
+list_dependencies_command = {envbindir}/python.exe {envbindir}/pip.exe freeze
 
 [testenv:py39-cloudcoverage]
 deps =
@@ -394,7 +405,7 @@ commands =
 
 [testenv:py39-tensorflow-212]
 deps =
-  212: 
+  212:
     tensorflow>=2.12rc1,<2.13
     # Help pip resolve conflict with typing-extensions for old version of TF 
https://github.com/apache/beam/issues/30852
     pydantic<2.7

Reply via email to