This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 3d78bf471f Keep output as scalar for scalar function if all inputs are 
scalar (#7967)
3d78bf471f is described below

commit 3d78bf471f8b2a66c4298e1aca230bc87e7b4c25
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Tue Oct 31 00:05:03 2023 -0700

    Keep output as scalar for scalar function if all inputs are scalar (#7967)
    
    * Keep output as scalar for scalar function if all inputs are scalar
    
    * Add end-to-end tests
---
 datafusion/physical-expr/src/functions.rs     | 11 +++++-
 datafusion/physical-expr/src/planner.rs       | 34 +++++++++++++++++++
 datafusion/sqllogictest/test_files/scalar.slt | 48 +++++++++++++++++++++++++++
 3 files changed, 92 insertions(+), 1 deletion(-)

diff --git a/datafusion/physical-expr/src/functions.rs 
b/datafusion/physical-expr/src/functions.rs
index 8422862043..b66bac4101 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -357,6 +357,8 @@ where
                 ColumnarValue::Array(a) => Some(a.len()),
             });
 
+        let is_scalar = len.is_none();
+
         let inferred_length = len.unwrap_or(1);
         let args = args
             .iter()
@@ -373,7 +375,14 @@ where
             .collect::<Vec<ArrayRef>>();
 
         let result = (inner)(&args);
-        result.map(ColumnarValue::Array)
+
+        if is_scalar {
+            // If all inputs are scalar, keeps output as scalar
+            let result = result.and_then(|arr| 
ScalarValue::try_from_array(&arr, 0));
+            result.map(ColumnarValue::Scalar)
+        } else {
+            result.map(ColumnarValue::Array)
+        }
     })
 }
 
diff --git a/datafusion/physical-expr/src/planner.rs 
b/datafusion/physical-expr/src/planner.rs
index 9a74c2ca64..64c1d0be04 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -448,3 +448,37 @@ pub fn create_physical_expr(
         }
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StringArray};
+    use arrow_schema::{DataType, Field, Schema};
+    use datafusion_common::{DFSchema, Result};
+    use datafusion_expr::{col, left, Literal};
+
+    #[test]
+    fn test_create_physical_expr_scalar_input_output() -> Result<()> {
+        let expr = col("letter").eq(left("APACHE".lit(), 1i64.lit()));
+
+        let schema = Schema::new(vec![Field::new("letter", DataType::Utf8, 
false)]);
+        let df_schema = DFSchema::try_from_qualified_schema("data", &schema)?;
+        let p = create_physical_expr(&expr, &df_schema, &schema, 
&ExecutionProps::new())?;
+
+        let batch = RecordBatch::try_new(
+            Arc::new(schema),
+            vec![Arc::new(StringArray::from_iter_values(vec![
+                "A", "B", "C", "D",
+            ]))],
+        )?;
+        let result = p.evaluate(&batch)?;
+        let result = result.into_array(4);
+
+        assert_eq!(
+            &result,
+            &(Arc::new(BooleanArray::from(vec![true, false, false, false,])) 
as ArrayRef)
+        );
+
+        Ok(())
+    }
+}
diff --git a/datafusion/sqllogictest/test_files/scalar.slt 
b/datafusion/sqllogictest/test_files/scalar.slt
index e5c1a82849..ecb7fe13fc 100644
--- a/datafusion/sqllogictest/test_files/scalar.slt
+++ b/datafusion/sqllogictest/test_files/scalar.slt
@@ -1878,3 +1878,51 @@ query T
 SELECT CONCAT('Hello', 'World')
 ----
 HelloWorld
+
+statement ok
+CREATE TABLE simple_string(
+  letter STRING,
+  letter2 STRING
+) as VALUES
+  ('A', 'APACHE'),
+  ('B', 'APACHE'),
+  ('C', 'APACHE'),
+  ('D', 'APACHE')
+;
+
+query TT
+EXPLAIN SELECT letter, letter = LEFT('APACHE', 1) FROM simple_string;
+----
+logical_plan
+Projection: simple_string.letter, simple_string.letter = Utf8("A") AS 
simple_string.letter = left(Utf8("APACHE"),Int64(1))
+--TableScan: simple_string projection=[letter]
+physical_plan
+ProjectionExec: expr=[letter@0 as letter, letter@0 = A as simple_string.letter 
= left(Utf8("APACHE"),Int64(1))]
+--MemoryExec: partitions=1, partition_sizes=[1]
+
+query TB
+SELECT letter, letter = LEFT('APACHE', 1) FROM simple_string;
+ ----
+----
+A true
+B false
+C false
+D false
+
+query TT
+EXPLAIN SELECT letter, letter = LEFT(letter2, 1) FROM simple_string;
+----
+logical_plan
+Projection: simple_string.letter, simple_string.letter = 
left(simple_string.letter2, Int64(1))
+--TableScan: simple_string projection=[letter, letter2]
+physical_plan
+ProjectionExec: expr=[letter@0 as letter, letter@0 = left(letter2@1, 1) as 
simple_string.letter = left(simple_string.letter2,Int64(1))]
+--MemoryExec: partitions=1, partition_sizes=[1]
+
+query TB
+SELECT letter, letter = LEFT(letter2, 1) FROM simple_string;
+----
+A true
+B false
+C false
+D false

Reply via email to