alamb commented on code in PR #5554:
URL: https://github.com/apache/arrow-datafusion/pull/5554#discussion_r1135497797


##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -31,7 +32,7 @@ use datafusion_common::{DataFusionError, Result};
 use datafusion_expr::Accumulator;
 
 type DistinctScalarValues = ScalarValue;
-
+type ValueSet = HashSet<DistinctScalarValues, RandomState>;

Review Comment:
   I wonder what value these type aliases add. The extra indirection of 
`DistinctScalarValues` --> `ScalarValue` simply seems to make things more 
complicated 🤔 



##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -85,64 +86,124 @@ impl AggregateExpr for DistinctCount {
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        Ok(Box::new(DistinctCountAccumulator {
-            values: HashSet::default(),
-            state_data_type: self.state_data_type.clone(),
-        }))
+        use arrow::datatypes;
+        use datatypes::DataType::*;
+
+        Ok(match &self.state_data_type {
+            Dictionary(key, val) if key.is_dictionary_key_type() => {
+                let val_type = *val.clone();
+                match **key {
+                    Int8 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int8Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int16 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int16Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int32 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int32Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int64 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int64Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    UInt8 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::UInt8Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    UInt16 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt16Type,
+                    >::new(val_type)),
+                    UInt32 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt32Type,
+                    >::new(val_type)),
+                    UInt64 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt64Type,
+                    >::new(val_type)),
+                    _ => {
+                        // just checked that datatype is a valid dict key type
+                        unreachable!()
+                    }
+                }
+            }
+            _ => Box::new(DistinctCountAccumulator {
+                values: HashSet::default(),
+                state_data_type: self.state_data_type.clone(),
+            }),
+        })
     }
 
     fn name(&self) -> &str {
         &self.name
     }
 }
 
-#[derive(Debug)]
-struct DistinctCountAccumulator {
-    values: HashSet<DistinctScalarValues, RandomState>,
-    state_data_type: DataType,
+// calculating the size of values hashset for fixed length values,
+// taking first batch size * number of batches.
+// This method is faster than full_size(), however it is not suitable for 
variable length
+// values like strings or complex types
+fn values_fixed_size(values: &ValueSet) -> usize {
+    (std::mem::size_of::<DistinctScalarValues>() * values.capacity())
+        + values
+            .iter()
+            .next()
+            .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals))
+            .unwrap_or(0)
+}
+// calculates the size as accurate as possible, call to this method is 
expensive

Review Comment:
   ```suggestion
   // calculates the size as accurate as possible, call to this method is 
expensive
   // but necessary to correctly account for variable length strings
   ```



##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -158,38 +219,98 @@ impl Accumulator for DistinctCountAccumulator {
         })
     }
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
-        if states.is_empty() {
-            return Ok(());
+        merge_values(&mut self.values, states)
+    }
+
+    fn evaluate(&self) -> Result<ScalarValue> {
+        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
+    }
+
+    fn size(&self) -> usize {
+        let values_size = match &self.state_data_type {
+            DataType::Boolean | DataType::Null => 
values_fixed_size(&self.values),
+            d if d.is_primitive() => values_fixed_size(&self.values),
+            _ => values_full_size(&self.values),
+        };
+        std::mem::size_of_val(self) + values_size + 
std::mem::size_of::<DataType>()
+    }
+}
+/// Special case accumulator for counting distinct values in a dict
+struct CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    /// `K` is required when casting to dict array
+    _dt: core::marker::PhantomData<K>,
+    values_datatype: DataType,
+    values: ValueSet,
+}
+
+impl<K> std::fmt::Debug for CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("CountDistinctDictAccumulator")
+            .field("values", &self.values)
+            .field("values_datatype", &self.values_datatype)
+            .finish()
+    }
+}
+impl<K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync>
+    CountDistinctDictAccumulator<K>
+{
+    fn new(values_datatype: DataType) -> Self {
+        Self {
+            _dt: core::marker::PhantomData,
+            values: Default::default(),
+            values_datatype,
         }
-        let arr = &states[0];
-        (0..arr.len()).try_for_each(|index| {
-            let scalar = ScalarValue::try_from_array(arr, index)?;
+    }
+}
+impl<K> Accumulator for CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    fn state(&self) -> Result<Vec<ScalarValue>> {
+        values_to_state(&self.values, &self.values_datatype)
+    }
 
-            if let ScalarValue::List(Some(scalar), _) = scalar {
-                scalar.iter().for_each(|scalar| {
-                    if !ScalarValue::is_null(scalar) {
-                        self.values.insert(scalar.clone());
-                    }
-                });
-            } else {
-                return Err(DataFusionError::Internal(
-                    "Unexpected accumulator state".into(),
-                ));
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        if values.is_empty() {
+            return Ok(());
+        }
+        let arr = as_dictionary_array::<K>(&values[0])?;
+        let nvalues = arr.values().len();
+        // map keys to whether their corresponding value has been seen or not
+        let mut seen_map = (0..nvalues).map(|_| false).collect::<Vec<_>>();

Review Comment:
   ```suggestion
           let mut seen_map = vec![(false; nvalues];
   ```



##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -158,38 +219,98 @@ impl Accumulator for DistinctCountAccumulator {
         })
     }
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
-        if states.is_empty() {
-            return Ok(());
+        merge_values(&mut self.values, states)
+    }
+
+    fn evaluate(&self) -> Result<ScalarValue> {
+        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
+    }
+
+    fn size(&self) -> usize {
+        let values_size = match &self.state_data_type {
+            DataType::Boolean | DataType::Null => 
values_fixed_size(&self.values),
+            d if d.is_primitive() => values_fixed_size(&self.values),
+            _ => values_full_size(&self.values),
+        };
+        std::mem::size_of_val(self) + values_size + 
std::mem::size_of::<DataType>()
+    }
+}
+/// Special case accumulator for counting distinct values in a dict
+struct CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    /// `K` is required when casting to dict array
+    _dt: core::marker::PhantomData<K>,
+    values_datatype: DataType,
+    values: ValueSet,
+}
+
+impl<K> std::fmt::Debug for CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("CountDistinctDictAccumulator")
+            .field("values", &self.values)
+            .field("values_datatype", &self.values_datatype)
+            .finish()
+    }
+}
+impl<K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync>
+    CountDistinctDictAccumulator<K>
+{
+    fn new(values_datatype: DataType) -> Self {
+        Self {
+            _dt: core::marker::PhantomData,
+            values: Default::default(),
+            values_datatype,
         }
-        let arr = &states[0];
-        (0..arr.len()).try_for_each(|index| {
-            let scalar = ScalarValue::try_from_array(arr, index)?;
+    }
+}
+impl<K> Accumulator for CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    fn state(&self) -> Result<Vec<ScalarValue>> {
+        values_to_state(&self.values, &self.values_datatype)
+    }
 
-            if let ScalarValue::List(Some(scalar), _) = scalar {
-                scalar.iter().for_each(|scalar| {
-                    if !ScalarValue::is_null(scalar) {
-                        self.values.insert(scalar.clone());
-                    }
-                });
-            } else {
-                return Err(DataFusionError::Internal(
-                    "Unexpected accumulator state".into(),
-                ));
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {

Review Comment:
   As an alternate construction, since it is only `update_batch` that varies by 
the  dictionary key type, I suspect you could make this PR quite a bit smaller 
by dispatching at runtime to an appropriate type of update_batch. However, that 
would require a dispatch on each batch, where this PR only requires a single 
dispatch during planning (at the expense of larger code)



##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -577,4 +697,76 @@ mod tests {
         assert_eq!(result, ScalarValue::Int64(Some(2)));
         Ok(())
     }
+
+    #[test]
+    fn count_distinct_dict_update() -> Result<()> {
+        let values = StringArray::from_iter_values(["a", "b", "c"]);
+        // value "b" is never used
+        let keys =
+            Int8Array::from_iter(vec![Some(0), Some(0), Some(0), Some(0), 
None, Some(2)]);
+        let arrays =
+            vec![
+                Arc::new(DictionaryArray::<Int8Type>::try_new(&keys, 
&values).unwrap())
+                    as ArrayRef,
+            ];
+        let agg = DistinctCount::new(
+            arrays[0].data_type().clone(),
+            Arc::new(NoOp::new()),
+            String::from("__col_name__"),
+        );
+        let mut accum = agg.create_accumulator()?;
+        accum.update_batch(&arrays)?;
+        // should evaluate to 2 since "b" never seen
+        assert_eq!(accum.evaluate()?, ScalarValue::Int64(Some(2)));
+        // now update with a new batch that does use "b" (and non-normalized 
values)
+        let values = StringArray::from_iter_values(["b", "a", "c", "d"]);
+        let keys = Int8Array::from_iter(vec![Some(0), Some(0), None]);
+        let arrays =
+            vec![
+                Arc::new(DictionaryArray::<Int8Type>::try_new(&keys, 
&values).unwrap())
+                    as ArrayRef,
+            ];
+        accum.update_batch(&arrays)?;
+        assert_eq!(accum.evaluate()?, ScalarValue::Int64(Some(3)));
+        Ok(())
+    }
+
+    #[test]
+    fn count_distinct_dict_merge() -> Result<()> {
+        let values = StringArray::from_iter_values(["a", "b", "c"]);
+        let keys = Int8Array::from_iter(vec![Some(0), Some(0), None]);
+        let arrays =
+            vec![
+                Arc::new(DictionaryArray::<Int8Type>::try_new(&keys, 
&values).unwrap())
+                    as ArrayRef,
+            ];
+        let agg = DistinctCount::new(
+            arrays[0].data_type().clone(),
+            Arc::new(NoOp::new()),
+            String::from("__col_name__"),
+        );
+        // create accum with 1 value seen
+        let mut accum = agg.create_accumulator()?;
+        accum.update_batch(&arrays)?;
+        assert_eq!(accum.evaluate()?, ScalarValue::Int64(Some(1)));
+        // create accum with state that has seen "a" and "b" but not "c"
+        let values = StringArray::from_iter_values(["c", "b", "a"]);

Review Comment:
   👍 good call to use a different dictionary



##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -85,64 +86,124 @@ impl AggregateExpr for DistinctCount {
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        Ok(Box::new(DistinctCountAccumulator {
-            values: HashSet::default(),
-            state_data_type: self.state_data_type.clone(),
-        }))
+        use arrow::datatypes;
+        use datatypes::DataType::*;
+
+        Ok(match &self.state_data_type {
+            Dictionary(key, val) if key.is_dictionary_key_type() => {
+                let val_type = *val.clone();
+                match **key {
+                    Int8 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int8Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int16 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int16Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int32 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int32Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int64 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int64Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    UInt8 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::UInt8Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    UInt16 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt16Type,
+                    >::new(val_type)),
+                    UInt32 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt32Type,
+                    >::new(val_type)),
+                    UInt64 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt64Type,
+                    >::new(val_type)),
+                    _ => {
+                        // just checked that datatype is a valid dict key type
+                        unreachable!()

Review Comment:
   Though to be clear, I do think various parts of the rust arrow 
implementation will panic if another type is used as a dictionary key. Being 
defensive and returning an internal error sounds like a good idea to me



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to