This is an automated email from the ASF dual-hosted git repository.

huaxingao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 459b2b0c fix: window function range offset should be long instead of 
int (#733)
459b2b0c is described below

commit 459b2b0c07c8cb6192d0febcd87061c991bf84b6
Author: Huaxin Gao <huaxin_...@apple.com>
AuthorDate: Mon Sep 23 14:12:41 2024 -0700

    fix: window function range offset should be long instead of int (#733)
    
    * fix: window function range offset should be long instead of int
    
    * fix error
    
    * fall back to Spark if range offset is not int or long
    
    * uncomment tests
    
    * rebase
    
    * fix offset datatype
    
    * fix data type
    
    * address comments
    
    * throw Err for WindowFrameUnits::Groups
    
    * formatting
---
 native/core/src/execution/datafusion/planner.rs    | 84 ++++++++++++++++++----
 native/proto/src/proto/operator.proto              |  4 +-
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 57 +++++++++++++--
 .../org/apache/comet/exec/CometExecSuite.scala     | 39 ++++++----
 4 files changed, 151 insertions(+), 33 deletions(-)

diff --git a/native/core/src/execution/datafusion/planner.rs 
b/native/core/src/execution/datafusion/planner.rs
index d7c8d745..663db0d1 100644
--- a/native/core/src/execution/datafusion/planner.rs
+++ b/native/core/src/execution/datafusion/planner.rs
@@ -1692,16 +1692,46 @@ impl PhysicalPlanner {
             .and_then(|inner| inner.lower_frame_bound_struct.as_ref())
         {
             Some(l) => match l {
-                LowerFrameBoundStruct::UnboundedPreceding(_) => {
-                    WindowFrameBound::Preceding(ScalarValue::UInt64(None))
-                }
+                LowerFrameBoundStruct::UnboundedPreceding(_) => match units {
+                    WindowFrameUnits::Rows => {
+                        WindowFrameBound::Preceding(ScalarValue::UInt64(None))
+                    }
+                    WindowFrameUnits::Range => {
+                        WindowFrameBound::Preceding(ScalarValue::Int64(None))
+                    }
+                    WindowFrameUnits::Groups => {
+                        return Err(ExecutionError::GeneralError(
+                            "WindowFrameUnits::Groups is not 
supported.".to_string(),
+                        ));
+                    }
+                },
                 LowerFrameBoundStruct::Preceding(offset) => {
-                    let offset_value = offset.offset.unsigned_abs() as u64;
-                    
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(offset_value)))
+                    let offset_value = offset.offset.abs();
+                    match units {
+                        WindowFrameUnits::Rows => 
WindowFrameBound::Preceding(ScalarValue::UInt64(
+                            Some(offset_value as u64),
+                        )),
+                        WindowFrameUnits::Range => {
+                            
WindowFrameBound::Preceding(ScalarValue::Int64(Some(offset_value)))
+                        }
+                        WindowFrameUnits::Groups => {
+                            return Err(ExecutionError::GeneralError(
+                                "WindowFrameUnits::Groups is not 
supported.".to_string(),
+                            ));
+                        }
+                    }
                 }
                 LowerFrameBoundStruct::CurrentRow(_) => 
WindowFrameBound::CurrentRow,
             },
-            None => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
+            None => match units {
+                WindowFrameUnits::Rows => 
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
+                WindowFrameUnits::Range => 
WindowFrameBound::Preceding(ScalarValue::Int64(None)),
+                WindowFrameUnits::Groups => {
+                    return Err(ExecutionError::GeneralError(
+                        "WindowFrameUnits::Groups is not 
supported.".to_string(),
+                    ));
+                }
+            },
         };
 
         let upper_bound: WindowFrameBound = match spark_window_frame
@@ -1710,15 +1740,43 @@ impl PhysicalPlanner {
             .and_then(|inner| inner.upper_frame_bound_struct.as_ref())
         {
             Some(u) => match u {
-                UpperFrameBoundStruct::UnboundedFollowing(_) => {
-                    WindowFrameBound::Following(ScalarValue::UInt64(None))
-                }
-                UpperFrameBoundStruct::Following(offset) => {
-                    
WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64)))
-                }
+                UpperFrameBoundStruct::UnboundedFollowing(_) => match units {
+                    WindowFrameUnits::Rows => {
+                        WindowFrameBound::Following(ScalarValue::UInt64(None))
+                    }
+                    WindowFrameUnits::Range => {
+                        WindowFrameBound::Following(ScalarValue::Int64(None))
+                    }
+                    WindowFrameUnits::Groups => {
+                        return Err(ExecutionError::GeneralError(
+                            "WindowFrameUnits::Groups is not 
supported.".to_string(),
+                        ));
+                    }
+                },
+                UpperFrameBoundStruct::Following(offset) => match units {
+                    WindowFrameUnits::Rows => {
+                        
WindowFrameBound::Following(ScalarValue::UInt64(Some(offset.offset as u64)))
+                    }
+                    WindowFrameUnits::Range => {
+                        
WindowFrameBound::Following(ScalarValue::Int64(Some(offset.offset)))
+                    }
+                    WindowFrameUnits::Groups => {
+                        return Err(ExecutionError::GeneralError(
+                            "WindowFrameUnits::Groups is not 
supported.".to_string(),
+                        ));
+                    }
+                },
                 UpperFrameBoundStruct::CurrentRow(_) => 
WindowFrameBound::CurrentRow,
             },
-            None => WindowFrameBound::Following(ScalarValue::UInt64(None)),
+            None => match units {
+                WindowFrameUnits::Rows => 
WindowFrameBound::Following(ScalarValue::UInt64(None)),
+                WindowFrameUnits::Range => 
WindowFrameBound::Following(ScalarValue::Int64(None)),
+                WindowFrameUnits::Groups => {
+                    return Err(ExecutionError::GeneralError(
+                        "WindowFrameUnits::Groups is not 
supported.".to_string(),
+                    ));
+                }
+            },
         };
 
         let window_frame = WindowFrame::new_bounds(units, lower_bound, 
upper_bound);
diff --git a/native/proto/src/proto/operator.proto 
b/native/proto/src/proto/operator.proto
index 6a29e633..533d504c 100644
--- a/native/proto/src/proto/operator.proto
+++ b/native/proto/src/proto/operator.proto
@@ -161,11 +161,11 @@ message UpperWindowFrameBound {
 }
 
 message Preceding {
-  int32 offset = 1;
+  int64 offset = 1;
 }
 
 message Following {
-  int32 offset = 1;
+  int64 offset = 1;
 }
 
 message UnboundedPreceding {}
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 4fde2fd1..0a949a31 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -255,15 +255,17 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
       (None, exprToProto(windowExpr.windowFunction, output))
     }
 
+    if (aggExpr.isEmpty && builtinFunc.isEmpty) {
+      return None
+    }
+
     val f = windowExpr.windowSpec.frameSpecification
 
     val (frameType, lowerBound, upperBound) = f match {
       case SpecifiedWindowFrame(frameType, lBound, uBound) =>
         val frameProto = frameType match {
           case RowFrame => OperatorOuterClass.WindowFrameType.Rows
-          case RangeFrame =>
-            withInfo(windowExpr, "Range frame is not supported")
-            return None
+          case RangeFrame => OperatorOuterClass.WindowFrameType.Range
         }
 
         val lBoundProto = lBound match {
@@ -278,12 +280,17 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
               
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
               .build()
           case e =>
+            val offset = e.eval() match {
+              case i: Integer => i.toLong
+              case l: Long => l
+              case _ => return None
+            }
             OperatorOuterClass.LowerWindowFrameBound
               .newBuilder()
               .setPreceding(
                 OperatorOuterClass.Preceding
                   .newBuilder()
-                  .setOffset(e.eval().asInstanceOf[Int])
+                  .setOffset(offset)
                   .build())
               .build()
         }
@@ -300,12 +307,18 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
               
.setCurrentRow(OperatorOuterClass.CurrentRow.newBuilder().build())
               .build()
           case e =>
+            val offset = e.eval() match {
+              case i: Integer => i.toLong
+              case l: Long => l
+              case _ => return None
+            }
+
             OperatorOuterClass.UpperWindowFrameBound
               .newBuilder()
               .setFollowing(
                 OperatorOuterClass.Following
                   .newBuilder()
-                  .setOffset(e.eval().asInstanceOf[Int])
+                  .setOffset(offset)
                   .build())
               .build()
         }
@@ -2774,6 +2787,11 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
           return None
         }
 
+        if (partitionSpec.nonEmpty && orderSpec.nonEmpty &&
+          !validatePartitionAndSortSpecsForWindowFunc(partitionSpec, 
orderSpec, op)) {
+          return None
+        }
+
         val windowExprProto = winExprs.map(windowExprToProto(_, output, 
op.conf))
         val partitionExprs = partitionSpec.map(exprToProto(_, child.output))
 
@@ -3280,4 +3298,33 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
       true
     }
   }
+
+  private def validatePartitionAndSortSpecsForWindowFunc(
+      partitionSpec: Seq[Expression],
+      orderSpec: Seq[SortOrder],
+      op: SparkPlan): Boolean = {
+    if (partitionSpec.length != orderSpec.length) {
+      withInfo(op, "Partitioning and sorting specifications do not match")
+      return false
+    }
+
+    val partitionColumnNames = partitionSpec.collect { case a: 
AttributeReference =>
+      a.name
+    }
+
+    val orderColumnNames = orderSpec.collect { case s: SortOrder =>
+      s.child match {
+        case a: AttributeReference => a.name
+      }
+    }
+
+    if (partitionColumnNames.zip(orderColumnNames).exists { case (partCol, 
orderCol) =>
+        partCol != orderCol
+      }) {
+      withInfo(op, "Partitioning and sorting specifications must be the same.")
+      return false
+    }
+
+    true
+  }
 }
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index c054d02d..05aa2372 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -149,6 +149,22 @@ class CometExecSuite extends CometTestBase {
     }
   }
 
+  test(
+    "fall back to Spark when the partition spec and order spec are not the 
same for window function") {
+    withTempView("test") {
+      sql("""
+          |CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES
+          | (1, true), (1, false),
+          |(2, true), (3, false), (4, true) AS test(k, v)
+          |""".stripMargin)
+
+      val df = sql("""
+          SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg
+          |""".stripMargin)
+      checkSparkAnswer(df)
+    }
+  }
+
   test("Native window operator should be CometUnaryExec") {
     withTempView("testData") {
       sql("""
@@ -164,11 +180,11 @@ class CometExecSuite extends CometTestBase {
           |(3, 1L, 1.0D, date("2017-08-01"), timestamp_seconds(1501545600), 
null)
           |AS testData(val, val_long, val_double, val_date, val_timestamp, 
cate)
           |""".stripMargin)
-      val df = sql("""
+      val df1 = sql("""
           |SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val 
ROWS CURRENT ROW)
           |FROM testData ORDER BY cate, val
           |""".stripMargin)
-      checkSparkAnswer(df)
+      checkSparkAnswer(df1)
     }
   }
 
@@ -193,23 +209,21 @@ class CometExecSuite extends CometTestBase {
     }
   }
 
-  test("Window range frame should fall back to Spark") {
+  test("Window range frame with long boundary should not fail") {
     val df =
       Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), 
(2147483650L, "2"))
         .toDF("key", "value")
 
-    checkAnswer(
+    checkSparkAnswer(
       df.select(
         $"key",
         count("key").over(
-          Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 
2147483648L))),
-      Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), 
Row(2147483650L, 1)))
-    checkAnswer(
+          Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 
2147483648L))))
+    checkSparkAnswer(
       df.select(
         $"key",
         count("key").over(
-          
Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))),
-      Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), 
Row(2147483650L, 4), Row(3, 1)))
+          
Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))))
   }
 
   test("Unsupported window expression should fall back to Spark") {
@@ -1777,10 +1791,9 @@ class CometExecSuite extends CometTestBase {
           aggregateFunctions.foreach { function =>
             val queries = Seq(
               s"SELECT $function OVER() FROM t1",
-              // TODO: Range frame is not supported yet.
-              // s"SELECT $function OVER(order by _2) FROM t1",
-              // s"SELECT $function OVER(order by _2 desc) FROM t1",
-              // s"SELECT $function OVER(partition by _2 order by _2) FROM t1",
+              s"SELECT $function OVER(order by _2) FROM t1",
+              s"SELECT $function OVER(order by _2 desc) FROM t1",
+              s"SELECT $function OVER(partition by _2 order by _2) FROM t1",
               s"SELECT $function OVER(rows between 1 preceding and 1 
following) FROM t1",
               s"SELECT $function OVER(order by _2 rows between 1 preceding and 
current row) FROM t1",
               s"SELECT $function OVER(order by _2 rows between current row and 
1 following) FROM t1")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to