This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 658206a668 Support fixed_size_list for make_array (#6759)
658206a668 is described below

commit 658206a66825e229304ee744715a88908d281b9b
Author: Jay Zhan <[email protected]>
AuthorDate: Thu Jul 6 01:35:20 2023 +0800

    Support fixed_size_list for make_array (#6759)
    
    * support make_array for fixed_size_list
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add arrow-typeof in test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix schema mismatch
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * cleanup code
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * create array data with correct len
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 datafusion/common/src/scalar.rs                    |  62 ++++++++-------
 .../core/tests/data/fixed_size_list_array.parquet  | Bin 0 -> 718 bytes
 .../core/tests/sqllogictests/test_files/array.slt  |  39 +++++++++-
 datafusion/optimizer/src/analyzer/type_coercion.rs |  86 +++++++++++++++++++--
 datafusion/physical-expr/src/array_expressions.rs  |   4 +-
 datafusion/proto/src/logical_plan/to_proto.rs      |   4 +
 datafusion/sql/src/expr/arrow_cast.rs              |  19 ++++-
 7 files changed, 174 insertions(+), 40 deletions(-)

diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 4fef60020f..b0769df1e9 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -101,7 +101,9 @@ pub enum ScalarValue {
     FixedSizeBinary(i32, Option<Vec<u8>>),
     /// large binary
     LargeBinary(Option<Vec<u8>>),
-    /// list of nested ScalarValue
+    /// Fixed size list of nested ScalarValue
+    Fixedsizelist(Option<Vec<ScalarValue>>, FieldRef, i32),
+    /// List of nested ScalarValue
     List(Option<Vec<ScalarValue>>, FieldRef),
     /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01
     Date32(Option<i32>),
@@ -196,6 +198,10 @@ impl PartialEq for ScalarValue {
             (FixedSizeBinary(_, _), _) => false,
             (LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2),
             (LargeBinary(_), _) => false,
+            (Fixedsizelist(v1, t1, l1), Fixedsizelist(v2, t2, l2)) => {
+                v1.eq(v2) && t1.eq(t2) && l1.eq(l2)
+            }
+            (Fixedsizelist(_, _, _), _) => false,
             (List(v1, t1), List(v2, t2)) => v1.eq(v2) && t1.eq(t2),
             (List(_, _), _) => false,
             (Date32(v1), Date32(v2)) => v1.eq(v2),
@@ -315,6 +321,14 @@ impl PartialOrd for ScalarValue {
             (FixedSizeBinary(_, _), _) => None,
             (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2),
             (LargeBinary(_), _) => None,
+            (Fixedsizelist(v1, t1, l1), Fixedsizelist(v2, t2, l2)) => {
+                if t1.eq(t2) && l1.eq(l2) {
+                    v1.partial_cmp(v2)
+                } else {
+                    None
+                }
+            }
+            (Fixedsizelist(_, _, _), _) => None,
             (List(v1, t1), List(v2, t2)) => {
                 if t1.eq(t2) {
                     v1.partial_cmp(v2)
@@ -1518,6 +1532,11 @@ impl std::hash::Hash for ScalarValue {
             Binary(v) => v.hash(state),
             FixedSizeBinary(_, v) => v.hash(state),
             LargeBinary(v) => v.hash(state),
+            Fixedsizelist(v, t, l) => {
+                v.hash(state);
+                t.hash(state);
+                l.hash(state);
+            }
             List(v, t) => {
                 v.hash(state);
                 t.hash(state);
@@ -1994,6 +2013,10 @@ impl ScalarValue {
             ScalarValue::Binary(_) => DataType::Binary,
             ScalarValue::FixedSizeBinary(sz, _) => 
DataType::FixedSizeBinary(*sz),
             ScalarValue::LargeBinary(_) => DataType::LargeBinary,
+            ScalarValue::Fixedsizelist(_, field, length) => 
DataType::FixedSizeList(
+                Arc::new(Field::new("item", field.data_type().clone(), true)),
+                *length,
+            ),
             ScalarValue::List(_, field) => DataType::List(Arc::new(Field::new(
                 "item",
                 field.data_type().clone(),
@@ -2142,6 +2165,7 @@ impl ScalarValue {
             ScalarValue::Binary(v) => v.is_none(),
             ScalarValue::FixedSizeBinary(_, v) => v.is_none(),
             ScalarValue::LargeBinary(v) => v.is_none(),
+            ScalarValue::Fixedsizelist(v, ..) => v.is_none(),
             ScalarValue::List(v, _) => v.is_none(),
             ScalarValue::Date32(v) => v.is_none(),
             ScalarValue::Date64(v) => v.is_none(),
@@ -2847,6 +2871,9 @@ impl ScalarValue {
                         .collect::<LargeBinaryArray>(),
                 ),
             },
+            ScalarValue::Fixedsizelist(..) => {
+                unimplemented!("FixedSizeList is not supported yet")
+            }
             ScalarValue::List(values, field) => Arc::new(match 
field.data_type() {
                 DataType::Boolean => build_list!(BooleanBuilder, Boolean, 
values, size),
                 DataType::Int8 => build_list!(Int8Builder, Int8, values, size),
@@ -3294,6 +3321,7 @@ impl ScalarValue {
             ScalarValue::LargeBinary(val) => {
                 eq_array_primitive!(array, index, LargeBinaryArray, val)
             }
+            ScalarValue::Fixedsizelist(..) => unimplemented!(),
             ScalarValue::List(_, _) => unimplemented!(),
             ScalarValue::Date32(val) => {
                 eq_array_primitive!(array, index, Date32Array, val)
@@ -3414,7 +3442,8 @@ impl ScalarValue {
                 | ScalarValue::LargeBinary(b) => {
                     b.as_ref().map(|b| b.capacity()).unwrap_or_default()
                 }
-                ScalarValue::List(vals, field) => {
+                ScalarValue::Fixedsizelist(vals, field, _)
+                | ScalarValue::List(vals, field) => {
                     vals.as_ref()
                         .map(|vals| Self::size_of_vec(vals) - 
std::mem::size_of_val(vals))
                         .unwrap_or_default()
@@ -3732,29 +3761,9 @@ impl fmt::Display for ScalarValue {
             ScalarValue::TimestampNanosecond(e, _) => format_option!(f, e)?,
             ScalarValue::Utf8(e) => format_option!(f, e)?,
             ScalarValue::LargeUtf8(e) => format_option!(f, e)?,
-            ScalarValue::Binary(e) => match e {
-                Some(l) => write!(
-                    f,
-                    "{}",
-                    l.iter()
-                        .map(|v| format!("{v}"))
-                        .collect::<Vec<_>>()
-                        .join(",")
-                )?,
-                None => write!(f, "NULL")?,
-            },
-            ScalarValue::FixedSizeBinary(_, e) => match e {
-                Some(l) => write!(
-                    f,
-                    "{}",
-                    l.iter()
-                        .map(|v| format!("{v}"))
-                        .collect::<Vec<_>>()
-                        .join(",")
-                )?,
-                None => write!(f, "NULL")?,
-            },
-            ScalarValue::LargeBinary(e) => match e {
+            ScalarValue::Binary(e)
+            | ScalarValue::FixedSizeBinary(_, e)
+            | ScalarValue::LargeBinary(e) => match e {
                 Some(l) => write!(
                     f,
                     "{}",
@@ -3765,7 +3774,7 @@ impl fmt::Display for ScalarValue {
                 )?,
                 None => write!(f, "NULL")?,
             },
-            ScalarValue::List(e, _) => match e {
+            ScalarValue::Fixedsizelist(e, ..) | ScalarValue::List(e, _) => 
match e {
                 Some(l) => write!(
                     f,
                     "{}",
@@ -3849,6 +3858,7 @@ impl fmt::Debug for ScalarValue {
             }
             ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({self})"),
             ScalarValue::LargeBinary(Some(_)) => write!(f, 
"LargeBinary(\"{self}\")"),
+            ScalarValue::Fixedsizelist(..) => write!(f, 
"FixedSizeList([{self}])"),
             ScalarValue::List(_, _) => write!(f, "List([{self}])"),
             ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"),
             ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"),
diff --git a/datafusion/core/tests/data/fixed_size_list_array.parquet 
b/datafusion/core/tests/data/fixed_size_list_array.parquet
new file mode 100644
index 0000000000..aafc5ce62f
Binary files /dev/null and 
b/datafusion/core/tests/data/fixed_size_list_array.parquet differ
diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt 
b/datafusion/core/tests/sqllogictests/test_files/array.slt
index 0d99e6cbb3..1f43c5f8e1 100644
--- a/datafusion/core/tests/sqllogictests/test_files/array.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/array.slt
@@ -417,8 +417,6 @@ select make_array(x, y) from foo2;
 
 # array_contains
 
-
-
 # array_contains scalar function #1
 query BBB rowsort
 select array_contains(make_array(1, 2, 3), make_array(1, 1, 2, 3)), 
array_contains([1, 2, 3], [1, 1, 2]), array_contains([1, 2, 3], [2, 1, 3, 1]);
@@ -531,3 +529,40 @@ SELECT
 FROM t
 ----
 true true
+
+statement ok
+CREATE EXTERNAL TABLE fixed_size_list_array STORED AS PARQUET LOCATION 
'tests/data/fixed_size_list_array.parquet';
+
+query T
+select arrow_typeof(f0) from fixed_size_list_array;
+----
+FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 
0, dict_is_ordered: false, metadata: {} }, 2)
+FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 
0, dict_is_ordered: false, metadata: {} }, 2)
+
+query ?
+select * from fixed_size_list_array;
+----
+[1, 2]
+[3, 4]
+
+query ?
+select f0 from fixed_size_list_array;
+----
+[1, 2]
+[3, 4]
+
+query ?
+select arrow_cast(f0, 'List(Int64)') from fixed_size_list_array;
+----
+[1, 2]
+[3, 4]
+
+query ?
+select make_array(arrow_cast(f0, 'List(Int64)')) from fixed_size_list_array
+----
+[[1, 2], [3, 4]]
+
+query ?
+select make_array(f0) from fixed_size_list_array
+----
+[[1, 2], [3, 4]]
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs 
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 5d1fef5352..7cf4a233f7 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -330,8 +330,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                     &self.schema,
                     &fun.signature,
                 )?;
-                let expr = Expr::ScalarUDF(ScalarUDF::new(fun, new_expr));
-                Ok(expr)
+                Ok(Expr::ScalarUDF(ScalarUDF::new(fun, new_expr)))
             }
             Expr::ScalarFunction(ScalarFunction { fun, args }) => {
                 let new_args = coerce_arguments_for_signature(
@@ -520,7 +519,7 @@ fn coerce_window_frame(
 fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> 
Result<Expr> {
     let left_type = expr.get_type(schema)?;
     get_input_types(&left_type, &Operator::IsDistinctFrom, 
&DataType::Boolean)?;
-    expr.clone().cast_to(&DataType::Boolean, schema)
+    cast_expr(expr, &DataType::Boolean, schema)
 }
 
 /// Returns `expressions` coerced to types compatible with
@@ -559,6 +558,25 @@ fn coerce_arguments_for_fun(
         return Ok(vec![]);
     }
 
+    let mut expressions: Vec<Expr> = expressions.to_vec();
+
+    // Cast Fixedsizelist to List for array functions
+    if *fun == BuiltinScalarFunction::MakeArray {
+        expressions = expressions
+            .into_iter()
+            .map(|expr| {
+                let data_type = expr.get_type(schema).unwrap();
+                if let DataType::FixedSizeList(field, _) = data_type {
+                    let field = field.as_ref().clone();
+                    let to_type = DataType::List(Arc::new(field));
+                    expr.cast_to(&to_type, schema)
+                } else {
+                    Ok(expr)
+                }
+            })
+            .collect::<Result<Vec<_>>>()?;
+    }
+
     if *fun == BuiltinScalarFunction::MakeArray {
         // Find the final data type for the function arguments
         let current_types = expressions
@@ -579,8 +597,7 @@ fn coerce_arguments_for_fun(
             .map(|(expr, from_type)| cast_array_expr(expr, &from_type, 
&new_type, schema))
             .collect();
     }
-
-    Ok(expressions.to_vec())
+    Ok(expressions)
 }
 
 /// Cast `expr` to the specified type, if possible
@@ -598,7 +615,7 @@ fn cast_array_expr(
     if from_type.equals_datatype(&DataType::Null) {
         Ok(expr.clone())
     } else {
-        expr.clone().cast_to(to_type, schema)
+        cast_expr(expr, to_type, schema)
     }
 }
 
@@ -625,7 +642,7 @@ fn coerce_agg_exprs_for_signature(
     input_exprs
         .iter()
         .enumerate()
-        .map(|(i, expr)| expr.clone().cast_to(&coerced_types[i], schema))
+        .map(|(i, expr)| cast_expr(expr, &coerced_types[i], schema))
         .collect::<Result<Vec<_>>>()
 }
 
@@ -746,6 +763,7 @@ mod test {
 
     use arrow::datatypes::{DataType, TimeUnit};
 
+    use arrow::datatypes::Field;
     use datafusion_common::tree_node::TreeNode;
     use datafusion_common::{DFField, DFSchema, DFSchemaRef, Result, 
ScalarValue};
     use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction};
@@ -763,7 +781,7 @@ mod test {
     use datafusion_physical_expr::expressions::AvgAccumulator;
 
     use crate::analyzer::type_coercion::{
-        coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
+        cast_expr, coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
     };
     use crate::test::assert_analyzed_plan_eq;
 
@@ -1220,6 +1238,58 @@ mod test {
         Ok(())
     }
 
+    #[test]
+    fn test_casting_for_fixed_size_list() -> Result<()> {
+        let val = lit(ScalarValue::Fixedsizelist(
+            Some(vec![
+                ScalarValue::from(1i32),
+                ScalarValue::from(2i32),
+                ScalarValue::from(3i32),
+            ]),
+            Arc::new(Field::new("item", DataType::Int32, true)),
+            3,
+        ));
+        let expr = Expr::ScalarFunction(ScalarFunction {
+            fun: BuiltinScalarFunction::MakeArray,
+            args: vec![val.clone()],
+        });
+        let schema = Arc::new(DFSchema::new_with_metadata(
+            vec![DFField::new_unqualified(
+                "item",
+                DataType::FixedSizeList(
+                    Arc::new(Field::new("a", DataType::Int32, true)),
+                    3,
+                ),
+                true,
+            )],
+            std::collections::HashMap::new(),
+        )?);
+        let mut rewriter = TypeCoercionRewriter { schema };
+        let result = expr.rewrite(&mut rewriter)?;
+
+        let schema = Arc::new(DFSchema::new_with_metadata(
+            vec![DFField::new_unqualified(
+                "item",
+                DataType::List(Arc::new(Field::new("a", DataType::Int32, 
true))),
+                true,
+            )],
+            std::collections::HashMap::new(),
+        )?);
+        let expected_casted_expr = cast_expr(
+            &val,
+            &DataType::List(Arc::new(Field::new("item", DataType::Int32, 
true))),
+            &schema,
+        )?;
+
+        let expected = Expr::ScalarFunction(ScalarFunction {
+            fun: BuiltinScalarFunction::MakeArray,
+            args: vec![expected_casted_expr],
+        });
+
+        assert_eq!(result, expected);
+        Ok(())
+    }
+
     #[test]
     fn test_type_coercion_rewrite() -> Result<()> {
         // gt
diff --git a/datafusion/physical-expr/src/array_expressions.rs 
b/datafusion/physical-expr/src/array_expressions.rs
index 911c94b06d..bddeef526a 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -111,7 +111,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> 
Result<ArrayRef> {
         DataType::List(..) => {
             let arrays =
                 downcast_vec!(args, 
ListArray).collect::<Result<Vec<&ListArray>>>()?;
-            let len: i32 = arrays.len() as i32;
+            let len = arrays.iter().map(|arr| arr.len() as i32).sum();
             let capacity =
                 Capacities::Array(arrays.iter().map(|a| 
a.get_array_memory_size()).sum());
             let array_data: Vec<_> =
@@ -125,7 +125,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> 
Result<ArrayRef> {
             }
 
             let list_data_type =
-                DataType::List(Arc::new(Field::new("item", data_type, false)));
+                DataType::List(Arc::new(Field::new("item", data_type, true)));
 
             let list_data = ArrayData::builder(list_data_type)
                 .len(1)
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index a046be35d4..4a4b16db80 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -1068,6 +1068,10 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue {
                     Value::LargeUtf8Value(s.to_owned())
                 })
             }
+            ScalarValue::Fixedsizelist(..) => Err(Error::General(
+                "Proto serialization error: ScalarValue::Fixedsizelist not 
supported"
+                    .to_string(),
+            )),
             ScalarValue::List(values, boxed_field) => {
                 let is_null = values.is_none();
 
diff --git a/datafusion/sql/src/expr/arrow_cast.rs 
b/datafusion/sql/src/expr/arrow_cast.rs
index 91a42f4736..46957a9cdd 100644
--- a/datafusion/sql/src/expr/arrow_cast.rs
+++ b/datafusion/sql/src/expr/arrow_cast.rs
@@ -18,9 +18,9 @@
 //! Implementation of the `arrow_cast` function that allows
 //! casting to arbitrary arrow types (rather than SQL types)
 
-use std::{fmt::Display, iter::Peekable, str::Chars};
+use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc};
 
-use arrow_schema::{DataType, IntervalUnit, TimeUnit};
+use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit};
 use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue};
 
 use datafusion_expr::{Expr, ExprSchemable};
@@ -150,6 +150,7 @@ impl<'a> Parser<'a> {
             Token::Decimal128 => self.parse_decimal_128(),
             Token::Decimal256 => self.parse_decimal_256(),
             Token::Dictionary => self.parse_dictionary(),
+            Token::List => self.parse_list(),
             tok => Err(make_error(
                 self.val,
                 &format!("finding next type, got unexpected '{tok}'"),
@@ -157,6 +158,16 @@ impl<'a> Parser<'a> {
         }
     }
 
+    /// Parses the List type
+    fn parse_list(&mut self) -> Result<DataType> {
+        self.expect_token(Token::LParen)?;
+        let data_type = self.parse_next_type()?;
+        self.expect_token(Token::RParen)?;
+        Ok(DataType::List(Arc::new(Field::new(
+            "item", data_type, true,
+        ))))
+    }
+
     /// Parses the next timeunit
     fn parse_time_unit(&mut self, context: &str) -> Result<TimeUnit> {
         match self.next_token()? {
@@ -486,6 +497,8 @@ impl<'a> Tokenizer<'a> {
             "Date32" => Token::SimpleType(DataType::Date32),
             "Date64" => Token::SimpleType(DataType::Date64),
 
+            "List" => Token::List,
+
             "Second" => Token::TimeUnit(TimeUnit::Second),
             "Millisecond" => Token::TimeUnit(TimeUnit::Millisecond),
             "Microsecond" => Token::TimeUnit(TimeUnit::Microsecond),
@@ -573,12 +586,14 @@ enum Token {
     None,
     Integer(i64),
     DoubleQuotedString(String),
+    List,
 }
 
 impl Display for Token {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         match self {
             Token::SimpleType(t) => write!(f, "{t}"),
+            Token::List => write!(f, "List"),
             Token::Timestamp => write!(f, "Timestamp"),
             Token::Time32 => write!(f, "Time32"),
             Token::Time64 => write!(f, "Time64"),

Reply via email to