This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new dec9adcbe Optimize the evaluation of `IN` for large lists using InSet
(#2156)
dec9adcbe is described below
commit dec9adcbe1c7b878e2544c43c9351ab9ded50e4f
Author: Yang Jiang <[email protected]>
AuthorDate: Sat Apr 9 02:36:46 2022 +0800
Optimize the evaluation of `IN` for large lists using InSet (#2156)
* commit 1
* Add an InSet as an optimized version for IN_LIST
* fix clippy
* fix ut
* fix fmt
* fix clippy
* make clear in explain
Co-authored-by: Andrew Lamb <[email protected]>
* add UT and change threshold
* fix clippy
* change OPTIMIZER_INSET_THRESHOLD
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/core/src/physical_plan/planner.rs | 60 ++-
datafusion/core/tests/sql/predicates.rs | 17 +
.../physical-expr/src/expressions/in_list.rs | 413 +++++++++++++++------
3 files changed, 372 insertions(+), 118 deletions(-)
diff --git a/datafusion/core/src/physical_plan/planner.rs
b/datafusion/core/src/physical_plan/planner.rs
index 1e6c13f9c..273ec0a85 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -1706,7 +1706,7 @@ mod tests {
.build()?;
let execution_plan = plan(&logical_plan).await?;
// verify that the plan correctly adds cast from Int64(1) to Utf8
- let expected = "InListExpr { expr: Column { name: \"c1\", index: 0 },
list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value:
Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }],
negated: false }";
+ let expected = "InListExpr { expr: Column { name: \"c1\", index: 0 },
list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value:
Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }],
negated: false, set: None }";
assert!(format!("{:?}", execution_plan).contains(expected));
// expression: "a in (true, 'a')"
@@ -1742,6 +1742,64 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn in_set_test() -> Result<()> {
+ let testdata = crate::test_util::arrow_test_data();
+ let path = format!("{}/csv/aggregate_test_100.csv", testdata);
+ let options = CsvReadOptions::new().schema_infer_max_records(100);
+
+ // OPTIMIZER_INSET_THRESHOLD = 10
+ // expression: "a in ('a', 1, 2, ..30)"
+ let mut list =
vec![Expr::Literal(ScalarValue::Utf8(Some("a".to_string())))];
+ for i in 1..31 {
+ list.push(Expr::Literal(ScalarValue::Int64(Some(i))));
+ }
+
+ let logical_plan = LogicalPlanBuilder::scan_csv(
+ Arc::new(LocalFileSystem {}),
+ &path,
+ options,
+ None,
+ 1,
+ )
+ .await?
+ .filter(col("c12").lt(lit(0.05)))?
+ .project(vec![col("c1").in_list(list, false)])?
+ .build()?;
+ let execution_plan = plan(&logical_plan).await?;
+ let expected = "expr: [(InListExpr { expr: Column { name: \"c1\",
index: 0 }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal {
value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false }
}, CastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8, cast_options:
CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(3) },
cast_type: Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr:
Literal { value: Int64(4) [...]
+ assert!(format!("{:?}", execution_plan).contains(expected));
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn in_set_null_test() -> Result<()> {
+ let testdata = crate::test_util::arrow_test_data();
+ let path = format!("{}/csv/aggregate_test_100.csv", testdata);
+ let options = CsvReadOptions::new().schema_infer_max_records(100);
+ // test NULL
+ let mut list = vec![Expr::Literal(ScalarValue::Int64(None))];
+ for i in 1..31 {
+ list.push(Expr::Literal(ScalarValue::Int64(Some(i))));
+ }
+
+ let logical_plan = LogicalPlanBuilder::scan_csv(
+ Arc::new(LocalFileSystem {}),
+ &path,
+ options,
+ None,
+ 1,
+ )
+ .await?
+ .filter(col("c12").lt(lit(0.05)))?
+ .project(vec![col("c1").in_list(list, false)])?
+ .build()?;
+ let execution_plan = plan(&logical_plan).await?;
+ let expected = "expr: [(InListExpr { expr: Column { name: \"c1\",
index: 0 }, list: [CastExpr { expr: Literal { value: Int64(NULL) }, cast_type:
Utf8, cast_options: CastOptions { safe: false } }, CastExpr { expr: Literal {
value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false }
}, CastExpr { expr: Literal { value: Int64(2) }, cast_type: Utf8, cast_options:
CastOptions { safe: false } }, CastExpr { expr: Literal { value: Int64(3) },
cast_type: Utf8, cast_opti [...]
+ assert!(format!("{:?}", execution_plan).contains(expected));
+ Ok(())
+ }
+
#[tokio::test]
async fn hash_agg_input_schema() -> Result<()> {
let testdata = crate::test_util::arrow_test_data();
diff --git a/datafusion/core/tests/sql/predicates.rs
b/datafusion/core/tests/sql/predicates.rs
index 1369baa75..ea79e2b14 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -369,3 +369,20 @@ async fn test_expect_distinct() -> Result<()> {
assert_batches_eq!(expected, &actual);
Ok(())
}
+
+#[tokio::test]
+async fn csv_in_set_test() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql = "SELECT count(*) FROM aggregate_test_100 WHERE c7 in
('25','155','204','77','208','67','139','191','26','7','202','113','129','197','249','146','129','220','154','163','220','19','71','243','150','231','196','170','99','255');";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----------------+",
+ "| COUNT(UInt8(1)) |",
+ "+-----------------+",
+ "| 36 |",
+ "+-----------------+",
+ ];
+ assert_batches_sorted_eq!(expected, &actual);
+ Ok(())
+}
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs
b/datafusion/physical-expr/src/expressions/in_list.rs
index 2aee0d87d..a6894b938 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -18,6 +18,7 @@
//! InList expression
use std::any::Any;
+use std::collections::HashSet;
use std::sync::Arc;
use arrow::array::GenericStringArray;
@@ -32,13 +33,19 @@ use arrow::{
record_batch::RecordBatch,
};
-use crate::PhysicalExpr;
+use crate::{expressions, PhysicalExpr};
use arrow::array::*;
use arrow::buffer::{Buffer, MutableBuffer};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
+/// Size at which to use a Set rather than Vec for `IN` / `NOT IN`
+/// Value chosen by the benchmark at
+/// https://github.com/apache/arrow-datafusion/pull/2156#discussion_r845198369
+/// TODO: add switch codeGen in In_List
+static OPTIMIZER_INSET_THRESHOLD: usize = 30;
+
macro_rules! compare_op_scalar {
($left: expr, $right:expr, $op:expr) => {{
let null_bit_buffer = $left.data().null_buffer().cloned();
@@ -69,6 +76,23 @@ pub struct InListExpr {
expr: Arc<dyn PhysicalExpr>,
list: Vec<Arc<dyn PhysicalExpr>>,
negated: bool,
+ set: Option<InSet>,
+}
+
+/// InSet
+#[derive(Debug)]
+pub struct InSet {
+ set: HashSet<ScalarValue>,
+}
+
+impl InSet {
+ pub fn new(set: HashSet<ScalarValue>) -> Self {
+ Self { set }
+ }
+
+ pub fn get_set(&self) -> &HashSet<ScalarValue> {
+ &self.set
+ }
}
macro_rules! make_contains {
@@ -181,6 +205,26 @@ macro_rules! make_contains_primitive {
}};
}
+macro_rules! set_contains_with_negated {
+ ($ARRAY:expr, $LIST_VALUES:expr, $NEGATED:expr) => {{
+ if $NEGATED {
+ return Ok(ColumnarValue::Array(Arc::new(
+ $ARRAY
+ .iter()
+ .map(|x| x.map(|v|
!$LIST_VALUES.contains(&v.try_into().unwrap())))
+ .collect::<BooleanArray>(),
+ )));
+ } else {
+ return Ok(ColumnarValue::Array(Arc::new(
+ $ARRAY
+ .iter()
+ .map(|x| x.map(|v|
$LIST_VALUES.contains(&v.try_into().unwrap())))
+ .collect::<BooleanArray>(),
+ )));
+ }
+ }};
+}
+
// whether each value on the left (can be null) is contained in the non-null
list
fn in_list_primitive<T: ArrowPrimitiveType>(
array: &PrimitiveArray<T>,
@@ -220,6 +264,42 @@ fn not_in_list_utf8<OffsetSize: StringOffsetSizeTrait>(
compare_op_scalar!(array, values, |x, v: &[&str]| !v.contains(&x))
}
+//check all filter values of In clause are static.
+//include `CastExpr + Literal` or `Literal`
+fn check_all_static_filter_expr(list: &[Arc<dyn PhysicalExpr>]) -> bool {
+ list.iter().all(|v| {
+ let cast = v.as_any().downcast_ref::<expressions::CastExpr>();
+ if let Some(c) = cast {
+ c.expr()
+ .as_any()
+ .downcast_ref::<expressions::Literal>()
+ .is_some()
+ } else {
+ let cast = v.as_any().downcast_ref::<expressions::Literal>();
+ cast.is_some()
+ }
+ })
+}
+
+fn cast_static_filter_to_set(list: &[Arc<dyn PhysicalExpr>]) ->
HashSet<ScalarValue> {
+ HashSet::from_iter(list.iter().map(|expr| {
+ if let Some(cast) =
expr.as_any().downcast_ref::<expressions::CastExpr>() {
+ cast.expr()
+ .as_any()
+ .downcast_ref::<expressions::Literal>()
+ .unwrap()
+ .value()
+ .clone()
+ } else {
+ expr.as_any()
+ .downcast_ref::<expressions::Literal>()
+ .unwrap()
+ .value()
+ .clone()
+ }
+ }))
+}
+
impl InListExpr {
/// Create a new InList expression
pub fn new(
@@ -227,10 +307,20 @@ impl InListExpr {
list: Vec<Arc<dyn PhysicalExpr>>,
negated: bool,
) -> Self {
- Self {
- expr,
- list,
- negated,
+ if list.len() > OPTIMIZER_INSET_THRESHOLD &&
check_all_static_filter_expr(&list) {
+ Self {
+ expr,
+ set: Some(InSet::new(cast_static_filter_to_set(&list))),
+ list,
+ negated,
+ }
+ } else {
+ Self {
+ expr,
+ list,
+ negated,
+ set: None,
+ }
}
}
@@ -318,7 +408,13 @@ impl InListExpr {
impl std::fmt::Display for InListExpr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
if self.negated {
- write!(f, "{} NOT IN ({:?})", self.expr, self.list)
+ if self.set.is_some() {
+ write!(f, "{} NOT IN (SET) ({:?})", self.expr, self.list)
+ } else {
+ write!(f, "{} NOT IN ({:?})", self.expr, self.list)
+ }
+ } else if self.set.is_some() {
+ write!(f, "Use {} IN (SET) ({:?})", self.expr, self.list)
} else {
write!(f, "{} IN ({:?})", self.expr, self.list)
}
@@ -342,119 +438,202 @@ impl PhysicalExpr for InListExpr {
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let value = self.expr.evaluate(batch)?;
let value_data_type = value.data_type();
- let list_values = self
- .list
- .iter()
- .map(|expr| expr.evaluate(batch))
- .collect::<Result<Vec<_>>>()?;
-
- let array = match value {
- ColumnarValue::Array(array) => array,
- ColumnarValue::Scalar(scalar) => scalar.to_array(),
- };
- match value_data_type {
- DataType::Float32 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Float32,
- Float32Array
- )
- }
- DataType::Float64 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Float64,
- Float64Array
- )
- }
- DataType::Int16 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Int16,
- Int16Array
- )
- }
- DataType::Int32 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Int32,
- Int32Array
- )
- }
- DataType::Int64 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Int64,
- Int64Array
- )
- }
- DataType::Int8 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Int8,
- Int8Array
- )
- }
- DataType::UInt16 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- UInt16,
- UInt16Array
- )
- }
- DataType::UInt32 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- UInt32,
- UInt32Array
- )
- }
- DataType::UInt64 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- UInt64,
- UInt64Array
- )
- }
- DataType::UInt8 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- UInt8,
- UInt8Array
- )
- }
- DataType::Boolean => {
- make_contains!(array, list_values, self.negated, Boolean,
BooleanArray)
- }
- DataType::Utf8 => self.compare_utf8::<i32>(array, list_values,
self.negated),
- DataType::LargeUtf8 => {
- self.compare_utf8::<i64>(array, list_values, self.negated)
+ if let Some(in_set) = &self.set {
+ let array = match value {
+ ColumnarValue::Array(array) => array,
+ ColumnarValue::Scalar(scalar) => scalar.to_array(),
+ };
+ let set = in_set.get_set();
+ match value_data_type {
+ DataType::Boolean => {
+ let array =
array.as_any().downcast_ref::<BooleanArray>().unwrap();
+ set_contains_with_negated!(array, set, self.negated)
+ }
+ DataType::Int8 => {
+ let array =
array.as_any().downcast_ref::<Int8Array>().unwrap();
+ set_contains_with_negated!(array, set, self.negated)
+ }
+ DataType::Int16 => {
+ let array =
array.as_any().downcast_ref::<Int16Array>().unwrap();
+ set_contains_with_negated!(array, set, self.negated)
+ }
+ DataType::Int32 => {
+ let array =
array.as_any().downcast_ref::<Int32Array>().unwrap();
+ set_contains_with_negated!(array, set, self.negated)
+ }
+ DataType::Int64 => {
+ let array =
array.as_any().downcast_ref::<Int64Array>().unwrap();
+ set_contains_with_negated!(array, set, self.negated)
+ }
+ DataType::UInt8 => {
+ let array =
array.as_any().downcast_ref::<UInt8Array>().unwrap();
+ set_contains_with_negated!(array, set, self.negated)
+ }
+ DataType::UInt16 => {
+ let array =
array.as_any().downcast_ref::<UInt16Array>().unwrap();
+ set_contains_with_negated!(array, set, self.negated)
+ }
+ DataType::UInt32 => {
+ let array =
array.as_any().downcast_ref::<UInt32Array>().unwrap();
+ set_contains_with_negated!(array, set, self.negated)
+ }
+ DataType::UInt64 => {
+ let array =
array.as_any().downcast_ref::<UInt64Array>().unwrap();
+ set_contains_with_negated!(array, set, self.negated)
+ }
+ DataType::Float32 => {
+ let array =
array.as_any().downcast_ref::<Float32Array>().unwrap();
+ set_contains_with_negated!(array, set, self.negated)
+ }
+ DataType::Float64 => {
+ let array =
array.as_any().downcast_ref::<Float64Array>().unwrap();
+ set_contains_with_negated!(array, set, self.negated)
+ }
+ DataType::Utf8 => {
+ let array = array
+ .as_any()
+ .downcast_ref::<GenericStringArray<i32>>()
+ .unwrap();
+ set_contains_with_negated!(array, set, self.negated)
+ }
+ DataType::LargeUtf8 => {
+ let array = array
+ .as_any()
+ .downcast_ref::<GenericStringArray<i64>>()
+ .unwrap();
+ set_contains_with_negated!(array, set, self.negated)
+ }
+ datatype => {
+ return Result::Err(DataFusionError::NotImplemented(format!(
+ "InSet does not support datatype {:?}.",
+ datatype
+ )))
+ }
+ };
+ } else {
+ let list_values = self
+ .list
+ .iter()
+ .map(|expr| expr.evaluate(batch))
+ .collect::<Result<Vec<_>>>()?;
+
+ let array = match value {
+ ColumnarValue::Array(array) => array,
+ ColumnarValue::Scalar(scalar) => scalar.to_array(),
+ };
+
+ match value_data_type {
+ DataType::Float32 => {
+ make_contains_primitive!(
+ array,
+ list_values,
+ self.negated,
+ Float32,
+ Float32Array
+ )
+ }
+ DataType::Float64 => {
+ make_contains_primitive!(
+ array,
+ list_values,
+ self.negated,
+ Float64,
+ Float64Array
+ )
+ }
+ DataType::Int16 => {
+ make_contains_primitive!(
+ array,
+ list_values,
+ self.negated,
+ Int16,
+ Int16Array
+ )
+ }
+ DataType::Int32 => {
+ make_contains_primitive!(
+ array,
+ list_values,
+ self.negated,
+ Int32,
+ Int32Array
+ )
+ }
+ DataType::Int64 => {
+ make_contains_primitive!(
+ array,
+ list_values,
+ self.negated,
+ Int64,
+ Int64Array
+ )
+ }
+ DataType::Int8 => {
+ make_contains_primitive!(
+ array,
+ list_values,
+ self.negated,
+ Int8,
+ Int8Array
+ )
+ }
+ DataType::UInt16 => {
+ make_contains_primitive!(
+ array,
+ list_values,
+ self.negated,
+ UInt16,
+ UInt16Array
+ )
+ }
+ DataType::UInt32 => {
+ make_contains_primitive!(
+ array,
+ list_values,
+ self.negated,
+ UInt32,
+ UInt32Array
+ )
+ }
+ DataType::UInt64 => {
+ make_contains_primitive!(
+ array,
+ list_values,
+ self.negated,
+ UInt64,
+ UInt64Array
+ )
+ }
+ DataType::UInt8 => {
+ make_contains_primitive!(
+ array,
+ list_values,
+ self.negated,
+ UInt8,
+ UInt8Array
+ )
+ }
+ DataType::Boolean => {
+ make_contains!(
+ array,
+ list_values,
+ self.negated,
+ Boolean,
+ BooleanArray
+ )
+ }
+ DataType::Utf8 => {
+ self.compare_utf8::<i32>(array, list_values, self.negated)
+ }
+ DataType::LargeUtf8 => {
+ self.compare_utf8::<i64>(array, list_values, self.negated)
+ }
+ datatype =>
Result::Err(DataFusionError::NotImplemented(format!(
+ "InList does not support datatype {:?}.",
+ datatype
+ ))),
}
- datatype => Result::Err(DataFusionError::NotImplemented(format!(
- "InList does not support datatype {:?}.",
- datatype
- ))),
}
}
}