This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 7067606  fix: SortMergeJoin with unsupported key type should fall back 
to Spark (#355)
7067606 is described below

commit 7067606e68c1042347505b34cc21f4df74da4968
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Apr 30 10:59:07 2024 -0700

    fix: SortMergeJoin with unsupported key type should fall back to Spark 
(#355)
    
    * fix: SortMergeJoin with unsupported key type should fall back to Spark
    
    * Fix
    
    * For review
    
    * For review
---
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 26 ++++++++++++++++++++++
 .../org/apache/comet/exec/CometJoinSuite.scala     | 19 ++++++++++++++++
 .../scala/org/apache/spark/sql/CometTestBase.scala | 14 ++++++++++--
 3 files changed, 57 insertions(+), 2 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index e1e7a71..6eda054 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -2104,6 +2104,18 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
     expression
   }
 
+  /**
+   * Returns true if given datatype is supported as a key in DataFusion sort 
merge join.
+   */
+  def supportedSortMergeJoinEqualType(dataType: DataType): Boolean = dataType 
match {
+    case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: 
FloatType |
+        _: DoubleType | _: StringType | _: DateType | _: DecimalType | _: 
BooleanType =>
+      true
+    // `TimestampNTZType` is private in Spark 3.2/3.3.
+    case dt if dt.typeName == "timestamp_ntz" => true
+    case _ => false
+  }
+
   /**
    * Convert a Spark plan operator to a protobuf Comet operator.
    *
@@ -2410,6 +2422,20 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
             return None
         }
 
+        // Checks if the join keys are supported by DataFusion SortMergeJoin.
+        val errorMsgs = join.leftKeys.flatMap { key =>
+          if (!supportedSortMergeJoinEqualType(key.dataType)) {
+            Some(s"Unsupported join key type ${key.dataType} on key: 
${key.sql}")
+          } else {
+            None
+          }
+        }
+
+        if (errorMsgs.nonEmpty) {
+          withInfo(op, errorMsgs.flatten.mkString("\n"))
+          return None
+        }
+
         val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
         val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))
 
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
index 54c0baf..91d88c7 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
@@ -40,6 +40,25 @@ class CometJoinSuite extends CometTestBase {
     }
   }
 
+  test("SortMergeJoin with unsupported key type should fall back to Spark") {
+    withSQLConf(
+      SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu",
+      SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+      withTable("t1", "t2") {
+        sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET")
+        sql("INSERT OVERWRITE t1 VALUES('a', timestamp'2019-01-01 11:11:11')")
+
+        sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET")
+        sql("INSERT OVERWRITE t2 VALUES('a', timestamp'2019-01-01 11:11:11')")
+
+        val df = sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time")
+        val (sparkPlan, cometPlan) = checkSparkAnswer(df)
+        assert(sparkPlan.canonicalized === cometPlan.canonicalized)
+      }
+    }
+  }
+
   test("Broadcast HashJoin without join filter") {
     assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 
3.4+")
     withSQLConf(
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala 
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index ef64d66..27428b8 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -127,18 +127,28 @@ abstract class CometTestBase
     }
   }
 
-  protected def checkSparkAnswer(query: String): Unit = {
+  protected def checkSparkAnswer(query: String): (SparkPlan, SparkPlan) = {
     checkSparkAnswer(sql(query))
   }
 
-  protected def checkSparkAnswer(df: => DataFrame): Unit = {
+  /**
+   * Check the answer of a Comet SQL query with Spark result.
+   * @param df
+   *   The DataFrame of the query.
+   * @return
+   *   A tuple of the SparkPlan of the query and the SparkPlan of the Comet 
query.
+   */
+  protected def checkSparkAnswer(df: => DataFrame): (SparkPlan, SparkPlan) = {
     var expected: Array[Row] = Array.empty
+    var sparkPlan = null.asInstanceOf[SparkPlan]
     withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
       val dfSpark = Dataset.ofRows(spark, df.logicalPlan)
       expected = dfSpark.collect()
+      sparkPlan = dfSpark.queryExecution.executedPlan
     }
     val dfComet = Dataset.ofRows(spark, df.logicalPlan)
     checkAnswer(dfComet, expected)
+    (sparkPlan, dfComet.queryExecution.executedPlan)
   }
 
   protected def checkSparkAnswerAndOperator(query: String, excludedClasses: 
Class[_]*): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to