This is an automated email from the ASF dual-hosted git repository.

jiayuliu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 80c309a  add approx_distinct function (#1087)
80c309a is described below

commit 80c309ac2ce77afd97e3417b0367baacd69d43c1
Author: Jiayu Liu <[email protected]>
AuthorDate: Tue Oct 12 18:44:24 2021 +0800

    add approx_distinct function (#1087)
---
 README.md                                          |   2 +
 ballista/rust/core/proto/ballista.proto            |   1 +
 .../rust/core/src/serde/logical_plan/to_proto.rs   |   4 +
 ballista/rust/core/src/serde/mod.rs                |   3 +
 datafusion/src/logical_plan/expr.rs                |  15 +
 datafusion/src/logical_plan/mod.rs                 |  22 +-
 datafusion/src/physical_plan/aggregates.rs         |  14 +-
 .../physical_plan/expressions/approx_distinct.rs   | 321 +++++++++++++++++++++
 datafusion/src/physical_plan/expressions/mod.rs    |   2 +
 datafusion/src/physical_plan/hyperloglog/mod.rs    |  67 ++++-
 datafusion/tests/sql.rs                            |  17 ++
 11 files changed, 451 insertions(+), 17 deletions(-)

diff --git a/README.md b/README.md
index 00d868c..8b12917 100644
--- a/README.md
+++ b/README.md
@@ -190,6 +190,8 @@ DataFusion also includes a simple command-line interactive 
SQL utility. See the
   - [x] trim
 - Miscellaneous/Boolean functions
   - [x] nullif
+- Approximation functions
+  - [ ] approx_distinct
 - Common date/time functions
   - [ ] Basic date functions
   - [ ] Basic time functions
diff --git a/ballista/rust/core/proto/ballista.proto 
b/ballista/rust/core/proto/ballista.proto
index 8175156..9a2ec71 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -167,6 +167,7 @@ enum AggregateFunction {
   SUM = 2;
   AVG = 3;
   COUNT = 4;
+  APPROX_DISTINCT = 5;
 }
 
 message AggregateExprNode {
diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs 
b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
index c3ffb1a..402422a 100644
--- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
@@ -1137,6 +1137,9 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
                 ref fun, ref args, ..
             } => {
                 let aggr_function = match fun {
+                    AggregateFunction::ApproxDistinct => {
+                        protobuf::AggregateFunction::ApproxDistinct
+                    }
                     AggregateFunction::Min => protobuf::AggregateFunction::Min,
                     AggregateFunction::Max => protobuf::AggregateFunction::Max,
                     AggregateFunction::Sum => protobuf::AggregateFunction::Sum,
@@ -1370,6 +1373,7 @@ impl From<&AggregateFunction> for 
protobuf::AggregateFunction {
             AggregateFunction::Sum => Self::Sum,
             AggregateFunction::Avg => Self::Avg,
             AggregateFunction::Count => Self::Count,
+            AggregateFunction::ApproxDistinct => Self::ApproxDistinct,
         }
     }
 }
diff --git a/ballista/rust/core/src/serde/mod.rs 
b/ballista/rust/core/src/serde/mod.rs
index 1383ba8..a4df5a4 100644
--- a/ballista/rust/core/src/serde/mod.rs
+++ b/ballista/rust/core/src/serde/mod.rs
@@ -114,6 +114,9 @@ impl From<protobuf::AggregateFunction> for 
AggregateFunction {
             protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
             protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
             protobuf::AggregateFunction::Count => AggregateFunction::Count,
+            protobuf::AggregateFunction::ApproxDistinct => {
+                AggregateFunction::ApproxDistinct
+            }
         }
     }
 }
diff --git a/datafusion/src/logical_plan/expr.rs 
b/datafusion/src/logical_plan/expr.rs
index f61ed83..8ef69e9 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -1495,6 +1495,21 @@ pub fn random() -> Expr {
     }
 }
 
+/// Returns the approximate number of distinct input values.
+/// This function provides an approximation of count(DISTINCT x).
+/// Zero is returned if all input values are null.
+/// This function should produce a standard error of 0.81%,
+/// which is the standard deviation of the (approximately normal)
+/// error distribution over all possible sets.
+/// It does not guarantee an upper bound on the error for any specific input 
set.
+pub fn approx_distinct(expr: Expr) -> Expr {
+    Expr::AggregateFunction {
+        fun: aggregates::AggregateFunction::ApproxDistinct,
+        distinct: false,
+        args: vec![expr],
+    }
+}
+
 /// Create an convenience function representing a unary scalar function
 macro_rules! unary_scalar_expr {
     ($ENUM:ident, $FUNC:ident) => {
diff --git a/datafusion/src/logical_plan/mod.rs 
b/datafusion/src/logical_plan/mod.rs
index 3f0c7d2..8569b35 100644
--- a/datafusion/src/logical_plan/mod.rs
+++ b/datafusion/src/logical_plan/mod.rs
@@ -36,17 +36,17 @@ pub use builder::{
 pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema};
 pub use display::display_schema;
 pub use expr::{
-    abs, acos, and, array, ascii, asin, atan, avg, binary_expr, bit_length, 
btrim, case,
-    ceil, character_length, chr, col, columnize_expr, combine_filters, concat, 
concat_ws,
-    cos, count, count_distinct, create_udaf, create_udf, date_part, 
date_trunc, digest,
-    exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit,
-    lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, max, md5, min,
-    normalize_col, normalize_cols, now, octet_length, or, random, regexp_match,
-    regexp_replace, repeat, replace, replace_col, reverse, right, round, rpad, 
rtrim,
-    sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, 
starts_with, strpos,
-    substr, sum, tan, to_hex, translate, trim, trunc, unnormalize_col, 
unnormalize_cols,
-    upper, when, Column, Expr, ExprRewriter, ExpressionVisitor, Literal, 
Recursion,
-    RewriteRecursion,
+    abs, acos, and, approx_distinct, array, ascii, asin, atan, avg, 
binary_expr,
+    bit_length, btrim, case, ceil, character_length, chr, col, columnize_expr,
+    combine_filters, concat, concat_ws, cos, count, count_distinct, 
create_udaf,
+    create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, 
in_list,
+    initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, 
lpad, ltrim,
+    max, md5, min, normalize_col, normalize_cols, now, octet_length, or, 
random,
+    regexp_match, regexp_replace, repeat, replace, replace_col, reverse, 
right, round,
+    rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
+    starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc,
+    unnormalize_col, unnormalize_cols, upper, when, Column, Expr, ExprRewriter,
+    ExpressionVisitor, Literal, Recursion, RewriteRecursion,
 };
 pub use extension::UserDefinedLogicalNode;
 pub use operators::Operator;
diff --git a/datafusion/src/physical_plan/aggregates.rs 
b/datafusion/src/physical_plan/aggregates.rs
index aad43cc..eb3f6ca 100644
--- a/datafusion/src/physical_plan/aggregates.rs
+++ b/datafusion/src/physical_plan/aggregates.rs
@@ -59,6 +59,8 @@ pub enum AggregateFunction {
     Max,
     /// avg
     Avg,
+    /// Approximate aggregate function
+    ApproxDistinct,
 }
 
 impl fmt::Display for AggregateFunction {
@@ -77,6 +79,7 @@ impl FromStr for AggregateFunction {
             "count" => AggregateFunction::Count,
             "avg" => AggregateFunction::Avg,
             "sum" => AggregateFunction::Sum,
+            "approx_distinct" => AggregateFunction::ApproxDistinct,
             _ => {
                 return Err(DataFusionError::Plan(format!(
                     "There is no built-in function named {}",
@@ -96,7 +99,9 @@ pub fn return_type(fun: &AggregateFunction, arg_types: 
&[DataType]) -> Result<Da
     data_types(arg_types, &signature(fun))?;
 
     match fun {
-        AggregateFunction::Count => Ok(DataType::UInt64),
+        AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
+            Ok(DataType::UInt64)
+        }
         AggregateFunction::Max | AggregateFunction::Min => 
Ok(arg_types[0].clone()),
         AggregateFunction::Sum => sum_return_type(&arg_types[0]),
         AggregateFunction::Avg => avg_return_type(&arg_types[0]),
@@ -149,6 +154,9 @@ pub fn create_aggregate_expr(
                 "SUM(DISTINCT) aggregations are not available".to_string(),
             ));
         }
+        (AggregateFunction::ApproxDistinct, _) => Arc::new(
+            expressions::ApproxDistinct::new(arg, name, arg_types[0].clone()),
+        ),
         (AggregateFunction::Min, _) => {
             Arc::new(expressions::Min::new(arg, name, return_type))
         }
@@ -194,7 +202,9 @@ static DATES: &[DataType] = &[DataType::Date32, 
DataType::Date64];
 pub fn signature(fun: &AggregateFunction) -> Signature {
     // note: the physical expression must accept the type returned by this 
function or the execution panics.
     match fun {
-        AggregateFunction::Count => Signature::any(1, Volatility::Immutable),
+        AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
+            Signature::any(1, Volatility::Immutable)
+        }
         AggregateFunction::Min | AggregateFunction::Max => {
             let valid = STRINGS
                 .iter()
diff --git a/datafusion/src/physical_plan/expressions/approx_distinct.rs 
b/datafusion/src/physical_plan/expressions/approx_distinct.rs
new file mode 100644
index 0000000..7a19b6c
--- /dev/null
+++ b/datafusion/src/physical_plan/expressions/approx_distinct.rs
@@ -0,0 +1,321 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines physical expressions that can evaluated at runtime during query 
execution
+
+use super::format_state_name;
+use crate::error::{DataFusionError, Result};
+use crate::physical_plan::{
+    hyperloglog::HyperLogLog, Accumulator, AggregateExpr, PhysicalExpr,
+};
+use crate::scalar::ScalarValue;
+use arrow::array::{
+    ArrayRef, BinaryArray, BinaryOffsetSizeTrait, GenericBinaryArray, 
GenericStringArray,
+    PrimitiveArray, StringOffsetSizeTrait,
+};
+use arrow::datatypes::{
+    ArrowPrimitiveType, DataType, Field, Int16Type, Int32Type, Int64Type, 
Int8Type,
+    UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+};
+use std::any::type_name;
+use std::any::Any;
+use std::convert::TryFrom;
+use std::convert::TryInto;
+use std::hash::Hash;
+use std::marker::PhantomData;
+use std::sync::Arc;
+
+/// APPROX_DISTINCT aggregate expression
+#[derive(Debug)]
+pub struct ApproxDistinct {
+    name: String,
+    input_data_type: DataType,
+    expr: Arc<dyn PhysicalExpr>,
+}
+
+impl ApproxDistinct {
+    /// Create a new ApproxDistinct aggregate function.
+    pub fn new(
+        expr: Arc<dyn PhysicalExpr>,
+        name: impl Into<String>,
+        input_data_type: DataType,
+    ) -> Self {
+        Self {
+            name: name.into(),
+            input_data_type,
+            expr,
+        }
+    }
+}
+
+impl AggregateExpr for ApproxDistinct {
+    /// Return a reference to Any that can be used for downcasting
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn field(&self) -> Result<Field> {
+        Ok(Field::new(&self.name, DataType::UInt64, false))
+    }
+
+    fn state_fields(&self) -> Result<Vec<Field>> {
+        Ok(vec![Field::new(
+            &format_state_name(&self.name, "hll_registers"),
+            DataType::Binary,
+            false,
+        )])
+    }
+
+    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+        vec![self.expr.clone()]
+    }
+
+    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+        let accumulator: Box<dyn Accumulator> = match &self.input_data_type {
+            // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL
+            // TODO support for boolean (trivial case)
+            DataType::UInt8 => 
Box::new(NumericHLLAccumulator::<UInt8Type>::new()),
+            DataType::UInt16 => 
Box::new(NumericHLLAccumulator::<UInt16Type>::new()),
+            DataType::UInt32 => 
Box::new(NumericHLLAccumulator::<UInt32Type>::new()),
+            DataType::UInt64 => 
Box::new(NumericHLLAccumulator::<UInt64Type>::new()),
+            DataType::Int8 => 
Box::new(NumericHLLAccumulator::<Int8Type>::new()),
+            DataType::Int16 => 
Box::new(NumericHLLAccumulator::<Int16Type>::new()),
+            DataType::Int32 => 
Box::new(NumericHLLAccumulator::<Int32Type>::new()),
+            DataType::Int64 => 
Box::new(NumericHLLAccumulator::<Int64Type>::new()),
+            DataType::Utf8 => Box::new(StringHLLAccumulator::<i32>::new()),
+            DataType::LargeUtf8 => 
Box::new(StringHLLAccumulator::<i64>::new()),
+            DataType::Binary => Box::new(BinaryHLLAccumulator::<i32>::new()),
+            DataType::LargeBinary => 
Box::new(BinaryHLLAccumulator::<i64>::new()),
+            other => {
+                return Err(DataFusionError::NotImplemented(format!(
+                    "Support for count_distinct for data type {} is not 
implemented",
+                    other
+                )))
+            }
+        };
+        Ok(accumulator)
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+}
+
+#[derive(Debug)]
+struct BinaryHLLAccumulator<T>
+where
+    T: BinaryOffsetSizeTrait,
+{
+    hll: HyperLogLog<Vec<u8>>,
+    phantom_data: PhantomData<T>,
+}
+
+impl<T> BinaryHLLAccumulator<T>
+where
+    T: BinaryOffsetSizeTrait,
+{
+    /// new approx_distinct accumulator
+    pub fn new() -> Self {
+        Self {
+            hll: HyperLogLog::new(),
+            phantom_data: PhantomData,
+        }
+    }
+}
+
+#[derive(Debug)]
+struct StringHLLAccumulator<T>
+where
+    T: StringOffsetSizeTrait,
+{
+    hll: HyperLogLog<String>,
+    phantom_data: PhantomData<T>,
+}
+
+impl<T> StringHLLAccumulator<T>
+where
+    T: StringOffsetSizeTrait,
+{
+    /// new approx_distinct accumulator
+    pub fn new() -> Self {
+        Self {
+            hll: HyperLogLog::new(),
+            phantom_data: PhantomData,
+        }
+    }
+}
+
+#[derive(Debug)]
+struct NumericHLLAccumulator<T>
+where
+    T: ArrowPrimitiveType,
+    T::Native: Hash,
+{
+    hll: HyperLogLog<T::Native>,
+}
+
+impl<T> NumericHLLAccumulator<T>
+where
+    T: ArrowPrimitiveType,
+    T::Native: Hash,
+{
+    /// new approx_distinct accumulator
+    pub fn new() -> Self {
+        Self {
+            hll: HyperLogLog::new(),
+        }
+    }
+}
+
+impl<T: Hash> From<&HyperLogLog<T>> for ScalarValue {
+    fn from(v: &HyperLogLog<T>) -> ScalarValue {
+        let values = v.as_ref().to_vec();
+        ScalarValue::Binary(Some(values))
+    }
+}
+
+impl<T: Hash> TryFrom<&[u8]> for HyperLogLog<T> {
+    type Error = DataFusionError;
+    fn try_from(v: &[u8]) -> Result<HyperLogLog<T>> {
+        let arr: [u8; 16384] = v.try_into().map_err(|_| {
+            DataFusionError::Internal(
+                "Impossibly got invalid binary array from states".into(),
+            )
+        })?;
+        Ok(HyperLogLog::<T>::new_with_registers(arr))
+    }
+}
+
+impl<T: Hash> TryFrom<&ScalarValue> for HyperLogLog<T> {
+    type Error = DataFusionError;
+    fn try_from(v: &ScalarValue) -> Result<HyperLogLog<T>> {
+        if let ScalarValue::Binary(Some(slice)) = v {
+            slice.as_slice().try_into()
+        } else {
+            Err(DataFusionError::Internal(
+                "Impossibly got invalid scalar value while converting to 
HyperLogLog"
+                    .into(),
+            ))
+        }
+    }
+}
+
+macro_rules! default_accumulator_impl {
+    () => {
+        fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
+            self.update_batch(
+                values
+                    .iter()
+                    .map(|s| s.to_array() as ArrayRef)
+                    .collect::<Vec<_>>()
+                    .as_slice(),
+            )
+        }
+
+        fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
+            assert_eq!(1, states.len(), "expect only 1 element in the states");
+            let other = HyperLogLog::try_from(&states[0])?;
+            self.hll.merge(&other);
+            Ok(())
+        }
+
+        fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+            assert_eq!(1, states.len(), "expect only 1 element in the states");
+            let binary_array = 
states[0].as_any().downcast_ref::<BinaryArray>().unwrap();
+            for v in binary_array.iter() {
+                let v = v.ok_or_else(|| {
+                    DataFusionError::Internal(
+                        "Impossibly got empty binary array from states".into(),
+                    )
+                })?;
+                let other = v.try_into()?;
+                self.hll.merge(&other);
+            }
+            Ok(())
+        }
+
+        fn state(&self) -> Result<Vec<ScalarValue>> {
+            let value = ScalarValue::from(&self.hll);
+            Ok(vec![value])
+        }
+
+        fn evaluate(&self) -> Result<ScalarValue> {
+            Ok(ScalarValue::UInt64(Some(self.hll.count() as u64)))
+        }
+    };
+}
+
+macro_rules! downcast_value {
+    ($Value: expr, $Type: ident, $T: tt) => {{
+        $Value[0]
+            .as_any()
+            .downcast_ref::<$Type<T>>()
+            .ok_or_else(|| {
+                DataFusionError::Internal(format!(
+                    "could not cast value to {}",
+                    type_name::<$Type<T>>()
+                ))
+            })?
+    }};
+}
+
+impl<T> Accumulator for BinaryHLLAccumulator<T>
+where
+    T: BinaryOffsetSizeTrait,
+{
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let array: &GenericBinaryArray<T> =
+            downcast_value!(values, GenericBinaryArray, T);
+        // flatten because we would skip nulls
+        self.hll
+            .extend(array.into_iter().flatten().map(|v| v.to_vec()));
+        Ok(())
+    }
+
+    default_accumulator_impl!();
+}
+
+impl<T> Accumulator for StringHLLAccumulator<T>
+where
+    T: StringOffsetSizeTrait,
+{
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let array: &GenericStringArray<T> =
+            downcast_value!(values, GenericStringArray, T);
+        // flatten because we would skip nulls
+        self.hll
+            .extend(array.into_iter().flatten().map(|i| i.to_string()));
+        Ok(())
+    }
+
+    default_accumulator_impl!();
+}
+
+impl<T> Accumulator for NumericHLLAccumulator<T>
+where
+    T: ArrowPrimitiveType + std::fmt::Debug,
+    T::Native: Hash,
+{
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        let array: &PrimitiveArray<T> = downcast_value!(values, 
PrimitiveArray, T);
+        // flatten because we would skip nulls
+        self.hll.extend(array.into_iter().flatten());
+        Ok(())
+    }
+
+    default_accumulator_impl!();
+}
diff --git a/datafusion/src/physical_plan/expressions/mod.rs 
b/datafusion/src/physical_plan/expressions/mod.rs
index 9f7a6cc..4ca0036 100644
--- a/datafusion/src/physical_plan/expressions/mod.rs
+++ b/datafusion/src/physical_plan/expressions/mod.rs
@@ -25,6 +25,7 @@ use crate::physical_plan::PhysicalExpr;
 use arrow::compute::kernels::sort::{SortColumn, SortOptions};
 use arrow::record_batch::RecordBatch;
 
+mod approx_distinct;
 mod average;
 #[macro_use]
 mod binary;
@@ -55,6 +56,7 @@ pub mod helpers {
     pub use super::min_max::{max, min};
 }
 
+pub use approx_distinct::ApproxDistinct;
 pub use average::{avg_return_type, Avg, AvgAccumulator};
 pub use binary::{binary, binary_operator_data_type, BinaryExpr};
 pub use case::{case, CaseExpr};
diff --git a/datafusion/src/physical_plan/hyperloglog/mod.rs 
b/datafusion/src/physical_plan/hyperloglog/mod.rs
index 25e5213..3b91d30 100644
--- a/datafusion/src/physical_plan/hyperloglog/mod.rs
+++ b/datafusion/src/physical_plan/hyperloglog/mod.rs
@@ -34,9 +34,6 @@
 //!
 //! This module also borrows some code structure from 
[pdatastructs.rs](https://github.com/crepererum/pdatastructs.rs/blob/3997ed50f6b6871c9e53c4c5e0f48f431405fc63/src/hyperloglog.rs).
 
-// TODO remove this when hooked up with the rest
-#![allow(dead_code)]
-
 use ahash::{AHasher, RandomState};
 use std::hash::{BuildHasher, Hash, Hasher};
 use std::marker::PhantomData;
@@ -58,7 +55,12 @@ where
     phantom: PhantomData<T>,
 }
 
-/// fixed seed for the hashing so that values are consistent across runs
+/// Fixed seed for the hashing so that values are consistent across runs
+///
+/// Note that when we later move on to have serialized HLL register binaries
+/// shared across cluster, this SEED will have to be consistent across all
+/// parties otherwise we might have corruption. So ideally for later this seed
+/// shall be part of the serialized form (or stay unchanged across versions).
 const SEED: RandomState = RandomState::with_seeds(
     0x885f6cab121d01a3_u64,
     0x71e4379f2976ad8f_u64,
@@ -73,6 +75,13 @@ where
     /// Creates a new, empty HyperLogLog.
     pub fn new() -> Self {
         let registers = [0; NUM_REGISTERS];
+        Self::new_with_registers(registers)
+    }
+
+    /// Creates a HyperLogLog from already populated registers
+    /// note that this method should not be invoked in untrusted environment
+    /// because the internal structure of registers are not examined.
+    pub(crate) fn new_with_registers(registers: [u8; NUM_REGISTERS]) -> Self {
         Self {
             registers,
             phantom: PhantomData,
@@ -109,6 +118,19 @@ where
         histogram
     }
 
+    /// Merge the other [`HyperLogLog`] into this one
+    pub fn merge(&mut self, other: &HyperLogLog<T>) {
+        assert!(
+            self.registers.len() == other.registers.len(),
+            "unexpected got unequal register size, expect {}, got {}",
+            self.registers.len(),
+            other.registers.len()
+        );
+        for i in 0..self.registers.len() {
+            self.registers[i] = self.registers[i].max(other.registers[i]);
+        }
+    }
+
     /// Guess the number of unique elements seen by the HyperLogLog.
     pub fn count(&self) -> usize {
         let histogram = self.get_histogram();
@@ -171,6 +193,15 @@ fn hll_tau(x: f64) -> f64 {
     }
 }
 
+impl<T> AsRef<[u8]> for HyperLogLog<T>
+where
+    T: Hash + ?Sized,
+{
+    fn as_ref(&self) -> &[u8] {
+        &self.registers
+    }
+}
+
 impl<T> Extend<T> for HyperLogLog<T>
 where
     T: Hash,
@@ -300,4 +331,32 @@ mod tests {
         hll.extend((0..1000).map(|i| i.to_string()));
         compare_with_delta(hll.count(), 1000);
     }
+
+    #[test]
+    fn test_empty_merge() {
+        let mut hll = HyperLogLog::<u64>::new();
+        hll.merge(&HyperLogLog::<u64>::new());
+        assert_eq!(hll.count(), 0);
+    }
+
+    #[test]
+    fn test_merge_overlapped() {
+        let mut hll = HyperLogLog::<String>::new();
+        hll.extend((0..1000).map(|i| i.to_string()));
+
+        let mut other = HyperLogLog::<String>::new();
+        other.extend((0..1000).map(|i| i.to_string()));
+
+        hll.merge(&other);
+        compare_with_delta(hll.count(), 1000);
+    }
+
+    #[test]
+    fn test_repetition() {
+        let mut hll = HyperLogLog::<u32>::new();
+        for i in 0..1_000_000 {
+            hll.add(&(i % 1000));
+        }
+        compare_with_delta(hll.count(), 1000);
+    }
 }
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index 801451f..e822542 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -983,6 +983,23 @@ async fn csv_query_count() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn csv_query_approx_count() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_csv(&mut ctx)?;
+    let sql = "SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as 
varchar)) count_c9_str FROM aggregate_test_100";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    let expected = vec![
+        "+----------+--------------+",
+        "| count_c9 | count_c9_str |",
+        "+----------+--------------+",
+        "| 100      | 99           |",
+        "+----------+--------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
 /// for window functions without order by the first, last, and nth function 
call does not make sense
 #[tokio::test]
 async fn csv_query_window_with_empty_over() -> Result<()> {

Reply via email to