tustvold commented on code in PR #4488:
URL: https://github.com/apache/arrow-datafusion/pull/4488#discussion_r1043705396


##########
datafusion/common/src/scalar.rs:
##########
@@ -720,7 +720,7 @@ impl std::hash::Hash for ScalarValue {
 /// dictionary array
 #[inline]
 fn get_dict_value<K: ArrowDictionaryKeyType>(
-    array: &ArrayRef,
+    array: &dyn Array,

Review Comment:
   :heart:



##########
datafusion/physical-expr/src/aggregate/median.rs:
##########
@@ -91,157 +91,124 @@ impl AggregateExpr for Median {
 }
 
 #[derive(Debug)]
+/// The median accumulator accumulates the raw input values
+/// as `ScalarValue`s
+///
+/// The intermediate state is represented as a List of those scalars
 struct MedianAccumulator {
     data_type: DataType,
-    all_values: Vec<ArrayRef>,
-}
-
-macro_rules! median {
-    ($SELF:ident, $TY:ty, $SCALAR_TY:ident, $TWO:expr) => {{
-        let combined = combine_arrays::<$TY>($SELF.all_values.as_slice())?;
-        if combined.is_empty() {
-            return Ok(ScalarValue::Null);
-        }
-        let sorted = sort(&combined, None)?;
-        let array = as_primitive_array::<$TY>(&sorted)?;
-        let len = sorted.len();
-        let mid = len / 2;
-        if len % 2 == 0 {
-            Ok(ScalarValue::$SCALAR_TY(Some(
-                (array.value(mid - 1) + array.value(mid)) / $TWO,
-            )))
-        } else {
-            Ok(ScalarValue::$SCALAR_TY(Some(array.value(mid))))
-        }
-    }};
+    all_values: Vec<ScalarValue>,
 }
 
 impl Accumulator for MedianAccumulator {
     fn state(&self) -> Result<Vec<AggregateState>> {
-        let mut vec: Vec<AggregateState> = self
-            .all_values
-            .iter()
-            .map(|v| AggregateState::Array(v.clone()))
-            .collect();
-        if vec.is_empty() {
-            match self.data_type {
-                DataType::UInt8 => vec.push(empty_array::<UInt8Type>()),
-                DataType::UInt16 => vec.push(empty_array::<UInt16Type>()),
-                DataType::UInt32 => vec.push(empty_array::<UInt32Type>()),
-                DataType::UInt64 => vec.push(empty_array::<UInt64Type>()),
-                DataType::Int8 => vec.push(empty_array::<Int8Type>()),
-                DataType::Int16 => vec.push(empty_array::<Int16Type>()),
-                DataType::Int32 => vec.push(empty_array::<Int32Type>()),
-                DataType::Int64 => vec.push(empty_array::<Int64Type>()),
-                DataType::Float32 => vec.push(empty_array::<Float32Type>()),
-                DataType::Float64 => vec.push(empty_array::<Float64Type>()),
-                _ => {
-                    return Err(DataFusionError::Execution(
-                        "unsupported data type for median".to_string(),
-                    ))
-                }
-            }
-        }
-        Ok(vec)
+        let state =
+            ScalarValue::new_list(Some(self.all_values.clone()), 
self.data_type.clone());
+        Ok(vec![AggregateState::Scalar(state)])
     }
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        let x = values[0].clone();
-        self.all_values.extend_from_slice(&[x]);
-        Ok(())
-    }
+        assert_eq!(values.len(), 1);
+        let array = &values[0];
 
-    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
-        for array in states {
-            self.all_values.extend_from_slice(&[array.clone()]);
+        self.all_values.reserve(self.all_values.len() + array.len());

Review Comment:
   ```suggestion
           assert!(matches!(array.data_type(), DataType::List(_)));
           self.all_values.reserve(self.all_values.len() + array.len());
   ```
   ?



##########
datafusion/physical-expr/src/aggregate/median.rs:
##########
@@ -91,157 +91,124 @@ impl AggregateExpr for Median {
 }
 
 #[derive(Debug)]
+/// The median accumulator accumulates the raw input values
+/// as `ScalarValue`s
+///
+/// The intermediate state is represented as a List of those scalars
 struct MedianAccumulator {
     data_type: DataType,
-    all_values: Vec<ArrayRef>,
-}
-
-macro_rules! median {
-    ($SELF:ident, $TY:ty, $SCALAR_TY:ident, $TWO:expr) => {{
-        let combined = combine_arrays::<$TY>($SELF.all_values.as_slice())?;
-        if combined.is_empty() {
-            return Ok(ScalarValue::Null);
-        }
-        let sorted = sort(&combined, None)?;
-        let array = as_primitive_array::<$TY>(&sorted)?;
-        let len = sorted.len();
-        let mid = len / 2;
-        if len % 2 == 0 {
-            Ok(ScalarValue::$SCALAR_TY(Some(
-                (array.value(mid - 1) + array.value(mid)) / $TWO,
-            )))
-        } else {
-            Ok(ScalarValue::$SCALAR_TY(Some(array.value(mid))))
-        }
-    }};
+    all_values: Vec<ScalarValue>,
 }
 
 impl Accumulator for MedianAccumulator {
     fn state(&self) -> Result<Vec<AggregateState>> {
-        let mut vec: Vec<AggregateState> = self
-            .all_values
-            .iter()
-            .map(|v| AggregateState::Array(v.clone()))
-            .collect();
-        if vec.is_empty() {
-            match self.data_type {
-                DataType::UInt8 => vec.push(empty_array::<UInt8Type>()),
-                DataType::UInt16 => vec.push(empty_array::<UInt16Type>()),
-                DataType::UInt32 => vec.push(empty_array::<UInt32Type>()),
-                DataType::UInt64 => vec.push(empty_array::<UInt64Type>()),
-                DataType::Int8 => vec.push(empty_array::<Int8Type>()),
-                DataType::Int16 => vec.push(empty_array::<Int16Type>()),
-                DataType::Int32 => vec.push(empty_array::<Int32Type>()),
-                DataType::Int64 => vec.push(empty_array::<Int64Type>()),
-                DataType::Float32 => vec.push(empty_array::<Float32Type>()),
-                DataType::Float64 => vec.push(empty_array::<Float64Type>()),
-                _ => {
-                    return Err(DataFusionError::Execution(
-                        "unsupported data type for median".to_string(),
-                    ))
-                }
-            }
-        }
-        Ok(vec)
+        let state =
+            ScalarValue::new_list(Some(self.all_values.clone()), 
self.data_type.clone());
+        Ok(vec![AggregateState::Scalar(state)])
     }
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        let x = values[0].clone();
-        self.all_values.extend_from_slice(&[x]);
-        Ok(())
-    }
+        assert_eq!(values.len(), 1);
+        let array = &values[0];
 
-    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
-        for array in states {
-            self.all_values.extend_from_slice(&[array.clone()]);
+        self.all_values.reserve(self.all_values.len() + array.len());
+        for index in 0..array.len() {
+            self.all_values
+                .push(ScalarValue::try_from_array(array, index)?);
         }
+
         Ok(())
     }
 
-    fn evaluate(&self) -> Result<ScalarValue> {
-        match self.all_values[0].data_type() {
-            DataType::Int8 => median!(self, arrow::datatypes::Int8Type, Int8, 
2),
-            DataType::Int16 => median!(self, arrow::datatypes::Int16Type, 
Int16, 2),
-            DataType::Int32 => median!(self, arrow::datatypes::Int32Type, 
Int32, 2),
-            DataType::Int64 => median!(self, arrow::datatypes::Int64Type, 
Int64, 2),
-            DataType::UInt8 => median!(self, arrow::datatypes::UInt8Type, 
UInt8, 2),
-            DataType::UInt16 => median!(self, arrow::datatypes::UInt16Type, 
UInt16, 2),
-            DataType::UInt32 => median!(self, arrow::datatypes::UInt32Type, 
UInt32, 2),
-            DataType::UInt64 => median!(self, arrow::datatypes::UInt64Type, 
UInt64, 2),
-            DataType::Float32 => {
-                median!(self, arrow::datatypes::Float32Type, Float32, 2_f32)
-            }
-            DataType::Float64 => {
-                median!(self, arrow::datatypes::Float64Type, Float64, 2_f64)
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        assert_eq!(states.len(), 1);
+
+        let array = &states[0];
+        for index in 0..array.len() {

Review Comment:
   ```suggestion
           assert!(matches!(array.data_type(), DataType::List(_)));
           for index in 0..array.len() {
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to