robertwb commented on a change in pull request #16101:
URL: https://github.com/apache/beam/pull/16101#discussion_r766267071



##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -177,7 +178,7 @@ def default_label(self):
         return '%s:%s' % (self.stage.ops, id(self))
 
       def expand(self, pcolls):
-
+        logging.info('Computing stage %s for %s', self, self.stage)

Review comment:
       I think this is on par with other information we print at the info 
level. Added the word "dataframe" to give this a bit more context.

##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -345,29 +357,47 @@ def is_scalar(expr):
 
     @_memoize
     def expr_to_stages(expr):
-      assert expr not in inputs
+      if expr in inputs:
+        # Don't create a stage for each input, but it is still useful to record
+        # what which stages inputs are available from.
+        return []
+
       # First attempt to compute this expression as part of an existing stage,
       # if possible.
-      #
-      # If expr does not require partitioning, just grab any stage, else grab
-      # the first stage where all of expr's inputs are partitioned as required.
-      # In either case, use the first such stage because earlier stages are
-      # closer to the inputs (have fewer intermediate stages).
-      required_partitioning = expr.requires_partition_by()
-      for stage in common_stages([expr_to_stages(arg) for arg in expr.args()
-                                  if arg not in inputs]):
-        if is_computable_in_stage(expr, stage):
-          break
+      if all(arg in inputs for arg in expr.args()):

Review comment:
       If there are *any* non-inputs arguments, we must pick (or create) a 
stage that has *all* of them. (Well, technically one can sometimes add 
expressions to an existing stage, but this requires graph analysis to ensure 
one is not creating cycles.) That's the code that was there below (and before). 
However, if there aren't any non-input arguments, that code won't find any 
candidates and we'll create a new stage for each root operation, limiting our 
ability to find stages that have common ancestors (even if they would have 
fused in the Beam optimizer). 
   
   Maybe an example is more clear. If we have
   
       `df.A + df.B`
   
    we need to compute `df.A` somewhere, so we create a stage (call it stage1) 
with an input of df and a single operation. Now when we compute `df.B`, we'd 
rather re-use stage1 so that we can find a single stage that contains both 
`df.A` and `df.B`. This wasn't important before as we always introduced a 
shuffle for multi-input operations so close to the roots anyway.

##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -299,9 +304,15 @@ def output_partitioning_in_stage(expr, stage):
       """Return the output partitioning of expr when computed in stage,
       or returns None if the expression cannot be computed in this stage.
       """
+      def upgrade_to_join_index(partitioning):
+        if partitioning.is_subpartitioning_of(partitionings.JoinIndex()):
+          return partitionings.JoinIndex(expr)

Review comment:
       Currently, yeah. 

##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -299,9 +304,15 @@ def output_partitioning_in_stage(expr, stage):
       """Return the output partitioning of expr when computed in stage,
       or returns None if the expression cannot be computed in this stage.
       """
+      def upgrade_to_join_index(partitioning):

Review comment:
       Renamed to `maybe_upgrade_to_join_index` as it never fails. 

##########
File path: sdks/python/apache_beam/dataframe/transforms_test.py
##########
@@ -348,6 +350,44 @@ def test_rename(self):
               }, errors='raise'))
 
 
+class FusionTest(unittest.TestCase):
+  @staticmethod
+  def fused_stages(p):
+    return p.result.metrics().query(
+        metrics.MetricsFilter().with_name(
+            fn_runner.FnApiRunner.NUM_FUSED_STAGES_COUNTER)
+    )['counters'][0].result

Review comment:
       The fusion logic is not very easy to invoke in isolation. 

##########
File path: sdks/python/apache_beam/dataframe/partitionings.py
##########
@@ -175,6 +175,61 @@ def check(self, dfs):
     return len(dfs) <= 1
 
 
+class JoinIndex(Partitioning):
+  """A partitioning that lets two frames be joined.
+  This can either be a hash partitioning on the full index, or a common
+  ancestor with no intervening re-indexing/re-partitioning.
+
+  It fits into the partial ordering as
+
+      Index() < JoinIndex(x) < JoinIndex() < Arbitrary()
+
+  with
+
+      JoinIndex(x) and JoinIndex(y)
+
+  being incomparable for nontrivial x != y.
+
+  Expressions desiring to make use of this index should simply declare a
+  requirement of JoinIndex().
+  """
+  def __init__(self, ancestor=None):
+    self._ancestor = ancestor
+
+  def __repr__(self):
+    if self._ancestor:
+      return 'JoinIndex[%s]' % self._ancestor
+    else:
+      return 'JoinIndex'
+
+  def __eq__(self, other):
+    if type(self) != type(other):
+      return False
+    elif self._ancestor is None:
+      return other._ancestor is None
+    elif other._ancestor is None:
+      return False
+    else:
+      return self._ancestor == other._ancestor
+
+  def __hash__(self):
+    return hash((type(self), self._ancestor))
+
+  def is_subpartitioning_of(self, other):
+    if isinstance(other, Arbitrary):
+      return False
+    elif isinstance(other, JoinIndex):
+      return self._ancestor is None or self == other
+    else:
+      return True
+
+  def test_partition_fn(self, df):
+    return Index().test_partition_fn(df)
+
+  def check(self, dfs):
+    return True

Review comment:
       Well, if something preserves Arbitrary and is given JoinIndex, shouldn't 
check be called with JoinIndex?

##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -345,29 +357,47 @@ def is_scalar(expr):
 
     @_memoize
     def expr_to_stages(expr):
-      assert expr not in inputs
+      if expr in inputs:
+        # Don't create a stage for each input, but it is still useful to record
+        # what which stages inputs are available from.
+        return []
+
       # First attempt to compute this expression as part of an existing stage,
       # if possible.
-      #
-      # If expr does not require partitioning, just grab any stage, else grab
-      # the first stage where all of expr's inputs are partitioned as required.
-      # In either case, use the first such stage because earlier stages are
-      # closer to the inputs (have fewer intermediate stages).
-      required_partitioning = expr.requires_partition_by()
-      for stage in common_stages([expr_to_stages(arg) for arg in expr.args()
-                                  if arg not in inputs]):
-        if is_computable_in_stage(expr, stage):
-          break
+      if all(arg in inputs for arg in expr.args()):
+        # All input arguments;  try to pick a stage that already has as many
+        # of the inputs, correctly partitioned, as possible.
+        inputs_by_stage = collections.defaultdict(int)
+        for arg in expr.args():
+          for stage in expr_to_stages(arg):
+            if is_computable_in_stage(expr, stage):
+              inputs_by_stage[stage] += 1 + 100 * (
+                  expr.requires_partition_by() == stage.partitioning)
+        if inputs_by_stage:
+          stage = sorted(inputs_by_stage.items(), key=lambda kv: kv[1])[-1][0]

Review comment:
       Max doesn't let you specify a key. I suppose could swap the tuple 
ordering (but then I'd have to worry about stages being comparable in case the 
counts were equal). Added a comment 'cause that line is a mouthful. 

##########
File path: sdks/python/apache_beam/dataframe/partitionings.py
##########
@@ -175,6 +175,61 @@ def check(self, dfs):
     return len(dfs) <= 1
 
 
+class JoinIndex(Partitioning):
+  """A partitioning that lets two frames be joined.
+  This can either be a hash partitioning on the full index, or a common
+  ancestor with no intervening re-indexing/re-partitioning.
+
+  It fits into the partial ordering as
+
+      Index() < JoinIndex(x) < JoinIndex() < Arbitrary()
+
+  with
+
+      JoinIndex(x) and JoinIndex(y)
+
+  being incomparable for nontrivial x != y.
+
+  Expressions desiring to make use of this index should simply declare a
+  requirement of JoinIndex().
+  """
+  def __init__(self, ancestor=None):
+    self._ancestor = ancestor
+
+  def __repr__(self):
+    if self._ancestor:
+      return 'JoinIndex[%s]' % self._ancestor
+    else:
+      return 'JoinIndex'
+
+  def __eq__(self, other):
+    if type(self) != type(other):
+      return False
+    elif self._ancestor is None:
+      return other._ancestor is None
+    elif other._ancestor is None:
+      return False
+    else:
+      return self._ancestor == other._ancestor
+
+  def __hash__(self):
+    return hash((type(self), self._ancestor))
+
+  def is_subpartitioning_of(self, other):
+    if isinstance(other, Arbitrary):
+      return False
+    elif isinstance(other, JoinIndex):
+      return self._ancestor is None or self == other
+    else:
+      return True
+
+  def test_partition_fn(self, df):
+    return Index().test_partition_fn(df)

Review comment:
       I did some thinking about this and the problem is that JoinIndex is not 
a well defined partitioning in isolation, it's only meaningful when we start 
talking about whether things have common ancestry within a stage. (E.g. across 
stage boundaries, we upgrade JoinIndex to Index.) That does unfortunately mean 
that our non-pipeline tests aren't going to catch as many issues here (which 
makes sense, as the interesting logic to deal with join indices is in the 
translation to beam steps code). 

##########
File path: sdks/python/apache_beam/dataframe/partitionings.py
##########
@@ -175,6 +175,61 @@ def check(self, dfs):
     return len(dfs) <= 1
 
 
+class JoinIndex(Partitioning):
+  """A partitioning that lets two frames be joined.
+  This can either be a hash partitioning on the full index, or a common
+  ancestor with no intervening re-indexing/re-partitioning.
+
+  It fits into the partial ordering as
+
+      Index() < JoinIndex(x) < JoinIndex() < Arbitrary()
+
+  with
+
+      JoinIndex(x) and JoinIndex(y)
+
+  being incomparable for nontrivial x != y.
+
+  Expressions desiring to make use of this index should simply declare a
+  requirement of JoinIndex().
+  """
+  def __init__(self, ancestor=None):
+    self._ancestor = ancestor
+
+  def __repr__(self):
+    if self._ancestor:
+      return 'JoinIndex[%s]' % self._ancestor
+    else:
+      return 'JoinIndex'
+
+  def __eq__(self, other):
+    if type(self) != type(other):
+      return False
+    elif self._ancestor is None:
+      return other._ancestor is None
+    elif other._ancestor is None:
+      return False
+    else:
+      return self._ancestor == other._ancestor
+
+  def __hash__(self):
+    return hash((type(self), self._ancestor))
+
+  def is_subpartitioning_of(self, other):
+    if isinstance(other, Arbitrary):
+      return False
+    elif isinstance(other, JoinIndex):
+      return self._ancestor is None or self == other
+    else:
+      return True

Review comment:
       Done.

##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -345,29 +357,47 @@ def is_scalar(expr):
 
     @_memoize
     def expr_to_stages(expr):
-      assert expr not in inputs
+      if expr in inputs:
+        # Don't create a stage for each input, but it is still useful to record
+        # what which stages inputs are available from.
+        return []
+
       # First attempt to compute this expression as part of an existing stage,
       # if possible.
-      #
-      # If expr does not require partitioning, just grab any stage, else grab
-      # the first stage where all of expr's inputs are partitioned as required.
-      # In either case, use the first such stage because earlier stages are
-      # closer to the inputs (have fewer intermediate stages).
-      required_partitioning = expr.requires_partition_by()
-      for stage in common_stages([expr_to_stages(arg) for arg in expr.args()
-                                  if arg not in inputs]):
-        if is_computable_in_stage(expr, stage):
-          break
+      if all(arg in inputs for arg in expr.args()):
+        # All input arguments;  try to pick a stage that already has as many
+        # of the inputs, correctly partitioned, as possible.
+        inputs_by_stage = collections.defaultdict(int)
+        for arg in expr.args():
+          for stage in expr_to_stages(arg):
+            if is_computable_in_stage(expr, stage):
+              inputs_by_stage[stage] += 1 + 100 * (
+                  expr.requires_partition_by() == stage.partitioning)

Review comment:
       Poor man's lexicographical ordering. I could add a couple more zeros. (I 
don't know what the right ratio actually should be if we wanted to make the 
comparable, mostly this is to prefer the most distributed option.)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to