This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 3d78bf471f Keep output as scalar for scalar function if all inputs are
scalar (#7967)
3d78bf471f is described below
commit 3d78bf471f8b2a66c4298e1aca230bc87e7b4c25
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Oct 31 00:05:03 2023 -0700
Keep output as scalar for scalar function if all inputs are scalar (#7967)
* Keep output as scalar for scalar function if all inputs are scalar
* Add end-to-end tests
---
datafusion/physical-expr/src/functions.rs | 11 +++++-
datafusion/physical-expr/src/planner.rs | 34 +++++++++++++++++++
datafusion/sqllogictest/test_files/scalar.slt | 48 +++++++++++++++++++++++++++
3 files changed, 92 insertions(+), 1 deletion(-)
diff --git a/datafusion/physical-expr/src/functions.rs
b/datafusion/physical-expr/src/functions.rs
index 8422862043..b66bac4101 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -357,6 +357,8 @@ where
ColumnarValue::Array(a) => Some(a.len()),
});
+ let is_scalar = len.is_none();
+
let inferred_length = len.unwrap_or(1);
let args = args
.iter()
@@ -373,7 +375,14 @@ where
.collect::<Vec<ArrayRef>>();
let result = (inner)(&args);
- result.map(ColumnarValue::Array)
+
+ if is_scalar {
+ // If all inputs are scalar, keeps output as scalar
+ let result = result.and_then(|arr|
ScalarValue::try_from_array(&arr, 0));
+ result.map(ColumnarValue::Scalar)
+ } else {
+ result.map(ColumnarValue::Array)
+ }
})
}
diff --git a/datafusion/physical-expr/src/planner.rs
b/datafusion/physical-expr/src/planner.rs
index 9a74c2ca64..64c1d0be04 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -448,3 +448,37 @@ pub fn create_physical_expr(
}
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StringArray};
+ use arrow_schema::{DataType, Field, Schema};
+ use datafusion_common::{DFSchema, Result};
+ use datafusion_expr::{col, left, Literal};
+
+ #[test]
+ fn test_create_physical_expr_scalar_input_output() -> Result<()> {
+ let expr = col("letter").eq(left("APACHE".lit(), 1i64.lit()));
+
+ let schema = Schema::new(vec![Field::new("letter", DataType::Utf8,
false)]);
+ let df_schema = DFSchema::try_from_qualified_schema("data", &schema)?;
+ let p = create_physical_expr(&expr, &df_schema, &schema,
&ExecutionProps::new())?;
+
+ let batch = RecordBatch::try_new(
+ Arc::new(schema),
+ vec![Arc::new(StringArray::from_iter_values(vec![
+ "A", "B", "C", "D",
+ ]))],
+ )?;
+ let result = p.evaluate(&batch)?;
+ let result = result.into_array(4);
+
+ assert_eq!(
+ &result,
+ &(Arc::new(BooleanArray::from(vec![true, false, false, false,]))
as ArrayRef)
+ );
+
+ Ok(())
+ }
+}
diff --git a/datafusion/sqllogictest/test_files/scalar.slt
b/datafusion/sqllogictest/test_files/scalar.slt
index e5c1a82849..ecb7fe13fc 100644
--- a/datafusion/sqllogictest/test_files/scalar.slt
+++ b/datafusion/sqllogictest/test_files/scalar.slt
@@ -1878,3 +1878,51 @@ query T
SELECT CONCAT('Hello', 'World')
----
HelloWorld
+
+statement ok
+CREATE TABLE simple_string(
+ letter STRING,
+ letter2 STRING
+) as VALUES
+ ('A', 'APACHE'),
+ ('B', 'APACHE'),
+ ('C', 'APACHE'),
+ ('D', 'APACHE')
+;
+
+query TT
+EXPLAIN SELECT letter, letter = LEFT('APACHE', 1) FROM simple_string;
+----
+logical_plan
+Projection: simple_string.letter, simple_string.letter = Utf8("A") AS
simple_string.letter = left(Utf8("APACHE"),Int64(1))
+--TableScan: simple_string projection=[letter]
+physical_plan
+ProjectionExec: expr=[letter@0 as letter, letter@0 = A as simple_string.letter
= left(Utf8("APACHE"),Int64(1))]
+--MemoryExec: partitions=1, partition_sizes=[1]
+
+query TB
+SELECT letter, letter = LEFT('APACHE', 1) FROM simple_string;
+ ----
+----
+A true
+B false
+C false
+D false
+
+query TT
+EXPLAIN SELECT letter, letter = LEFT(letter2, 1) FROM simple_string;
+----
+logical_plan
+Projection: simple_string.letter, simple_string.letter =
left(simple_string.letter2, Int64(1))
+--TableScan: simple_string projection=[letter, letter2]
+physical_plan
+ProjectionExec: expr=[letter@0 as letter, letter@0 = left(letter2@1, 1) as
simple_string.letter = left(simple_string.letter2,Int64(1))]
+--MemoryExec: partitions=1, partition_sizes=[1]
+
+query TB
+SELECT letter, letter = LEFT(letter2, 1) FROM simple_string;
+----
+A true
+B false
+C false
+D false