This is an automated email from the ASF dual-hosted git repository.
nevime pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 551ab49 ARROW-10163: [Rust] [DataFusion] Add DictionaryArray coercion
support
551ab49 is described below
commit 551ab4946e5974b44e92861fd17204700770cc09
Author: alamb <[email protected]>
AuthorDate: Sun Oct 18 21:12:48 2020 +0200
ARROW-10163: [Rust] [DataFusion] Add DictionaryArray coercion support
This PR adds:
1. Basic `DictionaryArray` coercion (not cast) support in DataFusion
2. A test in `sql.rs` demonstrating basic operations using DataFusion on
`DictionaryArray` arrays
Note that the performance operating on `DictionaryArrays` is likely to
leave a lot to be desired -- specifically almost any operation will cause the
`DictionaryArray` to get unpacked to a normal array reducing most/all of any
performance gains.
I plan to add additional performance improvements over time -- but I felt
getting queries to run was the first important step.
Closes #8463 from alamb/alamb/ARROW-10159-dictionary-array-coercion-take-3
Authored-by: alamb <[email protected]>
Signed-off-by: Neville Dipale <[email protected]>
---
rust/datafusion/src/physical_plan/expressions.rs | 179 ++++++++++++++++++++---
rust/datafusion/tests/sql.rs | 72 ++++++++-
2 files changed, 226 insertions(+), 25 deletions(-)
diff --git a/rust/datafusion/src/physical_plan/expressions.rs
b/rust/datafusion/src/physical_plan/expressions.rs
index 084f818..e9bbe19 100644
--- a/rust/datafusion/src/physical_plan/expressions.rs
+++ b/rust/datafusion/src/physical_plan/expressions.rs
@@ -1080,7 +1080,40 @@ impl fmt::Display for BinaryExpr {
}
}
-// the type that both lhs and rhs can be casted to for the purpose of a string
computation
+/// Coercion rules for dictionary values (aka the type of the dictionary
itself)
+fn dictionary_value_coercion(
+ lhs_type: &DataType,
+ rhs_type: &DataType,
+) -> Option<DataType> {
+ numerical_coercion(lhs_type, rhs_type).or_else(||
string_coercion(lhs_type, rhs_type))
+}
+
+/// Coercion rules for Dictionaries: the type that both lhs and rhs
+/// can be casted to for the purpose of a computation.
+///
+/// It would likely be preferable to cast primitive values to
+/// dictionaries, and thus avoid unpacking dictionary as well as doing
+/// faster comparisons. However, the arrow compute kernels (e.g. eq)
+/// don't have DictionaryArray support yet, so fall back to unpacking
+/// the dictionaries
+fn dictionary_coercion(lhs_type: &DataType, rhs_type: &DataType) ->
Option<DataType> {
+ match (lhs_type, rhs_type) {
+ (
+ DataType::Dictionary(_lhs_index_type, lhs_value_type),
+ DataType::Dictionary(_rhs_index_type, rhs_value_type),
+ ) => dictionary_value_coercion(lhs_value_type, rhs_value_type),
+ (DataType::Dictionary(_index_type, value_type), _) => {
+ dictionary_value_coercion(value_type, rhs_type)
+ }
+ (_, DataType::Dictionary(_index_type, value_type)) => {
+ dictionary_value_coercion(lhs_type, value_type)
+ }
+ _ => None,
+ }
+}
+
+/// Coercion rules for Strings: the type that both lhs and rhs can be
+/// casted to for the purpose of a string computation
fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) ->
Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
@@ -1092,7 +1125,9 @@ fn string_coercion(lhs_type: &DataType, rhs_type:
&DataType) -> Option<DataType>
}
}
-/// coercion rule for numerical types
+/// Coercion rule for numerical types: The type that both lhs and rhs
+/// can be casted to for numerical calculation, while maintaining
+/// maximum precision
pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) ->
Option<DataType> {
use arrow::datatypes::DataType::*;
@@ -1150,6 +1185,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType)
-> Option<DataType> {
return Some(lhs_type.clone());
}
numerical_coercion(lhs_type, rhs_type)
+ .or_else(|| dictionary_coercion(lhs_type, rhs_type))
}
// coercion rules that assume an ordered set, such as "less than".
@@ -1160,16 +1196,13 @@ fn order_coercion(lhs_type: &DataType, rhs_type:
&DataType) -> Option<DataType>
return Some(lhs_type.clone());
}
- match numerical_coercion(lhs_type, rhs_type) {
- None => {
- // strings are naturally ordered, and thus ordering can be applied
to them.
- string_coercion(lhs_type, rhs_type)
- }
- t => t,
- }
+ numerical_coercion(lhs_type, rhs_type)
+ .or_else(|| string_coercion(lhs_type, rhs_type))
+ .or_else(|| dictionary_coercion(lhs_type, rhs_type))
}
-/// coercion rules for all binary operators
+/// Coercion rules for all binary operators. Returns the output type
+/// of applying `op` to an argument of `lhs_type` and `rhs_type`.
fn common_binary_type(
lhs_type: &DataType,
op: &Operator,
@@ -1526,8 +1559,8 @@ impl PhysicalExpr for CastExpr {
}
}
-/// Returns a physical cast operation that casts `expr` to `cast_type`
-/// if casting is needed.
+/// Return a PhysicalExpression representing `expr` casted to
+/// `cast_type`, if any casting is needed.
///
/// Note that such casts may lose type information
pub fn cast(
@@ -1665,11 +1698,14 @@ impl PhysicalSortExpr {
mod tests {
use super::*;
use crate::error::Result;
- use arrow::array::{
- LargeStringArray, PrimitiveArray, PrimitiveArrayOps, StringArray,
- Time64NanosecondArray,
- };
use arrow::datatypes::*;
+ use arrow::{
+ array::{
+ LargeStringArray, PrimitiveArray, PrimitiveArrayOps,
PrimitiveBuilder,
+ StringArray, StringDictionaryBuilder, Time64NanosecondArray,
+ },
+ util::display::array_value_to_string,
+ };
// Create a binary expression without coercion. Used here when we do not
want to coerce the expressions
// to valid types. Usage can result in an execution (after plan) error.
@@ -1772,11 +1808,13 @@ mod tests {
// runs an end-to-end test of physical type coercion:
// 1. construct a record batch with two columns of type A and B
+ // (*_ARRAY is the Rust Arrow array type, and *_TYPE is the DataType of
the elements)
// 2. construct a physical expression of A OP B
// 3. evaluate the expression
// 4. verify that the resulting expression is of type C
+ // 5. verify that the results of evaluation are $VEC
macro_rules! test_coercion {
- ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident,
$B_TYPE:expr, $B_VEC:expr, $OP:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr)
=> {{
+ ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident,
$B_TYPE:expr, $B_VEC:expr, $OP:expr, $C_ARRAY:ident, $C_TYPE:expr, $VEC:expr)
=> {{
let schema = Schema::new(vec![
Field::new("a", $A_TYPE, false),
Field::new("b", $B_TYPE, false),
@@ -1792,18 +1830,18 @@ mod tests {
let expression = binary(col("a"), $OP, col("b"), &schema)?;
// verify that the expression's type is correct
- assert_eq!(expression.data_type(&schema)?, $TYPE);
+ assert_eq!(expression.data_type(&schema)?, $C_TYPE);
// compute
let result = expression.evaluate(&batch)?;
// verify that the array's data_type is correct
- assert_eq!(*result.data_type(), $TYPE);
+ assert_eq!(*result.data_type(), $C_TYPE);
// verify that the data itself is downcastable
let result = result
.as_any()
- .downcast_ref::<$TYPEARRAY>()
+ .downcast_ref::<$C_ARRAY>()
.expect("failed to downcast");
// verify that the result itself is correct
for (i, x) in $VEC.iter().enumerate() {
@@ -1878,6 +1916,107 @@ mod tests {
}
#[test]
+ fn test_dictionary_type_coersion() -> Result<()> {
+ use DataType::*;
+
+ // TODO: In the future, this would ideally return Dictionary types and
avoid unpacking
+ let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32));
+ let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16));
+ assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Int32));
+
+ let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
+ let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16));
+ assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), None);
+
+ let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
+ let rhs_type = Utf8;
+ assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8));
+
+ let lhs_type = Utf8;
+ let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
+ assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8));
+
+ Ok(())
+ }
+
+ // Note it would be nice to use the same test_coercion macro as
+ // above, but sadly the type of the values of the dictionary are
+ // not encoded in the rust type of the DictionaryArray. Thus there
+ // is no way at the time of this writing to create a dictionary
+ // array using the `From` trait
+ #[test]
+ fn test_dictionary_type_to_array_coersion() -> Result<()> {
+ // Test string a string dictionary
+ let dict_type =
+ DataType::Dictionary(Box::new(DataType::Int32),
Box::new(DataType::Utf8));
+ let string_type = DataType::Utf8;
+
+ // build dictionary
+ let keys_builder = PrimitiveBuilder::<Int32Type>::new(10);
+ let values_builder = StringBuilder::new(10);
+ let mut dict_builder = StringDictionaryBuilder::new(keys_builder,
values_builder);
+
+ dict_builder.append("one")?;
+ dict_builder.append_null()?;
+ dict_builder.append("three")?;
+ dict_builder.append("four")?;
+ let dict_array = dict_builder.finish();
+
+ let str_array =
+ StringArray::from(vec![Some("not one"), Some("two"), None,
Some("four")]);
+
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("dict", dict_type.clone(), true),
+ Field::new("str", string_type.clone(), true),
+ ]));
+
+ let batch = RecordBatch::try_new(
+ schema.clone(),
+ vec![Arc::new(dict_array), Arc::new(str_array)],
+ )?;
+
+ let expected = "false\n\n\ntrue";
+
+ // Test 1: dict = str
+
+ // verify that we can construct the expression
+ let expression = binary(col("dict"), Operator::Eq, col("str"),
&schema)?;
+ assert_eq!(expression.data_type(&schema)?, DataType::Boolean);
+
+ // evaluate and verify the result type matched
+ let result = expression.evaluate(&batch)?;
+ assert_eq!(result.data_type(), &DataType::Boolean);
+
+ // verify that the result itself is correct
+ assert_eq!(expected, array_to_string(&result)?);
+
+ // Test 2: now test the other direction
+ // str = dict
+
+ // verify that we can construct the expression
+ let expression = binary(col("str"), Operator::Eq, col("dict"),
&schema)?;
+ assert_eq!(expression.data_type(&schema)?, DataType::Boolean);
+
+ // evaluate and verify the result type matched
+ let result = expression.evaluate(&batch)?;
+ assert_eq!(result.data_type(), &DataType::Boolean);
+
+ // verify that the result itself is correct
+ assert_eq!(expected, array_to_string(&result)?);
+
+ Ok(())
+ }
+
+ // Convert the array to a newline delimited string of pretty printed values
+ fn array_to_string(array: &ArrayRef) -> Result<String> {
+ let s = (0..array.len())
+ .map(|i| array_value_to_string(array, i))
+ .collect::<std::result::Result<Vec<_>,
arrow::error::ArrowError>>()?
+ .join("\n");
+ Ok(s)
+ }
+
+ #[test]
fn test_coersion_error() -> Result<()> {
let expr =
common_binary_type(&DataType::Float32, &Operator::Plus,
&DataType::Utf8);
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index 52027a4..5bf4525 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -21,8 +21,8 @@ use std::sync::Arc;
extern crate arrow;
extern crate datafusion;
-use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::TimeUnit};
+use arrow::{datatypes::Int32Type, record_batch::RecordBatch};
use arrow::{
datatypes::{DataType, Field, Schema, SchemaRef},
util::display::array_value_to_string,
@@ -930,14 +930,20 @@ fn register_alltypes_parquet(ctx: &mut ExecutionContext) {
/// Execute query and return result set as 2-d table of Vecs
/// `result[row][column]`
async fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec<Vec<String>> {
- let plan = ctx.create_logical_plan(&sql).unwrap();
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx.create_logical_plan(&sql).expect(&msg);
let logical_schema = plan.schema();
- let plan = ctx.optimize(&plan).unwrap();
+
+ let msg = format!("Optimizing logical plan for '{}': {:?}", sql, plan);
+ let plan = ctx.optimize(&plan).expect(&msg);
let optimized_logical_schema = plan.schema();
- let plan = ctx.create_physical_plan(&plan).unwrap();
+
+ let msg = format!("Creating physical plan for '{}': {:?}", sql, plan);
+ let plan = ctx.create_physical_plan(&plan).expect(&msg);
let physical_schema = plan.schema();
- let results = ctx.collect(plan).await.unwrap();
+ let msg = format!("Executing physical plan for '{}': {:?}", sql, plan);
+ let results = ctx.collect(plan).await.expect(&msg);
assert_eq!(logical_schema.as_ref(), optimized_logical_schema.as_ref());
assert_eq!(logical_schema.as_ref(), physical_schema.as_ref());
@@ -1238,3 +1244,59 @@ async fn query_count_distinct() -> Result<()> {
assert_eq!(expected, actual);
Ok(())
}
+
+#[tokio::test]
+async fn query_on_string_dictionary() -> Result<()> {
+ // Test to ensure DataFusion can operate on dictionary types
+ // Use StringDictionary (32 bit indexes = keys)
+ let field_type =
+ DataType::Dictionary(Box::new(DataType::Int32),
Box::new(DataType::Utf8));
+ let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type,
true)]));
+
+ let keys_builder = PrimitiveBuilder::<Int32Type>::new(10);
+ let values_builder = StringBuilder::new(10);
+ let mut builder = StringDictionaryBuilder::new(keys_builder,
values_builder);
+
+ builder.append("one")?;
+ builder.append_null()?;
+ builder.append("three")?;
+ let array = Arc::new(builder.finish());
+
+ let data = RecordBatch::try_new(schema.clone(), vec![array])?;
+
+ let table = MemTable::new(schema, vec![vec![data]])?;
+ let mut ctx = ExecutionContext::new();
+ ctx.register_table("test", Box::new(table));
+
+ // Basic SELECT
+ let sql = "SELECT * FROM test";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![vec!["one"], vec!["NULL"], vec!["three"]];
+ assert_eq!(expected, actual);
+
+ // basic filtering
+ let sql = "SELECT * FROM test WHERE d1 IS NOT NULL";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![vec!["one"], vec!["three"]];
+ assert_eq!(expected, actual);
+
+ // filtering with constant
+ let sql = "SELECT * FROM test WHERE d1 = 'three'";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![vec!["three"]];
+ assert_eq!(expected, actual);
+
+ // Expression evaluation
+ let sql = "SELECT concat(d1, '-foo') FROM test";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![vec!["one-foo"], vec!["NULL"], vec!["three-foo"]];
+ assert_eq!(expected, actual);
+
+ // aggregation
+ let sql = "SELECT COUNT(d1) FROM test";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![vec!["2"]];
+ assert_eq!(expected, actual);
+
+ Ok(())
+}