robertwb commented on a change in pull request #16101:
URL: https://github.com/apache/beam/pull/16101#discussion_r776875484



##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -345,29 +357,47 @@ def is_scalar(expr):
 
     @_memoize
     def expr_to_stages(expr):
-      assert expr not in inputs
+      if expr in inputs:
+        # Don't create a stage for each input, but it is still useful to record
+        # what which stages inputs are available from.
+        return []
+
       # First attempt to compute this expression as part of an existing stage,
       # if possible.
-      #
-      # If expr does not require partitioning, just grab any stage, else grab
-      # the first stage where all of expr's inputs are partitioned as required.
-      # In either case, use the first such stage because earlier stages are
-      # closer to the inputs (have fewer intermediate stages).
-      required_partitioning = expr.requires_partition_by()
-      for stage in common_stages([expr_to_stages(arg) for arg in expr.args()
-                                  if arg not in inputs]):
-        if is_computable_in_stage(expr, stage):
-          break
+      if all(arg in inputs for arg in expr.args()):
+        # All input arguments;  try to pick a stage that already has as many
+        # of the inputs, correctly partitioned, as possible.
+        inputs_by_stage = collections.defaultdict(int)
+        for arg in expr.args():
+          for stage in expr_to_stages(arg):
+            if is_computable_in_stage(expr, stage):
+              inputs_by_stage[stage] += 1 + 100 * (
+                  expr.requires_partition_by() == stage.partitioning)
+        if inputs_by_stage:
+          stage = sorted(inputs_by_stage.items(), key=lambda kv: kv[1])[-1][0]

Review comment:
       Cool. Done.

##########
File path: sdks/python/apache_beam/dataframe/partitionings.py
##########
@@ -175,6 +175,61 @@ def check(self, dfs):
     return len(dfs) <= 1
 
 
+class JoinIndex(Partitioning):
+  """A partitioning that lets two frames be joined.
+  This can either be a hash partitioning on the full index, or a common
+  ancestor with no intervening re-indexing/re-partitioning.
+
+  It fits into the partial ordering as
+
+      Index() < JoinIndex(x) < JoinIndex() < Arbitrary()
+
+  with
+
+      JoinIndex(x) and JoinIndex(y)
+
+  being incomparable for nontrivial x != y.
+
+  Expressions desiring to make use of this index should simply declare a
+  requirement of JoinIndex().
+  """
+  def __init__(self, ancestor=None):
+    self._ancestor = ancestor
+
+  def __repr__(self):
+    if self._ancestor:
+      return 'JoinIndex[%s]' % self._ancestor
+    else:
+      return 'JoinIndex'
+
+  def __eq__(self, other):
+    if type(self) != type(other):
+      return False
+    elif self._ancestor is None:
+      return other._ancestor is None
+    elif other._ancestor is None:
+      return False
+    else:
+      return self._ancestor == other._ancestor
+
+  def __hash__(self):
+    return hash((type(self), self._ancestor))
+
+  def is_subpartitioning_of(self, other):
+    if isinstance(other, Arbitrary):
+      return False
+    elif isinstance(other, JoinIndex):
+      return self._ancestor is None or self == other
+    else:
+      return True
+
+  def test_partition_fn(self, df):
+    return Index().test_partition_fn(df)
+
+  def check(self, dfs):
+    return True

Review comment:
       Done.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to