This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new a4941ee  Remove GroupByScalar and use ScalarValue instead (#786)
a4941ee is described below

commit a4941ee3a0e9dc630b6a144ccd83e577f61e0958
Author: Andrew Lamb <[email protected]>
AuthorDate: Fri Jul 30 13:40:58 2021 -0400

    Remove GroupByScalar and use ScalarValue instead (#786)
---
 .../src/physical_plan/distinct_expressions.rs      |  13 +-
 datafusion/src/physical_plan/group_scalar.rs       | 217 ---------------------
 datafusion/src/physical_plan/hash_aggregate.rs     | 139 ++-----------
 datafusion/src/physical_plan/mod.rs                |   1 -
 datafusion/src/scalar.rs                           | 214 ++++++++++++++++++--
 5 files changed, 215 insertions(+), 369 deletions(-)

diff --git a/datafusion/src/physical_plan/distinct_expressions.rs 
b/datafusion/src/physical_plan/distinct_expressions.rs
index 90c0836..ae60253 100644
--- a/datafusion/src/physical_plan/distinct_expressions.rs
+++ b/datafusion/src/physical_plan/distinct_expressions.rs
@@ -18,7 +18,6 @@
 //! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)`
 
 use std::any::Any;
-use std::convert::TryFrom;
 use std::fmt::Debug;
 use std::hash::Hash;
 use std::sync::Arc;
@@ -29,12 +28,11 @@ use ahash::RandomState;
 use std::collections::HashSet;
 
 use crate::error::{DataFusionError, Result};
-use crate::physical_plan::group_scalar::GroupByScalar;
 use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
 use crate::scalar::ScalarValue;
 
 #[derive(Debug, PartialEq, Eq, Hash, Clone)]
-struct DistinctScalarValues(Vec<GroupByScalar>);
+struct DistinctScalarValues(Vec<ScalarValue>);
 
 fn format_state_name(name: &str, state_name: &str) -> String {
     format!("{}[{}]", name, state_name)
@@ -137,12 +135,7 @@ impl Accumulator for DistinctCountAccumulator {
     fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
         // If a row has a NULL, it is not included in the final count.
         if !values.iter().any(|v| v.is_null()) {
-            self.values.insert(DistinctScalarValues(
-                values
-                    .iter()
-                    .map(GroupByScalar::try_from)
-                    .collect::<Result<Vec<_>>>()?,
-            ));
+            self.values.insert(DistinctScalarValues(values.to_vec()));
         }
 
         Ok(())
@@ -195,7 +188,7 @@ impl Accumulator for DistinctCountAccumulator {
         self.values.iter().for_each(|distinct_values| {
             distinct_values.0.iter().enumerate().for_each(
                 |(col_index, distinct_value)| {
-                    
cols_vec[col_index].push(ScalarValue::from(distinct_value));
+                    cols_vec[col_index].push(distinct_value.clone());
                 },
             )
         });
diff --git a/datafusion/src/physical_plan/group_scalar.rs 
b/datafusion/src/physical_plan/group_scalar.rs
deleted file mode 100644
index d5f72b0..0000000
--- a/datafusion/src/physical_plan/group_scalar.rs
+++ /dev/null
@@ -1,217 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! Defines scalars used to construct groups, ex. in GROUP BY clauses.
-
-use ordered_float::OrderedFloat;
-use std::convert::{From, TryFrom};
-
-use crate::error::{DataFusionError, Result};
-use crate::scalar::ScalarValue;
-
-/// Enumeration of types that can be used in a GROUP BY expression
-#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
-pub(crate) enum GroupByScalar {
-    Float32(OrderedFloat<f32>),
-    Float64(OrderedFloat<f64>),
-    UInt8(u8),
-    UInt16(u16),
-    UInt32(u32),
-    UInt64(u64),
-    Int8(i8),
-    Int16(i16),
-    Int32(i32),
-    Int64(i64),
-    Utf8(Box<String>),
-    LargeUtf8(Box<String>),
-    Boolean(bool),
-    TimeMillisecond(i64),
-    TimeMicrosecond(i64),
-    TimeNanosecond(i64),
-    Date32(i32),
-}
-
-impl TryFrom<&ScalarValue> for GroupByScalar {
-    type Error = DataFusionError;
-
-    fn try_from(scalar_value: &ScalarValue) -> Result<Self> {
-        Ok(match scalar_value {
-            ScalarValue::Float32(Some(v)) => {
-                GroupByScalar::Float32(OrderedFloat::from(*v))
-            }
-            ScalarValue::Float64(Some(v)) => {
-                GroupByScalar::Float64(OrderedFloat::from(*v))
-            }
-            ScalarValue::Boolean(Some(v)) => GroupByScalar::Boolean(*v),
-            ScalarValue::Int8(Some(v)) => GroupByScalar::Int8(*v),
-            ScalarValue::Int16(Some(v)) => GroupByScalar::Int16(*v),
-            ScalarValue::Int32(Some(v)) => GroupByScalar::Int32(*v),
-            ScalarValue::Int64(Some(v)) => GroupByScalar::Int64(*v),
-            ScalarValue::UInt8(Some(v)) => GroupByScalar::UInt8(*v),
-            ScalarValue::UInt16(Some(v)) => GroupByScalar::UInt16(*v),
-            ScalarValue::UInt32(Some(v)) => GroupByScalar::UInt32(*v),
-            ScalarValue::UInt64(Some(v)) => GroupByScalar::UInt64(*v),
-            ScalarValue::TimestampMillisecond(Some(v)) => {
-                GroupByScalar::TimeMillisecond(*v)
-            }
-            ScalarValue::TimestampMicrosecond(Some(v)) => {
-                GroupByScalar::TimeMicrosecond(*v)
-            }
-            ScalarValue::TimestampNanosecond(Some(v)) => {
-                GroupByScalar::TimeNanosecond(*v)
-            }
-            ScalarValue::Utf8(Some(v)) => 
GroupByScalar::Utf8(Box::new(v.clone())),
-            ScalarValue::LargeUtf8(Some(v)) => {
-                GroupByScalar::LargeUtf8(Box::new(v.clone()))
-            }
-            ScalarValue::Float32(None)
-            | ScalarValue::Float64(None)
-            | ScalarValue::Boolean(None)
-            | ScalarValue::Int8(None)
-            | ScalarValue::Int16(None)
-            | ScalarValue::Int32(None)
-            | ScalarValue::Int64(None)
-            | ScalarValue::UInt8(None)
-            | ScalarValue::UInt16(None)
-            | ScalarValue::UInt32(None)
-            | ScalarValue::UInt64(None)
-            | ScalarValue::Utf8(None) => {
-                return Err(DataFusionError::Internal(format!(
-                    "Cannot convert a ScalarValue holding NULL ({:?})",
-                    scalar_value
-                )));
-            }
-            v => {
-                return Err(DataFusionError::Internal(format!(
-                    "Cannot convert a ScalarValue with associated DataType 
{:?}",
-                    v.get_datatype()
-                )))
-            }
-        })
-    }
-}
-
-impl From<&GroupByScalar> for ScalarValue {
-    fn from(group_by_scalar: &GroupByScalar) -> Self {
-        match group_by_scalar {
-            GroupByScalar::Float32(v) => 
ScalarValue::Float32(Some((*v).into())),
-            GroupByScalar::Float64(v) => 
ScalarValue::Float64(Some((*v).into())),
-            GroupByScalar::Boolean(v) => ScalarValue::Boolean(Some(*v)),
-            GroupByScalar::Int8(v) => ScalarValue::Int8(Some(*v)),
-            GroupByScalar::Int16(v) => ScalarValue::Int16(Some(*v)),
-            GroupByScalar::Int32(v) => ScalarValue::Int32(Some(*v)),
-            GroupByScalar::Int64(v) => ScalarValue::Int64(Some(*v)),
-            GroupByScalar::UInt8(v) => ScalarValue::UInt8(Some(*v)),
-            GroupByScalar::UInt16(v) => ScalarValue::UInt16(Some(*v)),
-            GroupByScalar::UInt32(v) => ScalarValue::UInt32(Some(*v)),
-            GroupByScalar::UInt64(v) => ScalarValue::UInt64(Some(*v)),
-            GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.to_string())),
-            GroupByScalar::LargeUtf8(v) => 
ScalarValue::LargeUtf8(Some(v.to_string())),
-            GroupByScalar::TimeMillisecond(v) => {
-                ScalarValue::TimestampMillisecond(Some(*v))
-            }
-            GroupByScalar::TimeMicrosecond(v) => {
-                ScalarValue::TimestampMicrosecond(Some(*v))
-            }
-            GroupByScalar::TimeNanosecond(v) => {
-                ScalarValue::TimestampNanosecond(Some(*v))
-            }
-            GroupByScalar::Date32(v) => ScalarValue::Date32(Some(*v)),
-        }
-    }
-}
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-
-    use crate::error::DataFusionError;
-
-    macro_rules! scalar_eq_test {
-        ($TYPE:expr, $VALUE:expr) => {{
-            let scalar_value = $TYPE($VALUE);
-            let a = GroupByScalar::try_from(&scalar_value).unwrap();
-
-            let scalar_value = $TYPE($VALUE);
-            let b = GroupByScalar::try_from(&scalar_value).unwrap();
-
-            assert_eq!(a, b);
-        }};
-    }
-
-    #[test]
-    fn test_scalar_ne_non_std() {
-        // Test only Scalars with non native Eq, Hash
-        scalar_eq_test!(ScalarValue::Float32, Some(1.0));
-        scalar_eq_test!(ScalarValue::Float64, Some(1.0));
-    }
-
-    macro_rules! scalar_ne_test {
-        ($TYPE:expr, $LVALUE:expr, $RVALUE:expr) => {{
-            let scalar_value = $TYPE($LVALUE);
-            let a = GroupByScalar::try_from(&scalar_value).unwrap();
-
-            let scalar_value = $TYPE($RVALUE);
-            let b = GroupByScalar::try_from(&scalar_value).unwrap();
-
-            assert_ne!(a, b);
-        }};
-    }
-
-    #[test]
-    fn test_scalar_eq_non_std() {
-        // Test only Scalars with non native Eq, Hash
-        scalar_ne_test!(ScalarValue::Float32, Some(1.0), Some(2.0));
-        scalar_ne_test!(ScalarValue::Float64, Some(1.0), Some(2.0));
-    }
-
-    #[test]
-    fn from_scalar_holding_none() {
-        let scalar_value = ScalarValue::Int8(None);
-        let result = GroupByScalar::try_from(&scalar_value);
-
-        match result {
-            Err(DataFusionError::Internal(error_message)) => assert_eq!(
-                error_message,
-                String::from("Cannot convert a ScalarValue holding NULL 
(Int8(NULL))")
-            ),
-            _ => panic!("Unexpected result"),
-        }
-    }
-
-    #[test]
-    fn from_scalar_unsupported() {
-        // Use any ScalarValue type not supported by GroupByScalar.
-        let scalar_value = ScalarValue::Binary(Some(vec![1, 2]));
-        let result = GroupByScalar::try_from(&scalar_value);
-
-        match result {
-            Err(DataFusionError::Internal(error_message)) => assert_eq!(
-                error_message,
-                String::from(
-                    "Cannot convert a ScalarValue with associated DataType 
Binary"
-                )
-            ),
-            _ => panic!("Unexpected result"),
-        }
-    }
-
-    #[test]
-    fn size_of_group_by_scalar() {
-        assert_eq!(std::mem::size_of::<GroupByScalar>(), 16);
-    }
-}
diff --git a/datafusion/src/physical_plan/hash_aggregate.rs 
b/datafusion/src/physical_plan/hash_aggregate.rs
index ae51383..eb4a356 100644
--- a/datafusion/src/physical_plan/hash_aggregate.rs
+++ b/datafusion/src/physical_plan/hash_aggregate.rs
@@ -59,7 +59,6 @@ use arrow::{
     record_batch::RecordBatch,
 };
 use hashbrown::HashMap;
-use ordered_float::OrderedFloat;
 use pin_project_lite::pin_project;
 
 use arrow::array::{
@@ -68,10 +67,7 @@ use arrow::array::{
 };
 use async_trait::async_trait;
 
-use super::{
-    expressions::Column, group_scalar::GroupByScalar, RecordBatchStream,
-    SendableRecordBatchStream,
-};
+use super::{expressions::Column, RecordBatchStream, SendableRecordBatchStream};
 
 /// Hash aggregate modes
 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
@@ -362,7 +358,7 @@ fn group_aggregate_batch(
     // it will be overwritten on every iteration of the loop below
     let mut group_by_values = Vec::with_capacity(group_values.len());
     for _ in 0..group_values.len() {
-        group_by_values.push(GroupByScalar::UInt32(0));
+        group_by_values.push(ScalarValue::UInt32(Some(0)));
     }
 
     let mut group_by_values = group_by_values.into_boxed_slice();
@@ -730,7 +726,7 @@ impl GroupedHashAggregateStream {
 
 type AccumulatorItem = Box<dyn Accumulator>;
 type Accumulators =
-    HashMap<Vec<u8>, (Box<[GroupByScalar]>, Vec<AccumulatorItem>, Vec<u32>), 
RandomState>;
+    HashMap<Vec<u8>, (Box<[ScalarValue]>, Vec<AccumulatorItem>, Vec<u32>), 
RandomState>;
 
 impl Stream for GroupedHashAggregateStream {
     type Item = ArrowResult<RecordBatch>;
@@ -1004,9 +1000,11 @@ fn create_batch_from_map(
 
     let mut columns = (0..num_group_expr)
         .map(|i| {
-            ScalarValue::iter_to_array(accumulators.into_iter().map(
-                |(_, (group_by_values, _, _))| 
ScalarValue::from(&group_by_values[i]),
-            ))
+            ScalarValue::iter_to_array(
+                accumulators
+                    .into_iter()
+                    .map(|(_, (group_by_values, _, _))| 
group_by_values[i].clone()),
+            )
         })
         .collect::<Result<Vec<_>>>()
         .map_err(|x| x.into_arrow_external_error())?;
@@ -1088,124 +1086,9 @@ fn finalize_aggregation(
     }
 }
 
-/// Extract the value in `col[row]` from a dictionary a GroupByScalar
-fn dictionary_create_group_by_value<K: ArrowDictionaryKeyType>(
-    col: &ArrayRef,
-    row: usize,
-) -> Result<GroupByScalar> {
-    let dict_col = col.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
-
-    // look up the index in the values dictionary
-    let keys_col = dict_col.keys();
-    let values_index = keys_col.value(row).to_usize().ok_or_else(|| {
-        DataFusionError::Internal(format!(
-            "Can not convert index to usize in dictionary of type creating 
group by value {:?}",
-            keys_col.data_type()
-        ))
-    })?;
-
-    create_group_by_value(dict_col.values(), values_index)
-}
-
 /// Extract the value in `col[row]` as a GroupByScalar
-fn create_group_by_value(col: &ArrayRef, row: usize) -> Result<GroupByScalar> {
-    match col.data_type() {
-        DataType::Float32 => {
-            let array = col.as_any().downcast_ref::<Float32Array>().unwrap();
-            Ok(GroupByScalar::Float32(OrderedFloat::from(array.value(row))))
-        }
-        DataType::Float64 => {
-            let array = col.as_any().downcast_ref::<Float64Array>().unwrap();
-            Ok(GroupByScalar::Float64(OrderedFloat::from(array.value(row))))
-        }
-        DataType::UInt8 => {
-            let array = col.as_any().downcast_ref::<UInt8Array>().unwrap();
-            Ok(GroupByScalar::UInt8(array.value(row)))
-        }
-        DataType::UInt16 => {
-            let array = col.as_any().downcast_ref::<UInt16Array>().unwrap();
-            Ok(GroupByScalar::UInt16(array.value(row)))
-        }
-        DataType::UInt32 => {
-            let array = col.as_any().downcast_ref::<UInt32Array>().unwrap();
-            Ok(GroupByScalar::UInt32(array.value(row)))
-        }
-        DataType::UInt64 => {
-            let array = col.as_any().downcast_ref::<UInt64Array>().unwrap();
-            Ok(GroupByScalar::UInt64(array.value(row)))
-        }
-        DataType::Int8 => {
-            let array = col.as_any().downcast_ref::<Int8Array>().unwrap();
-            Ok(GroupByScalar::Int8(array.value(row)))
-        }
-        DataType::Int16 => {
-            let array = col.as_any().downcast_ref::<Int16Array>().unwrap();
-            Ok(GroupByScalar::Int16(array.value(row)))
-        }
-        DataType::Int32 => {
-            let array = col.as_any().downcast_ref::<Int32Array>().unwrap();
-            Ok(GroupByScalar::Int32(array.value(row)))
-        }
-        DataType::Int64 => {
-            let array = col.as_any().downcast_ref::<Int64Array>().unwrap();
-            Ok(GroupByScalar::Int64(array.value(row)))
-        }
-        DataType::Utf8 => {
-            let array = col.as_any().downcast_ref::<StringArray>().unwrap();
-            Ok(GroupByScalar::Utf8(Box::new(array.value(row).into())))
-        }
-        DataType::LargeUtf8 => {
-            let array = 
col.as_any().downcast_ref::<LargeStringArray>().unwrap();
-            Ok(GroupByScalar::LargeUtf8(Box::new(array.value(row).into())))
-        }
-        DataType::Boolean => {
-            let array = col.as_any().downcast_ref::<BooleanArray>().unwrap();
-            Ok(GroupByScalar::Boolean(array.value(row)))
-        }
-        DataType::Timestamp(TimeUnit::Millisecond, None) => {
-            let array = col
-                .as_any()
-                .downcast_ref::<TimestampMillisecondArray>()
-                .unwrap();
-            Ok(GroupByScalar::TimeMillisecond(array.value(row)))
-        }
-        DataType::Timestamp(TimeUnit::Microsecond, None) => {
-            let array = col
-                .as_any()
-                .downcast_ref::<TimestampMicrosecondArray>()
-                .unwrap();
-            Ok(GroupByScalar::TimeMicrosecond(array.value(row)))
-        }
-        DataType::Timestamp(TimeUnit::Nanosecond, None) => {
-            let array = col
-                .as_any()
-                .downcast_ref::<TimestampNanosecondArray>()
-                .unwrap();
-            Ok(GroupByScalar::TimeNanosecond(array.value(row)))
-        }
-        DataType::Date32 => {
-            let array = col.as_any().downcast_ref::<Date32Array>().unwrap();
-            Ok(GroupByScalar::Date32(array.value(row)))
-        }
-        DataType::Dictionary(index_type, _) => match **index_type {
-            DataType::Int8 => 
dictionary_create_group_by_value::<Int8Type>(col, row),
-            DataType::Int16 => 
dictionary_create_group_by_value::<Int16Type>(col, row),
-            DataType::Int32 => 
dictionary_create_group_by_value::<Int32Type>(col, row),
-            DataType::Int64 => 
dictionary_create_group_by_value::<Int64Type>(col, row),
-            DataType::UInt8 => 
dictionary_create_group_by_value::<UInt8Type>(col, row),
-            DataType::UInt16 => 
dictionary_create_group_by_value::<UInt16Type>(col, row),
-            DataType::UInt32 => 
dictionary_create_group_by_value::<UInt32Type>(col, row),
-            DataType::UInt64 => 
dictionary_create_group_by_value::<UInt64Type>(col, row),
-            _ => Err(DataFusionError::NotImplemented(format!(
-                "Unsupported GROUP BY type (dictionary index type not 
supported) {}",
-                col.data_type(),
-            ))),
-        },
-        _ => Err(DataFusionError::NotImplemented(format!(
-            "Unsupported GROUP BY type {}",
-            col.data_type(),
-        ))),
-    }
+fn create_group_by_value(col: &ArrayRef, row: usize) -> Result<ScalarValue> {
+    ScalarValue::try_from_array(col, row)
 }
 
 /// Extract the values in `group_by_keys` arrow arrays into the target vector
@@ -1213,7 +1096,7 @@ fn create_group_by_value(col: &ArrayRef, row: usize) -> 
Result<GroupByScalar> {
 pub(crate) fn create_group_by_values(
     group_by_keys: &[ArrayRef],
     row: usize,
-    vec: &mut Box<[GroupByScalar]>,
+    vec: &mut Box<[ScalarValue]>,
 ) -> Result<()> {
     for (i, col) in group_by_keys.iter().enumerate() {
         vec[i] = create_group_by_value(col, row)?
diff --git a/datafusion/src/physical_plan/mod.rs 
b/datafusion/src/physical_plan/mod.rs
index 86bceb1..0df6e60 100644
--- a/datafusion/src/physical_plan/mod.rs
+++ b/datafusion/src/physical_plan/mod.rs
@@ -655,7 +655,6 @@ pub mod explain;
 pub mod expressions;
 pub mod filter;
 pub mod functions;
-pub mod group_scalar;
 pub mod hash_aggregate;
 pub mod hash_join;
 pub mod hash_utils;
diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs
index 129b416..8efea63 100644
--- a/datafusion/src/scalar.rs
+++ b/datafusion/src/scalar.rs
@@ -27,13 +27,14 @@ use arrow::{
         TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
     },
 };
+use ordered_float::OrderedFloat;
 use std::convert::Infallible;
 use std::str::FromStr;
 use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
 
 /// Represents a dynamically typed, nullable single value.
 /// This is the single-valued counter-part of arrow’s `Array`.
-#[derive(Clone, PartialEq)]
+#[derive(Clone)]
 pub enum ScalarValue {
     /// true or false value
     Boolean(Option<bool>),
@@ -86,6 +87,120 @@ pub enum ScalarValue {
     IntervalDayTime(Option<i64>),
 }
 
+// manual implementation of `PartialEq` that uses OrderedFloat to
+// get defined behavior for floating point
+impl PartialEq for ScalarValue {
+    fn eq(&self, other: &Self) -> bool {
+        use ScalarValue::*;
+        // This purposely doesn't have a catch-all "(_, _)" so that
+        // any newly added enum variant will require editing this list
+        // or else face a compile error
+        match (self, other) {
+            (Boolean(v1), Boolean(v2)) => v1.eq(v2),
+            (Boolean(_), _) => false,
+            (Float32(v1), Float32(v2)) => {
+                let v1 = v1.map(OrderedFloat);
+                let v2 = v2.map(OrderedFloat);
+                v1.eq(&v2)
+            }
+            (Float32(_), _) => false,
+            (Float64(v1), Float64(v2)) => {
+                let v1 = v1.map(OrderedFloat);
+                let v2 = v2.map(OrderedFloat);
+                v1.eq(&v2)
+            }
+            (Float64(_), _) => false,
+            (Int8(v1), Int8(v2)) => v1.eq(v2),
+            (Int8(_), _) => false,
+            (Int16(v1), Int16(v2)) => v1.eq(v2),
+            (Int16(_), _) => false,
+            (Int32(v1), Int32(v2)) => v1.eq(v2),
+            (Int32(_), _) => false,
+            (Int64(v1), Int64(v2)) => v1.eq(v2),
+            (Int64(_), _) => false,
+            (UInt8(v1), UInt8(v2)) => v1.eq(v2),
+            (UInt8(_), _) => false,
+            (UInt16(v1), UInt16(v2)) => v1.eq(v2),
+            (UInt16(_), _) => false,
+            (UInt32(v1), UInt32(v2)) => v1.eq(v2),
+            (UInt32(_), _) => false,
+            (UInt64(v1), UInt64(v2)) => v1.eq(v2),
+            (UInt64(_), _) => false,
+            (Utf8(v1), Utf8(v2)) => v1.eq(v2),
+            (Utf8(_), _) => false,
+            (LargeUtf8(v1), LargeUtf8(v2)) => v1.eq(v2),
+            (LargeUtf8(_), _) => false,
+            (Binary(v1), Binary(v2)) => v1.eq(v2),
+            (Binary(_), _) => false,
+            (LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2),
+            (LargeBinary(_), _) => false,
+            (List(v1, t1), List(v2, t2)) => v1.eq(v2) && t1.eq(t2),
+            (List(_, _), _) => false,
+            (Date32(v1), Date32(v2)) => v1.eq(v2),
+            (Date32(_), _) => false,
+            (Date64(v1), Date64(v2)) => v1.eq(v2),
+            (Date64(_), _) => false,
+            (TimestampSecond(v1), TimestampSecond(v2)) => v1.eq(v2),
+            (TimestampSecond(_), _) => false,
+            (TimestampMillisecond(v1), TimestampMillisecond(v2)) => v1.eq(v2),
+            (TimestampMillisecond(_), _) => false,
+            (TimestampMicrosecond(v1), TimestampMicrosecond(v2)) => v1.eq(v2),
+            (TimestampMicrosecond(_), _) => false,
+            (TimestampNanosecond(v1), TimestampNanosecond(v2)) => v1.eq(v2),
+            (TimestampNanosecond(_), _) => false,
+            (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.eq(v2),
+            (IntervalYearMonth(_), _) => false,
+            (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.eq(v2),
+            (IntervalDayTime(_), _) => false,
+        }
+    }
+}
+
+impl Eq for ScalarValue {}
+
+// manual implementation of `Hash` that uses OrderedFloat to
+// get defined behavior for floating point
+impl std::hash::Hash for ScalarValue {
+    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+        use ScalarValue::*;
+        match self {
+            Boolean(v) => v.hash(state),
+            Float32(v) => {
+                let v = v.map(OrderedFloat);
+                v.hash(state)
+            }
+            Float64(v) => {
+                let v = v.map(OrderedFloat);
+                v.hash(state)
+            }
+            Int8(v) => v.hash(state),
+            Int16(v) => v.hash(state),
+            Int32(v) => v.hash(state),
+            Int64(v) => v.hash(state),
+            UInt8(v) => v.hash(state),
+            UInt16(v) => v.hash(state),
+            UInt32(v) => v.hash(state),
+            UInt64(v) => v.hash(state),
+            Utf8(v) => v.hash(state),
+            LargeUtf8(v) => v.hash(state),
+            Binary(v) => v.hash(state),
+            LargeBinary(v) => v.hash(state),
+            List(v, t) => {
+                v.hash(state);
+                t.hash(state);
+            }
+            Date32(v) => v.hash(state),
+            Date64(v) => v.hash(state),
+            TimestampSecond(v) => v.hash(state),
+            TimestampMillisecond(v) => v.hash(state),
+            TimestampMicrosecond(v) => v.hash(state),
+            TimestampNanosecond(v) => v.hash(state),
+            IntervalYearMonth(v) => v.hash(state),
+            IntervalDayTime(v) => v.hash(state),
+        }
+    }
+}
+
 macro_rules! typed_cast {
     ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{
         let array = $array.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
@@ -795,73 +910,146 @@ impl ScalarValue {
 
 impl From<f64> for ScalarValue {
     fn from(value: f64) -> Self {
-        ScalarValue::Float64(Some(value))
+        Some(value).into()
+    }
+}
+
+impl From<Option<f64>> for ScalarValue {
+    fn from(value: Option<f64>) -> Self {
+        ScalarValue::Float64(value)
     }
 }
 
 impl From<f32> for ScalarValue {
     fn from(value: f32) -> Self {
-        ScalarValue::Float32(Some(value))
+        Some(value).into()
+    }
+}
+
+impl From<Option<f32>> for ScalarValue {
+    fn from(value: Option<f32>) -> Self {
+        ScalarValue::Float32(value)
     }
 }
 
 impl From<i8> for ScalarValue {
     fn from(value: i8) -> Self {
-        ScalarValue::Int8(Some(value))
+        Some(value).into()
+    }
+}
+
+impl From<Option<i8>> for ScalarValue {
+    fn from(value: Option<i8>) -> Self {
+        ScalarValue::Int8(value)
     }
 }
 
 impl From<i16> for ScalarValue {
     fn from(value: i16) -> Self {
-        ScalarValue::Int16(Some(value))
+        Some(value).into()
+    }
+}
+
+impl From<Option<i16>> for ScalarValue {
+    fn from(value: Option<i16>) -> Self {
+        ScalarValue::Int16(value)
     }
 }
 
 impl From<i32> for ScalarValue {
     fn from(value: i32) -> Self {
-        ScalarValue::Int32(Some(value))
+        Some(value).into()
+    }
+}
+
+impl From<Option<i32>> for ScalarValue {
+    fn from(value: Option<i32>) -> Self {
+        ScalarValue::Int32(value)
     }
 }
 
 impl From<i64> for ScalarValue {
     fn from(value: i64) -> Self {
-        ScalarValue::Int64(Some(value))
+        Some(value).into()
+    }
+}
+
+impl From<Option<i64>> for ScalarValue {
+    fn from(value: Option<i64>) -> Self {
+        ScalarValue::Int64(value)
     }
 }
 
 impl From<bool> for ScalarValue {
     fn from(value: bool) -> Self {
-        ScalarValue::Boolean(Some(value))
+        Some(value).into()
+    }
+}
+
+impl From<Option<bool>> for ScalarValue {
+    fn from(value: Option<bool>) -> Self {
+        ScalarValue::Boolean(value)
     }
 }
 
 impl From<u8> for ScalarValue {
     fn from(value: u8) -> Self {
-        ScalarValue::UInt8(Some(value))
+        Some(value).into()
+    }
+}
+
+impl From<Option<u8>> for ScalarValue {
+    fn from(value: Option<u8>) -> Self {
+        ScalarValue::UInt8(value)
     }
 }
 
 impl From<u16> for ScalarValue {
     fn from(value: u16) -> Self {
-        ScalarValue::UInt16(Some(value))
+        Some(value).into()
+    }
+}
+
+impl From<Option<u16>> for ScalarValue {
+    fn from(value: Option<u16>) -> Self {
+        ScalarValue::UInt16(value)
     }
 }
 
 impl From<u32> for ScalarValue {
     fn from(value: u32) -> Self {
-        ScalarValue::UInt32(Some(value))
+        Some(value).into()
+    }
+}
+
+impl From<Option<u32>> for ScalarValue {
+    fn from(value: Option<u32>) -> Self {
+        ScalarValue::UInt32(value)
     }
 }
 
 impl From<u64> for ScalarValue {
     fn from(value: u64) -> Self {
-        ScalarValue::UInt64(Some(value))
+        Some(value).into()
+    }
+}
+
+impl From<Option<u64>> for ScalarValue {
+    fn from(value: Option<u64>) -> Self {
+        ScalarValue::UInt64(value)
     }
 }
 
 impl From<&str> for ScalarValue {
     fn from(value: &str) -> Self {
-        ScalarValue::Utf8(Some(value.to_string()))
+        Some(value).into()
+    }
+}
+
+impl From<Option<&str>> for ScalarValue {
+    fn from(value: Option<&str>) -> Self {
+        let value = value.map(|s| s.to_string());
+        ScalarValue::Utf8(value)
     }
 }
 

Reply via email to