This is an automated email from the ASF dual-hosted git repository.

kszucs pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git

commit 083adadb5b44b7cbe97ac892025e4c4ffebd3387
Author: Davis Silverman <sinistersn...@gmail.com>
AuthorDate: Fri Oct 4 08:27:44 2019 -0600

    ARROW-6657: [Rust] [DataFusion] Add Count Aggregate Expression
    
    Hi, I added this code, and the tests pass. I still need to actually test it 
using a real example, so I would say its not completely ready for merge yet.
    
    Closes #5513 from sinistersnare/ARROW-6657 and squashes the following 
commits:
    
    64d0c00b0 <Andy Grove> formatting
    12d0c2c56 <Davis Silverman> Add Count Aggregate Expression
    
    Lead-authored-by: Davis Silverman <sinistersn...@gmail.com>
    Co-authored-by: Andy Grove <andygrov...@gmail.com>
    Signed-off-by: Andy Grove <andygrov...@gmail.com>
---
 rust/datafusion/src/execution/context.rs           |  44 ++++++++-
 .../src/execution/physical_plan/expressions.rs     | 110 +++++++++++++++++++++
 2 files changed, 153 insertions(+), 1 deletion(-)

diff --git a/rust/datafusion/src/execution/context.rs 
b/rust/datafusion/src/execution/context.rs
index dc54b99..f07c8b9 100644
--- a/rust/datafusion/src/execution/context.rs
+++ b/rust/datafusion/src/execution/context.rs
@@ -38,7 +38,7 @@ use crate::execution::limit::LimitRelation;
 use crate::execution::physical_plan::common;
 use crate::execution::physical_plan::datasource::DatasourceExec;
 use crate::execution::physical_plan::expressions::{
-    BinaryExpr, CastExpr, Column, Literal, Sum,
+    BinaryExpr, CastExpr, Column, Count, Literal, Sum,
 };
 use crate::execution::physical_plan::hash_aggregate::HashAggregateExec;
 use crate::execution::physical_plan::merge::MergeExec;
@@ -333,6 +333,9 @@ impl ExecutionContext {
                     "sum" => Ok(Arc::new(Sum::new(
                         self.create_physical_expr(&args[0], input_schema)?,
                     ))),
+                    "count" => Ok(Arc::new(Count::new(
+                        self.create_physical_expr(&args[0], input_schema)?,
+                    ))),
                     other => Err(ExecutionError::NotImplemented(format!(
                         "Unsupported aggregate function '{}'",
                         other
@@ -641,6 +644,45 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn count_basic() -> Result<()> {
+        let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 1)?;
+        assert_eq!(results.len(), 1);
+
+        let batch = &results[0];
+        let expected: Vec<&str> = vec!["10,10"];
+        let mut rows = test::format_batch(&batch);
+        rows.sort();
+        assert_eq!(rows, expected);
+        Ok(())
+    }
+
+    #[test]
+    fn count_partitioned() -> Result<()> {
+        let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 4)?;
+        assert_eq!(results.len(), 1);
+
+        let batch = &results[0];
+        let expected: Vec<&str> = vec!["40,40"];
+        let mut rows = test::format_batch(&batch);
+        rows.sort();
+        assert_eq!(rows, expected);
+        Ok(())
+    }
+
+    #[test]
+    fn count_aggregated() -> Result<()> {
+        let results = execute("SELECT c1, COUNT(c2) FROM test GROUP BY c1", 
4)?;
+        assert_eq!(results.len(), 1);
+
+        let batch = &results[0];
+        let expected = vec!["0,10", "1,10", "2,10", "3,10"];
+        let mut rows = test::format_batch(&batch);
+        rows.sort();
+        assert_eq!(rows, expected);
+        Ok(())
+    }
+
     /// Execute SQL and return results
     fn execute(sql: &str, partition_count: usize) -> Result<Vec<RecordBatch>> {
         let tmp_dir = TempDir::new("execute")?;
diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs 
b/rust/datafusion/src/execution/physical_plan/expressions.rs
index f63b40c..5f53536 100644
--- a/rust/datafusion/src/execution/physical_plan/expressions.rs
+++ b/rust/datafusion/src/execution/physical_plan/expressions.rs
@@ -147,6 +147,7 @@ macro_rules! sum_accumulate {
         }
     }};
 }
+
 struct SumAccumulator {
     expr: Arc<dyn PhysicalExpr>,
     sum: Option<ScalarValue>,
@@ -207,6 +208,68 @@ pub fn sum(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn 
AggregateExpr> {
     Arc::new(Sum::new(expr))
 }
 
+/// COUNT aggregate expression
+/// Returns the amount of non-null values of the given expression.
+pub struct Count {
+    expr: Arc<dyn PhysicalExpr>,
+}
+
+impl Count {
+    /// Create a new COUNT aggregate function.
+    pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
+        Self { expr: expr }
+    }
+}
+
+impl AggregateExpr for Count {
+    fn name(&self) -> String {
+        "COUNT".to_string()
+    }
+
+    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
+        Ok(DataType::UInt64)
+    }
+
+    fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
+        self.expr.evaluate(batch)
+    }
+
+    fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
+        Rc::new(RefCell::new(CountAccumulator { count: 0 }))
+    }
+
+    fn create_combiner(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
+        Arc::new(Sum::new(Arc::new(Column::new(column_index))))
+    }
+}
+
+struct CountAccumulator {
+    count: u64,
+}
+
+impl Accumulator for CountAccumulator {
+    fn accumulate(
+        &mut self,
+        _batch: &RecordBatch,
+        array: &ArrayRef,
+        row_index: usize,
+    ) -> Result<()> {
+        if array.is_valid(row_index) {
+            self.count += 1;
+        }
+        Ok(())
+    }
+
+    fn get_value(&self) -> Result<Option<ScalarValue>> {
+        Ok(Some(ScalarValue::UInt64(self.count)))
+    }
+}
+
+/// Create a count expression
+pub fn count(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
+    Arc::new(Count::new(expr))
+}
+
 /// Invoke a compute kernel on a pair of arrays
 macro_rules! compute_op {
     ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
@@ -702,6 +765,42 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn count_elements() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Int32, 
false)]);
+        let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
+        let batch = RecordBatch::try_new(Arc::new(schema.clone()), 
vec![Arc::new(a)])?;
+        assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(5)));
+        Ok(())
+    }
+
+    #[test]
+    fn count_with_nulls() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Int32, 
false)]);
+        let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), 
None]);
+        let batch = RecordBatch::try_new(Arc::new(schema.clone()), 
vec![Arc::new(a)])?;
+        assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(3)));
+        Ok(())
+    }
+
+    #[test]
+    fn count_all_nulls() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Boolean, 
false)]);
+        let a = BooleanArray::from(vec![None, None, None, None, None, None, 
None, None]);
+        let batch = RecordBatch::try_new(Arc::new(schema.clone()), 
vec![Arc::new(a)])?;
+        assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(0)));
+        Ok(())
+    }
+
+    #[test]
+    fn count_empty() -> Result<()> {
+        let schema = Schema::new(vec![Field::new("a", DataType::Boolean, 
false)]);
+        let a = BooleanArray::from(Vec::<bool>::new());
+        let batch = RecordBatch::try_new(Arc::new(schema.clone()), 
vec![Arc::new(a)])?;
+        assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(0)));
+        Ok(())
+    }
+
     fn do_sum(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
         let sum = sum(col(0));
         let accum = sum.create_accumulator();
@@ -712,4 +811,15 @@ mod tests {
         }
         accum.get_value()
     }
+
+    fn do_count(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
+        let count = count(col(0));
+        let accum = count.create_accumulator();
+        let input = count.evaluate_input(batch)?;
+        let mut accum = accum.borrow_mut();
+        for i in 0..batch.num_rows() {
+            accum.accumulate(&batch, &input, i)?;
+        }
+        accum.get_value()
+    }
 }

Reply via email to