icexelloss closed pull request #23279: [SPARK-26328][SQL] Use GenerateOrdering 
for group key comparision in WindowExec
URL: https://github.com/apache/spark/pull/23279
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
index fede0f3e92d67..5cd5f50c44451 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
@@ -24,8 +24,8 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, 
SparkPlan, UnaryExecNode}
 import org.apache.spark.sql.types.{CalendarIntervalType, DateType, 
IntegerType, TimestampType}
 
@@ -304,20 +304,18 @@ case class WindowExec(
 
         // Get all relevant projections.
         val result = createResultProjection(expressions)
-        val grouping = UnsafeProjection.create(partitionSpec, child.output)
+        val groupOrdering = GenerateOrdering.generate(
+          partitionSpec.map(SortOrder(_, Ascending)), child.output)
 
         // Manage the stream and the grouping.
         var nextRow: UnsafeRow = null
-        var nextGroup: UnsafeRow = null
         var nextRowAvailable: Boolean = false
         private[this] def fetchNextRow() {
           nextRowAvailable = stream.hasNext
           if (nextRowAvailable) {
             nextRow = stream.next().asInstanceOf[UnsafeRow]
-            nextGroup = grouping(nextRow)
           } else {
             nextRow = null
-            nextGroup = null
           }
         }
         fetchNextRow()
@@ -333,13 +331,13 @@ case class WindowExec(
         val numFrames = frames.length
         private[this] def fetchNextPartition() {
           // Collect all the rows in the current partition.
-          // Before we start to fetch new input rows, make a copy of nextGroup.
-          val currentGroup = nextGroup.copy()
+          // Before we start to fetch new input rows, make a copy of nextRow.
+          val currentRow = nextRow.copy()
 
           // clear last partition
           buffer.clear()
 
-          while (nextRowAvailable && nextGroup == currentGroup) {
+          while (nextRowAvailable && groupOrdering.compare(currentRow, 
nextRow) == 0) {
             buffer.add(nextRow)
             fetchNextRow()
           }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to