This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 24a4a44  ARROW-9902: [Rust] [DataFusion] Add array() built-in function
24a4a44 is described below

commit 24a4a44ea99f52ed2b36bb5b2d068488a5b6768a
Author: Jorge C. Leitao <[email protected]>
AuthorDate: Sun Sep 20 09:24:27 2020 -0600

    ARROW-9902: [Rust] [DataFusion] Add array() built-in function
    
    This adds `array()` built-in function to most primitive types. For 
composite types, this is more challenging and I decided to scope out of this PR.
    
    Closes #8102 from jorgecarleitao/array
    
    Authored-by: Jorge C. Leitao <[email protected]>
    Signed-off-by: Andy Grove <[email protected]>
---
 rust/datafusion/README.md                          |   2 +
 rust/datafusion/src/logical_plan/mod.rs            |   8 +
 .../src/physical_plan/array_expressions.rs         | 108 +++++++++++++
 rust/datafusion/src/physical_plan/functions.rs     |  81 +++++++++-
 rust/datafusion/src/physical_plan/mod.rs           |   1 +
 rust/datafusion/src/prelude.rs                     |   2 +-
 rust/datafusion/tests/sql.rs                       | 175 ++++++++++++---------
 7 files changed, 303 insertions(+), 74 deletions(-)

diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md
index bf3c60f..5405f26 100644
--- a/rust/datafusion/README.md
+++ b/rust/datafusion/README.md
@@ -61,6 +61,8 @@ DataFusion includes a simple command-line interactive SQL 
utility. See the [CLI
   - [ ] Basic date functions
   - [ ] Basic time functions
   - [x] Basic timestamp functions
+- nested functions
+  - [x] Array of columns
 - [x] Sorting
 - [ ] Nested types
 - [ ] Lists
diff --git a/rust/datafusion/src/logical_plan/mod.rs 
b/rust/datafusion/src/logical_plan/mod.rs
index e0f5d9d..e37bd10 100644
--- a/rust/datafusion/src/logical_plan/mod.rs
+++ b/rust/datafusion/src/logical_plan/mod.rs
@@ -623,6 +623,14 @@ pub fn concat(args: Vec<Expr>) -> Expr {
     }
 }
 
+/// returns an array of fixed size with each argument on it.
+pub fn array(args: Vec<Expr>) -> Expr {
+    Expr::ScalarFunction {
+        fun: functions::BuiltinScalarFunction::Array,
+        args,
+    }
+}
+
 /// Creates a new UDF with a specific signature and specific return type.
 /// This is a helper function to create a new UDF.
 /// The function `create_udf` returns a subset of all possible 
`ScalarFunction`:
diff --git a/rust/datafusion/src/physical_plan/array_expressions.rs 
b/rust/datafusion/src/physical_plan/array_expressions.rs
new file mode 100644
index 0000000..79fb64e
--- /dev/null
+++ b/rust/datafusion/src/physical_plan/array_expressions.rs
@@ -0,0 +1,108 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Array expressions
+
+use crate::error::{ExecutionError, Result};
+use arrow::array::*;
+use arrow::datatypes::DataType;
+use std::sync::Arc;
+
+macro_rules! downcast_vec {
+    ($ARGS:expr, $ARRAY_TYPE:ident) => {{
+        $ARGS
+            .iter()
+            .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() {
+                Some(array) => Ok(array),
+                _ => Err(ExecutionError::General("failed to 
downcast".to_string())),
+            })
+    }};
+}
+
+macro_rules! array {
+    ($ARGS:expr, $ARRAY_TYPE:ident, $BUILDER_TYPE:ident) => {{
+        // downcast all arguments to their common format
+        let args =
+            downcast_vec!($ARGS, 
$ARRAY_TYPE).collect::<Result<Vec<&$ARRAY_TYPE>>>()?;
+
+        let mut builder = FixedSizeListBuilder::<$BUILDER_TYPE>::new(
+            <$BUILDER_TYPE>::new(args[0].len()),
+            args.len() as i32,
+        );
+        // for each entry in the array
+        for index in 0..args[0].len() {
+            for arg in &args {
+                if arg.is_null(index) {
+                    builder.values().append_null()?;
+                } else {
+                    builder.values().append_value(arg.value(index))?;
+                }
+            }
+            builder.append(true)?;
+        }
+        Ok(Arc::new(builder.finish()))
+    }};
+}
+
+/// put values in an array.
+pub fn array(args: &[ArrayRef]) -> Result<ArrayRef> {
+    // do not accept 0 arguments.
+    if args.len() == 0 {
+        return Err(ExecutionError::InternalError(
+            "array requires at least one argument".to_string(),
+        ));
+    }
+
+    match args[0].data_type() {
+        DataType::Utf8 => array!(args, StringArray, StringBuilder),
+        DataType::LargeUtf8 => array!(args, LargeStringArray, 
LargeStringBuilder),
+        DataType::Boolean => array!(args, BooleanArray, BooleanBuilder),
+        DataType::Float32 => array!(args, Float32Array, Float32Builder),
+        DataType::Float64 => array!(args, Float64Array, Float64Builder),
+        DataType::Int8 => array!(args, Int8Array, Int8Builder),
+        DataType::Int16 => array!(args, Int16Array, Int16Builder),
+        DataType::Int32 => array!(args, Int32Array, Int32Builder),
+        DataType::Int64 => array!(args, Int64Array, Int64Builder),
+        DataType::UInt8 => array!(args, UInt8Array, UInt8Builder),
+        DataType::UInt16 => array!(args, UInt16Array, UInt16Builder),
+        DataType::UInt32 => array!(args, UInt32Array, UInt32Builder),
+        DataType::UInt64 => array!(args, UInt64Array, UInt64Builder),
+        data_type => Err(ExecutionError::NotImplemented(format!(
+            "Array is not implemented for type '{:?}'.",
+            data_type
+        ))),
+    }
+}
+
+/// Currently supported types by the array function.
+/// The order of these types correspond to the order on which coercion applies
+/// This should thus be from least informative to most informative
+pub static SUPPORTED_ARRAY_TYPES: &'static [DataType] = &[
+    DataType::Boolean,
+    DataType::UInt8,
+    DataType::UInt16,
+    DataType::UInt32,
+    DataType::UInt64,
+    DataType::Int8,
+    DataType::Int16,
+    DataType::Int32,
+    DataType::Int64,
+    DataType::Float32,
+    DataType::Float64,
+    DataType::Utf8,
+    DataType::LargeUtf8,
+];
diff --git a/rust/datafusion/src/physical_plan/functions.rs 
b/rust/datafusion/src/physical_plan/functions.rs
index af02c6d..95bd252 100644
--- a/rust/datafusion/src/physical_plan/functions.rs
+++ b/rust/datafusion/src/physical_plan/functions.rs
@@ -34,6 +34,7 @@ use super::{
     PhysicalExpr,
 };
 use crate::error::{ExecutionError, Result};
+use crate::physical_plan::array_expressions;
 use crate::physical_plan::datetime_expressions;
 use crate::physical_plan::math_expressions;
 use crate::physical_plan::string_expressions;
@@ -118,6 +119,8 @@ pub enum BuiltinScalarFunction {
     Concat,
     /// to_timestamp
     ToTimestamp,
+    /// construct an array from columns
+    Array,
 }
 
 impl fmt::Display for BuiltinScalarFunction {
@@ -151,6 +154,7 @@ impl FromStr for BuiltinScalarFunction {
             "length" => BuiltinScalarFunction::Length,
             "concat" => BuiltinScalarFunction::Concat,
             "to_timestamp" => BuiltinScalarFunction::ToTimestamp,
+            "array" => BuiltinScalarFunction::Array,
             _ => {
                 return Err(ExecutionError::General(format!(
                     "There is no built-in function named {}",
@@ -189,6 +193,10 @@ pub fn return_type(
         BuiltinScalarFunction::ToTimestamp => {
             Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
         }
+        BuiltinScalarFunction::Array => Ok(DataType::FixedSizeList(
+            Box::new(arg_types[0].clone()),
+            arg_types.len() as i32,
+        )),
         _ => Ok(DataType::Float64),
     }
 }
@@ -225,6 +233,7 @@ pub fn create_physical_expr(
         BuiltinScalarFunction::ToTimestamp => {
             |args| Ok(Arc::new(datetime_expressions::to_timestamp(args)?))
         }
+        BuiltinScalarFunction::Array => |args| 
Ok(array_expressions::array(args)?),
     });
     // coerce
     let args = coerce(args, input_schema, &signature(fun))?;
@@ -251,6 +260,9 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature {
         BuiltinScalarFunction::Length => Signature::Uniform(1, 
vec![DataType::Utf8]),
         BuiltinScalarFunction::Concat => 
Signature::Variadic(vec![DataType::Utf8]),
         BuiltinScalarFunction::ToTimestamp => Signature::Uniform(1, 
vec![DataType::Utf8]),
+        BuiltinScalarFunction::Array => {
+            
Signature::Variadic(array_expressions::SUPPORTED_ARRAY_TYPES.to_vec())
+        }
         // math expressions expect 1 argument of type f64 or f32
         // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real 
numbers) and thus we
         // return the best approximation for it (in f64).
@@ -342,8 +354,8 @@ mod tests {
     };
     use arrow::{
         array::{
-            ArrayRef, Float64Array, Int32Array, PrimitiveArrayOps, StringArray,
-            StringArrayOps,
+            ArrayRef, FixedSizeListArray, Float64Array, Int32Array, 
PrimitiveArrayOps,
+            StringArray, StringArrayOps,
         },
         datatypes::Field,
         record_batch::RecordBatch,
@@ -432,4 +444,69 @@ mod tests {
             Ok(())
         }
     }
+
+    fn generic_test_array(
+        value1: ScalarValue,
+        value2: ScalarValue,
+        expected_type: DataType,
+        expected: &str,
+    ) -> Result<()> {
+        // any type works here: we evaluate against a literal of `value`
+        let schema = Schema::new(vec![Field::new("a", DataType::Int32, 
false)]);
+        let columns: Vec<ArrayRef> = vec![Arc::new(Int32Array::from(vec![1]))];
+
+        let expr = create_physical_expr(
+            &BuiltinScalarFunction::Array,
+            &vec![lit(value1.clone()), lit(value2.clone())],
+            &schema,
+        )?;
+
+        // type is correct
+        assert_eq!(
+            expr.data_type(&schema)?,
+            // type equals to a common coercion
+            DataType::FixedSizeList(Box::new(expected_type), 2)
+        );
+
+        // evaluate works
+        let result =
+            expr.evaluate(&RecordBatch::try_new(Arc::new(schema.clone()), 
columns)?)?;
+
+        // downcast works
+        let result = result
+            .as_any()
+            .downcast_ref::<FixedSizeListArray>()
+            .unwrap();
+
+        // value is correct
+        assert_eq!(format!("{:?}", result.value(0)), expected);
+
+        Ok(())
+    }
+
+    #[test]
+    fn test_array() -> Result<()> {
+        generic_test_array(
+            ScalarValue::Utf8("aa".to_string()),
+            ScalarValue::Utf8("aa".to_string()),
+            DataType::Utf8,
+            "StringArray\n[\n  \"aa\",\n  \"aa\",\n]",
+        )?;
+
+        // different types, to validate that casting happens
+        generic_test_array(
+            ScalarValue::UInt32(1),
+            ScalarValue::UInt64(1),
+            DataType::UInt64,
+            "PrimitiveArray<UInt64>\n[\n  1,\n  1,\n]",
+        )?;
+
+        // different types (another order), to validate that casting happens
+        generic_test_array(
+            ScalarValue::UInt64(1),
+            ScalarValue::UInt32(1),
+            DataType::UInt64,
+            "PrimitiveArray<UInt64>\n[\n  1,\n  1,\n]",
+        )
+    }
 }
diff --git a/rust/datafusion/src/physical_plan/mod.rs 
b/rust/datafusion/src/physical_plan/mod.rs
index 99ce8d6..f71b279 100644
--- a/rust/datafusion/src/physical_plan/mod.rs
+++ b/rust/datafusion/src/physical_plan/mod.rs
@@ -131,6 +131,7 @@ pub trait Accumulator: Debug {
 }
 
 pub mod aggregates;
+pub mod array_expressions;
 pub mod common;
 pub mod csv;
 pub mod datetime_expressions;
diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs
index 1b68347..aac2ebf 100644
--- a/rust/datafusion/src/prelude.rs
+++ b/rust/datafusion/src/prelude.rs
@@ -28,6 +28,6 @@
 pub use crate::dataframe::DataFrame;
 pub use crate::execution::context::{ExecutionConfig, ExecutionContext};
 pub use crate::logical_plan::{
-    avg, col, concat, count, create_udf, length, lit, max, min, sum,
+    array, avg, col, concat, count, create_udf, length, lit, max, min, sum,
 };
 pub use crate::physical_plan::csv::CsvReadOptions;
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index 9de59fb..a87dc79 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -609,6 +609,77 @@ fn execute(ctx: &mut ExecutionContext, sql: &str) -> 
Vec<String> {
     result_str(&results)
 }
 
+/// Converts an array's value at `row_index` to a string.
+fn array_str(array: &Arc<dyn Array>, row_index: usize) -> String {
+    if array.is_null(row_index) {
+        return "NULL".to_string();
+    }
+    // beyond this point, we can assume that 
`array...downcast().value(row_index)` is valid,
+    // due to the `if` above.
+
+    match array.data_type() {
+        DataType::Int8 => {
+            let array = array.as_any().downcast_ref::<Int8Array>().unwrap();
+            format!("{:?}", array.value(row_index))
+        }
+        DataType::Int16 => {
+            let array = array.as_any().downcast_ref::<Int16Array>().unwrap();
+            format!("{:?}", array.value(row_index))
+        }
+        DataType::Int32 => {
+            let array = array.as_any().downcast_ref::<Int32Array>().unwrap();
+            format!("{:?}", array.value(row_index))
+        }
+        DataType::Int64 => {
+            let array = array.as_any().downcast_ref::<Int64Array>().unwrap();
+            format!("{:?}", array.value(row_index))
+        }
+        DataType::UInt8 => {
+            let array = array.as_any().downcast_ref::<UInt8Array>().unwrap();
+            format!("{:?}", array.value(row_index))
+        }
+        DataType::UInt16 => {
+            let array = array.as_any().downcast_ref::<UInt16Array>().unwrap();
+            format!("{:?}", array.value(row_index))
+        }
+        DataType::UInt32 => {
+            let array = array.as_any().downcast_ref::<UInt32Array>().unwrap();
+            format!("{:?}", array.value(row_index))
+        }
+        DataType::UInt64 => {
+            let array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
+            format!("{:?}", array.value(row_index))
+        }
+        DataType::Float32 => {
+            let array = array.as_any().downcast_ref::<Float32Array>().unwrap();
+            format!("{:?}", array.value(row_index))
+        }
+        DataType::Float64 => {
+            let array = array.as_any().downcast_ref::<Float64Array>().unwrap();
+            format!("{:?}", array.value(row_index))
+        }
+        DataType::Utf8 => {
+            let array = array.as_any().downcast_ref::<StringArray>().unwrap();
+            format!("{:?}", array.value(row_index))
+        }
+        DataType::Boolean => {
+            let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
+            format!("{:?}", array.value(row_index))
+        }
+        DataType::FixedSizeList(_, n) => {
+            let array = 
array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
+            let array = array.value(row_index);
+
+            let mut r = Vec::with_capacity(*n as usize);
+            for i in 0..*n {
+                r.push(array_str(&array, i as usize));
+            }
+            format!("[{}]", r.join(","))
+        }
+        _ => "???".to_string(),
+    }
+}
+
 fn result_str(results: &[RecordBatch]) -> Vec<String> {
     let mut result = vec![];
     for batch in results {
@@ -620,76 +691,7 @@ fn result_str(results: &[RecordBatch]) -> Vec<String> {
                 }
                 let column = batch.column(column_index);
 
-                match column.data_type() {
-                    DataType::Int8 => {
-                        let array = 
column.as_any().downcast_ref::<Int8Array>().unwrap();
-                        str.push_str(&format!("{:?}", array.value(row_index)));
-                    }
-                    DataType::Int16 => {
-                        let array = 
column.as_any().downcast_ref::<Int16Array>().unwrap();
-                        str.push_str(&format!("{:?}", array.value(row_index)));
-                    }
-                    DataType::Int32 => {
-                        let array = 
column.as_any().downcast_ref::<Int32Array>().unwrap();
-                        str.push_str(&format!("{:?}", array.value(row_index)));
-                    }
-                    DataType::Int64 => {
-                        let array = 
column.as_any().downcast_ref::<Int64Array>().unwrap();
-                        str.push_str(&format!("{:?}", array.value(row_index)));
-                    }
-                    DataType::UInt8 => {
-                        let array = 
column.as_any().downcast_ref::<UInt8Array>().unwrap();
-                        str.push_str(&format!("{:?}", array.value(row_index)));
-                    }
-                    DataType::UInt16 => {
-                        let array =
-                            
column.as_any().downcast_ref::<UInt16Array>().unwrap();
-                        str.push_str(&format!("{:?}", array.value(row_index)));
-                    }
-                    DataType::UInt32 => {
-                        let array =
-                            
column.as_any().downcast_ref::<UInt32Array>().unwrap();
-                        str.push_str(&format!("{:?}", array.value(row_index)));
-                    }
-                    DataType::UInt64 => {
-                        let array =
-                            
column.as_any().downcast_ref::<UInt64Array>().unwrap();
-                        str.push_str(&format!("{:?}", array.value(row_index)));
-                    }
-                    DataType::Float32 => {
-                        let array =
-                            
column.as_any().downcast_ref::<Float32Array>().unwrap();
-                        str.push_str(&format!("{:?}", array.value(row_index)));
-                    }
-                    DataType::Float64 => {
-                        let array =
-                            
column.as_any().downcast_ref::<Float64Array>().unwrap();
-                        str.push_str(&format!("{:?}", array.value(row_index)));
-                    }
-                    DataType::Utf8 => {
-                        let array =
-                            
column.as_any().downcast_ref::<StringArray>().unwrap();
-                        let s = if array.is_null(row_index) {
-                            "NULL"
-                        } else {
-                            array.value(row_index)
-                        };
-
-                        str.push_str(&format!("{:?}", s));
-                    }
-                    DataType::Boolean => {
-                        let array =
-                            
column.as_any().downcast_ref::<BooleanArray>().unwrap();
-                        let s = if array.is_null(row_index) {
-                            "NULL".to_string()
-                        } else {
-                            format!("{:?}", array.value(row_index))
-                        };
-
-                        str.push_str(&s);
-                    }
-                    _ => str.push_str("???"),
-                }
+                str.push_str(&array_str(column, row_index));
             }
             result.push(str);
         }
@@ -762,7 +764,38 @@ fn query_concat() -> Result<()> {
     ctx.register_table("test", Box::new(table));
     let sql = "SELECT concat(c1, '-hi-', cast(c2 as varchar)) FROM test";
     let actual = execute(&mut ctx, sql);
-    let expected = vec!["\"-hi-0\"", "\"a-hi-1\"", "\"NULL\"", "\"aaa-hi-3\""];
+    let expected = vec!["\"-hi-0\"", "\"a-hi-1\"", "NULL", "\"aaa-hi-3\""];
+    assert_eq!(expected, actual);
+    Ok(())
+}
+
+#[test]
+fn query_array() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Utf8, false),
+        Field::new("c2", DataType::Int32, true),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(StringArray::from(vec!["", "a", "aa", "aaa"])),
+            Arc::new(Int32Array::from(vec![Some(0), Some(1), None, Some(3)])),
+        ],
+    )?;
+
+    let table = MemTable::new(schema, vec![vec![data]])?;
+
+    let mut ctx = ExecutionContext::new();
+    ctx.register_table("test", Box::new(table));
+    let sql = "SELECT array(c1, cast(c2 as varchar)) FROM test";
+    let actual = execute(&mut ctx, sql);
+    let expected = vec![
+        "[\"\",\"0\"]",
+        "[\"a\",\"1\"]",
+        "[\"aa\",NULL]",
+        "[\"aaa\",\"3\"]",
+    ];
     assert_eq!(expected, actual);
     Ok(())
 }

Reply via email to