This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 2e6dc35 feature: Add additional aggregation functions (#170)
2e6dc35 is described below
commit 2e6dc3555287fa0ce06fd629280621d9c7f433f7
Author: Dejan Simic <[email protected]>
AuthorDate: Mon Feb 6 16:33:58 2023 +0100
feature: Add additional aggregation functions (#170)
* Enable uuid & struct
* Add missing aggregation functions
* Write unit tests for aggregation functions
* Use numpy wherever possible
* Apply clippy suggestions
---
datafusion/tests/test_aggregation.py | 99 ++++++++++++++++++++++++++++++++----
src/functions.rs | 47 +++++++++++++++--
2 files changed, 131 insertions(+), 15 deletions(-)
diff --git a/datafusion/tests/test_aggregation.py
b/datafusion/tests/test_aggregation.py
index b274e18..2c8c064 100644
--- a/datafusion/tests/test_aggregation.py
+++ b/datafusion/tests/test_aggregation.py
@@ -15,10 +15,11 @@
# specific language governing permissions and limitations
# under the License.
+import numpy as np
import pyarrow as pa
import pytest
-from datafusion import SessionContext, column
+from datafusion import SessionContext, column, lit
from datafusion import functions as f
@@ -28,8 +29,12 @@ def df():
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
- [pa.array([1, 2, 3]), pa.array([4, 4, 6])],
- names=["a", "b"],
+ [
+ pa.array([1, 2, 3]),
+ pa.array([4, 4, 6]),
+ pa.array([9, 8, 5]),
+ ],
+ names=["a", "b", "c"],
)
return ctx.create_dataframe([[batch]])
@@ -37,12 +42,86 @@ def df():
def test_built_in_aggregation(df):
col_a = column("a")
col_b = column("b")
- df = df.aggregate(
+ col_c = column("c")
+
+ agg_df = df.aggregate(
[],
- [f.max(col_a), f.min(col_a), f.count(col_a), f.approx_distinct(col_b)],
+ [
+ f.approx_distinct(col_b),
+ f.approx_median(col_b),
+ f.approx_percentile_cont(col_b, lit(0.5)),
+ f.approx_percentile_cont_with_weight(col_b, lit(0.6), lit(0.5)),
+ f.array_agg(col_b),
+ f.avg(col_a),
+ f.corr(col_a, col_b),
+ f.count(col_a),
+ f.covar(col_a, col_b),
+ f.covar_pop(col_a, col_c),
+ f.covar_samp(col_b, col_c),
+ # f.grouping(col_a), # No physical plan implemented yet
+ f.max(col_a),
+ f.mean(col_b),
+ f.median(col_b),
+ f.min(col_a),
+ f.sum(col_b),
+ f.stddev(col_a),
+ f.stddev_pop(col_b),
+ f.stddev_samp(col_c),
+ f.var(col_a),
+ f.var_pop(col_b),
+ f.var_samp(col_c),
+ ],
+ )
+ result = agg_df.collect()[0]
+ values_a, values_b, values_c = df.collect()[0]
+
+ assert result.column(0) == pa.array([2], type=pa.uint64())
+ assert result.column(1) == pa.array([4])
+ assert result.column(2) == pa.array([4])
+ assert result.column(3) == pa.array([6])
+ assert result.column(4) == pa.array([[4, 4, 6]])
+ np.testing.assert_array_almost_equal(
+ result.column(5), np.average(values_a)
+ )
+ np.testing.assert_array_almost_equal(
+ result.column(6), np.corrcoef(values_a, values_b)[0][1]
+ )
+ assert result.column(7) == pa.array([len(values_a)])
+ # Sample (co)variance -> ddof=1
+ # Population (co)variance -> ddof=0
+ np.testing.assert_array_almost_equal(
+ result.column(8), np.cov(values_a, values_b, ddof=1)[0][1]
+ )
+ np.testing.assert_array_almost_equal(
+ result.column(9), np.cov(values_a, values_c, ddof=0)[0][1]
+ )
+ np.testing.assert_array_almost_equal(
+ result.column(10), np.cov(values_b, values_c, ddof=1)[0][1]
+ )
+ np.testing.assert_array_almost_equal(result.column(11), np.max(values_a))
+ np.testing.assert_array_almost_equal(result.column(12), np.mean(values_b))
+ np.testing.assert_array_almost_equal(
+ result.column(13), np.median(values_b)
+ )
+ np.testing.assert_array_almost_equal(result.column(14), np.min(values_a))
+ np.testing.assert_array_almost_equal(
+ result.column(15), np.sum(values_b.to_pylist())
+ )
+ np.testing.assert_array_almost_equal(
+ result.column(16), np.std(values_a, ddof=1)
+ )
+ np.testing.assert_array_almost_equal(
+ result.column(17), np.std(values_b, ddof=0)
+ )
+ np.testing.assert_array_almost_equal(
+ result.column(18), np.std(values_c, ddof=1)
+ )
+ np.testing.assert_array_almost_equal(
+ result.column(19), np.var(values_a, ddof=1)
+ )
+ np.testing.assert_array_almost_equal(
+ result.column(20), np.var(values_b, ddof=0)
+ )
+ np.testing.assert_array_almost_equal(
+ result.column(21), np.var(values_c, ddof=1)
)
- result = df.collect()[0]
- assert result.column(0) == pa.array([3])
- assert result.column(1) == pa.array([1])
- assert result.column(2) == pa.array([3], type=pa.int64())
- assert result.column(3) == pa.array([2], type=pa.uint64())
diff --git a/src/functions.rs b/src/functions.rs
index ac1077e..8847dab 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -287,25 +287,49 @@ scalar_function!(upper, Upper, "Converts the string to
all upper case.");
scalar_function!(make_array, MakeArray);
scalar_function!(array, MakeArray);
scalar_function!(nullif, NullIf);
-//scalar_function!(uuid, Uuid);
-//scalar_function!(struct, Struct);
+scalar_function!(uuid, Uuid);
+scalar_function!(r#struct, Struct); // Use raw identifier since struct is a
keyword
scalar_function!(from_unixtime, FromUnixtime);
scalar_function!(arrow_typeof, ArrowTypeof);
scalar_function!(random, Random);
+aggregate_function!(approx_distinct, ApproxDistinct);
+aggregate_function!(approx_median, ApproxMedian);
+aggregate_function!(approx_percentile_cont, ApproxPercentileCont);
+aggregate_function!(
+ approx_percentile_cont_with_weight,
+ ApproxPercentileContWithWeight
+);
+aggregate_function!(array_agg, ArrayAgg);
aggregate_function!(avg, Avg);
+aggregate_function!(corr, Correlation);
aggregate_function!(count, Count);
+aggregate_function!(covar, Covariance);
+aggregate_function!(covar_pop, CovariancePop);
+aggregate_function!(covar_samp, Covariance);
+aggregate_function!(grouping, Grouping);
aggregate_function!(max, Max);
+aggregate_function!(mean, Avg);
+aggregate_function!(median, Median);
aggregate_function!(min, Min);
aggregate_function!(sum, Sum);
-aggregate_function!(approx_distinct, ApproxDistinct);
+aggregate_function!(stddev, Stddev);
+aggregate_function!(stddev_pop, StddevPop);
+aggregate_function!(stddev_samp, Stddev);
+aggregate_function!(var, Variance);
+aggregate_function!(var_pop, VariancePop);
+aggregate_function!(var_samp, Variance);
pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(abs))?;
m.add_wrapped(wrap_pyfunction!(acos))?;
m.add_wrapped(wrap_pyfunction!(approx_distinct))?;
m.add_wrapped(wrap_pyfunction!(alias))?;
+ m.add_wrapped(wrap_pyfunction!(approx_median))?;
+ m.add_wrapped(wrap_pyfunction!(approx_percentile_cont))?;
+ m.add_wrapped(wrap_pyfunction!(approx_percentile_cont_with_weight))?;
m.add_wrapped(wrap_pyfunction!(array))?;
+ m.add_wrapped(wrap_pyfunction!(array_agg))?;
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
m.add_wrapped(wrap_pyfunction!(ascii))?;
m.add_wrapped(wrap_pyfunction!(asin))?;
@@ -322,9 +346,13 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(col))?;
m.add_wrapped(wrap_pyfunction!(concat_ws))?;
m.add_wrapped(wrap_pyfunction!(concat))?;
+ m.add_wrapped(wrap_pyfunction!(corr))?;
m.add_wrapped(wrap_pyfunction!(cos))?;
m.add_wrapped(wrap_pyfunction!(count))?;
m.add_wrapped(wrap_pyfunction!(count_star))?;
+ m.add_wrapped(wrap_pyfunction!(covar))?;
+ m.add_wrapped(wrap_pyfunction!(covar_pop))?;
+ m.add_wrapped(wrap_pyfunction!(covar_samp))?;
m.add_wrapped(wrap_pyfunction!(current_date))?;
m.add_wrapped(wrap_pyfunction!(current_time))?;
m.add_wrapped(wrap_pyfunction!(date_bin))?;
@@ -336,6 +364,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(exp))?;
m.add_wrapped(wrap_pyfunction!(floor))?;
m.add_wrapped(wrap_pyfunction!(from_unixtime))?;
+ m.add_wrapped(wrap_pyfunction!(grouping))?;
m.add_wrapped(wrap_pyfunction!(in_list))?;
m.add_wrapped(wrap_pyfunction!(initcap))?;
m.add_wrapped(wrap_pyfunction!(left))?;
@@ -350,6 +379,8 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(max))?;
m.add_wrapped(wrap_pyfunction!(make_array))?;
m.add_wrapped(wrap_pyfunction!(md5))?;
+ m.add_wrapped(wrap_pyfunction!(mean))?;
+ m.add_wrapped(wrap_pyfunction!(median))?;
m.add_wrapped(wrap_pyfunction!(min))?;
m.add_wrapped(wrap_pyfunction!(now))?;
m.add_wrapped(wrap_pyfunction!(nullif))?;
@@ -376,8 +407,11 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(split_part))?;
m.add_wrapped(wrap_pyfunction!(sqrt))?;
m.add_wrapped(wrap_pyfunction!(starts_with))?;
+ m.add_wrapped(wrap_pyfunction!(stddev))?;
+ m.add_wrapped(wrap_pyfunction!(stddev_pop))?;
+ m.add_wrapped(wrap_pyfunction!(stddev_samp))?;
m.add_wrapped(wrap_pyfunction!(strpos))?;
- //m.add_wrapped(wrap_pyfunction!(struct))?;
+ m.add_wrapped(wrap_pyfunction!(r#struct))?; // Use raw identifier since
struct is a keyword
m.add_wrapped(wrap_pyfunction!(substr))?;
m.add_wrapped(wrap_pyfunction!(sum))?;
m.add_wrapped(wrap_pyfunction!(tan))?;
@@ -390,7 +424,10 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(trim))?;
m.add_wrapped(wrap_pyfunction!(trunc))?;
m.add_wrapped(wrap_pyfunction!(upper))?;
- //m.add_wrapped(wrap_pyfunction!(uuid))?;
+ m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name
collision
+ m.add_wrapped(wrap_pyfunction!(var))?;
+ m.add_wrapped(wrap_pyfunction!(var_pop))?;
+ m.add_wrapped(wrap_pyfunction!(var_samp))?;
m.add_wrapped(wrap_pyfunction!(window))?;
Ok(())
}