This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new c2972a29c support decimal for inlist expr (#2764)
c2972a29c is described below
commit c2972a29c3165dc8cb92f8d437de80d56d99740e
Author: Kun Liu <[email protected]>
AuthorDate: Sat Jun 25 02:03:33 2022 +0800
support decimal for inlist expr (#2764)
---
datafusion/expr/src/binary_rule.rs | 6 +-
.../physical-expr/src/expressions/in_list.rs | 353 +++++++++++++++++++--
datafusion/physical-expr/src/planner.rs | 7 +-
3 files changed, 335 insertions(+), 31 deletions(-)
diff --git a/datafusion/expr/src/binary_rule.rs
b/datafusion/expr/src/binary_rule.rs
index 6770fccd7..88b4d95ec 100644
--- a/datafusion/expr/src/binary_rule.rs
+++ b/datafusion/expr/src/binary_rule.rs
@@ -150,7 +150,11 @@ fn bitwise_coercion(left_type: &DataType, right_type:
&DataType) -> Option<DataT
}
}
-fn comparison_eq_coercion(lhs_type: &DataType, rhs_type: &DataType) ->
Option<DataType> {
+/// Get the coerced data type for `eq` or `not eq` operation
+pub fn comparison_eq_coercion(
+ lhs_type: &DataType,
+ rhs_type: &DataType,
+) -> Option<DataType> {
// can't compare dictionaries directly due to
// https://github.com/apache/arrow-rs/issues/1201
if lhs_type == rhs_type && !is_dictionary(lhs_type) {
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs
b/datafusion/physical-expr/src/expressions/in_list.rs
index 7d15dd5e9..346eea472 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -33,11 +33,14 @@ use arrow::{
record_batch::RecordBatch,
};
+use crate::expressions::try_cast;
use crate::{expressions, PhysicalExpr};
use arrow::array::*;
use arrow::buffer::{Buffer, MutableBuffer};
use datafusion_common::ScalarValue;
+use datafusion_common::ScalarValue::Decimal128;
use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::binary_rule::comparison_eq_coercion;
use datafusion_expr::ColumnarValue;
/// Size at which to use a Set rather than Vec for `IN` / `NOT IN`
@@ -82,6 +85,8 @@ pub struct InListExpr {
/// InSet
#[derive(Debug)]
pub struct InSet {
+ // TODO: optimization: In the `IN` or `NOT IN` we don't need to consider
the NULL value
+ // The data type is same, we can use set: HashSet<T>
set: HashSet<ScalarValue>,
}
@@ -160,6 +165,7 @@ macro_rules! make_contains_primitive {
ColumnarValue::Scalar(s) => match s {
ScalarValue::$SCALAR_VALUE(Some(v)) => Some(*v),
ScalarValue::$SCALAR_VALUE(None) => None,
+ // TODO this is bug, for primitive the expr list should be
cast to the same data type
ScalarValue::Utf8(None) => None,
datatype => unimplemented!("Unexpected type {} for
InList", datatype),
},
@@ -300,6 +306,90 @@ fn cast_static_filter_to_set(list: &[Arc<dyn
PhysicalExpr>]) -> HashSet<ScalarVa
}))
}
+fn make_list_contains_decimal(
+ array: &DecimalArray,
+ list: Vec<ColumnarValue>,
+ negated: bool,
+) -> BooleanArray {
+ let contains_null = list
+ .iter()
+ .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null()));
+ let values = list
+ .iter()
+ .flat_map(|v| match v {
+ ColumnarValue::Scalar(s) => match s {
+ Decimal128(v128op, _, _) => *v128op,
+ _ => {
+ unreachable!(
+ "InList can't reach other data type for decimal data
type."
+ )
+ }
+ },
+ ColumnarValue::Array(_) => {
+ unimplemented!("InList does not yet support nested columns.")
+ }
+ })
+ .collect::<Vec<_>>();
+
+ if !negated {
+ // In
+ array
+ .iter()
+ .map(|v| v.map(|v128| values.contains(&v128)))
+ .collect::<BooleanArray>()
+ } else {
+ // Not in
+ if contains_null {
+ // If the expr is NOT IN and the list contains NULL value
+ // All the result must be NONE
+ BooleanArray::from(vec![None; array.len()])
+ } else {
+ array
+ .iter()
+ .map(|v| v.map(|v128| !values.contains(&v128)))
+ .collect::<BooleanArray>()
+ }
+ }
+}
+
+fn make_set_contains_decimal(
+ array: &DecimalArray,
+ set: &HashSet<ScalarValue>,
+ negated: bool,
+) -> BooleanArray {
+ let contains_null = set.iter().any(|v| v.is_null());
+ let native_array = set
+ .iter()
+ .flat_map(|v| match v {
+ Decimal128(v128op, _, _) => *v128op,
+ _ => {
+ unreachable!("InList can't reach other data type for decimal
data type.")
+ }
+ })
+ .collect::<Vec<_>>();
+ let native_set: HashSet<i128> = HashSet::from_iter(native_array);
+
+ if !negated {
+ // In
+ array
+ .iter()
+ .map(|v| v.map(|v128| native_set.contains(&v128)))
+ .collect::<BooleanArray>()
+ } else {
+ // Not in
+ if contains_null {
+ // If the expr is NOT IN and the list contains NULL value
+ // All the result must be NONE
+ BooleanArray::from(vec![None; array.len()])
+ } else {
+ array
+ .iter()
+ .map(|v| v.map(|v128| !native_set.contains(&v128)))
+ .collect::<BooleanArray>()
+ }
+ }
+}
+
impl InListExpr {
/// Create a new InList expression
pub fn new(
@@ -504,6 +594,11 @@ impl PhysicalExpr for InListExpr {
.unwrap();
set_contains_with_negated!(array, set, self.negated)
}
+ DataType::Decimal(_, _) => {
+ let array =
array.as_any().downcast_ref::<DecimalArray>().unwrap();
+ let result = make_set_contains_decimal(array, set,
self.negated);
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ }
datatype =>
Result::Err(DataFusionError::NotImplemented(format!(
"InSet does not support datatype {:?}.",
datatype
@@ -631,6 +726,16 @@ impl PhysicalExpr for InListExpr {
let null_array = new_null_array(&DataType::Boolean,
array.len());
Ok(ColumnarValue::Array(Arc::new(null_array)))
}
+ DataType::Decimal(_, _) => {
+ let decimal_array =
+ array.as_any().downcast_ref::<DecimalArray>().unwrap();
+ let result = make_list_contains_decimal(
+ decimal_array,
+ list_values,
+ self.negated,
+ );
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ }
datatype =>
Result::Err(DataFusionError::NotImplemented(format!(
"InList does not support datatype {:?}.",
datatype
@@ -640,13 +745,63 @@ impl PhysicalExpr for InListExpr {
}
}
+type InListCastResult = (Arc<dyn PhysicalExpr>, Vec<Arc<dyn PhysicalExpr>>);
+
/// Creates a unary expression InList
pub fn in_list(
expr: Arc<dyn PhysicalExpr>,
list: Vec<Arc<dyn PhysicalExpr>>,
negated: &bool,
+ input_schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
- Ok(Arc::new(InListExpr::new(expr, list, *negated)))
+ let (cast_expr, cast_list) = in_list_cast(expr, list, input_schema)?;
+ Ok(Arc::new(InListExpr::new(cast_expr, cast_list, *negated)))
+}
+
+fn in_list_cast(
+ expr: Arc<dyn PhysicalExpr>,
+ list: Vec<Arc<dyn PhysicalExpr>>,
+ input_schema: &Schema,
+) -> Result<InListCastResult> {
+ let expr_type = &expr.data_type(input_schema)?;
+ let list_types: Vec<DataType> = list
+ .iter()
+ .map(|list_expr| list_expr.data_type(input_schema).unwrap())
+ .collect();
+ // TODO in the arrow-rs, should support NULL type to Decimal Data type
+ // TODO support in the arrow-rs, NULL value cast to Decimal Value
+ // https://github.com/apache/arrow-datafusion/issues/2759
+ let result_type = get_coerce_type(expr_type, &list_types);
+ match result_type {
+ None => Err(DataFusionError::Internal(format!(
+ "In expr can find the coerced type for {:?} in {:?}",
+ expr_type, list_types
+ ))),
+ Some(data_type) => {
+ // find the coerced type
+ let cast_expr = try_cast(expr, input_schema, data_type.clone())?;
+ let cast_list_expr = list
+ .into_iter()
+ .map(|list_expr| {
+ try_cast(list_expr, input_schema,
data_type.clone()).unwrap()
+ })
+ .collect();
+ Ok((cast_expr, cast_list_expr))
+ }
+ }
+}
+
+fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) ->
Option<DataType> {
+ // get the equal coerced data type
+ list_type
+ .iter()
+ .fold(Some(expr_type.clone()), |left, right_type| {
+ match left {
+ None => None,
+ // TODO refactor a framework to do the data type coercion
+ Some(left_type) => comparison_eq_coercion(&left_type,
right_type),
+ }
+ })
}
#[cfg(test)]
@@ -659,8 +814,8 @@ mod tests {
// applies the in_list expr to an input batch and list
macro_rules! in_list {
- ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr) =>
{{
- let expr = in_list($COL, $LIST, $NEGATED).unwrap();
+ ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr,
$SCHEMA:expr) => {{
+ let expr = in_list($COL, $LIST, $NEGATED, $SCHEMA).unwrap();
let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows());
let result = result
.as_any()
@@ -676,7 +831,7 @@ mod tests {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
let a = StringArray::from(vec![Some("a"), Some("d"), None]);
let col_a = col("a", &schema)?;
- let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()),
vec![Arc::new(a)])?;
// expression: "a in ("a", "b")"
let list = vec![
@@ -688,7 +843,8 @@ mod tests {
list,
&false,
vec![Some(true), Some(false), None],
- col_a.clone()
+ col_a.clone(),
+ &schema
);
// expression: "a not in ("a", "b")"
@@ -701,7 +857,8 @@ mod tests {
list,
&true,
vec![Some(false), Some(true), None],
- col_a.clone()
+ col_a.clone(),
+ &schema
);
// expression: "a not in ("a", "b")"
@@ -715,7 +872,8 @@ mod tests {
list,
&false,
vec![Some(true), None, None],
- col_a.clone()
+ col_a.clone(),
+ &schema
);
// expression: "a not in ("a", "b")"
@@ -729,7 +887,8 @@ mod tests {
list,
&true,
vec![Some(false), None, None],
- col_a.clone()
+ col_a.clone(),
+ &schema
);
Ok(())
@@ -740,7 +899,7 @@ mod tests {
let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
let a = Int64Array::from(vec![Some(0), Some(2), None]);
let col_a = col("a", &schema)?;
- let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()),
vec![Arc::new(a)])?;
// expression: "a in (0, 1)"
let list = vec![
@@ -752,7 +911,8 @@ mod tests {
list,
&false,
vec![Some(true), Some(false), None],
- col_a.clone()
+ col_a.clone(),
+ &schema
);
// expression: "a not in (0, 1)"
@@ -765,35 +925,38 @@ mod tests {
list,
&true,
vec![Some(false), Some(true), None],
- col_a.clone()
+ col_a.clone(),
+ &schema
);
// expression: "a in (0, 1, NULL)"
let list = vec![
lit(ScalarValue::Int64(Some(0))),
lit(ScalarValue::Int64(Some(1))),
- lit(ScalarValue::Utf8(None)),
+ lit(ScalarValue::Null),
];
in_list!(
batch,
list,
&false,
vec![Some(true), None, None],
- col_a.clone()
+ col_a.clone(),
+ &schema
);
// expression: "a not in (0, 1, NULL)"
let list = vec![
lit(ScalarValue::Int64(Some(0))),
lit(ScalarValue::Int64(Some(1))),
- lit(ScalarValue::Utf8(None)),
+ lit(ScalarValue::Null),
];
in_list!(
batch,
list,
&true,
vec![Some(false), None, None],
- col_a.clone()
+ col_a.clone(),
+ &schema
);
Ok(())
@@ -804,7 +967,7 @@ mod tests {
let schema = Schema::new(vec![Field::new("a", DataType::Float64,
true)]);
let a = Float64Array::from(vec![Some(0.0), Some(0.2), None]);
let col_a = col("a", &schema)?;
- let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()),
vec![Arc::new(a)])?;
// expression: "a in (0.0, 0.2)"
let list = vec![
@@ -816,7 +979,8 @@ mod tests {
list,
&false,
vec![Some(true), Some(false), None],
- col_a.clone()
+ col_a.clone(),
+ &schema
);
// expression: "a not in (0.0, 0.2)"
@@ -829,35 +993,38 @@ mod tests {
list,
&true,
vec![Some(false), Some(true), None],
- col_a.clone()
+ col_a.clone(),
+ &schema
);
// expression: "a in (0.0, 0.2, NULL)"
let list = vec![
lit(ScalarValue::Float64(Some(0.0))),
lit(ScalarValue::Float64(Some(0.1))),
- lit(ScalarValue::Utf8(None)),
+ lit(ScalarValue::Null),
];
in_list!(
batch,
list,
&false,
vec![Some(true), None, None],
- col_a.clone()
+ col_a.clone(),
+ &schema
);
// expression: "a not in (0.0, 0.2, NULL)"
let list = vec![
lit(ScalarValue::Float64(Some(0.0))),
lit(ScalarValue::Float64(Some(0.1))),
- lit(ScalarValue::Utf8(None)),
+ lit(ScalarValue::Null),
];
in_list!(
batch,
list,
&true,
vec![Some(false), None, None],
- col_a.clone()
+ col_a.clone(),
+ &schema
);
Ok(())
@@ -868,29 +1035,157 @@ mod tests {
let schema = Schema::new(vec![Field::new("a", DataType::Boolean,
true)]);
let a = BooleanArray::from(vec![Some(true), None]);
let col_a = col("a", &schema)?;
- let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()),
vec![Arc::new(a)])?;
// expression: "a in (true)"
let list = vec![lit(ScalarValue::Boolean(Some(true)))];
- in_list!(batch, list, &false, vec![Some(true), None], col_a.clone());
+ in_list!(
+ batch,
+ list,
+ &false,
+ vec![Some(true), None],
+ col_a.clone(),
+ &schema
+ );
// expression: "a not in (true)"
let list = vec![lit(ScalarValue::Boolean(Some(true)))];
- in_list!(batch, list, &true, vec![Some(false), None], col_a.clone());
+ in_list!(
+ batch,
+ list,
+ &true,
+ vec![Some(false), None],
+ col_a.clone(),
+ &schema
+ );
// expression: "a in (true, NULL)"
let list = vec![
lit(ScalarValue::Boolean(Some(true))),
- lit(ScalarValue::Utf8(None)),
+ lit(ScalarValue::Null),
];
- in_list!(batch, list, &false, vec![Some(true), None], col_a.clone());
+ in_list!(
+ batch,
+ list,
+ &false,
+ vec![Some(true), None],
+ col_a.clone(),
+ &schema
+ );
// expression: "a not in (true, NULL)"
let list = vec![
lit(ScalarValue::Boolean(Some(true))),
- lit(ScalarValue::Utf8(None)),
+ lit(ScalarValue::Null),
+ ];
+ in_list!(
+ batch,
+ list,
+ &true,
+ vec![Some(false), None],
+ col_a.clone(),
+ &schema
+ );
+
+ Ok(())
+ }
+
+ #[test]
+ fn in_list_decimal() -> Result<()> {
+ // Now, we can check the NULL type
+ let schema = Schema::new(vec![Field::new("a", DataType::Decimal(13,
4), true)]);
+ let array = vec![Some(100_0000_i128), None, Some(200_5000_i128)]
+ .into_iter()
+ .collect::<DecimalArray>();
+ let array = array.with_precision_and_scale(13, 4).unwrap();
+ let col_a = col("a", &schema)?;
+ let batch =
+ RecordBatch::try_new(Arc::new(schema.clone()),
vec![Arc::new(array)])?;
+
+ // expression: "a in (100,200), the data type of list is INT32
+ let list = vec![
+ lit(ScalarValue::Int32(Some(100))),
+ lit(ScalarValue::Int32(Some(200))),
+ ];
+ in_list!(
+ batch,
+ list,
+ &false,
+ vec![Some(true), None, Some(false)],
+ col_a.clone(),
+ &schema
+ );
+ // expression: "a not in (100,200)
+ let list = vec![
+ lit(ScalarValue::Int32(Some(100))),
+ lit(ScalarValue::Int32(Some(200))),
+ ];
+ in_list!(
+ batch,
+ list,
+ &true,
+ vec![Some(false), None, Some(true)],
+ col_a.clone(),
+ &schema
+ );
+
+ // expression: "a in (200,NULL), the data type of list is INT32 AND
NULL
+ // TODO support: NULL data type to decimal in arrow-rs
+ // let list = vec![lit(ScalarValue::Int32(Some(100))),
lit(ScalarValue::Null)];
+ // in_list!(batch, list, &false, vec![Some(true), None, Some(false)],
col_a.clone(), &schema);
+
+ // expression: "a in (200.5, 100), the data type of list is FLOAT32
and INT32
+ let list = vec![
+ lit(ScalarValue::Float32(Some(200.50f32))),
+ lit(ScalarValue::Int32(Some(100))),
];
- in_list!(batch, list, &true, vec![Some(false), None], col_a.clone());
+ in_list!(
+ batch,
+ list,
+ &false,
+ vec![Some(true), None, Some(true)],
+ col_a.clone(),
+ &schema
+ );
+
+ // expression: "a not in (200.5, 100), the data type of list is
FLOAT32 and INT32
+ let list = vec![
+ lit(ScalarValue::Float32(Some(200.50f32))),
+ lit(ScalarValue::Int32(Some(101))),
+ ];
+ in_list!(
+ batch,
+ list,
+ &true,
+ vec![Some(true), None, Some(false)],
+ col_a.clone(),
+ &schema
+ );
+
+ // test the optimization: set
+ // expression: "a in (99..300), the data type of list is INT32
+ let list = (99..300)
+ .into_iter()
+ .map(|v| lit(ScalarValue::Int32(Some(v))))
+ .collect::<Vec<_>>();
+
+ in_list!(
+ batch,
+ list.clone(),
+ &false,
+ vec![Some(true), None, Some(false)],
+ col_a.clone(),
+ &schema
+ );
+
+ in_list!(
+ batch,
+ list,
+ &true,
+ vec![Some(false), None, Some(true)],
+ col_a.clone(),
+ &schema
+ );
Ok(())
}
diff --git a/datafusion/physical-expr/src/planner.rs
b/datafusion/physical-expr/src/planner.rs
index 92580fce0..26583cd28 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -294,6 +294,9 @@ pub fn create_physical_expr(
input_schema,
execution_props,
),
+ // TODO refactor the logic of coercion the data type
+ // data type in the `list expr` may be conflict with
`value expr`,
+ // we should not just compare data type between `value
expr` with each `list expr`.
_ => {
let list_expr = create_physical_expr(
expr,
@@ -310,6 +313,8 @@ pub fn create_physical_expr(
&list_expr_data_type,
&value_expr_data_type,
) {
+ // TODO: Can't cast from list type to value
type directly
+ // We should use the coercion rule to get the
common data type
expressions::cast(
list_expr,
input_schema,
@@ -325,7 +330,7 @@ pub fn create_physical_expr(
})
.collect::<Result<Vec<_>>>()?;
- expressions::in_list(value_expr, list_exprs, negated)
+ expressions::in_list(value_expr, list_exprs, negated,
input_schema)
}
},
other => Err(DataFusionError::NotImplemented(format!(