This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 9aa42cd  fix: bitwise shift with different left/right types (#135)
9aa42cd is described below

commit 9aa42cd2ab5a308befb1372bf3e5270c2d0d0e6d
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Thu Feb 29 10:12:21 2024 -0800

    fix: bitwise shift with different left/right types (#135)
---
 .../org/apache/comet/serde/QueryPlanSerde.scala      | 16 ++++++++++++++--
 .../org/apache/comet/CometExpressionSuite.scala      | 20 ++++++++++++++++++++
 2 files changed, 34 insertions(+), 2 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 46eb1b0..75a2ff9 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1376,7 +1376,13 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
 
       case ShiftRight(left, right) =>
         val leftExpr = exprToProtoInternal(left, inputs)
-        val rightExpr = exprToProtoInternal(right, inputs)
+        val rightExpr = if (left.dataType == LongType) {
+          // DataFusion bitwise shift right expression requires
+          // same data type between left and right side
+          exprToProtoInternal(Cast(right, LongType), inputs)
+        } else {
+          exprToProtoInternal(right, inputs)
+        }
 
         if (leftExpr.isDefined && rightExpr.isDefined) {
           val builder = ExprOuterClass.BitwiseShiftRight.newBuilder()
@@ -1394,7 +1400,13 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
 
       case ShiftLeft(left, right) =>
         val leftExpr = exprToProtoInternal(left, inputs)
-        val rightExpr = exprToProtoInternal(right, inputs)
+        val rightExpr = if (left.dataType == LongType) {
+          // DataFusion bitwise shift left expression requires
+          // same data type between left and right side
+          exprToProtoInternal(Cast(right, LongType), inputs)
+        } else {
+          exprToProtoInternal(right, inputs)
+        }
 
         if (leftExpr.isDefined && rightExpr.isDefined) {
           val builder = ExprOuterClass.BitwiseShiftLeft.newBuilder()
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 3f29e95..2609bd3 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -34,6 +34,26 @@ import 
org.apache.comet.CometSparkSessionExtensions.{isSpark32, isSpark34Plus}
 class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
   import testImplicits._
 
+  test("bitwise shift with different left/right types") {
+    Seq(false, true).foreach { dictionary =>
+      withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
+        val table = "test"
+        withTable(table) {
+          sql(s"create table $table(col1 long, col2 int) using parquet")
+          sql(s"insert into $table values(1111, 2)")
+          sql(s"insert into $table values(1111, 2)")
+          sql(s"insert into $table values(3333, 4)")
+          sql(s"insert into $table values(5555, 6)")
+
+          checkSparkAnswerAndOperator(
+            s"SELECT shiftright(col1, 2), shiftright(col1, col2) FROM $table")
+          checkSparkAnswerAndOperator(
+            s"SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM $table")
+        }
+      }
+    }
+  }
+
   test("basic data type support") {
     Seq(true, false).foreach { dictionaryEnabled =>
       withTempDir { dir =>

Reply via email to