This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 7067606 fix: SortMergeJoin with unsupported key type should fall back
to Spark (#355)
7067606 is described below
commit 7067606e68c1042347505b34cc21f4df74da4968
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Apr 30 10:59:07 2024 -0700
fix: SortMergeJoin with unsupported key type should fall back to Spark
(#355)
* fix: SortMergeJoin with unsupported key type should fall back to Spark
* Fix
* For review
* For review
---
.../org/apache/comet/serde/QueryPlanSerde.scala | 26 ++++++++++++++++++++++
.../org/apache/comet/exec/CometJoinSuite.scala | 19 ++++++++++++++++
.../scala/org/apache/spark/sql/CometTestBase.scala | 14 ++++++++++--
3 files changed, 57 insertions(+), 2 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index e1e7a71..6eda054 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -2104,6 +2104,18 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
expression
}
+ /**
+ * Returns true if given datatype is supported as a key in DataFusion sort
merge join.
+ */
+ def supportedSortMergeJoinEqualType(dataType: DataType): Boolean = dataType
match {
+ case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _:
FloatType |
+ _: DoubleType | _: StringType | _: DateType | _: DecimalType | _:
BooleanType =>
+ true
+ // `TimestampNTZType` is private in Spark 3.2/3.3.
+ case dt if dt.typeName == "timestamp_ntz" => true
+ case _ => false
+ }
+
/**
* Convert a Spark plan operator to a protobuf Comet operator.
*
@@ -2410,6 +2422,20 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde {
return None
}
+ // Checks if the join keys are supported by DataFusion SortMergeJoin.
+ val errorMsgs = join.leftKeys.flatMap { key =>
+ if (!supportedSortMergeJoinEqualType(key.dataType)) {
+ Some(s"Unsupported join key type ${key.dataType} on key:
${key.sql}")
+ } else {
+ None
+ }
+ }
+
+ if (errorMsgs.nonEmpty) {
+ withInfo(op, errorMsgs.flatten.mkString("\n"))
+ return None
+ }
+
val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
index 54c0baf..91d88c7 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
@@ -40,6 +40,25 @@ class CometJoinSuite extends CometTestBase {
}
}
+ test("SortMergeJoin with unsupported key type should fall back to Spark") {
+ withSQLConf(
+ SQLConf.SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu",
+ SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ withTable("t1", "t2") {
+ sql("CREATE TABLE t1(name STRING, time TIMESTAMP) USING PARQUET")
+ sql("INSERT OVERWRITE t1 VALUES('a', timestamp'2019-01-01 11:11:11')")
+
+ sql("CREATE TABLE t2(name STRING, time TIMESTAMP) USING PARQUET")
+ sql("INSERT OVERWRITE t2 VALUES('a', timestamp'2019-01-01 11:11:11')")
+
+ val df = sql("SELECT * FROM t1 JOIN t2 ON t1.time = t2.time")
+ val (sparkPlan, cometPlan) = checkSparkAnswer(df)
+ assert(sparkPlan.canonicalized === cometPlan.canonicalized)
+ }
+ }
+ }
+
test("Broadcast HashJoin without join filter") {
assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark
3.4+")
withSQLConf(
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index ef64d66..27428b8 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -127,18 +127,28 @@ abstract class CometTestBase
}
}
- protected def checkSparkAnswer(query: String): Unit = {
+ protected def checkSparkAnswer(query: String): (SparkPlan, SparkPlan) = {
checkSparkAnswer(sql(query))
}
- protected def checkSparkAnswer(df: => DataFrame): Unit = {
+ /**
+ * Check the answer of a Comet SQL query with Spark result.
+ * @param df
+ * The DataFrame of the query.
+ * @return
+ * A tuple of the SparkPlan of the query and the SparkPlan of the Comet
query.
+ */
+ protected def checkSparkAnswer(df: => DataFrame): (SparkPlan, SparkPlan) = {
var expected: Array[Row] = Array.empty
+ var sparkPlan = null.asInstanceOf[SparkPlan]
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
val dfSpark = Dataset.ofRows(spark, df.logicalPlan)
expected = dfSpark.collect()
+ sparkPlan = dfSpark.queryExecution.executedPlan
}
val dfComet = Dataset.ofRows(spark, df.logicalPlan)
checkAnswer(dfComet, expected)
+ (sparkPlan, dfComet.queryExecution.executedPlan)
}
protected def checkSparkAnswerAndOperator(query: String, excludedClasses:
Class[_]*): Unit = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]