This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 295054013 fix: query tolerance= in SQL file tests now also asserts
Comet native execution (#3797)
295054013 is described below
commit 295054013b0f5aae86483354e111437f74823594
Author: Andy Grove <[email protected]>
AuthorDate: Mon Mar 30 13:26:30 2026 -0600
fix: query tolerance= in SQL file tests now also asserts Comet native
execution (#3797)
---
native/spark-expr/src/comet_scalar_funcs.rs | 5 +
native/spark-expr/src/lib.rs | 4 +-
native/spark-expr/src/math_funcs/log.rs | 227 +++++++++++++++++++++
native/spark-expr/src/math_funcs/mod.rs | 2 +
.../org/apache/comet/serde/QueryPlanSerde.scala | 1 +
.../main/scala/org/apache/comet/serde/math.scala | 19 +-
.../resources/sql-tests/expressions/math/log.sql | 10 +-
.../resources/sql-tests/expressions/math/log10.sql | 2 +-
.../resources/sql-tests/expressions/math/log2.sql | 2 +-
.../resources/sql-tests/expressions/math/tan.sql | 1 +
.../org/apache/comet/CometSqlFileTestSuite.scala | 2 +-
.../scala/org/apache/spark/sql/CometTestBase.scala | 11 +
12 files changed, 278 insertions(+), 8 deletions(-)
diff --git a/native/spark-expr/src/comet_scalar_funcs.rs
b/native/spark-expr/src/comet_scalar_funcs.rs
index ff75de763..1eaf0b2a9 100644
--- a/native/spark-expr/src/comet_scalar_funcs.rs
+++ b/native/spark-expr/src/comet_scalar_funcs.rs
@@ -18,6 +18,7 @@
use crate::hash_funcs::*;
use crate::math_funcs::abs::abs;
use crate::math_funcs::checked_arithmetic::{checked_add, checked_div,
checked_mul, checked_sub};
+use crate::math_funcs::log::spark_log;
use crate::math_funcs::modulo_expr::spark_modulo;
use crate::{
spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor,
spark_isnan,
@@ -177,6 +178,10 @@ pub fn create_comet_physical_fun_with_eval_mode(
let func = Arc::new(abs);
make_comet_scalar_udf!("abs", func, without data_type)
}
+ "spark_log" => {
+ let func = Arc::new(spark_log);
+ make_comet_scalar_udf!("spark_log", func, without data_type)
+ }
"split" => {
let func = Arc::new(crate::string_funcs::spark_split);
make_comet_scalar_udf!("split", func, without data_type)
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index a7711d642..342ef7361 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -79,8 +79,8 @@ pub use hash_funcs::*;
pub use json_funcs::{FromJson, ToJson};
pub use math_funcs::{
create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div,
- spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round,
spark_unhex,
- spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow,
NegativeExpr,
+ spark_decimal_integral_div, spark_floor, spark_log, spark_make_decimal,
spark_round,
+ spark_unhex, spark_unscaled_value, CheckOverflow,
DecimalRescaleCheckOverflow, NegativeExpr,
NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp,
};
pub use query_context::{create_query_context_map, QueryContext,
QueryContextMap};
diff --git a/native/spark-expr/src/math_funcs/log.rs
b/native/spark-expr/src/math_funcs/log.rs
new file mode 100644
index 000000000..499d4f33e
--- /dev/null
+++ b/native/spark-expr/src/math_funcs/log.rs
@@ -0,0 +1,227 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::{Array, Float64Array};
+use datafusion::common::{DataFusionError, ScalarValue};
+use datafusion::physical_plan::ColumnarValue;
+use std::sync::Arc;
+
+/// Spark-compatible two-argument logarithm: `log(base, value)`.
+///
+/// Returns `log(value) / log(base)`, matching Spark's `Logarithm` expression.
+/// Returns null when `base <= 0` or `value <= 0`, matching Spark's
`nullSafeEval`.
+pub fn spark_log(args: &[ColumnarValue]) -> Result<ColumnarValue,
DataFusionError> {
+ if args.len() != 2 {
+ return Err(DataFusionError::Internal(format!(
+ "spark_log requires 2 arguments, got {}",
+ args.len()
+ )));
+ }
+
+ // Spark's Logarithm: log(base, value) = ln(value) / ln(base)
+ // Returns null when base <= 0 or value <= 0
+ fn compute(base: f64, value: f64) -> Option<f64> {
+ if base <= 0.0 || value <= 0.0 {
+ None
+ } else {
+ Some(value.ln() / base.ln())
+ }
+ }
+
+ match (&args[0], &args[1]) {
+ (ColumnarValue::Array(base_arr), ColumnarValue::Array(val_arr)) => {
+ let bases = base_arr
+ .as_any()
+ .downcast_ref::<Float64Array>()
+ .ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "spark_log expected Float64 for base, got {:?}",
+ base_arr.data_type()
+ ))
+ })?;
+ let values = val_arr
+ .as_any()
+ .downcast_ref::<Float64Array>()
+ .ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "spark_log expected Float64 for value, got {:?}",
+ val_arr.data_type()
+ ))
+ })?;
+ let result: Float64Array = bases
+ .iter()
+ .zip(values.iter())
+ .map(|(b, v)| match (b, v) {
+ (Some(base), Some(value)) => compute(base, value),
+ _ => None,
+ })
+ .collect();
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ }
+ (ColumnarValue::Scalar(base_scalar), ColumnarValue::Array(val_arr)) =>
{
+ let base = match base_scalar {
+ ScalarValue::Float64(Some(b)) => *b,
+ ScalarValue::Float64(None) => {
+ let result = Float64Array::new_null(val_arr.len());
+ return Ok(ColumnarValue::Array(Arc::new(result)));
+ }
+ _ => {
+ return Err(DataFusionError::Internal(format!(
+ "spark_log expected Float64 scalar for base, got
{base_scalar:?}",
+ )));
+ }
+ };
+ let values = val_arr
+ .as_any()
+ .downcast_ref::<Float64Array>()
+ .ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "spark_log expected Float64 for value, got {:?}",
+ val_arr.data_type()
+ ))
+ })?;
+ let result: Float64Array = values
+ .iter()
+ .map(|v| v.and_then(|value| compute(base, value)))
+ .collect();
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ }
+ (ColumnarValue::Array(base_arr), ColumnarValue::Scalar(val_scalar)) =>
{
+ let value = match val_scalar {
+ ScalarValue::Float64(Some(v)) => *v,
+ ScalarValue::Float64(None) => {
+ let result = Float64Array::new_null(base_arr.len());
+ return Ok(ColumnarValue::Array(Arc::new(result)));
+ }
+ _ => {
+ return Err(DataFusionError::Internal(format!(
+ "spark_log expected Float64 scalar for value, got
{val_scalar:?}",
+ )));
+ }
+ };
+ let bases = base_arr
+ .as_any()
+ .downcast_ref::<Float64Array>()
+ .ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "spark_log expected Float64 for base, got {:?}",
+ base_arr.data_type()
+ ))
+ })?;
+ let result: Float64Array = bases
+ .iter()
+ .map(|b| b.and_then(|base| compute(base, value)))
+ .collect();
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ }
+ (ColumnarValue::Scalar(base_scalar),
ColumnarValue::Scalar(val_scalar)) => {
+ let result = match (base_scalar, val_scalar) {
+ (ScalarValue::Float64(Some(base)),
ScalarValue::Float64(Some(value))) => {
+ ScalarValue::Float64(compute(*base, *value))
+ }
+ (ScalarValue::Float64(_), ScalarValue::Float64(_)) =>
ScalarValue::Float64(None),
+ _ => {
+ return Err(DataFusionError::Internal(format!(
+ "spark_log expected Float64 scalars, got
{base_scalar:?} and {val_scalar:?}",
+ )));
+ }
+ };
+ Ok(ColumnarValue::Scalar(result))
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ use arrow::array::Array;
+
+ #[test]
+ fn test_spark_log_basic() {
+ let bases = Float64Array::from(vec![10.0, 2.0, 10.0]);
+ let values = Float64Array::from(vec![100.0, 8.0, 1.0]);
+ let result = spark_log(&[
+ ColumnarValue::Array(Arc::new(bases)),
+ ColumnarValue::Array(Arc::new(values)),
+ ])
+ .unwrap();
+ if let ColumnarValue::Array(arr) = result {
+ let arr = arr.as_any().downcast_ref::<Float64Array>().unwrap();
+ assert!((arr.value(0) - 2.0).abs() < 1e-10);
+ assert!((arr.value(1) - 3.0).abs() < 1e-10);
+ assert!((arr.value(2) - 0.0).abs() < 1e-10);
+ } else {
+ panic!("expected array result");
+ }
+ }
+
+ #[test]
+ fn test_spark_log_non_positive_returns_null() {
+ let bases = Float64Array::from(vec![Some(0.0), Some(-1.0), Some(10.0),
Some(10.0)]);
+ let values = Float64Array::from(vec![Some(10.0), Some(10.0),
Some(0.0), Some(-1.0)]);
+ let result = spark_log(&[
+ ColumnarValue::Array(Arc::new(bases)),
+ ColumnarValue::Array(Arc::new(values)),
+ ])
+ .unwrap();
+ if let ColumnarValue::Array(arr) = result {
+ let arr = arr.as_any().downcast_ref::<Float64Array>().unwrap();
+ assert!(arr.is_null(0));
+ assert!(arr.is_null(1));
+ assert!(arr.is_null(2));
+ assert!(arr.is_null(3));
+ } else {
+ panic!("expected array result");
+ }
+ }
+
+ #[test]
+ fn test_spark_log_null_propagation() {
+ let bases = Float64Array::from(vec![Some(10.0), None]);
+ let values = Float64Array::from(vec![None, Some(10.0)]);
+ let result = spark_log(&[
+ ColumnarValue::Array(Arc::new(bases)),
+ ColumnarValue::Array(Arc::new(values)),
+ ])
+ .unwrap();
+ if let ColumnarValue::Array(arr) = result {
+ let arr = arr.as_any().downcast_ref::<Float64Array>().unwrap();
+ assert!(arr.is_null(0));
+ assert!(arr.is_null(1));
+ } else {
+ panic!("expected array result");
+ }
+ }
+
+ #[test]
+ fn test_spark_log_base_one_returns_nan() {
+ // log(1, 1) = ln(1) / ln(1) = 0/0 = NaN
+ let bases = Float64Array::from(vec![1.0]);
+ let values = Float64Array::from(vec![1.0]);
+ let result = spark_log(&[
+ ColumnarValue::Array(Arc::new(bases)),
+ ColumnarValue::Array(Arc::new(values)),
+ ])
+ .unwrap();
+ if let ColumnarValue::Array(arr) = result {
+ let arr = arr.as_any().downcast_ref::<Float64Array>().unwrap();
+ assert!(arr.value(0).is_nan());
+ } else {
+ panic!("expected array result");
+ }
+ }
+}
diff --git a/native/spark-expr/src/math_funcs/mod.rs
b/native/spark-expr/src/math_funcs/mod.rs
index 1219bc720..f66c584e2 100644
--- a/native/spark-expr/src/math_funcs/mod.rs
+++ b/native/spark-expr/src/math_funcs/mod.rs
@@ -21,6 +21,7 @@ pub(crate) mod checked_arithmetic;
mod div;
mod floor;
pub mod internal;
+pub(crate) mod log;
pub mod modulo_expr;
mod negative;
mod round;
@@ -33,6 +34,7 @@ pub use div::spark_decimal_div;
pub use div::spark_decimal_integral_div;
pub use floor::spark_floor;
pub use internal::*;
+pub use log::spark_log;
pub use modulo_expr::create_modulo_expr;
pub use negative::{create_negate_expr, NegativeExpr};
pub use round::spark_round;
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 2ce398c8f..59fb0f981 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -105,6 +105,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[Log] -> CometLog,
classOf[Log2] -> CometLog2,
classOf[Log10] -> CometLog10,
+ classOf[Logarithm] -> CometLogarithm,
classOf[Multiply] -> CometMultiply,
classOf[Pow] -> CometScalarFunction("pow"),
classOf[Rand] -> CometRand,
diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala
b/spark/src/main/scala/org/apache/comet/serde/math.scala
index 5a0393142..45c60b822 100644
--- a/spark/src/main/scala/org/apache/comet/serde/math.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/math.scala
@@ -19,8 +19,8 @@
package org.apache.comet.serde
-import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil,
CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log,
Log10, Log2, Tan, Unhex}
-import org.apache.spark.sql.types.{DecimalType, NumericType}
+import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil,
CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log,
Log10, Log2, Logarithm, Tan, Unhex}
+import org.apache.spark.sql.types.{DecimalType, DoubleType, NumericType}
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal,
optExprWithInfo, scalarFunctionExprToProto,
scalarFunctionExprToProtoWithReturnType, serializeDataType}
@@ -138,6 +138,21 @@ object CometLog2 extends CometExpressionSerde[Log2] with
MathExprBase {
}
}
+object CometLogarithm extends CometExpressionSerde[Logarithm] {
+ override def convert(
+ expr: Logarithm,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
+ // Uses custom spark_log UDF that returns null when base <= 0 or value <=
0,
+ // matching Spark's Logarithm.nullSafeEval behavior.
+ val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
+ val rightExpr = exprToProtoInternal(expr.right, inputs, binding)
+ val optExpr =
+ scalarFunctionExprToProtoWithReturnType("spark_log", DoubleType, false,
leftExpr, rightExpr)
+ optExprWithInfo(optExpr, expr, expr.left, expr.right)
+ }
+}
+
object CometHex extends CometExpressionSerde[Hex] with MathExprBase {
override def convert(
expr: Hex,
diff --git a/spark/src/test/resources/sql-tests/expressions/math/log.sql
b/spark/src/test/resources/sql-tests/expressions/math/log.sql
index e7420954c..8ee5282cd 100644
--- a/spark/src/test/resources/sql-tests/expressions/math/log.sql
+++ b/spark/src/test/resources/sql-tests/expressions/math/log.sql
@@ -21,7 +21,7 @@ statement
CREATE TABLE test_log(d double) USING parquet
statement
-INSERT INTO test_log VALUES (1.0), (2.718281828459045), (10.0), (0.5), (NULL),
(cast('NaN' as double)), (cast('Infinity' as double))
+INSERT INTO test_log VALUES (1.0), (2.718281828459045), (10.0), (0.5), (NULL),
(cast('NaN' as double)), (cast('Infinity' as double)), (0.0), (-1.0)
query tolerance=1e-6
SELECT ln(d) FROM test_log
@@ -40,3 +40,11 @@ SELECT ln(1.0), ln(2.718281828459045), ln(10.0), ln(NULL)
-- literal + literal (2-arg form)
query tolerance=1e-6
SELECT log(10.0, 100.0), log(2.0, 8.0), log(10.0, 1.0), log(NULL, 10.0)
+
+-- edge cases: base or value <= 0 should return null
+query tolerance=1e-6
+SELECT log(0.0, 10.0), log(-1.0, 10.0), log(10.0, 0.0), log(10.0, -1.0),
log(0.0, 0.0), log(-1.0, -1.0)
+
+-- edge case: log(1, 1) produces NaN (0/0) which Spark preserves as NaN
+query tolerance=1e-6
+SELECT log(1.0, 1.0)
diff --git a/spark/src/test/resources/sql-tests/expressions/math/log10.sql
b/spark/src/test/resources/sql-tests/expressions/math/log10.sql
index 1b3c9417f..77019c9b6 100644
--- a/spark/src/test/resources/sql-tests/expressions/math/log10.sql
+++ b/spark/src/test/resources/sql-tests/expressions/math/log10.sql
@@ -21,7 +21,7 @@ statement
CREATE TABLE test_log10(d double) USING parquet
statement
-INSERT INTO test_log10 VALUES (1.0), (10.0), (100.0), (0.1), (NULL),
(cast('NaN' as double)), (cast('Infinity' as double))
+INSERT INTO test_log10 VALUES (1.0), (10.0), (100.0), (0.1), (NULL),
(cast('NaN' as double)), (cast('Infinity' as double)), (0.0), (-1.0)
query tolerance=1e-6
SELECT log10(d) FROM test_log10
diff --git a/spark/src/test/resources/sql-tests/expressions/math/log2.sql
b/spark/src/test/resources/sql-tests/expressions/math/log2.sql
index 5db0ca484..01ff6f75b 100644
--- a/spark/src/test/resources/sql-tests/expressions/math/log2.sql
+++ b/spark/src/test/resources/sql-tests/expressions/math/log2.sql
@@ -21,7 +21,7 @@ statement
CREATE TABLE test_log2(d double) USING parquet
statement
-INSERT INTO test_log2 VALUES (1.0), (2.0), (4.0), (8.0), (0.5), (NULL),
(cast('NaN' as double)), (cast('Infinity' as double))
+INSERT INTO test_log2 VALUES (1.0), (2.0), (4.0), (8.0), (0.5), (NULL),
(cast('NaN' as double)), (cast('Infinity' as double)), (0.0), (-1.0)
query tolerance=1e-6
SELECT log2(d) FROM test_log2
diff --git a/spark/src/test/resources/sql-tests/expressions/math/tan.sql
b/spark/src/test/resources/sql-tests/expressions/math/tan.sql
index 21bd44f90..949684480 100644
--- a/spark/src/test/resources/sql-tests/expressions/math/tan.sql
+++ b/spark/src/test/resources/sql-tests/expressions/math/tan.sql
@@ -16,6 +16,7 @@
-- under the License.
-- ConfigMatrix: parquet.enable.dictionary=false,true
+-- Config: spark.comet.expression.Tan.allowIncompatible=true
statement
CREATE TABLE test_tan(d double) USING parquet
diff --git a/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala
b/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala
index 020759a7a..5a0b34e05 100644
--- a/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala
@@ -102,7 +102,7 @@ class CometSqlFileTestSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
case SparkAnswerOnly =>
checkSparkAnswer(sql)
case WithTolerance(tol) =>
- checkSparkAnswerWithTolerance(sql, tol)
+ checkSparkAnswerAndOperatorWithTolerance(sql, tol)
case ExpectFallback(reason) =>
checkSparkAnswerAndFallbackReason(sql, reason)
case Ignore(reason) =>
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 33c1d444b..a540c61d3 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -190,6 +190,17 @@ abstract class CometTestBase
internalCheckSparkAnswer(df, assertCometNative = false, withTol =
Some(absTol))
}
+ /**
+ * Check that the query returns the correct results when Comet is enabled
and that Comet
+ * replaced all possible operators. Use the provided `absTol` when comparing
floating-point
+ * results.
+ */
+ protected def checkSparkAnswerAndOperatorWithTolerance(
+ query: String,
+ absTol: Double = 1e-6): (SparkPlan, SparkPlan) = {
+ checkSparkAnswerAndOperatorWithTol(sql(query), absTol)
+ }
+
/**
* Check that the query returns the correct results when Comet is enabled
and that Comet
* replaced all possible operators except for those specified in the
excluded list.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]