waynexia commented on code in PR #5554:
URL: https://github.com/apache/arrow-datafusion/pull/5554#discussion_r1137258276


##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -85,64 +85,126 @@ impl AggregateExpr for DistinctCount {
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        Ok(Box::new(DistinctCountAccumulator {
-            values: HashSet::default(),
-            state_data_type: self.state_data_type.clone(),
-        }))
+        use arrow::datatypes;
+        use datatypes::DataType::*;
+
+        Ok(match &self.state_data_type {
+            Dictionary(key, val) if key.is_dictionary_key_type() => {
+                let val_type = *val.clone();
+                match **key {
+                    Int8 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int8Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int16 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int16Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int32 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int32Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int64 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int64Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    UInt8 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::UInt8Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    UInt16 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt16Type,
+                    >::new(val_type)),
+                    UInt32 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt32Type,
+                    >::new(val_type)),
+                    UInt64 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt64Type,
+                    >::new(val_type)),
+                    _ => {
+                        return Err(DataFusionError::Internal(
+                            "Dict key has invalid datatype".to_string(),
+                        ))
+                    }
+                }
+            }
+            _ => Box::new(DistinctCountAccumulator {
+                values: HashSet::default(),
+                state_data_type: self.state_data_type.clone(),
+            }),
+        })
     }
 
     fn name(&self) -> &str {
         &self.name
     }
 }
 
-#[derive(Debug)]
-struct DistinctCountAccumulator {
-    values: HashSet<DistinctScalarValues, RandomState>,
-    state_data_type: DataType,
+// calculating the size of values hashset for fixed length values,
+// taking first batch size * number of batches.
+// This method is faster than full_size(), however it is not suitable for 
variable length
+// values like strings or complex types

Review Comment:
   ```suggestion
   /// calculating the size of values hashset for fixed length values,
   /// taking first batch size * number of batches.
   /// This method is faster than full_size(), however it is not suitable for 
variable length
   /// values like strings or complex types
   ```
   
   style: prefer to use document comments



##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -85,64 +85,126 @@ impl AggregateExpr for DistinctCount {
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        Ok(Box::new(DistinctCountAccumulator {
-            values: HashSet::default(),
-            state_data_type: self.state_data_type.clone(),
-        }))
+        use arrow::datatypes;
+        use datatypes::DataType::*;
+
+        Ok(match &self.state_data_type {
+            Dictionary(key, val) if key.is_dictionary_key_type() => {
+                let val_type = *val.clone();
+                match **key {
+                    Int8 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int8Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int16 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int16Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int32 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int32Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int64 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int64Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    UInt8 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::UInt8Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    UInt16 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt16Type,
+                    >::new(val_type)),
+                    UInt32 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt32Type,
+                    >::new(val_type)),
+                    UInt64 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt64Type,
+                    >::new(val_type)),
+                    _ => {
+                        return Err(DataFusionError::Internal(
+                            "Dict key has invalid datatype".to_string(),

Review Comment:
   nit: I would prefer to add the concrete type in the error message



##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -158,38 +220,96 @@ impl Accumulator for DistinctCountAccumulator {
         })
     }
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
-        if states.is_empty() {
-            return Ok(());
+        merge_values(&mut self.values, states)
+    }
+
+    fn evaluate(&self) -> Result<ScalarValue> {
+        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
+    }
+
+    fn size(&self) -> usize {
+        let values_size = match &self.state_data_type {
+            DataType::Boolean | DataType::Null => 
values_fixed_size(&self.values),
+            d if d.is_primitive() => values_fixed_size(&self.values),
+            _ => values_full_size(&self.values),
+        };
+        std::mem::size_of_val(self) + values_size + 
std::mem::size_of::<DataType>()
+    }
+}
+/// Special case accumulator for counting distinct values in a dict
+struct CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    /// `K` is required when casting to dict array
+    _dt: core::marker::PhantomData<K>,
+    values_datatype: DataType,
+    values: ValueSet,
+}
+
+impl<K> std::fmt::Debug for CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("CountDistinctDictAccumulator")
+            .field("values", &self.values)
+            .field("values_datatype", &self.values_datatype)
+            .finish()
+    }
+}
+impl<K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync>

Review Comment:
   ```suggestion
   }
   
   impl<K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync>
   ```
   
   style: add empty line between two blocks



##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -85,64 +85,126 @@ impl AggregateExpr for DistinctCount {
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        Ok(Box::new(DistinctCountAccumulator {
-            values: HashSet::default(),
-            state_data_type: self.state_data_type.clone(),
-        }))
+        use arrow::datatypes;
+        use datatypes::DataType::*;
+
+        Ok(match &self.state_data_type {
+            Dictionary(key, val) if key.is_dictionary_key_type() => {
+                let val_type = *val.clone();
+                match **key {
+                    Int8 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int8Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int16 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int16Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int32 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int32Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int64 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::Int64Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    UInt8 => Box::new(
+                        
CountDistinctDictAccumulator::<datatypes::UInt8Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    UInt16 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt16Type,
+                    >::new(val_type)),
+                    UInt32 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt32Type,
+                    >::new(val_type)),
+                    UInt64 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt64Type,
+                    >::new(val_type)),
+                    _ => {
+                        return Err(DataFusionError::Internal(
+                            "Dict key has invalid datatype".to_string(),
+                        ))
+                    }
+                }
+            }
+            _ => Box::new(DistinctCountAccumulator {
+                values: HashSet::default(),
+                state_data_type: self.state_data_type.clone(),
+            }),
+        })
     }
 
     fn name(&self) -> &str {
         &self.name
     }
 }
 
-#[derive(Debug)]
-struct DistinctCountAccumulator {
-    values: HashSet<DistinctScalarValues, RandomState>,
-    state_data_type: DataType,
+// calculating the size of values hashset for fixed length values,
+// taking first batch size * number of batches.
+// This method is faster than full_size(), however it is not suitable for 
variable length
+// values like strings or complex types
+fn values_fixed_size(values: &ValueSet) -> usize {
+    (std::mem::size_of::<ScalarValue>() * values.capacity())
+        + values
+            .iter()
+            .next()
+            .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals))
+            .unwrap_or(0)
+}
+// calculates the size as accurate as possible, call to this method is 
expensive

Review Comment:
   ```suggestion
   }
   
   // calculates the size as accurate as possible, call to this method is 
expensive
   ```
   
   style: add empty line between two fns



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to