HyukjinKwon commented on a change in pull request #28745:
URL: https://github.com/apache/spark/pull/28745#discussion_r436560561



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala
##########
@@ -60,42 +60,51 @@ case class FlatMapCoGroupsInPandasExec(
   private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
   private val pandasFunction = func.asInstanceOf[PythonUDF].func
   private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
+  private val inputExprs =
+    func.asInstanceOf[PythonUDF].children.map(_.asInstanceOf[NamedExpression])
+  private val leftExprs =
+    left.output.filter(e => inputExprs.exists(_.semanticEquals(e)))
+  private val rightExprs =
+    right.output.filter(e => inputExprs.exists(_.semanticEquals(e)))
 
   override def producedAttributes: AttributeSet = AttributeSet(output)
 
   override def outputPartitioning: Partitioning = left.outputPartitioning
 
   override def requiredChildDistribution: Seq[Distribution] = {
-    val leftDist = if (leftGroup.isEmpty) AllTuples else 
ClusteredDistribution(leftGroup)
-    val rightDist = if (rightGroup.isEmpty) AllTuples else 
ClusteredDistribution(rightGroup)
+    val leftDist =
+      if (leftGroupingExprs.isEmpty) AllTuples else 
ClusteredDistribution(leftGroupingExprs)
+    val rightDist =
+      if (rightGroupingExprs.isEmpty) AllTuples else 
ClusteredDistribution(rightGroupingExprs)
     leftDist :: rightDist :: Nil
   }
 
   override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
-    leftGroup
-      .map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) 
:: Nil
+    leftGroupingExprs
+      .map(SortOrder(_, Ascending)) :: rightGroupingExprs.map(SortOrder(_, 
Ascending)) :: Nil
   }
 
   override protected def doExecute(): RDD[InternalRow] = {
 
-    val (leftDedup, leftArgOffsets) = resolveArgOffsets(left, leftGroup)
-    val (rightDedup, rightArgOffsets) = resolveArgOffsets(right, rightGroup)
+    val (leftDedup, leftArgOffsets) = resolveArgOffsets(leftExprs, 
leftGroupingExprs)
+    val (rightDedup, rightArgOffsets) = resolveArgOffsets(rightExprs, 
rightGroupingExprs)
 
     // Map cogrouped rows to ArrowPythonRunner results, Only execute if 
partition is not empty
     left.execute().zipPartitions(right.execute())  { (leftData, rightData) =>
       if (leftData.isEmpty && rightData.isEmpty) Iterator.empty else {
 
-        val leftGrouped = groupAndProject(leftData, leftGroup, left.output, 
leftDedup)
-        val rightGrouped = groupAndProject(rightData, rightGroup, 
right.output, rightDedup)
-        val data = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup)
+        val leftGrouped = groupAndProject(leftData, leftGroupingExprs, 
left.output, leftDedup)
+        val rightGrouped = groupAndProject(rightData, rightGroupingExprs, 
right.output, rightDedup)

Review comment:
       Oh that's actually a very good point. Let me check how other grouping 
expressions work and update or bring some answers.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to