viirya commented on a change in pull request #28745:
URL: https://github.com/apache/spark/pull/28745#discussion_r436436434



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
##########
@@ -608,10 +608,14 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         execution.MapPartitionsInRWithArrowExec(
           f, p, b, is, ot, planLater(child)) :: Nil
       case logical.FlatMapGroupsInPandas(grouping, func, output, child) =>
-        execution.python.FlatMapGroupsInPandasExec(grouping, func, output, 
planLater(child)) :: Nil
-      case logical.FlatMapCoGroupsInPandas(leftGroup, rightGroup, func, 
output, left, right) =>
+        val groupingExprs = grouping.map(NamedExpression.fromExpression)
+        execution.python.FlatMapGroupsInPandasExec(
+          groupingExprs, func, output, planLater(child)) :: Nil
+      case logical.FlatMapCoGroupsInPandas(leftExprs, rightExprs, func, 
output, left, right) =>
+        val leftAttrs = leftExprs.map(NamedExpression.fromExpression)
+        val rightAttrs = rightExprs.map(NamedExpression.fromExpression)
         execution.python.FlatMapCoGroupsInPandasExec(
-          leftGroup, rightGroup, func, output, planLater(left), 
planLater(right)) :: Nil
+          leftAttrs, rightAttrs, func, output, planLater(left), 
planLater(right)) :: Nil

Review comment:
       leftNamedExprs/rightNamedExprs or leftGroupingExprs/rightGroupingExprs? 
They are not attributes actually.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala
##########
@@ -59,65 +59,65 @@ private[python] object PandasGroupUtils {
    */
   def groupAndProject(
       input: Iterator[InternalRow],
-      groupingAttributes: Seq[Attribute],
+      groupingExprs: Seq[NamedExpression],
       inputSchema: Seq[Attribute],
-      dedupSchema: Seq[Attribute]): Iterator[(InternalRow, 
Iterator[InternalRow])] = {
-    val groupedIter = GroupedIterator(input, groupingAttributes, inputSchema)
+      dedupSchema: Seq[NamedExpression]): Iterator[(InternalRow, 
Iterator[InternalRow])] = {
+    val groupedIter = GroupedIterator(input, groupingExprs, inputSchema)
     val dedupProj = UnsafeProjection.create(dedupSchema, inputSchema)
     groupedIter.map {
       case (k, groupedRowIter) => (k, groupedRowIter.map(dedupProj))
     }
   }
 
   /**
-   * Returns a the deduplicated attributes of the spark plan and the arg 
offsets of the
+   * Returns a the deduplicated named expressions of the spark plan and the 
arg offsets of the
    * keys and values.
    *
-   * The deduplicated attributes are needed because the spark plan may contain 
an attribute
-   * twice; once in the key and once in the value.  For any such attribute we 
need to
+   * The deduplicated expressions are needed because the spark plan may 
contain an expression
+   * twice; once in the key and once in the value.  For any such expression we 
need to
    * deduplicate.
    *
-   * The arg offsets are used to distinguish grouping grouping attributes and 
data attributes
+   * The arg offsets are used to distinguish grouping expressions and data 
expressions
    * as following:
    *
    * argOffsets[0] is the length of the argOffsets array
    *
-   * argOffsets[1] is the length of grouping attribute
-   * argOffsets[2 .. argOffsets[0]+2] is the arg offsets for grouping 
attributes
+   * argOffsets[1] is the length of grouping expression
+   * argOffsets[2 .. argOffsets[0]+2] is the arg offsets for grouping 
expressions
    *
-   * argOffsets[argOffsets[0]+2 .. ] is the arg offsets for data attributes
+   * argOffsets[argOffsets[0]+2 .. ] is the arg offsets for data expressions
    */
   def resolveArgOffsets(
-    child: SparkPlan, groupingAttributes: Seq[Attribute]): (Seq[Attribute], 
Array[Int]) = {
+      dataExprs: Seq[NamedExpression], groupingExprs: Seq[NamedExpression])
+    : (Seq[NamedExpression], Array[Int]) = {
 
-    val dataAttributes = child.output.drop(groupingAttributes.length)
-    val groupingIndicesInData = groupingAttributes.map { attribute =>
-      dataAttributes.indexWhere(attribute.semanticEquals)
+    val groupingIndicesInData = groupingExprs.map { expression =>
+      dataExprs.indexWhere(expression.semanticEquals)
     }

Review comment:
       ok, looks good after re-checking. 

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
##########
@@ -60,42 +60,51 @@ case class FlatMapCoGroupsInPandasExec(
   private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
   private val pandasFunction = func.asInstanceOf[PythonUDF].func
   private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
+  private val inputExprs =
+    func.asInstanceOf[PythonUDF].children.map(_.asInstanceOf[NamedExpression])
+  private val leftExprs =
+    left.output.filter(e => inputExprs.exists(_.semanticEquals(e)))
+  private val rightExprs =
+    right.output.filter(e => inputExprs.exists(_.semanticEquals(e)))

Review comment:
       leftAttributes and rightAttributes? 

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
##########
@@ -60,42 +60,51 @@ case class FlatMapCoGroupsInPandasExec(
   private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
   private val pandasFunction = func.asInstanceOf[PythonUDF].func
   private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
+  private val inputExprs =
+    func.asInstanceOf[PythonUDF].children.map(_.asInstanceOf[NamedExpression])
+  private val leftExprs =
+    left.output.filter(e => inputExprs.exists(_.semanticEquals(e)))
+  private val rightExprs =
+    right.output.filter(e => inputExprs.exists(_.semanticEquals(e)))
 
   override def producedAttributes: AttributeSet = AttributeSet(output)
 
   override def outputPartitioning: Partitioning = left.outputPartitioning
 
   override def requiredChildDistribution: Seq[Distribution] = {
-    val leftDist = if (leftGroup.isEmpty) AllTuples else 
ClusteredDistribution(leftGroup)
-    val rightDist = if (rightGroup.isEmpty) AllTuples else 
ClusteredDistribution(rightGroup)
+    val leftDist =
+      if (leftGroupingExprs.isEmpty) AllTuples else 
ClusteredDistribution(leftGroupingExprs)
+    val rightDist =
+      if (rightGroupingExprs.isEmpty) AllTuples else 
ClusteredDistribution(rightGroupingExprs)
     leftDist :: rightDist :: Nil
   }
 
   override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
-    leftGroup
-      .map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) 
:: Nil
+    leftGroupingExprs
+      .map(SortOrder(_, Ascending)) :: rightGroupingExprs.map(SortOrder(_, 
Ascending)) :: Nil
   }
 
   override protected def doExecute(): RDD[InternalRow] = {
 
-    val (leftDedup, leftArgOffsets) = resolveArgOffsets(left, leftGroup)
-    val (rightDedup, rightArgOffsets) = resolveArgOffsets(right, rightGroup)
+    val (leftDedup, leftArgOffsets) = resolveArgOffsets(leftExprs, 
leftGroupingExprs)
+    val (rightDedup, rightArgOffsets) = resolveArgOffsets(rightExprs, 
rightGroupingExprs)
 
     // Map cogrouped rows to ArrowPythonRunner results, Only execute if 
partition is not empty
     left.execute().zipPartitions(right.execute())  { (leftData, rightData) =>
       if (leftData.isEmpty && rightData.isEmpty) Iterator.empty else {
 
-        val leftGrouped = groupAndProject(leftData, leftGroup, left.output, 
leftDedup)
-        val rightGrouped = groupAndProject(rightData, rightGroup, 
right.output, rightDedup)
-        val data = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup)
+        val leftGrouped = groupAndProject(leftData, leftGroupingExprs, 
left.output, leftDedup)
+        val rightGrouped = groupAndProject(rightData, rightGroupingExprs, 
right.output, rightDedup)

Review comment:
       One disadvantage I can think of is, previously we evaluate grouping 
expressions in underlying projection. Now we move the grouping expression 
evaluation inside `FlatMapCoGroupsInPandasExec` execution.
   
   As we requires specified child distribution `leftGroupingExprs` and 
`rightGroupingExprs` in `requiredChildDistribution`. We would possibly add 
shuffle below `FlatMapCoGroupsInPandasExec`. That's said we evaluate grouping 
expressions twice and if any non-deterministic expressions inside, we probably 
get incorrect results.
   
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to