This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new e980ef8  ARROW-10817: [Rust] [DataFusion] Implement TypedString and 
DATE coercion
e980ef8 is described below

commit e980ef843922d8a2a07f0150b4a4ca54b23f280a
Author: Mike Seddon <[email protected]>
AuthorDate: Sun Dec 13 07:34:46 2020 -0500

    ARROW-10817: [Rust] [DataFusion] Implement TypedString and DATE coercion
    
    This PR adds support for what the `sqlparser` crate calls `TypedString` 
which is basically syntactic sugar for an inline-cast. As this was an effort to 
get the `TPC-H` queries behaving correctly I then went a step further and added 
support for `Date` (temporal) coercion. I can split this PR if needed.
    
    ```sql
    where
        l_shipdate <= date '1998-09-02'
    ```
    
    is equivalent to
    
    ```sql
    where
        l_shipdate <= CAST('1998-09-02' AS DATE)
    ```
    
    FYI I am planning to tackle `INTERVAL` next.
    
    Closes #8892 from seddonm1/typed_string
    
    Authored-by: Mike Seddon <[email protected]>
    Signed-off-by: Andrew Lamb <[email protected]>
---
 rust/arrow/src/compute/kernels/cast.rs           | 46 +++++++++++++++
 rust/benchmarks/src/bin/tpch.rs                  | 75 +++++++++++++++++-------
 rust/datafusion/src/physical_plan/expressions.rs | 72 ++++++++++++++++++++++-
 rust/datafusion/src/sql/planner.rs               | 16 +++++
 4 files changed, 185 insertions(+), 24 deletions(-)

diff --git a/rust/arrow/src/compute/kernels/cast.rs 
b/rust/arrow/src/compute/kernels/cast.rs
index 7b0c6bc..70acf5a 100644
--- a/rust/arrow/src/compute/kernels/cast.rs
+++ b/rust/arrow/src/compute/kernels/cast.rs
@@ -72,6 +72,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: 
&DataType) -> bool {
         (Boolean, _) => DataType::is_numeric(to_type) || to_type == &Utf8,
 
         (Utf8, Date32(DateUnit::Day)) => true,
+        (Utf8, Date64(DateUnit::Millisecond)) => true,
         (Utf8, _) => DataType::is_numeric(to_type),
         (_, Utf8) => DataType::is_numeric(from_type) || from_type == &Binary,
 
@@ -399,6 +400,26 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> 
Result<ArrayRef> {
                 }
                 Ok(Arc::new(builder.finish()) as ArrayRef)
             }
+            Date64(DateUnit::Millisecond) => {
+                use chrono::{NaiveDate, NaiveTime};
+                let zero_time = NaiveTime::from_hms(0, 0, 0);
+                let string_array = 
array.as_any().downcast_ref::<StringArray>().unwrap();
+                let mut builder = 
PrimitiveBuilder::<Date64Type>::new(string_array.len());
+                for i in 0..string_array.len() {
+                    if string_array.is_null(i) {
+                        builder.append_null()?;
+                    } else {
+                        match NaiveDate::parse_from_str(string_array.value(i), 
"%Y-%m-%d")
+                        {
+                            Ok(date) => builder.append_value(
+                                date.and_time(zero_time).timestamp_millis() as 
i64,
+                            )?,
+                            Err(_) => builder.append_null()?, // not a valid 
date
+                        };
+                    }
+                }
+                Ok(Arc::new(builder.finish()) as ArrayRef)
+            }
             _ => Err(ArrowError::ComputeError(format!(
                 "Casting from {:?} to {:?} not supported",
                 from_type, to_type,
@@ -2781,6 +2802,31 @@ mod tests {
     }
 
     #[test]
+    fn test_cast_utf8_to_date64() {
+        let a = StringArray::from(vec![
+            "2000-01-01",          // valid date with leading 0s
+            "2000-2-2",            // valid date without leading 0s
+            "2000-00-00",          // invalid month and day
+            "2000-01-01T12:00:00", // date + time is invalid
+            "2000",                // just a year is invalid
+        ]);
+        let array = Arc::new(a) as ArrayRef;
+        let b = cast(&array, 
&DataType::Date64(DateUnit::Millisecond)).unwrap();
+        let c = b.as_any().downcast_ref::<Date64Array>().unwrap();
+
+        // test valid inputs
+        assert_eq!(true, c.is_valid(0)); // "2000-01-01"
+        assert_eq!(946684800000, c.value(0));
+        assert_eq!(true, c.is_valid(1)); // "2000-2-2"
+        assert_eq!(949449600000, c.value(1));
+
+        // test invalid inputs
+        assert_eq!(false, c.is_valid(2)); // "2000-00-00"
+        assert_eq!(false, c.is_valid(3)); // "2000-01-01T12:00:00"
+        assert_eq!(false, c.is_valid(4)); // "2000"
+    }
+
+    #[test]
     fn test_can_cast_types() {
         // this function attempts to ensure that can_cast_types stays
         // in sync with cast.  It simply tries all combinations of
diff --git a/rust/benchmarks/src/bin/tpch.rs b/rust/benchmarks/src/bin/tpch.rs
index 2ed9ab0..cd3d9d8 100644
--- a/rust/benchmarks/src/bin/tpch.rs
+++ b/rust/benchmarks/src/bin/tpch.rs
@@ -21,7 +21,7 @@ use std::path::PathBuf;
 use std::sync::Arc;
 use std::time::Instant;
 
-use arrow::datatypes::{DataType, Field, Schema};
+use arrow::datatypes::{DataType, DateUnit, Field, Schema};
 use arrow::util::pretty;
 use datafusion::datasource::parquet::ParquetTable;
 use datafusion::datasource::{CsvFile, MemTable, TableProvider};
@@ -187,7 +187,7 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: 
usize) -> Result<Logic
             from
                 lineitem
             where
-                l_shipdate <= '1998-09-02'
+                l_shipdate <= date '1998-09-02'
             group by
                 l_returnflag,
                 l_linestatus
@@ -256,8 +256,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: 
usize) -> Result<Logic
                 c_mktsegment = 'BUILDING'
                 and c_custkey = o_custkey
                 and l_orderkey = o_orderkey
-                and o_orderdate < '1995-03-15'
-                and l_shipdate > '1995-03-15'
+                and o_orderdate < date '1995-03-15'
+                and l_shipdate > date '1995-03-15'
             group by
                 l_orderkey,
                 o_orderdate,
@@ -337,8 +337,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: 
usize) -> Result<Logic
                 and s_nationkey = n_nationkey
                 and n_regionkey = r_regionkey
                 and r_name = 'ASIA'
-                and o_orderdate >= '1994-01-01'
-                and o_orderdate < '1995-01-01'
+                and o_orderdate >= date '1994-01-01'
+                and o_orderdate < date '1995-01-01'
             group by
                 n_name
             order by
@@ -363,9 +363,9 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: 
usize) -> Result<Logic
             from
                 lineitem
             where
-                l_shipdate >= '1994-01-01'
-                and l_shipdate < '1995-01-01'
-                and l_discount between 0.06 - 0.01 and 0.06 + 0.01
+                l_shipdate >= date '1994-01-01'
+                and l_shipdate < date '1995-01-01'
+                and l_discount > 0.06 - 0.01 and l_discount < 0.06 + 0.01
                 and l_quantity < 24;"
         ),
 
@@ -399,7 +399,7 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: 
usize) -> Result<Logic
                             (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY')
                             or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE')
                         )
-                        and l_shipdate > '1995-01-01' and l_shipdate < 
'1996-12-31'
+                        and l_shipdate > date '1995-01-01' and l_shipdate < 
date '1996-12-31'
                 ) as shipping
             group by
                 supp_nation,
@@ -442,7 +442,7 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: 
usize) -> Result<Logic
                         and n1.n_regionkey = r_regionkey
                         and r_name = 'AMERICA'
                         and s_nationkey = n2.n_nationkey
-                        and o_orderdate between '1995-01-01' and '1996-12-31'
+                        and o_orderdate between date '1995-01-01' and date 
'1996-12-31'
                         and p_type = 'ECONOMY ANODIZED STEEL'
                 ) as all_nations
             group by
@@ -486,6 +486,39 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: 
usize) -> Result<Logic
                 o_year desc;"
         ),
 
+        // 10 => ctx.create_logical_plan(
+        //     "select
+        //         c_custkey,
+        //         c_name,
+        //         sum(l_extendedprice * (1 - l_discount)) as revenue,
+        //         c_acctbal,
+        //         n_name,
+        //         c_address,
+        //         c_phone,
+        //         c_comment
+        //     from
+        //         customer,
+        //         orders,
+        //         lineitem,
+        //         nation
+        //     where
+        //         c_custkey = o_custkey
+        //         and l_orderkey = o_orderkey
+        //         and o_orderdate >= date '1993-10-01'
+        //         and o_orderdate < date '1993-10-01' + interval '3' month
+        //         and l_returnflag = 'R'
+        //         and c_nationkey = n_nationkey
+        //     group by
+        //         c_custkey,
+        //         c_name,
+        //         c_acctbal,
+        //         c_phone,
+        //         n_name,
+        //         c_address,
+        //         c_comment
+        //     order by
+        //         revenue desc;"
+        // ),
         10 => ctx.create_logical_plan(
             "select
                 c_custkey,
@@ -504,8 +537,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: 
usize) -> Result<Logic
             where
                 c_custkey = o_custkey
                 and l_orderkey = o_orderkey
-                and o_orderdate >= '1993-10-01'
-                and o_orderdate < '1994-01-01'
+                and o_orderdate >= date '1993-10-01'
+                and o_orderdate < date '1994-01-01'
                 and l_returnflag = 'R'
                 and c_nationkey = n_nationkey
             group by
@@ -606,8 +639,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: 
usize) -> Result<Logic
                 (l_shipmode = 'MAIL' or l_shipmode = 'SHIP')
                 and l_commitdate < l_receiptdate
                 and l_shipdate < l_commitdate
-                and l_receiptdate >= '1994-01-01'
-                and l_receiptdate < '1995-01-01'
+                and l_receiptdate >= date '1994-01-01'
+                and l_receiptdate < date '1995-01-01'
             group by
                 l_shipmode
             order by
@@ -649,8 +682,8 @@ fn create_logical_plan(ctx: &mut ExecutionContext, query: 
usize) -> Result<Logic
                 part
             where
                 l_partkey = p_partkey
-                and l_shipdate >= '1995-09-01'
-                and l_shipdate < '1995-10-01';"
+                and l_shipdate >= date '1995-09-01'
+                and l_shipdate < date '1995-10-01';"
         ),
 
         15 => ctx.create_logical_plan(
@@ -1072,7 +1105,7 @@ fn get_schema(table: &str) -> Schema {
             Field::new("o_custkey", DataType::UInt32, false),
             Field::new("o_orderstatus", DataType::Utf8, false),
             Field::new("o_totalprice", DataType::Float64, false), // decimal
-            Field::new("o_orderdate", DataType::Utf8, false),
+            Field::new("o_orderdate", DataType::Date32(DateUnit::Day), false),
             Field::new("o_orderpriority", DataType::Utf8, false),
             Field::new("o_clerk", DataType::Utf8, false),
             Field::new("o_shippriority", DataType::UInt32, false),
@@ -1090,9 +1123,9 @@ fn get_schema(table: &str) -> Schema {
             Field::new("l_tax", DataType::Float64, false),      // decimal
             Field::new("l_returnflag", DataType::Utf8, false),
             Field::new("l_linestatus", DataType::Utf8, false),
-            Field::new("l_shipdate", DataType::Utf8, false),
-            Field::new("l_commitdate", DataType::Utf8, false),
-            Field::new("l_receiptdate", DataType::Utf8, false),
+            Field::new("l_shipdate", DataType::Date32(DateUnit::Day), false),
+            Field::new("l_commitdate", DataType::Date32(DateUnit::Day), false),
+            Field::new("l_receiptdate", DataType::Date32(DateUnit::Day), 
false),
             Field::new("l_shipinstruct", DataType::Utf8, false),
             Field::new("l_shipmode", DataType::Utf8, false),
             Field::new("l_comment", DataType::Utf8, false),
diff --git a/rust/datafusion/src/physical_plan/expressions.rs 
b/rust/datafusion/src/physical_plan/expressions.rs
index 79045d9..ffac95b 100644
--- a/rust/datafusion/src/physical_plan/expressions.rs
+++ b/rust/datafusion/src/physical_plan/expressions.rs
@@ -48,9 +48,9 @@ use arrow::datatypes::{DataType, DateUnit, Schema, TimeUnit};
 use arrow::record_batch::RecordBatch;
 use arrow::{
     array::{
-        ArrayRef, BooleanArray, Date32Array, Float32Array, Float64Array, 
Int16Array,
-        Int32Array, Int64Array, Int8Array, StringArray, 
TimestampNanosecondArray,
-        UInt16Array, UInt32Array, UInt64Array, UInt8Array,
+        ArrayRef, BooleanArray, Date32Array, Date64Array, Float32Array, 
Float64Array,
+        Int16Array, Int32Array, Int64Array, Int8Array, StringArray,
+        TimestampNanosecondArray, UInt16Array, UInt32Array, UInt64Array, 
UInt8Array,
     },
     datatypes::Field,
 };
@@ -1135,6 +1135,9 @@ macro_rules! binary_array_op {
             DataType::Date32(DateUnit::Day) => {
                 compute_op!($LEFT, $RIGHT, $OP, Date32Array)
             }
+            DataType::Date64(DateUnit::Millisecond) => {
+                compute_op!($LEFT, $RIGHT, $OP, Date64Array)
+            }
             other => Err(DataFusionError::Internal(format!(
                 "Unsupported data type {:?}",
                 other
@@ -1227,6 +1230,19 @@ fn string_coercion(lhs_type: &DataType, rhs_type: 
&DataType) -> Option<DataType>
     }
 }
 
+/// Coercion rules for Temporal columns: the type that both lhs and rhs can be
+/// casted to for the purpose of a date computation
+fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> 
Option<DataType> {
+    use arrow::datatypes::DataType::*;
+    match (lhs_type, rhs_type) {
+        (Utf8, Date32(DateUnit::Day)) => Some(Date32(DateUnit::Day)),
+        (Date32(DateUnit::Day), Utf8) => Some(Date32(DateUnit::Day)),
+        (Utf8, Date64(DateUnit::Millisecond)) => 
Some(Date64(DateUnit::Millisecond)),
+        (Date64(DateUnit::Millisecond), Utf8) => 
Some(Date64(DateUnit::Millisecond)),
+        _ => None,
+    }
+}
+
 /// Coercion rule for numerical types: The type that both lhs and rhs
 /// can be casted to for numerical calculation, while maintaining
 /// maximum precision
@@ -1288,6 +1304,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) 
-> Option<DataType> {
     }
     numerical_coercion(lhs_type, rhs_type)
         .or_else(|| dictionary_coercion(lhs_type, rhs_type))
+        .or_else(|| temporal_coercion(lhs_type, rhs_type))
 }
 
 // coercion rules that assume an ordered set, such as "less than".
@@ -1301,6 +1318,7 @@ fn order_coercion(lhs_type: &DataType, rhs_type: 
&DataType) -> Option<DataType>
     numerical_coercion(lhs_type, rhs_type)
         .or_else(|| string_coercion(lhs_type, rhs_type))
         .or_else(|| dictionary_coercion(lhs_type, rhs_type))
+        .or_else(|| temporal_coercion(lhs_type, rhs_type))
 }
 
 /// Coercion rules for all binary operators. Returns the output type
@@ -2638,6 +2656,54 @@ mod tests {
             DataType::Boolean,
             vec![true, false]
         );
+        test_coercion!(
+            StringArray,
+            DataType::Utf8,
+            vec!["1994-12-13", "1995-01-26"],
+            Date32Array,
+            DataType::Date32(DateUnit::Day),
+            vec![9112, 9156],
+            Operator::Eq,
+            BooleanArray,
+            DataType::Boolean,
+            vec![true, true]
+        );
+        test_coercion!(
+            StringArray,
+            DataType::Utf8,
+            vec!["1994-12-13", "1995-01-26"],
+            Date32Array,
+            DataType::Date32(DateUnit::Day),
+            vec![9113, 9154],
+            Operator::Lt,
+            BooleanArray,
+            DataType::Boolean,
+            vec![true, false]
+        );
+        test_coercion!(
+            StringArray,
+            DataType::Utf8,
+            vec!["1994-12-13", "1995-01-26"],
+            Date64Array,
+            DataType::Date64(DateUnit::Millisecond),
+            vec![787276800000, 791078400000],
+            Operator::Eq,
+            BooleanArray,
+            DataType::Boolean,
+            vec![true, true]
+        );
+        test_coercion!(
+            StringArray,
+            DataType::Utf8,
+            vec!["1994-12-13", "1995-01-26"],
+            Date64Array,
+            DataType::Date64(DateUnit::Millisecond),
+            vec![787276800001, 791078399999],
+            Operator::Lt,
+            BooleanArray,
+            DataType::Boolean,
+            vec![true, false]
+        );
         Ok(())
     }
 
diff --git a/rust/datafusion/src/sql/planner.rs 
b/rust/datafusion/src/sql/planner.rs
index 562e580..eb9a1a5 100644
--- a/rust/datafusion/src/sql/planner.rs
+++ b/rust/datafusion/src/sql/planner.rs
@@ -629,6 +629,14 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> {
                 data_type: convert_data_type(data_type)?,
             }),
 
+            SQLExpr::TypedString {
+                ref data_type,
+                ref value,
+            } => Ok(Expr::Cast {
+                expr: Box::new(lit(&**value)),
+                data_type: convert_data_type(data_type)?,
+            }),
+
             SQLExpr::IsNull(ref expr) => {
                 Ok(Expr::IsNull(Box::new(self.sql_to_rex(expr, schema)?)))
             }
@@ -1311,6 +1319,14 @@ mod tests {
         quick_test(sql, expected);
     }
 
+    #[test]
+    fn select_typedstring() {
+        let sql = "SELECT date '2020-12-10' AS date FROM person";
+        let expected = "Projection: CAST(Utf8(\"2020-12-10\") AS Date32(Day)) 
AS date\
+            \n  TableScan: person projection=None";
+        quick_test(sql, expected);
+    }
+
     fn logical_plan(sql: &str) -> Result<LogicalPlan> {
         let planner = SqlToRel::new(&MockSchemaProvider {});
         let result = DFParser::parse_sql(&sql);

Reply via email to