alamb commented on a change in pull request #2031:
URL: https://github.com/apache/arrow-datafusion/pull/2031#discussion_r833685712
##########
File path: datafusion-physical-expr/src/coercion_rule/aggregate_rule.rs
##########
@@ -152,6 +152,27 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
+ AggregateFunction::ApproxPercentileContWithWeight => {
+ if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
+ return Err(DataFusionError::Plan(format!(
+ "The function {:?} does not support inputs of type {:?}.",
+ agg_fun, input_types[0]
+ )));
+ }
+ if !is_approx_percentile_cont_supported_arg_type(&input_types[1]) {
+ return Err(DataFusionError::Plan(format!(
+ "The weight argument for {:?} does not support inputs of
type {:?}.",
+ agg_fun, input_types[0]
Review comment:
```suggestion
agg_fun, input_types[1]
```
##########
File path: datafusion/tests/sql/aggregates.rs
##########
@@ -476,6 +476,59 @@ async fn csv_query_approx_percentile_cont() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn csv_query_approx_percentile_cont_with_weight() -> Result<()> {
+ let mut ctx = SessionContext::new();
+ register_aggregate_csv(&mut ctx).await?;
+
+ // compare approx_percentile_cont and approx_percentile_cont_with_weight
+ let sql = "SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM
aggregate_test_100 GROUP BY 1 ORDER BY 1";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+--------+",
+ "| c1 | c3_p95 |",
+ "+----+--------+",
+ "| a | 73 |",
+ "| b | 68 |",
+ "| c | 122 |",
+ "| d | 124 |",
+ "| e | 115 |",
+ "+----+--------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT c1, approx_percentile_cont_with_weight(c3, 1, 0.95) AS
c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+--------+",
+ "| c1 | c3_p95 |",
+ "+----+--------+",
+ "| a | 73 |",
+ "| b | 68 |",
+ "| c | 122 |",
+ "| d | 124 |",
+ "| e | 115 |",
+ "+----+--------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT c1, approx_percentile_cont_with_weight(c3, c2, 0.95) AS
c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+--------+",
Review comment:
these values are different -- I don't honestly know if they are correct
or not 🤷
##########
File path: datafusion-physical-expr/src/expressions/approx_percentile_cont.rs
##########
@@ -194,75 +204,125 @@ pub struct ApproxPercentileAccumulator {
impl ApproxPercentileAccumulator {
pub fn new(percentile: f64, return_type: DataType) -> Self {
Self {
- digest: TDigest::new(100),
+ digest: TDigest::new(DEFAULT_MAX_SIZE),
percentile,
return_type,
}
}
-}
-impl Accumulator for ApproxPercentileAccumulator {
- fn state(&self) -> Result<Vec<ScalarValue>> {
- Ok(self.digest.to_scalar_state())
+ pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) {
+ self.digest = TDigest::merge_digests(digests);
}
- fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- debug_assert_eq!(
- values.len(),
- 1,
- "invalid number of values in batch percentile update"
- );
- let values = &values[0];
-
- self.digest = match values.data_type() {
+ pub(crate) fn convert_to_ordered_float(
+ values: &ArrayRef,
+ ) -> Result<Vec<OrderedFloat<f64>>> {
+ match values.data_type() {
DataType::Float64 => {
let array =
values.as_any().downcast_ref::<Float64Array>().unwrap();
- self.digest.merge_unsorted(array.values().iter().cloned())?
+ Ok(array
+ .values()
+ .iter()
+ .filter_map(|v| v.try_as_f64().transpose())
+ .collect::<Result<Vec<_>>>()?)
Review comment:
Yeah, it is tough because the type of the various branches are different.
You could could save the copy by doing something like letting the caller
provide a function that gets invoked for each element (untested):
```rust
pub(crate) fn convert_to_ordered_float(
values: &ArrayRef,
f: impl FnMut(Option<OrderedFloat<f64>))
) -> Result<()> {
```
And then call `f()` on each element;
```rust
...
DataType::Float32 => {
let array =
values.as_any().downcast_ref::<Float32Array>().unwrap();
array
.values()
.iter()
.try_for_each(|v| {
f(v.try_as_f64()?)
})
}
...
```
##########
File path: datafusion/tests/sql/aggregates.rs
##########
@@ -476,6 +476,59 @@ async fn csv_query_approx_percentile_cont() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn csv_query_approx_percentile_cont_with_weight() -> Result<()> {
+ let mut ctx = SessionContext::new();
+ register_aggregate_csv(&mut ctx).await?;
+
+ // compare approx_percentile_cont and approx_percentile_cont_with_weight
+ let sql = "SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM
aggregate_test_100 GROUP BY 1 ORDER BY 1";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+--------+",
+ "| c1 | c3_p95 |",
+ "+----+--------+",
+ "| a | 73 |",
+ "| b | 68 |",
+ "| c | 122 |",
+ "| d | 124 |",
+ "| e | 115 |",
+ "+----+--------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT c1, approx_percentile_cont_with_weight(c3, 1, 0.95) AS
c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+--------+",
+ "| c1 | c3_p95 |",
+ "+----+--------+",
+ "| a | 73 |",
+ "| b | 68 |",
+ "| c | 122 |",
+ "| d | 124 |",
+ "| e | 115 |",
+ "+----+--------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT c1, approx_percentile_cont_with_weight(c3, c2, 0.95) AS
c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+--------+",
+ "| c1 | c3_p95 |",
+ "+----+--------+",
+ "| a | 74 |",
+ "| b | 68 |",
+ "| c | 123 |",
+ "| d | 124 |",
+ "| e | 115 |",
+ "+----+--------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
Review comment:
I wonder if it is worth one or two error cases here (e.g. try and invoke
this function on a `StringArray`)
##########
File path: datafusion-physical-expr/src/expressions/approx_percentile_cont.rs
##########
@@ -194,75 +204,125 @@ pub struct ApproxPercentileAccumulator {
impl ApproxPercentileAccumulator {
pub fn new(percentile: f64, return_type: DataType) -> Self {
Self {
- digest: TDigest::new(100),
+ digest: TDigest::new(DEFAULT_MAX_SIZE),
Review comment:
changing to a symbolic constant `DEFAULT_MAX_SIZE` is a nice improvement
👍
##########
File path: datafusion/tests/sql/aggregates.rs
##########
@@ -476,6 +476,59 @@ async fn csv_query_approx_percentile_cont() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn csv_query_approx_percentile_cont_with_weight() -> Result<()> {
+ let mut ctx = SessionContext::new();
+ register_aggregate_csv(&mut ctx).await?;
+
+ // compare approx_percentile_cont and approx_percentile_cont_with_weight
+ let sql = "SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM
aggregate_test_100 GROUP BY 1 ORDER BY 1";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+--------+",
+ "| c1 | c3_p95 |",
+ "+----+--------+",
+ "| a | 73 |",
+ "| b | 68 |",
+ "| c | 122 |",
+ "| d | 124 |",
+ "| e | 115 |",
+ "+----+--------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT c1, approx_percentile_cont_with_weight(c3, 1, 0.95) AS
c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+--------+",
+ "| c1 | c3_p95 |",
+ "+----+--------+",
+ "| a | 73 |",
+ "| b | 68 |",
+ "| c | 122 |",
+ "| d | 124 |",
+ "| e | 115 |",
+ "+----+--------+",
+ ];
Review comment:
these values seem to be the same as the ones with
`approx_percentile_cont(c3, 0.95)` which I think is the point of the test.
Perhaps we could encode that into the test:
```suggestion
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]