This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 24a4a44 ARROW-9902: [Rust] [DataFusion] Add array() built-in function
24a4a44 is described below
commit 24a4a44ea99f52ed2b36bb5b2d068488a5b6768a
Author: Jorge C. Leitao <[email protected]>
AuthorDate: Sun Sep 20 09:24:27 2020 -0600
ARROW-9902: [Rust] [DataFusion] Add array() built-in function
This adds `array()` built-in function to most primitive types. For
composite types, this is more challenging and I decided to scope out of this PR.
Closes #8102 from jorgecarleitao/array
Authored-by: Jorge C. Leitao <[email protected]>
Signed-off-by: Andy Grove <[email protected]>
---
rust/datafusion/README.md | 2 +
rust/datafusion/src/logical_plan/mod.rs | 8 +
.../src/physical_plan/array_expressions.rs | 108 +++++++++++++
rust/datafusion/src/physical_plan/functions.rs | 81 +++++++++-
rust/datafusion/src/physical_plan/mod.rs | 1 +
rust/datafusion/src/prelude.rs | 2 +-
rust/datafusion/tests/sql.rs | 175 ++++++++++++---------
7 files changed, 303 insertions(+), 74 deletions(-)
diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md
index bf3c60f..5405f26 100644
--- a/rust/datafusion/README.md
+++ b/rust/datafusion/README.md
@@ -61,6 +61,8 @@ DataFusion includes a simple command-line interactive SQL
utility. See the [CLI
- [ ] Basic date functions
- [ ] Basic time functions
- [x] Basic timestamp functions
+- nested functions
+ - [x] Array of columns
- [x] Sorting
- [ ] Nested types
- [ ] Lists
diff --git a/rust/datafusion/src/logical_plan/mod.rs
b/rust/datafusion/src/logical_plan/mod.rs
index e0f5d9d..e37bd10 100644
--- a/rust/datafusion/src/logical_plan/mod.rs
+++ b/rust/datafusion/src/logical_plan/mod.rs
@@ -623,6 +623,14 @@ pub fn concat(args: Vec<Expr>) -> Expr {
}
}
+/// returns an array of fixed size with each argument on it.
+pub fn array(args: Vec<Expr>) -> Expr {
+ Expr::ScalarFunction {
+ fun: functions::BuiltinScalarFunction::Array,
+ args,
+ }
+}
+
/// Creates a new UDF with a specific signature and specific return type.
/// This is a helper function to create a new UDF.
/// The function `create_udf` returns a subset of all possible
`ScalarFunction`:
diff --git a/rust/datafusion/src/physical_plan/array_expressions.rs
b/rust/datafusion/src/physical_plan/array_expressions.rs
new file mode 100644
index 0000000..79fb64e
--- /dev/null
+++ b/rust/datafusion/src/physical_plan/array_expressions.rs
@@ -0,0 +1,108 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Array expressions
+
+use crate::error::{ExecutionError, Result};
+use arrow::array::*;
+use arrow::datatypes::DataType;
+use std::sync::Arc;
+
+macro_rules! downcast_vec {
+ ($ARGS:expr, $ARRAY_TYPE:ident) => {{
+ $ARGS
+ .iter()
+ .map(|e| match e.as_any().downcast_ref::<$ARRAY_TYPE>() {
+ Some(array) => Ok(array),
+ _ => Err(ExecutionError::General("failed to
downcast".to_string())),
+ })
+ }};
+}
+
+macro_rules! array {
+ ($ARGS:expr, $ARRAY_TYPE:ident, $BUILDER_TYPE:ident) => {{
+ // downcast all arguments to their common format
+ let args =
+ downcast_vec!($ARGS,
$ARRAY_TYPE).collect::<Result<Vec<&$ARRAY_TYPE>>>()?;
+
+ let mut builder = FixedSizeListBuilder::<$BUILDER_TYPE>::new(
+ <$BUILDER_TYPE>::new(args[0].len()),
+ args.len() as i32,
+ );
+ // for each entry in the array
+ for index in 0..args[0].len() {
+ for arg in &args {
+ if arg.is_null(index) {
+ builder.values().append_null()?;
+ } else {
+ builder.values().append_value(arg.value(index))?;
+ }
+ }
+ builder.append(true)?;
+ }
+ Ok(Arc::new(builder.finish()))
+ }};
+}
+
+/// put values in an array.
+pub fn array(args: &[ArrayRef]) -> Result<ArrayRef> {
+ // do not accept 0 arguments.
+ if args.len() == 0 {
+ return Err(ExecutionError::InternalError(
+ "array requires at least one argument".to_string(),
+ ));
+ }
+
+ match args[0].data_type() {
+ DataType::Utf8 => array!(args, StringArray, StringBuilder),
+ DataType::LargeUtf8 => array!(args, LargeStringArray,
LargeStringBuilder),
+ DataType::Boolean => array!(args, BooleanArray, BooleanBuilder),
+ DataType::Float32 => array!(args, Float32Array, Float32Builder),
+ DataType::Float64 => array!(args, Float64Array, Float64Builder),
+ DataType::Int8 => array!(args, Int8Array, Int8Builder),
+ DataType::Int16 => array!(args, Int16Array, Int16Builder),
+ DataType::Int32 => array!(args, Int32Array, Int32Builder),
+ DataType::Int64 => array!(args, Int64Array, Int64Builder),
+ DataType::UInt8 => array!(args, UInt8Array, UInt8Builder),
+ DataType::UInt16 => array!(args, UInt16Array, UInt16Builder),
+ DataType::UInt32 => array!(args, UInt32Array, UInt32Builder),
+ DataType::UInt64 => array!(args, UInt64Array, UInt64Builder),
+ data_type => Err(ExecutionError::NotImplemented(format!(
+ "Array is not implemented for type '{:?}'.",
+ data_type
+ ))),
+ }
+}
+
+/// Currently supported types by the array function.
+/// The order of these types correspond to the order on which coercion applies
+/// This should thus be from least informative to most informative
+pub static SUPPORTED_ARRAY_TYPES: &'static [DataType] = &[
+ DataType::Boolean,
+ DataType::UInt8,
+ DataType::UInt16,
+ DataType::UInt32,
+ DataType::UInt64,
+ DataType::Int8,
+ DataType::Int16,
+ DataType::Int32,
+ DataType::Int64,
+ DataType::Float32,
+ DataType::Float64,
+ DataType::Utf8,
+ DataType::LargeUtf8,
+];
diff --git a/rust/datafusion/src/physical_plan/functions.rs
b/rust/datafusion/src/physical_plan/functions.rs
index af02c6d..95bd252 100644
--- a/rust/datafusion/src/physical_plan/functions.rs
+++ b/rust/datafusion/src/physical_plan/functions.rs
@@ -34,6 +34,7 @@ use super::{
PhysicalExpr,
};
use crate::error::{ExecutionError, Result};
+use crate::physical_plan::array_expressions;
use crate::physical_plan::datetime_expressions;
use crate::physical_plan::math_expressions;
use crate::physical_plan::string_expressions;
@@ -118,6 +119,8 @@ pub enum BuiltinScalarFunction {
Concat,
/// to_timestamp
ToTimestamp,
+ /// construct an array from columns
+ Array,
}
impl fmt::Display for BuiltinScalarFunction {
@@ -151,6 +154,7 @@ impl FromStr for BuiltinScalarFunction {
"length" => BuiltinScalarFunction::Length,
"concat" => BuiltinScalarFunction::Concat,
"to_timestamp" => BuiltinScalarFunction::ToTimestamp,
+ "array" => BuiltinScalarFunction::Array,
_ => {
return Err(ExecutionError::General(format!(
"There is no built-in function named {}",
@@ -189,6 +193,10 @@ pub fn return_type(
BuiltinScalarFunction::ToTimestamp => {
Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
}
+ BuiltinScalarFunction::Array => Ok(DataType::FixedSizeList(
+ Box::new(arg_types[0].clone()),
+ arg_types.len() as i32,
+ )),
_ => Ok(DataType::Float64),
}
}
@@ -225,6 +233,7 @@ pub fn create_physical_expr(
BuiltinScalarFunction::ToTimestamp => {
|args| Ok(Arc::new(datetime_expressions::to_timestamp(args)?))
}
+ BuiltinScalarFunction::Array => |args|
Ok(array_expressions::array(args)?),
});
// coerce
let args = coerce(args, input_schema, &signature(fun))?;
@@ -251,6 +260,9 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature {
BuiltinScalarFunction::Length => Signature::Uniform(1,
vec![DataType::Utf8]),
BuiltinScalarFunction::Concat =>
Signature::Variadic(vec![DataType::Utf8]),
BuiltinScalarFunction::ToTimestamp => Signature::Uniform(1,
vec![DataType::Utf8]),
+ BuiltinScalarFunction::Array => {
+
Signature::Variadic(array_expressions::SUPPORTED_ARRAY_TYPES.to_vec())
+ }
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real
numbers) and thus we
// return the best approximation for it (in f64).
@@ -342,8 +354,8 @@ mod tests {
};
use arrow::{
array::{
- ArrayRef, Float64Array, Int32Array, PrimitiveArrayOps, StringArray,
- StringArrayOps,
+ ArrayRef, FixedSizeListArray, Float64Array, Int32Array,
PrimitiveArrayOps,
+ StringArray, StringArrayOps,
},
datatypes::Field,
record_batch::RecordBatch,
@@ -432,4 +444,69 @@ mod tests {
Ok(())
}
}
+
+ fn generic_test_array(
+ value1: ScalarValue,
+ value2: ScalarValue,
+ expected_type: DataType,
+ expected: &str,
+ ) -> Result<()> {
+ // any type works here: we evaluate against a literal of `value`
+ let schema = Schema::new(vec![Field::new("a", DataType::Int32,
false)]);
+ let columns: Vec<ArrayRef> = vec![Arc::new(Int32Array::from(vec![1]))];
+
+ let expr = create_physical_expr(
+ &BuiltinScalarFunction::Array,
+ &vec![lit(value1.clone()), lit(value2.clone())],
+ &schema,
+ )?;
+
+ // type is correct
+ assert_eq!(
+ expr.data_type(&schema)?,
+ // type equals to a common coercion
+ DataType::FixedSizeList(Box::new(expected_type), 2)
+ );
+
+ // evaluate works
+ let result =
+ expr.evaluate(&RecordBatch::try_new(Arc::new(schema.clone()),
columns)?)?;
+
+ // downcast works
+ let result = result
+ .as_any()
+ .downcast_ref::<FixedSizeListArray>()
+ .unwrap();
+
+ // value is correct
+ assert_eq!(format!("{:?}", result.value(0)), expected);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_array() -> Result<()> {
+ generic_test_array(
+ ScalarValue::Utf8("aa".to_string()),
+ ScalarValue::Utf8("aa".to_string()),
+ DataType::Utf8,
+ "StringArray\n[\n \"aa\",\n \"aa\",\n]",
+ )?;
+
+ // different types, to validate that casting happens
+ generic_test_array(
+ ScalarValue::UInt32(1),
+ ScalarValue::UInt64(1),
+ DataType::UInt64,
+ "PrimitiveArray<UInt64>\n[\n 1,\n 1,\n]",
+ )?;
+
+ // different types (another order), to validate that casting happens
+ generic_test_array(
+ ScalarValue::UInt64(1),
+ ScalarValue::UInt32(1),
+ DataType::UInt64,
+ "PrimitiveArray<UInt64>\n[\n 1,\n 1,\n]",
+ )
+ }
}
diff --git a/rust/datafusion/src/physical_plan/mod.rs
b/rust/datafusion/src/physical_plan/mod.rs
index 99ce8d6..f71b279 100644
--- a/rust/datafusion/src/physical_plan/mod.rs
+++ b/rust/datafusion/src/physical_plan/mod.rs
@@ -131,6 +131,7 @@ pub trait Accumulator: Debug {
}
pub mod aggregates;
+pub mod array_expressions;
pub mod common;
pub mod csv;
pub mod datetime_expressions;
diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs
index 1b68347..aac2ebf 100644
--- a/rust/datafusion/src/prelude.rs
+++ b/rust/datafusion/src/prelude.rs
@@ -28,6 +28,6 @@
pub use crate::dataframe::DataFrame;
pub use crate::execution::context::{ExecutionConfig, ExecutionContext};
pub use crate::logical_plan::{
- avg, col, concat, count, create_udf, length, lit, max, min, sum,
+ array, avg, col, concat, count, create_udf, length, lit, max, min, sum,
};
pub use crate::physical_plan::csv::CsvReadOptions;
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index 9de59fb..a87dc79 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -609,6 +609,77 @@ fn execute(ctx: &mut ExecutionContext, sql: &str) ->
Vec<String> {
result_str(&results)
}
+/// Converts an array's value at `row_index` to a string.
+fn array_str(array: &Arc<dyn Array>, row_index: usize) -> String {
+ if array.is_null(row_index) {
+ return "NULL".to_string();
+ }
+ // beyond this point, we can assume that
`array...downcast().value(row_index)` is valid,
+ // due to the `if` above.
+
+ match array.data_type() {
+ DataType::Int8 => {
+ let array = array.as_any().downcast_ref::<Int8Array>().unwrap();
+ format!("{:?}", array.value(row_index))
+ }
+ DataType::Int16 => {
+ let array = array.as_any().downcast_ref::<Int16Array>().unwrap();
+ format!("{:?}", array.value(row_index))
+ }
+ DataType::Int32 => {
+ let array = array.as_any().downcast_ref::<Int32Array>().unwrap();
+ format!("{:?}", array.value(row_index))
+ }
+ DataType::Int64 => {
+ let array = array.as_any().downcast_ref::<Int64Array>().unwrap();
+ format!("{:?}", array.value(row_index))
+ }
+ DataType::UInt8 => {
+ let array = array.as_any().downcast_ref::<UInt8Array>().unwrap();
+ format!("{:?}", array.value(row_index))
+ }
+ DataType::UInt16 => {
+ let array = array.as_any().downcast_ref::<UInt16Array>().unwrap();
+ format!("{:?}", array.value(row_index))
+ }
+ DataType::UInt32 => {
+ let array = array.as_any().downcast_ref::<UInt32Array>().unwrap();
+ format!("{:?}", array.value(row_index))
+ }
+ DataType::UInt64 => {
+ let array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
+ format!("{:?}", array.value(row_index))
+ }
+ DataType::Float32 => {
+ let array = array.as_any().downcast_ref::<Float32Array>().unwrap();
+ format!("{:?}", array.value(row_index))
+ }
+ DataType::Float64 => {
+ let array = array.as_any().downcast_ref::<Float64Array>().unwrap();
+ format!("{:?}", array.value(row_index))
+ }
+ DataType::Utf8 => {
+ let array = array.as_any().downcast_ref::<StringArray>().unwrap();
+ format!("{:?}", array.value(row_index))
+ }
+ DataType::Boolean => {
+ let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
+ format!("{:?}", array.value(row_index))
+ }
+ DataType::FixedSizeList(_, n) => {
+ let array =
array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
+ let array = array.value(row_index);
+
+ let mut r = Vec::with_capacity(*n as usize);
+ for i in 0..*n {
+ r.push(array_str(&array, i as usize));
+ }
+ format!("[{}]", r.join(","))
+ }
+ _ => "???".to_string(),
+ }
+}
+
fn result_str(results: &[RecordBatch]) -> Vec<String> {
let mut result = vec![];
for batch in results {
@@ -620,76 +691,7 @@ fn result_str(results: &[RecordBatch]) -> Vec<String> {
}
let column = batch.column(column_index);
- match column.data_type() {
- DataType::Int8 => {
- let array =
column.as_any().downcast_ref::<Int8Array>().unwrap();
- str.push_str(&format!("{:?}", array.value(row_index)));
- }
- DataType::Int16 => {
- let array =
column.as_any().downcast_ref::<Int16Array>().unwrap();
- str.push_str(&format!("{:?}", array.value(row_index)));
- }
- DataType::Int32 => {
- let array =
column.as_any().downcast_ref::<Int32Array>().unwrap();
- str.push_str(&format!("{:?}", array.value(row_index)));
- }
- DataType::Int64 => {
- let array =
column.as_any().downcast_ref::<Int64Array>().unwrap();
- str.push_str(&format!("{:?}", array.value(row_index)));
- }
- DataType::UInt8 => {
- let array =
column.as_any().downcast_ref::<UInt8Array>().unwrap();
- str.push_str(&format!("{:?}", array.value(row_index)));
- }
- DataType::UInt16 => {
- let array =
-
column.as_any().downcast_ref::<UInt16Array>().unwrap();
- str.push_str(&format!("{:?}", array.value(row_index)));
- }
- DataType::UInt32 => {
- let array =
-
column.as_any().downcast_ref::<UInt32Array>().unwrap();
- str.push_str(&format!("{:?}", array.value(row_index)));
- }
- DataType::UInt64 => {
- let array =
-
column.as_any().downcast_ref::<UInt64Array>().unwrap();
- str.push_str(&format!("{:?}", array.value(row_index)));
- }
- DataType::Float32 => {
- let array =
-
column.as_any().downcast_ref::<Float32Array>().unwrap();
- str.push_str(&format!("{:?}", array.value(row_index)));
- }
- DataType::Float64 => {
- let array =
-
column.as_any().downcast_ref::<Float64Array>().unwrap();
- str.push_str(&format!("{:?}", array.value(row_index)));
- }
- DataType::Utf8 => {
- let array =
-
column.as_any().downcast_ref::<StringArray>().unwrap();
- let s = if array.is_null(row_index) {
- "NULL"
- } else {
- array.value(row_index)
- };
-
- str.push_str(&format!("{:?}", s));
- }
- DataType::Boolean => {
- let array =
-
column.as_any().downcast_ref::<BooleanArray>().unwrap();
- let s = if array.is_null(row_index) {
- "NULL".to_string()
- } else {
- format!("{:?}", array.value(row_index))
- };
-
- str.push_str(&s);
- }
- _ => str.push_str("???"),
- }
+ str.push_str(&array_str(column, row_index));
}
result.push(str);
}
@@ -762,7 +764,38 @@ fn query_concat() -> Result<()> {
ctx.register_table("test", Box::new(table));
let sql = "SELECT concat(c1, '-hi-', cast(c2 as varchar)) FROM test";
let actual = execute(&mut ctx, sql);
- let expected = vec!["\"-hi-0\"", "\"a-hi-1\"", "\"NULL\"", "\"aaa-hi-3\""];
+ let expected = vec!["\"-hi-0\"", "\"a-hi-1\"", "NULL", "\"aaa-hi-3\""];
+ assert_eq!(expected, actual);
+ Ok(())
+}
+
+#[test]
+fn query_array() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("c1", DataType::Utf8, false),
+ Field::new("c2", DataType::Int32, true),
+ ]));
+
+ let data = RecordBatch::try_new(
+ schema.clone(),
+ vec![
+ Arc::new(StringArray::from(vec!["", "a", "aa", "aaa"])),
+ Arc::new(Int32Array::from(vec![Some(0), Some(1), None, Some(3)])),
+ ],
+ )?;
+
+ let table = MemTable::new(schema, vec![vec![data]])?;
+
+ let mut ctx = ExecutionContext::new();
+ ctx.register_table("test", Box::new(table));
+ let sql = "SELECT array(c1, cast(c2 as varchar)) FROM test";
+ let actual = execute(&mut ctx, sql);
+ let expected = vec![
+ "[\"\",\"0\"]",
+ "[\"a\",\"1\"]",
+ "[\"aa\",NULL]",
+ "[\"aaa\",\"3\"]",
+ ];
assert_eq!(expected, actual);
Ok(())
}