This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new e9649477c fix: Fall back to Spark when hashing decimals with precision 
> 18 (#1325)
e9649477c is described below

commit e9649477c4f8b4c6906244c3cc6828b83f32f735
Author: Andy Grove <agr...@apache.org>
AuthorDate: Wed Jan 29 12:07:20 2025 -0700

    fix: Fall back to Spark when hashing decimals with precision > 18 (#1325)
    
    * fall back to Spark when hashing decimals with precision > 18
    
    * murmur3 checks
    
    * refactor
    
    * fix
    
    * address feedback
---
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 32 +-------
 .../main/scala/org/apache/comet/serde/hash.scala   | 85 ++++++++++++++++++++++
 .../org/apache/comet/CometExpressionSuite.scala    | 52 +++++++++----
 3 files changed, 127 insertions(+), 42 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index cb4fffc1a..350aaf7ad 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -2176,35 +2176,9 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
           None
         }
 
-      case Murmur3Hash(children, seed) =>
-        val firstUnSupportedInput = children.find(c => 
!supportedDataType(c.dataType))
-        if (firstUnSupportedInput.isDefined) {
-          withInfo(expr, s"Unsupported datatype 
${firstUnSupportedInput.get.dataType}")
-          return None
-        }
-        val exprs = children.map(exprToProtoInternal(_, inputs, binding))
-        val seedBuilder = ExprOuterClass.Literal
-          .newBuilder()
-          .setDatatype(serializeDataType(IntegerType).get)
-          .setIntVal(seed)
-        val seedExpr = 
Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
-        // the seed is put at the end of the arguments
-        scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ 
seedExpr: _*)
-
-      case XxHash64(children, seed) =>
-        val firstUnSupportedInput = children.find(c => 
!supportedDataType(c.dataType))
-        if (firstUnSupportedInput.isDefined) {
-          withInfo(expr, s"Unsupported datatype 
${firstUnSupportedInput.get.dataType}")
-          return None
-        }
-        val exprs = children.map(exprToProtoInternal(_, inputs, binding))
-        val seedBuilder = ExprOuterClass.Literal
-          .newBuilder()
-          .setDatatype(serializeDataType(LongType).get)
-          .setLongVal(seed)
-        val seedExpr = 
Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
-        // the seed is put at the end of the arguments
-        scalarExprToProtoWithReturnType("xxhash64", LongType, exprs :+ 
seedExpr: _*)
+      case _: Murmur3Hash => CometMurmur3Hash.convert(expr, inputs, binding)
+
+      case _: XxHash64 => CometXxHash64.convert(expr, inputs, binding)
 
       case Sha2(left, numBits) =>
         if (!numBits.foldable) {
diff --git a/spark/src/main/scala/org/apache/comet/serde/hash.scala 
b/spark/src/main/scala/org/apache/comet/serde/hash.scala
new file mode 100644
index 000000000..226c4bab0
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/hash.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
Murmur3Hash, XxHash64}
+import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType}
+
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, 
scalarExprToProtoWithReturnType, serializeDataType, supportedDataType}
+
+object CometXxHash64 extends CometExpressionSerde {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    if (!HashUtils.isSupportedType(expr)) {
+      return None
+    }
+    val hash = expr.asInstanceOf[XxHash64]
+    val exprs = hash.children.map(exprToProtoInternal(_, inputs, binding))
+    val seedBuilder = ExprOuterClass.Literal
+      .newBuilder()
+      .setDatatype(serializeDataType(LongType).get)
+      .setLongVal(hash.seed)
+    val seedExpr = 
Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
+    // the seed is put at the end of the arguments
+    scalarExprToProtoWithReturnType("xxhash64", LongType, exprs :+ seedExpr: 
_*)
+  }
+}
+
+object CometMurmur3Hash extends CometExpressionSerde {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    if (!HashUtils.isSupportedType(expr)) {
+      return None
+    }
+    val hash = expr.asInstanceOf[Murmur3Hash]
+    val exprs = hash.children.map(exprToProtoInternal(_, inputs, binding))
+    val seedBuilder = ExprOuterClass.Literal
+      .newBuilder()
+      .setDatatype(serializeDataType(IntegerType).get)
+      .setIntVal(hash.seed)
+    val seedExpr = 
Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
+    // the seed is put at the end of the arguments
+    scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ 
seedExpr: _*)
+  }
+}
+
+private object HashUtils {
+  def isSupportedType(expr: Expression): Boolean = {
+    for (child <- expr.children) {
+      child.dataType match {
+        case dt: DecimalType if dt.precision > 18 =>
+          // Spark converts decimals with precision > 18 into
+          // Java BigDecimal before hashing
+          withInfo(expr, s"Unsupported datatype: $dt (precision > 18)")
+          return false
+        case dt if !supportedDataType(dt) =>
+          withInfo(expr, s"Unsupported datatype $dt")
+          return false
+        case _ =>
+      }
+    }
+    true
+  }
+}
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index f82101b3a..f226afeaa 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -1929,19 +1929,45 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
-  test("hash functions with decimal input") {
-    withTable("t1", "t2") {
-      // Apache Spark: if it's a small decimal, i.e. precision <= 18, turn it 
into long and hash it.
-      // Else, turn it into bytes and hash it.
-      sql("create table t1(c1 decimal(18, 2)) using parquet")
-      sql("insert into t1 values(1.23), (-1.23), (0.0), (null)")
-      checkSparkAnswerAndOperator("select c1, hash(c1), xxhash64(c1) from t1 
order by c1")
-
-      // TODO: comet hash function is not compatible with spark for decimal 
with precision greater than 18.
-      // https://github.com/apache/datafusion-comet/issues/1294
-//       sql("create table t2(c1 decimal(20, 2)) using parquet")
-//       sql("insert into t2 values(1.23), (-1.23), (0.0), (null)")
-//       checkSparkAnswerAndOperator("select c1, hash(c1), xxhash64(c1) from 
t2 order by c1")
+  test("hash function with decimal input") {
+    val testPrecisionScales: Seq[(Int, Int)] = Seq(
+      (1, 0),
+      (17, 2),
+      (18, 2),
+      (19, 2),
+      (DecimalType.MAX_PRECISION, DecimalType.MAX_SCALE - 1))
+    for ((p, s) <- testPrecisionScales) {
+      withTable("t1") {
+        sql(s"create table t1(c1 decimal($p, $s)) using parquet")
+        sql("insert into t1 values(1.23), (-1.23), (0.0), (null)")
+        if (p <= 18) {
+          checkSparkAnswerAndOperator("select c1, hash(c1) from t1 order by 
c1")
+        } else {
+          // not supported natively yet
+          checkSparkAnswer("select c1, hash(c1) from t1 order by c1")
+        }
+      }
+    }
+  }
+
+  test("xxhash64 function with decimal input") {
+    val testPrecisionScales: Seq[(Int, Int)] = Seq(
+      (1, 0),
+      (17, 2),
+      (18, 2),
+      (19, 2),
+      (DecimalType.MAX_PRECISION, DecimalType.MAX_SCALE - 1))
+    for ((p, s) <- testPrecisionScales) {
+      withTable("t1") {
+        sql(s"create table t1(c1 decimal($p, $s)) using parquet")
+        sql("insert into t1 values(1.23), (-1.23), (0.0), (null)")
+        if (p <= 18) {
+          checkSparkAnswerAndOperator("select c1, xxhash64(c1) from t1 order 
by c1")
+        } else {
+          // not supported natively yet
+          checkSparkAnswer("select c1, xxhash64(c1) from t1 order by c1")
+        }
+      }
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to