TheNeuralBit commented on a change in pull request #16101:
URL: https://github.com/apache/beam/pull/16101#discussion_r766185559



##########
File path: sdks/python/apache_beam/dataframe/partitionings.py
##########
@@ -175,6 +175,61 @@ def check(self, dfs):
     return len(dfs) <= 1
 
 
+class JoinIndex(Partitioning):
+  """A partitioning that lets two frames be joined.
+  This can either be a hash partitioning on the full index, or a common
+  ancestor with no intervening re-indexing/re-partitioning.
+
+  It fits into the partial ordering as
+
+      Index() < JoinIndex(x) < JoinIndex() < Arbitrary()
+
+  with
+
+      JoinIndex(x) and JoinIndex(y)
+
+  being incomparable for nontrivial x != y.
+
+  Expressions desiring to make use of this index should simply declare a
+  requirement of JoinIndex().
+  """
+  def __init__(self, ancestor=None):
+    self._ancestor = ancestor
+
+  def __repr__(self):
+    if self._ancestor:
+      return 'JoinIndex[%s]' % self._ancestor
+    else:
+      return 'JoinIndex'
+
+  def __eq__(self, other):
+    if type(self) != type(other):
+      return False
+    elif self._ancestor is None:
+      return other._ancestor is None
+    elif other._ancestor is None:
+      return False
+    else:
+      return self._ancestor == other._ancestor
+
+  def __hash__(self):
+    return hash((type(self), self._ancestor))
+
+  def is_subpartitioning_of(self, other):
+    if isinstance(other, Arbitrary):
+      return False
+    elif isinstance(other, JoinIndex):
+      return self._ancestor is None or self == other
+    else:
+      return True
+
+  def test_partition_fn(self, df):
+    return Index().test_partition_fn(df)
+
+  def check(self, dfs):
+    return True

Review comment:
       This is used for verifying a preserves_partition_by=JoinIndex() spec. 
Maybe we should raise NotImplementedError here instead since preserves 
JoinIndex doesn't really make sense?

##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -345,29 +357,47 @@ def is_scalar(expr):
 
     @_memoize
     def expr_to_stages(expr):
-      assert expr not in inputs
+      if expr in inputs:
+        # Don't create a stage for each input, but it is still useful to record
+        # what which stages inputs are available from.
+        return []
+
       # First attempt to compute this expression as part of an existing stage,
       # if possible.
-      #
-      # If expr does not require partitioning, just grab any stage, else grab
-      # the first stage where all of expr's inputs are partitioned as required.
-      # In either case, use the first such stage because earlier stages are
-      # closer to the inputs (have fewer intermediate stages).

Review comment:
       +1 on removing this comment, it looks like it wasn't updated when the 
logic was clarified.

##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -299,9 +304,15 @@ def output_partitioning_in_stage(expr, stage):
       """Return the output partitioning of expr when computed in stage,
       or returns None if the expression cannot be computed in this stage.
       """
+      def upgrade_to_join_index(partitioning):
+        if partitioning.is_subpartitioning_of(partitionings.JoinIndex()):
+          return partitionings.JoinIndex(expr)

Review comment:
       This is upgrading just in the  Arbitrary case, correct?

##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -177,7 +178,7 @@ def default_label(self):
         return '%s:%s' % (self.stage.ops, id(self))
 
       def expand(self, pcolls):
-
+        logging.info('Computing stage %s for %s', self, self.stage)

Review comment:
       Should this be removed, or dropped to debug?

##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -345,29 +357,47 @@ def is_scalar(expr):
 
     @_memoize
     def expr_to_stages(expr):
-      assert expr not in inputs
+      if expr in inputs:
+        # Don't create a stage for each input, but it is still useful to record
+        # what which stages inputs are available from.
+        return []
+
       # First attempt to compute this expression as part of an existing stage,
       # if possible.
-      #
-      # If expr does not require partitioning, just grab any stage, else grab
-      # the first stage where all of expr's inputs are partitioned as required.
-      # In either case, use the first such stage because earlier stages are
-      # closer to the inputs (have fewer intermediate stages).
-      required_partitioning = expr.requires_partition_by()
-      for stage in common_stages([expr_to_stages(arg) for arg in expr.args()
-                                  if arg not in inputs]):
-        if is_computable_in_stage(expr, stage):
-          break
+      if all(arg in inputs for arg in expr.args()):
+        # All input arguments;  try to pick a stage that already has as many
+        # of the inputs, correctly partitioned, as possible.
+        inputs_by_stage = collections.defaultdict(int)
+        for arg in expr.args():
+          for stage in expr_to_stages(arg):
+            if is_computable_in_stage(expr, stage):
+              inputs_by_stage[stage] += 1 + 100 * (
+                  expr.requires_partition_by() == stage.partitioning)
+        if inputs_by_stage:
+          stage = sorted(inputs_by_stage.items(), key=lambda kv: kv[1])[-1][0]

Review comment:
       nit: I think `max` would work here.

##########
File path: sdks/python/apache_beam/dataframe/partitionings.py
##########
@@ -175,6 +175,61 @@ def check(self, dfs):
     return len(dfs) <= 1
 
 
+class JoinIndex(Partitioning):
+  """A partitioning that lets two frames be joined.
+  This can either be a hash partitioning on the full index, or a common
+  ancestor with no intervening re-indexing/re-partitioning.
+
+  It fits into the partial ordering as
+
+      Index() < JoinIndex(x) < JoinIndex() < Arbitrary()
+
+  with
+
+      JoinIndex(x) and JoinIndex(y)
+
+  being incomparable for nontrivial x != y.
+
+  Expressions desiring to make use of this index should simply declare a
+  requirement of JoinIndex().
+  """
+  def __init__(self, ancestor=None):
+    self._ancestor = ancestor
+
+  def __repr__(self):
+    if self._ancestor:
+      return 'JoinIndex[%s]' % self._ancestor
+    else:
+      return 'JoinIndex'
+
+  def __eq__(self, other):
+    if type(self) != type(other):
+      return False
+    elif self._ancestor is None:
+      return other._ancestor is None
+    elif other._ancestor is None:
+      return False
+    else:
+      return self._ancestor == other._ancestor
+
+  def __hash__(self):
+    return hash((type(self), self._ancestor))
+
+  def is_subpartitioning_of(self, other):
+    if isinstance(other, Arbitrary):
+      return False
+    elif isinstance(other, JoinIndex):
+      return self._ancestor is None or self == other
+    else:
+      return True

Review comment:
       Could you add JoinIndex to the ordering tests in partitionings_test.py?

##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -345,29 +357,47 @@ def is_scalar(expr):
 
     @_memoize
     def expr_to_stages(expr):
-      assert expr not in inputs
+      if expr in inputs:
+        # Don't create a stage for each input, but it is still useful to record
+        # what which stages inputs are available from.
+        return []
+
       # First attempt to compute this expression as part of an existing stage,
       # if possible.
-      #
-      # If expr does not require partitioning, just grab any stage, else grab
-      # the first stage where all of expr's inputs are partitioned as required.
-      # In either case, use the first such stage because earlier stages are
-      # closer to the inputs (have fewer intermediate stages).
-      required_partitioning = expr.requires_partition_by()
-      for stage in common_stages([expr_to_stages(arg) for arg in expr.args()
-                                  if arg not in inputs]):
-        if is_computable_in_stage(expr, stage):
-          break
+      if all(arg in inputs for arg in expr.args()):
+        # All input arguments;  try to pick a stage that already has as many
+        # of the inputs, correctly partitioned, as possible.
+        inputs_by_stage = collections.defaultdict(int)
+        for arg in expr.args():
+          for stage in expr_to_stages(arg):
+            if is_computable_in_stage(expr, stage):
+              inputs_by_stage[stage] += 1 + 100 * (
+                  expr.requires_partition_by() == stage.partitioning)

Review comment:
       Is this 100:1 ratio arbitrary?

##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -345,29 +357,47 @@ def is_scalar(expr):
 
     @_memoize
     def expr_to_stages(expr):
-      assert expr not in inputs
+      if expr in inputs:
+        # Don't create a stage for each input, but it is still useful to record
+        # what which stages inputs are available from.
+        return []
+
       # First attempt to compute this expression as part of an existing stage,
       # if possible.
-      #
-      # If expr does not require partitioning, just grab any stage, else grab
-      # the first stage where all of expr's inputs are partitioned as required.
-      # In either case, use the first such stage because earlier stages are
-      # closer to the inputs (have fewer intermediate stages).
-      required_partitioning = expr.requires_partition_by()
-      for stage in common_stages([expr_to_stages(arg) for arg in expr.args()
-                                  if arg not in inputs]):
-        if is_computable_in_stage(expr, stage):
-          break
+      if all(arg in inputs for arg in expr.args()):

Review comment:
       Could you explain why we need separate logic for the all input argument 
case?

##########
File path: sdks/python/apache_beam/dataframe/partitionings.py
##########
@@ -175,6 +175,61 @@ def check(self, dfs):
     return len(dfs) <= 1
 
 
+class JoinIndex(Partitioning):
+  """A partitioning that lets two frames be joined.
+  This can either be a hash partitioning on the full index, or a common
+  ancestor with no intervening re-indexing/re-partitioning.
+
+  It fits into the partial ordering as
+
+      Index() < JoinIndex(x) < JoinIndex() < Arbitrary()
+
+  with
+
+      JoinIndex(x) and JoinIndex(y)
+
+  being incomparable for nontrivial x != y.
+
+  Expressions desiring to make use of this index should simply declare a
+  requirement of JoinIndex().
+  """
+  def __init__(self, ancestor=None):
+    self._ancestor = ancestor
+
+  def __repr__(self):
+    if self._ancestor:
+      return 'JoinIndex[%s]' % self._ancestor
+    else:
+      return 'JoinIndex'
+
+  def __eq__(self, other):
+    if type(self) != type(other):
+      return False
+    elif self._ancestor is None:
+      return other._ancestor is None
+    elif other._ancestor is None:
+      return False
+    else:
+      return self._ancestor == other._ancestor
+
+  def __hash__(self):
+    return hash((type(self), self._ancestor))
+
+  def is_subpartitioning_of(self, other):
+    if isinstance(other, Arbitrary):
+      return False
+    elif isinstance(other, JoinIndex):
+      return self._ancestor is None or self == other
+    else:
+      return True
+
+  def test_partition_fn(self, df):
+    return Index().test_partition_fn(df)

Review comment:
       It would be nice if this could replicate JoinIndex partitioning somehow. 
Any thoughts on how we could do that?
   
   I suppose we could use a slightly modified hash function to generate a 
different partitioning. I guess that doesn't get us much, though... it just 
verifies that we're not overfitting for Index's hashing technique. 

##########
File path: sdks/python/apache_beam/dataframe/transforms.py
##########
@@ -299,9 +304,15 @@ def output_partitioning_in_stage(expr, stage):
       """Return the output partitioning of expr when computed in stage,
       or returns None if the expression cannot be computed in this stage.
       """
+      def upgrade_to_join_index(partitioning):

Review comment:
       nit:
   ```suggestion
         def try_upgrade_to_join_index(partitioning):
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to