This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 295054013 fix: query tolerance= in SQL file tests now also asserts 
Comet native execution (#3797)
295054013 is described below

commit 295054013b0f5aae86483354e111437f74823594
Author: Andy Grove <[email protected]>
AuthorDate: Mon Mar 30 13:26:30 2026 -0600

    fix: query tolerance= in SQL file tests now also asserts Comet native 
execution (#3797)
---
 native/spark-expr/src/comet_scalar_funcs.rs        |   5 +
 native/spark-expr/src/lib.rs                       |   4 +-
 native/spark-expr/src/math_funcs/log.rs            | 227 +++++++++++++++++++++
 native/spark-expr/src/math_funcs/mod.rs            |   2 +
 .../org/apache/comet/serde/QueryPlanSerde.scala    |   1 +
 .../main/scala/org/apache/comet/serde/math.scala   |  19 +-
 .../resources/sql-tests/expressions/math/log.sql   |  10 +-
 .../resources/sql-tests/expressions/math/log10.sql |   2 +-
 .../resources/sql-tests/expressions/math/log2.sql  |   2 +-
 .../resources/sql-tests/expressions/math/tan.sql   |   1 +
 .../org/apache/comet/CometSqlFileTestSuite.scala   |   2 +-
 .../scala/org/apache/spark/sql/CometTestBase.scala |  11 +
 12 files changed, 278 insertions(+), 8 deletions(-)

diff --git a/native/spark-expr/src/comet_scalar_funcs.rs 
b/native/spark-expr/src/comet_scalar_funcs.rs
index ff75de763..1eaf0b2a9 100644
--- a/native/spark-expr/src/comet_scalar_funcs.rs
+++ b/native/spark-expr/src/comet_scalar_funcs.rs
@@ -18,6 +18,7 @@
 use crate::hash_funcs::*;
 use crate::math_funcs::abs::abs;
 use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, 
checked_mul, checked_sub};
+use crate::math_funcs::log::spark_log;
 use crate::math_funcs::modulo_expr::spark_modulo;
 use crate::{
     spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, 
spark_isnan,
@@ -177,6 +178,10 @@ pub fn create_comet_physical_fun_with_eval_mode(
             let func = Arc::new(abs);
             make_comet_scalar_udf!("abs", func, without data_type)
         }
+        "spark_log" => {
+            let func = Arc::new(spark_log);
+            make_comet_scalar_udf!("spark_log", func, without data_type)
+        }
         "split" => {
             let func = Arc::new(crate::string_funcs::spark_split);
             make_comet_scalar_udf!("split", func, without data_type)
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index a7711d642..342ef7361 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -79,8 +79,8 @@ pub use hash_funcs::*;
 pub use json_funcs::{FromJson, ToJson};
 pub use math_funcs::{
     create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div,
-    spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, 
spark_unhex,
-    spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow, 
NegativeExpr,
+    spark_decimal_integral_div, spark_floor, spark_log, spark_make_decimal, 
spark_round,
+    spark_unhex, spark_unscaled_value, CheckOverflow, 
DecimalRescaleCheckOverflow, NegativeExpr,
     NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp,
 };
 pub use query_context::{create_query_context_map, QueryContext, 
QueryContextMap};
diff --git a/native/spark-expr/src/math_funcs/log.rs 
b/native/spark-expr/src/math_funcs/log.rs
new file mode 100644
index 000000000..499d4f33e
--- /dev/null
+++ b/native/spark-expr/src/math_funcs/log.rs
@@ -0,0 +1,227 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::{Array, Float64Array};
+use datafusion::common::{DataFusionError, ScalarValue};
+use datafusion::physical_plan::ColumnarValue;
+use std::sync::Arc;
+
+/// Spark-compatible two-argument logarithm: `log(base, value)`.
+///
+/// Returns `log(value) / log(base)`, matching Spark's `Logarithm` expression.
+/// Returns null when `base <= 0` or `value <= 0`, matching Spark's 
`nullSafeEval`.
+pub fn spark_log(args: &[ColumnarValue]) -> Result<ColumnarValue, 
DataFusionError> {
+    if args.len() != 2 {
+        return Err(DataFusionError::Internal(format!(
+            "spark_log requires 2 arguments, got {}",
+            args.len()
+        )));
+    }
+
+    // Spark's Logarithm: log(base, value) = ln(value) / ln(base)
+    // Returns null when base <= 0 or value <= 0
+    fn compute(base: f64, value: f64) -> Option<f64> {
+        if base <= 0.0 || value <= 0.0 {
+            None
+        } else {
+            Some(value.ln() / base.ln())
+        }
+    }
+
+    match (&args[0], &args[1]) {
+        (ColumnarValue::Array(base_arr), ColumnarValue::Array(val_arr)) => {
+            let bases = base_arr
+                .as_any()
+                .downcast_ref::<Float64Array>()
+                .ok_or_else(|| {
+                    DataFusionError::Internal(format!(
+                        "spark_log expected Float64 for base, got {:?}",
+                        base_arr.data_type()
+                    ))
+                })?;
+            let values = val_arr
+                .as_any()
+                .downcast_ref::<Float64Array>()
+                .ok_or_else(|| {
+                    DataFusionError::Internal(format!(
+                        "spark_log expected Float64 for value, got {:?}",
+                        val_arr.data_type()
+                    ))
+                })?;
+            let result: Float64Array = bases
+                .iter()
+                .zip(values.iter())
+                .map(|(b, v)| match (b, v) {
+                    (Some(base), Some(value)) => compute(base, value),
+                    _ => None,
+                })
+                .collect();
+            Ok(ColumnarValue::Array(Arc::new(result)))
+        }
+        (ColumnarValue::Scalar(base_scalar), ColumnarValue::Array(val_arr)) => 
{
+            let base = match base_scalar {
+                ScalarValue::Float64(Some(b)) => *b,
+                ScalarValue::Float64(None) => {
+                    let result = Float64Array::new_null(val_arr.len());
+                    return Ok(ColumnarValue::Array(Arc::new(result)));
+                }
+                _ => {
+                    return Err(DataFusionError::Internal(format!(
+                        "spark_log expected Float64 scalar for base, got 
{base_scalar:?}",
+                    )));
+                }
+            };
+            let values = val_arr
+                .as_any()
+                .downcast_ref::<Float64Array>()
+                .ok_or_else(|| {
+                    DataFusionError::Internal(format!(
+                        "spark_log expected Float64 for value, got {:?}",
+                        val_arr.data_type()
+                    ))
+                })?;
+            let result: Float64Array = values
+                .iter()
+                .map(|v| v.and_then(|value| compute(base, value)))
+                .collect();
+            Ok(ColumnarValue::Array(Arc::new(result)))
+        }
+        (ColumnarValue::Array(base_arr), ColumnarValue::Scalar(val_scalar)) => 
{
+            let value = match val_scalar {
+                ScalarValue::Float64(Some(v)) => *v,
+                ScalarValue::Float64(None) => {
+                    let result = Float64Array::new_null(base_arr.len());
+                    return Ok(ColumnarValue::Array(Arc::new(result)));
+                }
+                _ => {
+                    return Err(DataFusionError::Internal(format!(
+                        "spark_log expected Float64 scalar for value, got 
{val_scalar:?}",
+                    )));
+                }
+            };
+            let bases = base_arr
+                .as_any()
+                .downcast_ref::<Float64Array>()
+                .ok_or_else(|| {
+                    DataFusionError::Internal(format!(
+                        "spark_log expected Float64 for base, got {:?}",
+                        base_arr.data_type()
+                    ))
+                })?;
+            let result: Float64Array = bases
+                .iter()
+                .map(|b| b.and_then(|base| compute(base, value)))
+                .collect();
+            Ok(ColumnarValue::Array(Arc::new(result)))
+        }
+        (ColumnarValue::Scalar(base_scalar), 
ColumnarValue::Scalar(val_scalar)) => {
+            let result = match (base_scalar, val_scalar) {
+                (ScalarValue::Float64(Some(base)), 
ScalarValue::Float64(Some(value))) => {
+                    ScalarValue::Float64(compute(*base, *value))
+                }
+                (ScalarValue::Float64(_), ScalarValue::Float64(_)) => 
ScalarValue::Float64(None),
+                _ => {
+                    return Err(DataFusionError::Internal(format!(
+                        "spark_log expected Float64 scalars, got 
{base_scalar:?} and {val_scalar:?}",
+                    )));
+                }
+            };
+            Ok(ColumnarValue::Scalar(result))
+        }
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+    use arrow::array::Array;
+
+    #[test]
+    fn test_spark_log_basic() {
+        let bases = Float64Array::from(vec![10.0, 2.0, 10.0]);
+        let values = Float64Array::from(vec![100.0, 8.0, 1.0]);
+        let result = spark_log(&[
+            ColumnarValue::Array(Arc::new(bases)),
+            ColumnarValue::Array(Arc::new(values)),
+        ])
+        .unwrap();
+        if let ColumnarValue::Array(arr) = result {
+            let arr = arr.as_any().downcast_ref::<Float64Array>().unwrap();
+            assert!((arr.value(0) - 2.0).abs() < 1e-10);
+            assert!((arr.value(1) - 3.0).abs() < 1e-10);
+            assert!((arr.value(2) - 0.0).abs() < 1e-10);
+        } else {
+            panic!("expected array result");
+        }
+    }
+
+    #[test]
+    fn test_spark_log_non_positive_returns_null() {
+        let bases = Float64Array::from(vec![Some(0.0), Some(-1.0), Some(10.0), 
Some(10.0)]);
+        let values = Float64Array::from(vec![Some(10.0), Some(10.0), 
Some(0.0), Some(-1.0)]);
+        let result = spark_log(&[
+            ColumnarValue::Array(Arc::new(bases)),
+            ColumnarValue::Array(Arc::new(values)),
+        ])
+        .unwrap();
+        if let ColumnarValue::Array(arr) = result {
+            let arr = arr.as_any().downcast_ref::<Float64Array>().unwrap();
+            assert!(arr.is_null(0));
+            assert!(arr.is_null(1));
+            assert!(arr.is_null(2));
+            assert!(arr.is_null(3));
+        } else {
+            panic!("expected array result");
+        }
+    }
+
+    #[test]
+    fn test_spark_log_null_propagation() {
+        let bases = Float64Array::from(vec![Some(10.0), None]);
+        let values = Float64Array::from(vec![None, Some(10.0)]);
+        let result = spark_log(&[
+            ColumnarValue::Array(Arc::new(bases)),
+            ColumnarValue::Array(Arc::new(values)),
+        ])
+        .unwrap();
+        if let ColumnarValue::Array(arr) = result {
+            let arr = arr.as_any().downcast_ref::<Float64Array>().unwrap();
+            assert!(arr.is_null(0));
+            assert!(arr.is_null(1));
+        } else {
+            panic!("expected array result");
+        }
+    }
+
+    #[test]
+    fn test_spark_log_base_one_returns_nan() {
+        // log(1, 1) = ln(1) / ln(1) = 0/0 = NaN
+        let bases = Float64Array::from(vec![1.0]);
+        let values = Float64Array::from(vec![1.0]);
+        let result = spark_log(&[
+            ColumnarValue::Array(Arc::new(bases)),
+            ColumnarValue::Array(Arc::new(values)),
+        ])
+        .unwrap();
+        if let ColumnarValue::Array(arr) = result {
+            let arr = arr.as_any().downcast_ref::<Float64Array>().unwrap();
+            assert!(arr.value(0).is_nan());
+        } else {
+            panic!("expected array result");
+        }
+    }
+}
diff --git a/native/spark-expr/src/math_funcs/mod.rs 
b/native/spark-expr/src/math_funcs/mod.rs
index 1219bc720..f66c584e2 100644
--- a/native/spark-expr/src/math_funcs/mod.rs
+++ b/native/spark-expr/src/math_funcs/mod.rs
@@ -21,6 +21,7 @@ pub(crate) mod checked_arithmetic;
 mod div;
 mod floor;
 pub mod internal;
+pub(crate) mod log;
 pub mod modulo_expr;
 mod negative;
 mod round;
@@ -33,6 +34,7 @@ pub use div::spark_decimal_div;
 pub use div::spark_decimal_integral_div;
 pub use floor::spark_floor;
 pub use internal::*;
+pub use log::spark_log;
 pub use modulo_expr::create_modulo_expr;
 pub use negative::{create_negate_expr, NegativeExpr};
 pub use round::spark_round;
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 2ce398c8f..59fb0f981 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -105,6 +105,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
     classOf[Log] -> CometLog,
     classOf[Log2] -> CometLog2,
     classOf[Log10] -> CometLog10,
+    classOf[Logarithm] -> CometLogarithm,
     classOf[Multiply] -> CometMultiply,
     classOf[Pow] -> CometScalarFunction("pow"),
     classOf[Rand] -> CometRand,
diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala 
b/spark/src/main/scala/org/apache/comet/serde/math.scala
index 5a0393142..45c60b822 100644
--- a/spark/src/main/scala/org/apache/comet/serde/math.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/math.scala
@@ -19,8 +19,8 @@
 
 package org.apache.comet.serde
 
-import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, 
CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, 
Log10, Log2, Tan, Unhex}
-import org.apache.spark.sql.types.{DecimalType, NumericType}
+import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, 
CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, 
Log10, Log2, Logarithm, Tan, Unhex}
+import org.apache.spark.sql.types.{DecimalType, DoubleType, NumericType}
 
 import org.apache.comet.CometSparkSessionExtensions.withInfo
 import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, 
optExprWithInfo, scalarFunctionExprToProto, 
scalarFunctionExprToProtoWithReturnType, serializeDataType}
@@ -138,6 +138,21 @@ object CometLog2 extends CometExpressionSerde[Log2] with 
MathExprBase {
   }
 }
 
+object CometLogarithm extends CometExpressionSerde[Logarithm] {
+  override def convert(
+      expr: Logarithm,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    // Uses custom spark_log UDF that returns null when base <= 0 or value <= 
0,
+    // matching Spark's Logarithm.nullSafeEval behavior.
+    val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
+    val rightExpr = exprToProtoInternal(expr.right, inputs, binding)
+    val optExpr =
+      scalarFunctionExprToProtoWithReturnType("spark_log", DoubleType, false, 
leftExpr, rightExpr)
+    optExprWithInfo(optExpr, expr, expr.left, expr.right)
+  }
+}
+
 object CometHex extends CometExpressionSerde[Hex] with MathExprBase {
   override def convert(
       expr: Hex,
diff --git a/spark/src/test/resources/sql-tests/expressions/math/log.sql 
b/spark/src/test/resources/sql-tests/expressions/math/log.sql
index e7420954c..8ee5282cd 100644
--- a/spark/src/test/resources/sql-tests/expressions/math/log.sql
+++ b/spark/src/test/resources/sql-tests/expressions/math/log.sql
@@ -21,7 +21,7 @@ statement
 CREATE TABLE test_log(d double) USING parquet
 
 statement
-INSERT INTO test_log VALUES (1.0), (2.718281828459045), (10.0), (0.5), (NULL), 
(cast('NaN' as double)), (cast('Infinity' as double))
+INSERT INTO test_log VALUES (1.0), (2.718281828459045), (10.0), (0.5), (NULL), 
(cast('NaN' as double)), (cast('Infinity' as double)), (0.0), (-1.0)
 
 query tolerance=1e-6
 SELECT ln(d) FROM test_log
@@ -40,3 +40,11 @@ SELECT ln(1.0), ln(2.718281828459045), ln(10.0), ln(NULL)
 -- literal + literal (2-arg form)
 query tolerance=1e-6
 SELECT log(10.0, 100.0), log(2.0, 8.0), log(10.0, 1.0), log(NULL, 10.0)
+
+-- edge cases: base or value <= 0 should return null
+query tolerance=1e-6
+SELECT log(0.0, 10.0), log(-1.0, 10.0), log(10.0, 0.0), log(10.0, -1.0), 
log(0.0, 0.0), log(-1.0, -1.0)
+
+-- edge case: log(1, 1) produces NaN (0/0) which Spark preserves as NaN
+query tolerance=1e-6
+SELECT log(1.0, 1.0)
diff --git a/spark/src/test/resources/sql-tests/expressions/math/log10.sql 
b/spark/src/test/resources/sql-tests/expressions/math/log10.sql
index 1b3c9417f..77019c9b6 100644
--- a/spark/src/test/resources/sql-tests/expressions/math/log10.sql
+++ b/spark/src/test/resources/sql-tests/expressions/math/log10.sql
@@ -21,7 +21,7 @@ statement
 CREATE TABLE test_log10(d double) USING parquet
 
 statement
-INSERT INTO test_log10 VALUES (1.0), (10.0), (100.0), (0.1), (NULL), 
(cast('NaN' as double)), (cast('Infinity' as double))
+INSERT INTO test_log10 VALUES (1.0), (10.0), (100.0), (0.1), (NULL), 
(cast('NaN' as double)), (cast('Infinity' as double)), (0.0), (-1.0)
 
 query tolerance=1e-6
 SELECT log10(d) FROM test_log10
diff --git a/spark/src/test/resources/sql-tests/expressions/math/log2.sql 
b/spark/src/test/resources/sql-tests/expressions/math/log2.sql
index 5db0ca484..01ff6f75b 100644
--- a/spark/src/test/resources/sql-tests/expressions/math/log2.sql
+++ b/spark/src/test/resources/sql-tests/expressions/math/log2.sql
@@ -21,7 +21,7 @@ statement
 CREATE TABLE test_log2(d double) USING parquet
 
 statement
-INSERT INTO test_log2 VALUES (1.0), (2.0), (4.0), (8.0), (0.5), (NULL), 
(cast('NaN' as double)), (cast('Infinity' as double))
+INSERT INTO test_log2 VALUES (1.0), (2.0), (4.0), (8.0), (0.5), (NULL), 
(cast('NaN' as double)), (cast('Infinity' as double)), (0.0), (-1.0)
 
 query tolerance=1e-6
 SELECT log2(d) FROM test_log2
diff --git a/spark/src/test/resources/sql-tests/expressions/math/tan.sql 
b/spark/src/test/resources/sql-tests/expressions/math/tan.sql
index 21bd44f90..949684480 100644
--- a/spark/src/test/resources/sql-tests/expressions/math/tan.sql
+++ b/spark/src/test/resources/sql-tests/expressions/math/tan.sql
@@ -16,6 +16,7 @@
 -- under the License.
 
 -- ConfigMatrix: parquet.enable.dictionary=false,true
+-- Config: spark.comet.expression.Tan.allowIncompatible=true
 
 statement
 CREATE TABLE test_tan(d double) USING parquet
diff --git a/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala
index 020759a7a..5a0b34e05 100644
--- a/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala
@@ -102,7 +102,7 @@ class CometSqlFileTestSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
                   case SparkAnswerOnly =>
                     checkSparkAnswer(sql)
                   case WithTolerance(tol) =>
-                    checkSparkAnswerWithTolerance(sql, tol)
+                    checkSparkAnswerAndOperatorWithTolerance(sql, tol)
                   case ExpectFallback(reason) =>
                     checkSparkAnswerAndFallbackReason(sql, reason)
                   case Ignore(reason) =>
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala 
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 33c1d444b..a540c61d3 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -190,6 +190,17 @@ abstract class CometTestBase
     internalCheckSparkAnswer(df, assertCometNative = false, withTol = 
Some(absTol))
   }
 
+  /**
+   * Check that the query returns the correct results when Comet is enabled 
and that Comet
+   * replaced all possible operators. Use the provided `absTol` when comparing 
floating-point
+   * results.
+   */
+  protected def checkSparkAnswerAndOperatorWithTolerance(
+      query: String,
+      absTol: Double = 1e-6): (SparkPlan, SparkPlan) = {
+    checkSparkAnswerAndOperatorWithTol(sql(query), absTol)
+  }
+
   /**
    * Check that the query returns the correct results when Comet is enabled 
and that Comet
    * replaced all possible operators except for those specified in the 
excluded list.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to