This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 91a44c1e5a Fix ArrayAgg schema mismatch issue (#8055)
91a44c1e5a is described below
commit 91a44c1e5aaed6b3037eb620c7a753b51755b187
Author: Jay Zhan <[email protected]>
AuthorDate: Fri Nov 10 01:44:54 2023 +0800
Fix ArrayAgg schema mismatch issue (#8055)
* fix schema
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* upd parquet-testing
Signed-off-by: jayzhan211 <[email protected]>
* avoid parquet file
Signed-off-by: jayzhan211 <[email protected]>
* reset parquet-testing
Signed-off-by: jayzhan211 <[email protected]>
* remove file
Signed-off-by: jayzhan211 <[email protected]>
* fix
Signed-off-by: jayzhan211 <[email protected]>
* fix test
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* rename and upd docstring
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
datafusion/core/src/dataframe/mod.rs | 86 ++++++++++++++++++++++
.../physical-expr/src/aggregate/array_agg.rs | 43 +++++++++--
.../src/aggregate/array_agg_distinct.rs | 13 +++-
.../src/aggregate/array_agg_ordered.rs | 20 +++--
datafusion/physical-expr/src/aggregate/build_in.rs | 18 +++--
5 files changed, 160 insertions(+), 20 deletions(-)
diff --git a/datafusion/core/src/dataframe/mod.rs
b/datafusion/core/src/dataframe/mod.rs
index 0a99c33182..89e82fa952 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -1340,6 +1340,92 @@ mod tests {
use super::*;
+ async fn assert_logical_expr_schema_eq_physical_expr_schema(
+ df: DataFrame,
+ ) -> Result<()> {
+ let logical_expr_dfschema = df.schema();
+ let logical_expr_schema =
SchemaRef::from(logical_expr_dfschema.to_owned());
+ let batches = df.collect().await?;
+ let physical_expr_schema = batches[0].schema();
+ assert_eq!(logical_expr_schema, physical_expr_schema);
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_array_agg_ord_schema() -> Result<()> {
+ let ctx = SessionContext::new();
+
+ let create_table_query = r#"
+ CREATE TABLE test_table (
+ "double_field" DOUBLE,
+ "string_field" VARCHAR
+ ) AS VALUES
+ (1.0, 'a'),
+ (2.0, 'b'),
+ (3.0, 'c')
+ "#;
+ ctx.sql(create_table_query).await?;
+
+ let query = r#"SELECT
+ array_agg("double_field" ORDER BY "string_field") as "double_field",
+ array_agg("string_field" ORDER BY "string_field") as "string_field"
+ FROM test_table"#;
+
+ let result = ctx.sql(query).await?;
+ assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_array_agg_schema() -> Result<()> {
+ let ctx = SessionContext::new();
+
+ let create_table_query = r#"
+ CREATE TABLE test_table (
+ "double_field" DOUBLE,
+ "string_field" VARCHAR
+ ) AS VALUES
+ (1.0, 'a'),
+ (2.0, 'b'),
+ (3.0, 'c')
+ "#;
+ ctx.sql(create_table_query).await?;
+
+ let query = r#"SELECT
+ array_agg("double_field") as "double_field",
+ array_agg("string_field") as "string_field"
+ FROM test_table"#;
+
+ let result = ctx.sql(query).await?;
+ assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_array_agg_distinct_schema() -> Result<()> {
+ let ctx = SessionContext::new();
+
+ let create_table_query = r#"
+ CREATE TABLE test_table (
+ "double_field" DOUBLE,
+ "string_field" VARCHAR
+ ) AS VALUES
+ (1.0, 'a'),
+ (2.0, 'b'),
+ (2.0, 'a')
+ "#;
+ ctx.sql(create_table_query).await?;
+
+ let query = r#"SELECT
+ array_agg(distinct "double_field") as "double_field",
+ array_agg(distinct "string_field") as "string_field"
+ FROM test_table"#;
+
+ let result = ctx.sql(query).await?;
+ assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
+ Ok(())
+ }
+
#[tokio::test]
async fn select_columns() -> Result<()> {
// build plan using Table API
diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs
b/datafusion/physical-expr/src/aggregate/array_agg.rs
index 4dccbfef07..91d5c867d3 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg.rs
@@ -34,9 +34,14 @@ use std::sync::Arc;
/// ARRAY_AGG aggregate expression
#[derive(Debug)]
pub struct ArrayAgg {
+ /// Column name
name: String,
+ /// The DataType for the input expression
input_data_type: DataType,
+ /// The input expression
expr: Arc<dyn PhysicalExpr>,
+ /// If the input expression can have NULLs
+ nullable: bool,
}
impl ArrayAgg {
@@ -45,11 +50,13 @@ impl ArrayAgg {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
+ nullable: bool,
) -> Self {
Self {
name: name.into(),
- expr,
input_data_type: data_type,
+ expr,
+ nullable,
}
}
}
@@ -62,8 +69,9 @@ impl AggregateExpr for ArrayAgg {
fn field(&self) -> Result<Field> {
Ok(Field::new_list(
&self.name,
+ // This should be the same as return type of
AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), true),
- false,
+ self.nullable,
))
}
@@ -77,7 +85,7 @@ impl AggregateExpr for ArrayAgg {
Ok(vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), true),
- false,
+ self.nullable,
)])
}
@@ -184,7 +192,6 @@ mod tests {
use super::*;
use crate::expressions::col;
use crate::expressions::tests::aggregate;
- use crate::generic_test_op;
use arrow::array::ArrayRef;
use arrow::array::Int32Array;
use arrow::datatypes::*;
@@ -195,6 +202,30 @@ mod tests {
use datafusion_common::DataFusionError;
use datafusion_common::Result;
+ macro_rules! test_op {
+ ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => {
+ test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type())
+ };
+ ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr,
$EXPECTED_DATATYPE:expr) => {{
+ let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]);
+
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()),
vec![$ARRAY])?;
+
+ let agg = Arc::new(<$OP>::new(
+ col("a", &schema)?,
+ "bla".to_string(),
+ $EXPECTED_DATATYPE,
+ true,
+ ));
+ let actual = aggregate(&batch, agg)?;
+ let expected = ScalarValue::from($EXPECTED);
+
+ assert_eq!(expected, actual);
+
+ Ok(()) as Result<(), DataFusionError>
+ }};
+ }
+
#[test]
fn array_agg_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
@@ -208,7 +239,7 @@ mod tests {
])]);
let list = ScalarValue::List(Arc::new(list));
- generic_test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32)
+ test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32)
}
#[test]
@@ -264,7 +295,7 @@ mod tests {
let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap();
- generic_test_op!(
+ test_op!(
array,
DataType::List(Arc::new(Field::new_list(
"item",
diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
index 9b391b0c42..1efae424cc 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
@@ -40,6 +40,8 @@ pub struct DistinctArrayAgg {
input_data_type: DataType,
/// The input expression
expr: Arc<dyn PhysicalExpr>,
+ /// If the input expression can have NULLs
+ nullable: bool,
}
impl DistinctArrayAgg {
@@ -48,12 +50,14 @@ impl DistinctArrayAgg {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
input_data_type: DataType,
+ nullable: bool,
) -> Self {
let name = name.into();
Self {
name,
- expr,
input_data_type,
+ expr,
+ nullable,
}
}
}
@@ -67,8 +71,9 @@ impl AggregateExpr for DistinctArrayAgg {
fn field(&self) -> Result<Field> {
Ok(Field::new_list(
&self.name,
+ // This should be the same as return type of
AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), true),
- false,
+ self.nullable,
))
}
@@ -82,7 +87,7 @@ impl AggregateExpr for DistinctArrayAgg {
Ok(vec![Field::new_list(
format_state_name(&self.name, "distinct_array_agg"),
Field::new("item", self.input_data_type.clone(), true),
- false,
+ self.nullable,
)])
}
@@ -238,6 +243,7 @@ mod tests {
col("a", &schema)?,
"bla".to_string(),
datatype,
+ true,
));
let actual = aggregate(&batch, agg)?;
@@ -255,6 +261,7 @@ mod tests {
col("a", &schema)?,
"bla".to_string(),
datatype,
+ true,
));
let mut accum1 = agg.create_accumulator()?;
diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
index a53d53107a..9ca83a781a 100644
--- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
+++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
@@ -48,10 +48,17 @@ use itertools::izip;
/// and that can merge aggregations from multiple partitions.
#[derive(Debug)]
pub struct OrderSensitiveArrayAgg {
+ /// Column name
name: String,
+ /// The DataType for the input expression
input_data_type: DataType,
- order_by_data_types: Vec<DataType>,
+ /// The input expression
expr: Arc<dyn PhysicalExpr>,
+ /// If the input expression can have NULLs
+ nullable: bool,
+ /// Ordering data types
+ order_by_data_types: Vec<DataType>,
+ /// Ordering requirement
ordering_req: LexOrdering,
}
@@ -61,13 +68,15 @@ impl OrderSensitiveArrayAgg {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
input_data_type: DataType,
+ nullable: bool,
order_by_data_types: Vec<DataType>,
ordering_req: LexOrdering,
) -> Self {
Self {
name: name.into(),
- expr,
input_data_type,
+ expr,
+ nullable,
order_by_data_types,
ordering_req,
}
@@ -82,8 +91,9 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
fn field(&self) -> Result<Field> {
Ok(Field::new_list(
&self.name,
+ // This should be the same as return type of
AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), true),
- false,
+ self.nullable,
))
}
@@ -99,13 +109,13 @@ impl AggregateExpr for OrderSensitiveArrayAgg {
let mut fields = vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), true),
- false,
+ self.nullable, // This should be the same as field()
)];
let orderings = ordering_fields(&self.ordering_req,
&self.order_by_data_types);
fields.push(Field::new_list(
format_state_name(&self.name, "array_agg_orderings"),
Field::new("item", DataType::Struct(Fields::from(orderings)),
true),
- false,
+ self.nullable,
));
Ok(fields)
}
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index 6568457bc2..596197b4ee 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -114,13 +114,16 @@ pub fn create_aggregate_expr(
),
(AggregateFunction::ArrayAgg, false) => {
let expr = input_phy_exprs[0].clone();
+ let nullable = expr.nullable(input_schema)?;
+
if ordering_req.is_empty() {
- Arc::new(expressions::ArrayAgg::new(expr, name, data_type))
+ Arc::new(expressions::ArrayAgg::new(expr, name, data_type,
nullable))
} else {
Arc::new(expressions::OrderSensitiveArrayAgg::new(
expr,
name,
data_type,
+ nullable,
ordering_types,
ordering_req.to_vec(),
))
@@ -132,10 +135,13 @@ pub fn create_aggregate_expr(
"ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive
aggregations are not available"
);
}
+ let expr = input_phy_exprs[0].clone();
+ let is_expr_nullable = expr.nullable(input_schema)?;
Arc::new(expressions::DistinctArrayAgg::new(
- input_phy_exprs[0].clone(),
+ expr,
name,
data_type,
+ is_expr_nullable,
))
}
(AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
@@ -432,8 +438,8 @@ mod tests {
assert_eq!(
Field::new_list(
"c1",
- Field::new("item", data_type.clone(), true,),
- false,
+ Field::new("item", data_type.clone(), true),
+ true,
),
result_agg_phy_exprs.field().unwrap()
);
@@ -471,8 +477,8 @@ mod tests {
assert_eq!(
Field::new_list(
"c1",
- Field::new("item", data_type.clone(), true,),
- false,
+ Field::new("item", data_type.clone(), true),
+ true,
),
result_agg_phy_exprs.field().unwrap()
);