This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 8a609dc20 Add SQL function overload LOG(base, x) for logarithm of x to
base (#5245)
8a609dc20 is described below
commit 8a609dc207d3f08f26b1f8a04e16147d3573525b
Author: comphead <[email protected]>
AuthorDate: Sun Feb 12 05:59:54 2023 -0800
Add SQL function overload LOG(base, x) for logarithm of x to base (#5245)
* support overloaded log scalar function
* support overloaded log scalar function
* fix comments
---
.../core/tests/sqllogictests/test_files/scalar.slt | 70 ++++++++++++++++++++++
datafusion/expr/src/expr_fn.rs | 1 +
datafusion/expr/src/function.rs | 15 ++++-
datafusion/physical-expr/src/functions.rs | 4 +-
datafusion/physical-expr/src/math_expressions.rs | 34 +++++++++++
datafusion/proto/src/logical_plan/from_proto.rs | 6 +-
6 files changed, 127 insertions(+), 3 deletions(-)
diff --git a/datafusion/core/tests/sqllogictests/test_files/scalar.slt
b/datafusion/core/tests/sqllogictests/test_files/scalar.slt
new file mode 100644
index 000000000..bf7d64e8c
--- /dev/null
+++ b/datafusion/core/tests/sqllogictests/test_files/scalar.slt
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#############
+## Scalar Function Tests
+#############
+
+statement ok
+CREATE TABLE t1(
+ a INT,
+ b INT
+) as VALUES
+ (1, 100),
+ (2, 1000),
+ (3, 10000)
+;
+
+# log scalar function
+query IT rowsort
+select log(2, 64) a, log(100) b union all select log(2, 8), log(10);
+----
+3 1
+6 2
+
+# log scalar function
+query IT rowsort
+select log(a, 64) a, log(b), log(10, b) from t1;
+----
+3.7855785 4 4
+6 3 3
+Infinity 2 2
+
+# log scalar nulls
+query IT rowsort
+select log(null, 64) a, log(null) b
+----
+NULL NULL
+
+# log scalar nulls 1
+query IT rowsort
+select log(2, null) a, log(null) b
+----
+NULL NULL
+
+# log scalar nulls 2
+query IT rowsort
+select log(null, null) a, log(null) b
+----
+NULL NULL
+
+# log scalar ops with zero edgecases
+# please see
https://github.com/apache/arrow-datafusion/pull/5245#issuecomment-1426828382
+query IT rowsort
+select log(0) a, log(1, 64) b
+----
+-Infinity Infinity
\ No newline at end of file
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 3c048327c..9ec35e238 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -372,6 +372,7 @@ scalar_expr!(
"returns the hexdecimal representation of an integer"
);
scalar_expr!(Uuid, uuid, , "Returns uuid v4 as a string value");
+scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`");
// string functions
scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character");
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index 50cc6bcd6..cfb9b3baa 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -255,6 +255,11 @@ pub fn return_type(
_ => Ok(DataType::Float64),
},
+ BuiltinScalarFunction::Log => match &input_expr_types[0] {
+ DataType::Float32 => Ok(DataType::Float32),
+ _ => Ok(DataType::Float64),
+ },
+
BuiltinScalarFunction::ArrowTypeof => Ok(DataType::Utf8),
BuiltinScalarFunction::Abs
@@ -265,7 +270,6 @@ pub fn return_type(
| BuiltinScalarFunction::Cos
| BuiltinScalarFunction::Exp
| BuiltinScalarFunction::Floor
- | BuiltinScalarFunction::Log
| BuiltinScalarFunction::Ln
| BuiltinScalarFunction::Log10
| BuiltinScalarFunction::Log2
@@ -607,6 +611,15 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature
{
],
fun.volatility(),
),
+ BuiltinScalarFunction::Log => Signature::one_of(
+ vec![
+ TypeSignature::Exact(vec![DataType::Float32]),
+ TypeSignature::Exact(vec![DataType::Float64]),
+ TypeSignature::Exact(vec![DataType::Float32,
DataType::Float32]),
+ TypeSignature::Exact(vec![DataType::Float64,
DataType::Float64]),
+ ],
+ fun.volatility(),
+ ),
BuiltinScalarFunction::ArrowTypeof => Signature::any(1,
fun.volatility()),
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real
numbers) and thus we
diff --git a/datafusion/physical-expr/src/functions.rs
b/datafusion/physical-expr/src/functions.rs
index c72ab161f..fda6e6aa2 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -323,7 +323,6 @@ pub fn create_physical_fun(
BuiltinScalarFunction::Cos => Arc::new(math_expressions::cos),
BuiltinScalarFunction::Exp => Arc::new(math_expressions::exp),
BuiltinScalarFunction::Floor => Arc::new(math_expressions::floor),
- BuiltinScalarFunction::Log => Arc::new(math_expressions::log10),
BuiltinScalarFunction::Ln => Arc::new(math_expressions::ln),
BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10),
BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2),
@@ -340,6 +339,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::Atan2 => {
Arc::new(|args|
make_scalar_function(math_expressions::atan2)(args))
}
+ BuiltinScalarFunction::Log => {
+ Arc::new(|args| make_scalar_function(math_expressions::log)(args))
+ }
// string functions
BuiltinScalarFunction::MakeArray => Arc::new(array_expressions::array),
diff --git a/datafusion/physical-expr/src/math_expressions.rs
b/datafusion/physical-expr/src/math_expressions.rs
index afaff49b8..8993595b8 100644
--- a/datafusion/physical-expr/src/math_expressions.rs
+++ b/datafusion/physical-expr/src/math_expressions.rs
@@ -201,6 +201,40 @@ pub fn atan2(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}
+pub fn log(args: &[ArrayRef]) -> Result<ArrayRef> {
+ // Support overloaded log(base, x) and log(x) which defaults to log(10, x)
+ // note in f64::log params order is different than in sql. e.g in sql
log(base, x) == f64::log(x, base)
+ let mut base = &(Arc::new(Float32Array::from_value(10.0, args[0].len()))
as ArrayRef);
+ let mut x = &args[0];
+ if args.len() == 2 {
+ x = &args[1];
+ base = &args[0];
+ }
+ match args[0].data_type() {
+ DataType::Float64 => Ok(Arc::new(make_function_inputs2!(
+ x,
+ base,
+ "x",
+ "base",
+ Float64Array,
+ { f64::log }
+ )) as ArrayRef),
+
+ DataType::Float32 => Ok(Arc::new(make_function_inputs2!(
+ x,
+ base,
+ "x",
+ "base",
+ Float32Array,
+ { f32::log }
+ )) as ArrayRef),
+
+ other => Err(DataFusionError::Internal(format!(
+ "Unsupported data type {other:?} for function log"
+ ))),
+ }
+}
+
#[cfg(test)]
mod tests {
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 3dcf2fdb4..a74874586 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -37,7 +37,7 @@ use datafusion_expr::{
character_length, chr, coalesce, concat_expr, concat_ws_expr, cos,
date_bin,
date_part, date_trunc, digest, exp,
expr::{self, Sort, WindowFunction},
- floor, from_unixtime, left, ln, log10, log2,
+ floor, from_unixtime, left, ln, log, log10, log2,
logical_plan::{PlanType, StringifiedPlan},
lower, lpad, ltrim, md5, now, nullif, octet_length, power, random,
regexp_match,
regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim,
sha224, sha256,
@@ -1303,6 +1303,10 @@ pub fn parse_expr(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
)),
+ ScalarFunction::Log => Ok(log(
+ parse_expr(&args[0], registry)?,
+ parse_expr(&args[1], registry)?,
+ )),
ScalarFunction::FromUnixtime => {
Ok(from_unixtime(parse_expr(&args[0], registry)?))
}