This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 4929590  Box ScalarValue:Lists, reduce size by half size (#788)
4929590 is described below

commit 4929590eea506608e1d8d425a8801fc21d8c8f45
Author: Andrew Lamb <[email protected]>
AuthorDate: Wed Jul 28 14:44:02 2021 -0400

    Box ScalarValue:Lists, reduce size by half size (#788)
    
    * Box ScalarValue:Lists DataType, reduce to half size
    
    * Fixup ballista
---
 .../rust/core/src/serde/logical_plan/from_proto.rs | 16 ++--
 ballista/rust/core/src/serde/logical_plan/mod.rs   | 76 ++++++++++-------
 .../rust/core/src/serde/logical_plan/to_proto.rs   | 33 +++++---
 .../src/physical_plan/distinct_expressions.rs      |  8 +-
 datafusion/src/scalar.rs                           | 98 +++++++++++++---------
 5 files changed, 140 insertions(+), 91 deletions(-)

diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs 
b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
index 38b5257..2665e33 100644
--- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
@@ -536,7 +536,7 @@ impl TryInto<datafusion::scalar::ScalarValue> for 
&protobuf::scalar_value::Value
             }
             protobuf::scalar_value::Value::ListValue(v) => v.try_into()?,
             protobuf::scalar_value::Value::NullListValue(v) => {
-                ScalarValue::List(None, v.try_into()?)
+                ScalarValue::List(None, Box::new(v.try_into()?))
             }
             protobuf::scalar_value::Value::NullValue(null_enum) => {
                 PrimitiveScalarType::from_i32(*null_enum)
@@ -580,8 +580,8 @@ impl TryInto<datafusion::scalar::ScalarValue> for 
&protobuf::ScalarListValue {
                     })
                     .collect::<Result<Vec<_>, _>>()?;
                 datafusion::scalar::ScalarValue::List(
-                    Some(typechecked_values),
-                    leaf_scalar_type.into(),
+                    Some(Box::new(typechecked_values)),
+                    Box::new(leaf_scalar_type.into()),
                 )
             }
             Datatype::List(list_type) => {
@@ -625,9 +625,9 @@ impl TryInto<datafusion::scalar::ScalarValue> for 
&protobuf::ScalarListValue {
                 datafusion::scalar::ScalarValue::List(
                     match typechecked_values.len() {
                         0 => None,
-                        _ => Some(typechecked_values),
+                        _ => Some(Box::new(typechecked_values)),
                     },
-                    list_type.try_into()?,
+                    Box::new(list_type.try_into()?),
                 )
             }
         };
@@ -765,14 +765,16 @@ impl TryInto<datafusion::scalar::ScalarValue> for 
&protobuf::ScalarValue {
                     .map(|val| val.try_into())
                     .collect::<Result<Vec<_>, _>>()?;
                 let scalar_type: DataType = pb_scalar_type.try_into()?;
-                ScalarValue::List(Some(typechecked_values), scalar_type)
+                let scalar_type = Box::new(scalar_type);
+                ScalarValue::List(Some(Box::new(typechecked_values)), 
scalar_type)
             }
             protobuf::scalar_value::Value::NullListValue(v) => {
                 let pb_datatype = v
                     .datatype
                     .as_ref()
                     .ok_or_else(|| proto_error("Protobuf deserialization 
error: NullListValue message missing required field 'datatyp'"))?;
-                ScalarValue::List(None, pb_datatype.try_into()?)
+                let pb_datatype = Box::new(pb_datatype.try_into()?);
+                ScalarValue::List(None, pb_datatype)
             }
             protobuf::scalar_value::Value::NullValue(v) => {
                 let null_type_enum = 
protobuf::PrimitiveScalarType::from_i32(*v)
diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs 
b/ballista/rust/core/src/serde/logical_plan/mod.rs
index f6dbeaf..e4e4383 100644
--- a/ballista/rust/core/src/serde/logical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/logical_plan/mod.rs
@@ -126,49 +126,57 @@ mod roundtrip_tests {
         let should_fail_on_seralize: Vec<ScalarValue> = vec![
             //Should fail due to inconsistent types
             ScalarValue::List(
-                Some(vec![
+                Some(Box::new(vec![
                     ScalarValue::Int16(None),
                     ScalarValue::Float32(Some(32.0)),
-                ]),
-                DataType::List(new_box_field("item", DataType::Int16, true)),
+                ])),
+                Box::new(DataType::List(new_box_field("item", DataType::Int16, 
true))),
             ),
             ScalarValue::List(
-                Some(vec![
+                Some(Box::new(vec![
                     ScalarValue::Float32(None),
                     ScalarValue::Float32(Some(32.0)),
-                ]),
-                DataType::List(new_box_field("item", DataType::Int16, true)),
+                ])),
+                Box::new(DataType::List(new_box_field("item", DataType::Int16, 
true))),
             ),
             ScalarValue::List(
-                Some(vec![
+                Some(Box::new(vec![
                     ScalarValue::List(
                         None,
-                        DataType::List(new_box_field("level2", 
DataType::Float32, true)),
+                        Box::new(DataType::List(new_box_field(
+                            "level2",
+                            DataType::Float32,
+                            true,
+                        ))),
                     ),
                     ScalarValue::List(
-                        Some(vec![
+                        Some(Box::new(vec![
                             ScalarValue::Float32(Some(-213.1)),
                             ScalarValue::Float32(None),
                             ScalarValue::Float32(Some(5.5)),
                             ScalarValue::Float32(Some(2.0)),
                             ScalarValue::Float32(Some(1.0)),
-                        ]),
-                        DataType::List(new_box_field("level2", 
DataType::Float32, true)),
+                        ])),
+                        Box::new(DataType::List(new_box_field(
+                            "level2",
+                            DataType::Float32,
+                            true,
+                        ))),
                     ),
                     ScalarValue::List(
                         None,
-                        DataType::List(new_box_field(
+                        Box::new(DataType::List(new_box_field(
                             "lists are typed inconsistently",
                             DataType::Int16,
                             true,
-                        )),
+                        ))),
                     ),
-                ]),
-                DataType::List(new_box_field(
+                ])),
+                Box::new(DataType::List(new_box_field(
                     "level1",
                     DataType::List(new_box_field("level2", DataType::Float32, 
true)),
                     true,
-                )),
+                ))),
             ),
         ];
 
@@ -200,7 +208,7 @@ mod roundtrip_tests {
             ScalarValue::UInt64(None),
             ScalarValue::Utf8(None),
             ScalarValue::LargeUtf8(None),
-            ScalarValue::List(None, DataType::Boolean),
+            ScalarValue::List(None, Box::new(DataType::Boolean)),
             ScalarValue::Date32(None),
             ScalarValue::TimestampMicrosecond(None),
             ScalarValue::TimestampNanosecond(None),
@@ -248,37 +256,49 @@ mod roundtrip_tests {
             ScalarValue::TimestampMicrosecond(Some(i64::MAX)),
             ScalarValue::TimestampMicrosecond(None),
             ScalarValue::List(
-                Some(vec![
+                Some(Box::new(vec![
                     ScalarValue::Float32(Some(-213.1)),
                     ScalarValue::Float32(None),
                     ScalarValue::Float32(Some(5.5)),
                     ScalarValue::Float32(Some(2.0)),
                     ScalarValue::Float32(Some(1.0)),
-                ]),
-                DataType::List(new_box_field("level1", DataType::Float32, 
true)),
+                ])),
+                Box::new(DataType::List(new_box_field(
+                    "level1",
+                    DataType::Float32,
+                    true,
+                ))),
             ),
             ScalarValue::List(
-                Some(vec![
+                Some(Box::new(vec![
                     ScalarValue::List(
                         None,
-                        DataType::List(new_box_field("level2", 
DataType::Float32, true)),
+                        Box::new(DataType::List(new_box_field(
+                            "level2",
+                            DataType::Float32,
+                            true,
+                        ))),
                     ),
                     ScalarValue::List(
-                        Some(vec![
+                        Some(Box::new(vec![
                             ScalarValue::Float32(Some(-213.1)),
                             ScalarValue::Float32(None),
                             ScalarValue::Float32(Some(5.5)),
                             ScalarValue::Float32(Some(2.0)),
                             ScalarValue::Float32(Some(1.0)),
-                        ]),
-                        DataType::List(new_box_field("level2", 
DataType::Float32, true)),
+                        ])),
+                        Box::new(DataType::List(new_box_field(
+                            "level2",
+                            DataType::Float32,
+                            true,
+                        ))),
                     ),
-                ]),
-                DataType::List(new_box_field(
+                ])),
+                Box::new(DataType::List(new_box_field(
                     "level1",
                     DataType::List(new_box_field("level2", DataType::Float32, 
true)),
                     true,
-                )),
+                ))),
             ),
         ];
 
diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs 
b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
index 07d7a59..87f26a1 100644
--- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
@@ -565,13 +565,13 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for 
protobuf::ScalarValue {
                             protobuf::ScalarValue {
                                 value: 
Some(protobuf::scalar_value::Value::ListValue(
                                     protobuf::ScalarListValue {
-                                        datatype: Some(datatype.try_into()?),
+                                        datatype: 
Some(datatype.as_ref().try_into()?),
                                         values: Vec::new(),
                                     },
                                 )),
                             }
                         } else {
-                            let scalar_type = match datatype {
+                            let scalar_type = match datatype.as_ref() {
                                 DataType::List(field) => 
field.as_ref().data_type(),
                                 _ => todo!("Proper error handling"),
                             };
@@ -579,16 +579,23 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for 
protobuf::ScalarValue {
                             let type_checked_values: 
Vec<protobuf::ScalarValue> = values
                                 .iter()
                                 .map(|scalar| match (scalar, scalar_type) {
-                                    (scalar::ScalarValue::List(_, 
DataType::List(list_field)), DataType::List(field)) => {
-                                        let scalar_datatype = 
field.data_type();
-                                        let list_datatype = 
list_field.data_type();
-                                        if 
std::mem::discriminant(list_datatype) != 
std::mem::discriminant(scalar_datatype) {
-                                            return Err(proto_error(format!(
-                                                "Protobuf serialization error: 
Lists with inconsistent typing {:?} and {:?} found within list",
-                                                list_datatype, scalar_datatype
-                                            )));
+                                    (scalar::ScalarValue::List(_, list_type), 
DataType::List(field)) => {
+                                        if let DataType::List(list_field) = 
list_type.as_ref() {
+                                            let scalar_datatype = 
field.data_type();
+                                            let list_datatype = 
list_field.data_type();
+                                            if 
std::mem::discriminant(list_datatype) != 
std::mem::discriminant(scalar_datatype) {
+                                                return Err(proto_error(format!(
+                                                    "Protobuf serialization 
error: Lists with inconsistent typing {:?} and {:?} found within list",
+                                                    list_datatype, 
scalar_datatype
+                                                )));
+                                            }
+                                            scalar.try_into()
+                                        } else {
+                                            Err(proto_error(format!(
+                                                "Protobuf serialization error, 
{:?} was inconsistent with designated type {:?}",
+                                                scalar, datatype
+                                            )))
                                         }
-                                        scalar.try_into()
                                     }
                                     (scalar::ScalarValue::Boolean(_), 
DataType::Boolean) => scalar.try_into(),
                                     (scalar::ScalarValue::Float32(_), 
DataType::Float32) => scalar.try_into(),
@@ -612,7 +619,7 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for 
protobuf::ScalarValue {
                             protobuf::ScalarValue {
                                 value: 
Some(protobuf::scalar_value::Value::ListValue(
                                     protobuf::ScalarListValue {
-                                        datatype: Some(datatype.try_into()?),
+                                        datatype: 
Some(datatype.as_ref().try_into()?),
                                         values: type_checked_values,
                                     },
                                 )),
@@ -621,7 +628,7 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for 
protobuf::ScalarValue {
                     }
                     None => protobuf::ScalarValue {
                         value: 
Some(protobuf::scalar_value::Value::NullListValue(
-                            datatype.try_into()?,
+                            datatype.as_ref().try_into()?,
                         )),
                     },
                 }
diff --git a/datafusion/src/physical_plan/distinct_expressions.rs 
b/datafusion/src/physical_plan/distinct_expressions.rs
index f3513c2..90c0836 100644
--- a/datafusion/src/physical_plan/distinct_expressions.rs
+++ b/datafusion/src/physical_plan/distinct_expressions.rs
@@ -178,7 +178,9 @@ impl Accumulator for DistinctCountAccumulator {
             .state_data_types
             .iter()
             .map(|state_data_type| {
-                ScalarValue::List(Some(Vec::new()), state_data_type.clone())
+                let values = Box::new(Vec::new());
+                let data_type = Box::new(state_data_type.clone());
+                ScalarValue::List(Some(values), data_type)
             })
             .collect::<Vec<_>>();
 
@@ -254,8 +256,8 @@ mod tests {
     macro_rules! state_to_vec {
         ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{
             match $LIST {
-                ScalarValue::List(_, data_type) => match data_type {
-                    DataType::$DATA_TYPE => (),
+                ScalarValue::List(_, data_type) => match data_type.as_ref() {
+                    &DataType::$DATA_TYPE => (),
                     _ => panic!("Unexpected DataType for list"),
                 },
                 _ => panic!("Expected a ScalarValue::List"),
diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs
index ab08364..129b416 100644
--- a/datafusion/src/scalar.rs
+++ b/datafusion/src/scalar.rs
@@ -65,8 +65,9 @@ pub enum ScalarValue {
     Binary(Option<Vec<u8>>),
     /// large binary
     LargeBinary(Option<Vec<u8>>),
-    /// list of nested ScalarValue
-    List(Option<Vec<ScalarValue>>, DataType),
+    /// list of nested ScalarValue (boxed to reduce size_of(ScalarValue))
+    #[allow(clippy::box_vec)]
+    List(Option<Box<Vec<ScalarValue>>>, Box<DataType>),
     /// Date stored as a signed 32bit int
     Date32(Option<i32>),
     /// Date stored as a signed 64bit int
@@ -110,7 +111,7 @@ macro_rules! build_list {
                 )
             }
             Some(values) => {
-                build_values_list!($VALUE_BUILDER_TY, $SCALAR_TY, values, 
$SIZE)
+                build_values_list!($VALUE_BUILDER_TY, $SCALAR_TY, 
values.as_ref(), $SIZE)
             }
         }
     }};
@@ -130,32 +131,35 @@ macro_rules! build_timestamp_list {
                     $SIZE,
                 )
             }
-            Some(values) => match $TIME_UNIT {
-                TimeUnit::Second => build_values_list!(
-                    TimestampSecondBuilder,
-                    TimestampSecond,
-                    values,
-                    $SIZE
-                ),
-                TimeUnit::Microsecond => build_values_list!(
-                    TimestampMillisecondBuilder,
-                    TimestampMillisecond,
-                    values,
-                    $SIZE
-                ),
-                TimeUnit::Millisecond => build_values_list!(
-                    TimestampMicrosecondBuilder,
-                    TimestampMicrosecond,
-                    values,
-                    $SIZE
-                ),
-                TimeUnit::Nanosecond => build_values_list!(
-                    TimestampNanosecondBuilder,
-                    TimestampNanosecond,
-                    values,
-                    $SIZE
-                ),
-            },
+            Some(values) => {
+                let values = values.as_ref();
+                match $TIME_UNIT {
+                    TimeUnit::Second => build_values_list!(
+                        TimestampSecondBuilder,
+                        TimestampSecond,
+                        values,
+                        $SIZE
+                    ),
+                    TimeUnit::Microsecond => build_values_list!(
+                        TimestampMillisecondBuilder,
+                        TimestampMillisecond,
+                        values,
+                        $SIZE
+                    ),
+                    TimeUnit::Millisecond => build_values_list!(
+                        TimestampMicrosecondBuilder,
+                        TimestampMicrosecond,
+                        values,
+                        $SIZE
+                    ),
+                    TimeUnit::Nanosecond => build_values_list!(
+                        TimestampNanosecondBuilder,
+                        TimestampNanosecond,
+                        values,
+                        $SIZE
+                    ),
+                }
+            }
         }
     }};
 }
@@ -235,9 +239,11 @@ impl ScalarValue {
             ScalarValue::LargeUtf8(_) => DataType::LargeUtf8,
             ScalarValue::Binary(_) => DataType::Binary,
             ScalarValue::LargeBinary(_) => DataType::LargeBinary,
-            ScalarValue::List(_, data_type) => {
-                DataType::List(Box::new(Field::new("item", data_type.clone(), 
true)))
-            }
+            ScalarValue::List(_, data_type) => 
DataType::List(Box::new(Field::new(
+                "item",
+                data_type.as_ref().clone(),
+                true,
+            ))),
             ScalarValue::Date32(_) => DataType::Date32,
             ScalarValue::Date64(_) => DataType::Date64,
             ScalarValue::IntervalYearMonth(_) => {
@@ -415,6 +421,7 @@ impl ScalarValue {
                 for scalar in scalars.into_iter() {
                     match scalar {
                         ScalarValue::List(Some(xs), _) => {
+                            let xs = *xs;
                             for s in xs {
                                 match s {
                                     ScalarValue::$SCALAR_TY(Some(val)) => {
@@ -627,7 +634,7 @@ impl ScalarValue {
                         .collect::<LargeBinaryArray>(),
                 ),
             },
-            ScalarValue::List(values, data_type) => Arc::new(match data_type {
+            ScalarValue::List(values, data_type) => Arc::new(match 
data_type.as_ref() {
                 DataType::Boolean => build_list!(BooleanBuilder, Boolean, 
values, size),
                 DataType::Int8 => build_list!(Int8Builder, Int8, values, size),
                 DataType::Int16 => build_list!(Int16Builder, Int16, values, 
size),
@@ -643,7 +650,7 @@ impl ScalarValue {
                 DataType::Timestamp(unit, tz) => {
                     build_timestamp_list!(unit.clone(), tz.clone(), values, 
size)
                 }
-                DataType::LargeUtf8 => {
+                &DataType::LargeUtf8 => {
                     build_list!(LargeStringBuilder, LargeUtf8, values, size)
                 }
                 dt => panic!("Unexpected DataType for list {:?}", dt),
@@ -705,7 +712,9 @@ impl ScalarValue {
                         Some(scalar_vec)
                     }
                 };
-                ScalarValue::List(value, nested_type.data_type().clone())
+                let value = value.map(Box::new);
+                let data_type = Box::new(nested_type.data_type().clone());
+                ScalarValue::List(value, data_type)
             }
             DataType::Date32 => {
                 typed_cast!(array, index, Date32Array, Date32)
@@ -965,7 +974,7 @@ impl TryFrom<&DataType> for ScalarValue {
                 ScalarValue::TimestampNanosecond(None)
             }
             DataType::List(ref nested_type) => {
-                ScalarValue::List(None, nested_type.data_type().clone())
+                ScalarValue::List(None, 
Box::new(nested_type.data_type().clone()))
             }
             _ => {
                 return Err(DataFusionError::NotImplemented(format!(
@@ -1167,7 +1176,8 @@ mod tests {
 
     #[test]
     fn scalar_list_null_to_array() {
-        let list_array_ref = ScalarValue::List(None, 
DataType::UInt64).to_array();
+        let list_array_ref =
+            ScalarValue::List(None, Box::new(DataType::UInt64)).to_array();
         let list_array = 
list_array_ref.as_any().downcast_ref::<ListArray>().unwrap();
 
         assert!(list_array.is_null(0));
@@ -1178,12 +1188,12 @@ mod tests {
     #[test]
     fn scalar_list_to_array() {
         let list_array_ref = ScalarValue::List(
-            Some(vec![
+            Some(Box::new(vec![
                 ScalarValue::UInt64(Some(100)),
                 ScalarValue::UInt64(None),
                 ScalarValue::UInt64(Some(101)),
-            ]),
-            DataType::UInt64,
+            ])),
+            Box::new(DataType::UInt64),
         )
         .to_array();
 
@@ -1336,4 +1346,12 @@ mod tests {
         assert!(result.to_string().contains("Inconsistent types in 
ScalarValue::iter_to_array. Expected Boolean, got Int32(5)"),
                 "{}", result);
     }
+
+    #[test]
+    fn size_of_scalar() {
+        // Since ScalarValues are used in a non trivial number of places,
+        // making it larger means significant more memory consumption
+        // per distinct value.
+        assert_eq!(std::mem::size_of::<ScalarValue>(), 32);
+    }
 }

Reply via email to