This is an automated email from the ASF dual-hosted git repository.
jiayuliu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 80c309a add approx_distinct function (#1087)
80c309a is described below
commit 80c309ac2ce77afd97e3417b0367baacd69d43c1
Author: Jiayu Liu <[email protected]>
AuthorDate: Tue Oct 12 18:44:24 2021 +0800
add approx_distinct function (#1087)
---
README.md | 2 +
ballista/rust/core/proto/ballista.proto | 1 +
.../rust/core/src/serde/logical_plan/to_proto.rs | 4 +
ballista/rust/core/src/serde/mod.rs | 3 +
datafusion/src/logical_plan/expr.rs | 15 +
datafusion/src/logical_plan/mod.rs | 22 +-
datafusion/src/physical_plan/aggregates.rs | 14 +-
.../physical_plan/expressions/approx_distinct.rs | 321 +++++++++++++++++++++
datafusion/src/physical_plan/expressions/mod.rs | 2 +
datafusion/src/physical_plan/hyperloglog/mod.rs | 67 ++++-
datafusion/tests/sql.rs | 17 ++
11 files changed, 451 insertions(+), 17 deletions(-)
diff --git a/README.md b/README.md
index 00d868c..8b12917 100644
--- a/README.md
+++ b/README.md
@@ -190,6 +190,8 @@ DataFusion also includes a simple command-line interactive
SQL utility. See the
- [x] trim
- Miscellaneous/Boolean functions
- [x] nullif
+- Approximation functions
+ - [ ] approx_distinct
- Common date/time functions
- [ ] Basic date functions
- [ ] Basic time functions
diff --git a/ballista/rust/core/proto/ballista.proto
b/ballista/rust/core/proto/ballista.proto
index 8175156..9a2ec71 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -167,6 +167,7 @@ enum AggregateFunction {
SUM = 2;
AVG = 3;
COUNT = 4;
+ APPROX_DISTINCT = 5;
}
message AggregateExprNode {
diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
index c3ffb1a..402422a 100644
--- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
@@ -1137,6 +1137,9 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
ref fun, ref args, ..
} => {
let aggr_function = match fun {
+ AggregateFunction::ApproxDistinct => {
+ protobuf::AggregateFunction::ApproxDistinct
+ }
AggregateFunction::Min => protobuf::AggregateFunction::Min,
AggregateFunction::Max => protobuf::AggregateFunction::Max,
AggregateFunction::Sum => protobuf::AggregateFunction::Sum,
@@ -1370,6 +1373,7 @@ impl From<&AggregateFunction> for
protobuf::AggregateFunction {
AggregateFunction::Sum => Self::Sum,
AggregateFunction::Avg => Self::Avg,
AggregateFunction::Count => Self::Count,
+ AggregateFunction::ApproxDistinct => Self::ApproxDistinct,
}
}
}
diff --git a/ballista/rust/core/src/serde/mod.rs
b/ballista/rust/core/src/serde/mod.rs
index 1383ba8..a4df5a4 100644
--- a/ballista/rust/core/src/serde/mod.rs
+++ b/ballista/rust/core/src/serde/mod.rs
@@ -114,6 +114,9 @@ impl From<protobuf::AggregateFunction> for
AggregateFunction {
protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
protobuf::AggregateFunction::Count => AggregateFunction::Count,
+ protobuf::AggregateFunction::ApproxDistinct => {
+ AggregateFunction::ApproxDistinct
+ }
}
}
}
diff --git a/datafusion/src/logical_plan/expr.rs
b/datafusion/src/logical_plan/expr.rs
index f61ed83..8ef69e9 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -1495,6 +1495,21 @@ pub fn random() -> Expr {
}
}
+/// Returns the approximate number of distinct input values.
+/// This function provides an approximation of count(DISTINCT x).
+/// Zero is returned if all input values are null.
+/// This function should produce a standard error of 0.81%,
+/// which is the standard deviation of the (approximately normal)
+/// error distribution over all possible sets.
+/// It does not guarantee an upper bound on the error for any specific input
set.
+pub fn approx_distinct(expr: Expr) -> Expr {
+ Expr::AggregateFunction {
+ fun: aggregates::AggregateFunction::ApproxDistinct,
+ distinct: false,
+ args: vec![expr],
+ }
+}
+
/// Create an convenience function representing a unary scalar function
macro_rules! unary_scalar_expr {
($ENUM:ident, $FUNC:ident) => {
diff --git a/datafusion/src/logical_plan/mod.rs
b/datafusion/src/logical_plan/mod.rs
index 3f0c7d2..8569b35 100644
--- a/datafusion/src/logical_plan/mod.rs
+++ b/datafusion/src/logical_plan/mod.rs
@@ -36,17 +36,17 @@ pub use builder::{
pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema};
pub use display::display_schema;
pub use expr::{
- abs, acos, and, array, ascii, asin, atan, avg, binary_expr, bit_length,
btrim, case,
- ceil, character_length, chr, col, columnize_expr, combine_filters, concat,
concat_ws,
- cos, count, count_distinct, create_udaf, create_udf, date_part,
date_trunc, digest,
- exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit,
- lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, max, md5, min,
- normalize_col, normalize_cols, now, octet_length, or, random, regexp_match,
- regexp_replace, repeat, replace, replace_col, reverse, right, round, rpad,
rtrim,
- sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
starts_with, strpos,
- substr, sum, tan, to_hex, translate, trim, trunc, unnormalize_col,
unnormalize_cols,
- upper, when, Column, Expr, ExprRewriter, ExpressionVisitor, Literal,
Recursion,
- RewriteRecursion,
+ abs, acos, and, approx_distinct, array, ascii, asin, atan, avg,
binary_expr,
+ bit_length, btrim, case, ceil, character_length, chr, col, columnize_expr,
+ combine_filters, concat, concat_ws, cos, count, count_distinct,
create_udaf,
+ create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor,
in_list,
+ initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower,
lpad, ltrim,
+ max, md5, min, normalize_col, normalize_cols, now, octet_length, or,
random,
+ regexp_match, regexp_replace, repeat, replace, replace_col, reverse,
right, round,
+ rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
+ starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc,
+ unnormalize_col, unnormalize_cols, upper, when, Column, Expr, ExprRewriter,
+ ExpressionVisitor, Literal, Recursion, RewriteRecursion,
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
diff --git a/datafusion/src/physical_plan/aggregates.rs
b/datafusion/src/physical_plan/aggregates.rs
index aad43cc..eb3f6ca 100644
--- a/datafusion/src/physical_plan/aggregates.rs
+++ b/datafusion/src/physical_plan/aggregates.rs
@@ -59,6 +59,8 @@ pub enum AggregateFunction {
Max,
/// avg
Avg,
+ /// Approximate aggregate function
+ ApproxDistinct,
}
impl fmt::Display for AggregateFunction {
@@ -77,6 +79,7 @@ impl FromStr for AggregateFunction {
"count" => AggregateFunction::Count,
"avg" => AggregateFunction::Avg,
"sum" => AggregateFunction::Sum,
+ "approx_distinct" => AggregateFunction::ApproxDistinct,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
@@ -96,7 +99,9 @@ pub fn return_type(fun: &AggregateFunction, arg_types:
&[DataType]) -> Result<Da
data_types(arg_types, &signature(fun))?;
match fun {
- AggregateFunction::Count => Ok(DataType::UInt64),
+ AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
+ Ok(DataType::UInt64)
+ }
AggregateFunction::Max | AggregateFunction::Min =>
Ok(arg_types[0].clone()),
AggregateFunction::Sum => sum_return_type(&arg_types[0]),
AggregateFunction::Avg => avg_return_type(&arg_types[0]),
@@ -149,6 +154,9 @@ pub fn create_aggregate_expr(
"SUM(DISTINCT) aggregations are not available".to_string(),
));
}
+ (AggregateFunction::ApproxDistinct, _) => Arc::new(
+ expressions::ApproxDistinct::new(arg, name, arg_types[0].clone()),
+ ),
(AggregateFunction::Min, _) => {
Arc::new(expressions::Min::new(arg, name, return_type))
}
@@ -194,7 +202,9 @@ static DATES: &[DataType] = &[DataType::Date32,
DataType::Date64];
pub fn signature(fun: &AggregateFunction) -> Signature {
// note: the physical expression must accept the type returned by this
function or the execution panics.
match fun {
- AggregateFunction::Count => Signature::any(1, Volatility::Immutable),
+ AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
+ Signature::any(1, Volatility::Immutable)
+ }
AggregateFunction::Min | AggregateFunction::Max => {
let valid = STRINGS
.iter()
diff --git a/datafusion/src/physical_plan/expressions/approx_distinct.rs
b/datafusion/src/physical_plan/expressions/approx_distinct.rs
new file mode 100644
index 0000000..7a19b6c
--- /dev/null
+++ b/datafusion/src/physical_plan/expressions/approx_distinct.rs
@@ -0,0 +1,321 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Defines physical expressions that can evaluated at runtime during query
execution
+
+use super::format_state_name;
+use crate::error::{DataFusionError, Result};
+use crate::physical_plan::{
+ hyperloglog::HyperLogLog, Accumulator, AggregateExpr, PhysicalExpr,
+};
+use crate::scalar::ScalarValue;
+use arrow::array::{
+ ArrayRef, BinaryArray, BinaryOffsetSizeTrait, GenericBinaryArray,
GenericStringArray,
+ PrimitiveArray, StringOffsetSizeTrait,
+};
+use arrow::datatypes::{
+ ArrowPrimitiveType, DataType, Field, Int16Type, Int32Type, Int64Type,
Int8Type,
+ UInt16Type, UInt32Type, UInt64Type, UInt8Type,
+};
+use std::any::type_name;
+use std::any::Any;
+use std::convert::TryFrom;
+use std::convert::TryInto;
+use std::hash::Hash;
+use std::marker::PhantomData;
+use std::sync::Arc;
+
+/// APPROX_DISTINCT aggregate expression
+#[derive(Debug)]
+pub struct ApproxDistinct {
+ name: String,
+ input_data_type: DataType,
+ expr: Arc<dyn PhysicalExpr>,
+}
+
+impl ApproxDistinct {
+ /// Create a new ApproxDistinct aggregate function.
+ pub fn new(
+ expr: Arc<dyn PhysicalExpr>,
+ name: impl Into<String>,
+ input_data_type: DataType,
+ ) -> Self {
+ Self {
+ name: name.into(),
+ input_data_type,
+ expr,
+ }
+ }
+}
+
+impl AggregateExpr for ApproxDistinct {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn field(&self) -> Result<Field> {
+ Ok(Field::new(&self.name, DataType::UInt64, false))
+ }
+
+ fn state_fields(&self) -> Result<Vec<Field>> {
+ Ok(vec![Field::new(
+ &format_state_name(&self.name, "hll_registers"),
+ DataType::Binary,
+ false,
+ )])
+ }
+
+ fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
+ vec![self.expr.clone()]
+ }
+
+ fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
+ let accumulator: Box<dyn Accumulator> = match &self.input_data_type {
+ // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL
+ // TODO support for boolean (trivial case)
+ DataType::UInt8 =>
Box::new(NumericHLLAccumulator::<UInt8Type>::new()),
+ DataType::UInt16 =>
Box::new(NumericHLLAccumulator::<UInt16Type>::new()),
+ DataType::UInt32 =>
Box::new(NumericHLLAccumulator::<UInt32Type>::new()),
+ DataType::UInt64 =>
Box::new(NumericHLLAccumulator::<UInt64Type>::new()),
+ DataType::Int8 =>
Box::new(NumericHLLAccumulator::<Int8Type>::new()),
+ DataType::Int16 =>
Box::new(NumericHLLAccumulator::<Int16Type>::new()),
+ DataType::Int32 =>
Box::new(NumericHLLAccumulator::<Int32Type>::new()),
+ DataType::Int64 =>
Box::new(NumericHLLAccumulator::<Int64Type>::new()),
+ DataType::Utf8 => Box::new(StringHLLAccumulator::<i32>::new()),
+ DataType::LargeUtf8 =>
Box::new(StringHLLAccumulator::<i64>::new()),
+ DataType::Binary => Box::new(BinaryHLLAccumulator::<i32>::new()),
+ DataType::LargeBinary =>
Box::new(BinaryHLLAccumulator::<i64>::new()),
+ other => {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Support for count_distinct for data type {} is not
implemented",
+ other
+ )))
+ }
+ };
+ Ok(accumulator)
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+}
+
+#[derive(Debug)]
+struct BinaryHLLAccumulator<T>
+where
+ T: BinaryOffsetSizeTrait,
+{
+ hll: HyperLogLog<Vec<u8>>,
+ phantom_data: PhantomData<T>,
+}
+
+impl<T> BinaryHLLAccumulator<T>
+where
+ T: BinaryOffsetSizeTrait,
+{
+ /// new approx_distinct accumulator
+ pub fn new() -> Self {
+ Self {
+ hll: HyperLogLog::new(),
+ phantom_data: PhantomData,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct StringHLLAccumulator<T>
+where
+ T: StringOffsetSizeTrait,
+{
+ hll: HyperLogLog<String>,
+ phantom_data: PhantomData<T>,
+}
+
+impl<T> StringHLLAccumulator<T>
+where
+ T: StringOffsetSizeTrait,
+{
+ /// new approx_distinct accumulator
+ pub fn new() -> Self {
+ Self {
+ hll: HyperLogLog::new(),
+ phantom_data: PhantomData,
+ }
+ }
+}
+
+#[derive(Debug)]
+struct NumericHLLAccumulator<T>
+where
+ T: ArrowPrimitiveType,
+ T::Native: Hash,
+{
+ hll: HyperLogLog<T::Native>,
+}
+
+impl<T> NumericHLLAccumulator<T>
+where
+ T: ArrowPrimitiveType,
+ T::Native: Hash,
+{
+ /// new approx_distinct accumulator
+ pub fn new() -> Self {
+ Self {
+ hll: HyperLogLog::new(),
+ }
+ }
+}
+
+impl<T: Hash> From<&HyperLogLog<T>> for ScalarValue {
+ fn from(v: &HyperLogLog<T>) -> ScalarValue {
+ let values = v.as_ref().to_vec();
+ ScalarValue::Binary(Some(values))
+ }
+}
+
+impl<T: Hash> TryFrom<&[u8]> for HyperLogLog<T> {
+ type Error = DataFusionError;
+ fn try_from(v: &[u8]) -> Result<HyperLogLog<T>> {
+ let arr: [u8; 16384] = v.try_into().map_err(|_| {
+ DataFusionError::Internal(
+ "Impossibly got invalid binary array from states".into(),
+ )
+ })?;
+ Ok(HyperLogLog::<T>::new_with_registers(arr))
+ }
+}
+
+impl<T: Hash> TryFrom<&ScalarValue> for HyperLogLog<T> {
+ type Error = DataFusionError;
+ fn try_from(v: &ScalarValue) -> Result<HyperLogLog<T>> {
+ if let ScalarValue::Binary(Some(slice)) = v {
+ slice.as_slice().try_into()
+ } else {
+ Err(DataFusionError::Internal(
+ "Impossibly got invalid scalar value while converting to
HyperLogLog"
+ .into(),
+ ))
+ }
+ }
+}
+
+macro_rules! default_accumulator_impl {
+ () => {
+ fn update(&mut self, values: &[ScalarValue]) -> Result<()> {
+ self.update_batch(
+ values
+ .iter()
+ .map(|s| s.to_array() as ArrayRef)
+ .collect::<Vec<_>>()
+ .as_slice(),
+ )
+ }
+
+ fn merge(&mut self, states: &[ScalarValue]) -> Result<()> {
+ assert_eq!(1, states.len(), "expect only 1 element in the states");
+ let other = HyperLogLog::try_from(&states[0])?;
+ self.hll.merge(&other);
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ assert_eq!(1, states.len(), "expect only 1 element in the states");
+ let binary_array =
states[0].as_any().downcast_ref::<BinaryArray>().unwrap();
+ for v in binary_array.iter() {
+ let v = v.ok_or_else(|| {
+ DataFusionError::Internal(
+ "Impossibly got empty binary array from states".into(),
+ )
+ })?;
+ let other = v.try_into()?;
+ self.hll.merge(&other);
+ }
+ Ok(())
+ }
+
+ fn state(&self) -> Result<Vec<ScalarValue>> {
+ let value = ScalarValue::from(&self.hll);
+ Ok(vec![value])
+ }
+
+ fn evaluate(&self) -> Result<ScalarValue> {
+ Ok(ScalarValue::UInt64(Some(self.hll.count() as u64)))
+ }
+ };
+}
+
+macro_rules! downcast_value {
+ ($Value: expr, $Type: ident, $T: tt) => {{
+ $Value[0]
+ .as_any()
+ .downcast_ref::<$Type<T>>()
+ .ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "could not cast value to {}",
+ type_name::<$Type<T>>()
+ ))
+ })?
+ }};
+}
+
+impl<T> Accumulator for BinaryHLLAccumulator<T>
+where
+ T: BinaryOffsetSizeTrait,
+{
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let array: &GenericBinaryArray<T> =
+ downcast_value!(values, GenericBinaryArray, T);
+ // flatten because we would skip nulls
+ self.hll
+ .extend(array.into_iter().flatten().map(|v| v.to_vec()));
+ Ok(())
+ }
+
+ default_accumulator_impl!();
+}
+
+impl<T> Accumulator for StringHLLAccumulator<T>
+where
+ T: StringOffsetSizeTrait,
+{
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let array: &GenericStringArray<T> =
+ downcast_value!(values, GenericStringArray, T);
+ // flatten because we would skip nulls
+ self.hll
+ .extend(array.into_iter().flatten().map(|i| i.to_string()));
+ Ok(())
+ }
+
+ default_accumulator_impl!();
+}
+
+impl<T> Accumulator for NumericHLLAccumulator<T>
+where
+ T: ArrowPrimitiveType + std::fmt::Debug,
+ T::Native: Hash,
+{
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ let array: &PrimitiveArray<T> = downcast_value!(values,
PrimitiveArray, T);
+ // flatten because we would skip nulls
+ self.hll.extend(array.into_iter().flatten());
+ Ok(())
+ }
+
+ default_accumulator_impl!();
+}
diff --git a/datafusion/src/physical_plan/expressions/mod.rs
b/datafusion/src/physical_plan/expressions/mod.rs
index 9f7a6cc..4ca0036 100644
--- a/datafusion/src/physical_plan/expressions/mod.rs
+++ b/datafusion/src/physical_plan/expressions/mod.rs
@@ -25,6 +25,7 @@ use crate::physical_plan::PhysicalExpr;
use arrow::compute::kernels::sort::{SortColumn, SortOptions};
use arrow::record_batch::RecordBatch;
+mod approx_distinct;
mod average;
#[macro_use]
mod binary;
@@ -55,6 +56,7 @@ pub mod helpers {
pub use super::min_max::{max, min};
}
+pub use approx_distinct::ApproxDistinct;
pub use average::{avg_return_type, Avg, AvgAccumulator};
pub use binary::{binary, binary_operator_data_type, BinaryExpr};
pub use case::{case, CaseExpr};
diff --git a/datafusion/src/physical_plan/hyperloglog/mod.rs
b/datafusion/src/physical_plan/hyperloglog/mod.rs
index 25e5213..3b91d30 100644
--- a/datafusion/src/physical_plan/hyperloglog/mod.rs
+++ b/datafusion/src/physical_plan/hyperloglog/mod.rs
@@ -34,9 +34,6 @@
//!
//! This module also borrows some code structure from
[pdatastructs.rs](https://github.com/crepererum/pdatastructs.rs/blob/3997ed50f6b6871c9e53c4c5e0f48f431405fc63/src/hyperloglog.rs).
-// TODO remove this when hooked up with the rest
-#![allow(dead_code)]
-
use ahash::{AHasher, RandomState};
use std::hash::{BuildHasher, Hash, Hasher};
use std::marker::PhantomData;
@@ -58,7 +55,12 @@ where
phantom: PhantomData<T>,
}
-/// fixed seed for the hashing so that values are consistent across runs
+/// Fixed seed for the hashing so that values are consistent across runs
+///
+/// Note that when we later move on to have serialized HLL register binaries
+/// shared across cluster, this SEED will have to be consistent across all
+/// parties otherwise we might have corruption. So ideally for later this seed
+/// shall be part of the serialized form (or stay unchanged across versions).
const SEED: RandomState = RandomState::with_seeds(
0x885f6cab121d01a3_u64,
0x71e4379f2976ad8f_u64,
@@ -73,6 +75,13 @@ where
/// Creates a new, empty HyperLogLog.
pub fn new() -> Self {
let registers = [0; NUM_REGISTERS];
+ Self::new_with_registers(registers)
+ }
+
+ /// Creates a HyperLogLog from already populated registers
+ /// note that this method should not be invoked in untrusted environment
+ /// because the internal structure of registers are not examined.
+ pub(crate) fn new_with_registers(registers: [u8; NUM_REGISTERS]) -> Self {
Self {
registers,
phantom: PhantomData,
@@ -109,6 +118,19 @@ where
histogram
}
+ /// Merge the other [`HyperLogLog`] into this one
+ pub fn merge(&mut self, other: &HyperLogLog<T>) {
+ assert!(
+ self.registers.len() == other.registers.len(),
+ "unexpected got unequal register size, expect {}, got {}",
+ self.registers.len(),
+ other.registers.len()
+ );
+ for i in 0..self.registers.len() {
+ self.registers[i] = self.registers[i].max(other.registers[i]);
+ }
+ }
+
/// Guess the number of unique elements seen by the HyperLogLog.
pub fn count(&self) -> usize {
let histogram = self.get_histogram();
@@ -171,6 +193,15 @@ fn hll_tau(x: f64) -> f64 {
}
}
+impl<T> AsRef<[u8]> for HyperLogLog<T>
+where
+ T: Hash + ?Sized,
+{
+ fn as_ref(&self) -> &[u8] {
+ &self.registers
+ }
+}
+
impl<T> Extend<T> for HyperLogLog<T>
where
T: Hash,
@@ -300,4 +331,32 @@ mod tests {
hll.extend((0..1000).map(|i| i.to_string()));
compare_with_delta(hll.count(), 1000);
}
+
+ #[test]
+ fn test_empty_merge() {
+ let mut hll = HyperLogLog::<u64>::new();
+ hll.merge(&HyperLogLog::<u64>::new());
+ assert_eq!(hll.count(), 0);
+ }
+
+ #[test]
+ fn test_merge_overlapped() {
+ let mut hll = HyperLogLog::<String>::new();
+ hll.extend((0..1000).map(|i| i.to_string()));
+
+ let mut other = HyperLogLog::<String>::new();
+ other.extend((0..1000).map(|i| i.to_string()));
+
+ hll.merge(&other);
+ compare_with_delta(hll.count(), 1000);
+ }
+
+ #[test]
+ fn test_repetition() {
+ let mut hll = HyperLogLog::<u32>::new();
+ for i in 0..1_000_000 {
+ hll.add(&(i % 1000));
+ }
+ compare_with_delta(hll.count(), 1000);
+ }
}
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index 801451f..e822542 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -983,6 +983,23 @@ async fn csv_query_count() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn csv_query_approx_count() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx)?;
+ let sql = "SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as
varchar)) count_c9_str FROM aggregate_test_100";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ let expected = vec![
+ "+----------+--------------+",
+ "| count_c9 | count_c9_str |",
+ "+----------+--------------+",
+ "| 100 | 99 |",
+ "+----------+--------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
/// for window functions without order by the first, last, and nth function
call does not make sense
#[tokio::test]
async fn csv_query_window_with_empty_over() -> Result<()> {