This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 6167ce915e feat: add substrait support for Interval types and literals 
(#10646)
6167ce915e is described below

commit 6167ce915e1768ea1951557118068bd68ce4aac3
Author: Ruihang Xia <[email protected]>
AuthorDate: Mon May 27 03:04:50 2024 +0800

    feat: add substrait support for Interval types and literals (#10646)
    
    * feat: support interval types
    
    Signed-off-by: Ruihang Xia <[email protected]>
    
    * impl literals
    
    Signed-off-by: Ruihang Xia <[email protected]>
    
    * fix deadlink in doc
    
    Signed-off-by: Ruihang Xia <[email protected]>
    
    ---------
    
    Signed-off-by: Ruihang Xia <[email protected]>
---
 datafusion/substrait/src/logical_plan/consumer.rs  |  76 ++++++++++++-
 datafusion/substrait/src/logical_plan/producer.rs  | 125 ++++++++++++++++++++-
 datafusion/substrait/src/variation_const.rs        |  56 +++++++++
 .../tests/cases/roundtrip_logical_plan.rs          |  26 ++++-
 4 files changed, 273 insertions(+), 10 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index 7e8a0cadb5..d6c60ebdde 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -16,7 +16,7 @@
 // under the License.
 
 use async_recursion::async_recursion;
-use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
+use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
 use datafusion::common::{
     not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, 
DFSchemaRef,
 };
@@ -39,6 +39,7 @@ use datafusion::{
     scalar::ScalarValue,
 };
 use substrait::proto::exchange_rel::ExchangeKind;
+use substrait::proto::expression::literal::user_defined::Val;
 use substrait::proto::expression::subquery::SubqueryType;
 use substrait::proto::expression::{FieldReference, Literal, ScalarFunction};
 use substrait::proto::{
@@ -71,9 +72,10 @@ use std::sync::Arc;
 
 use crate::variation_const::{
     DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, 
DECIMAL_256_TYPE_REF,
-    DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF,
-    TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, 
TIMESTAMP_NANO_TYPE_REF,
-    TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
+    DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, INTERVAL_DAY_TIME_TYPE_REF,
+    INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF,
+    LARGE_CONTAINER_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF, 
TIMESTAMP_MILLI_TYPE_REF,
+    TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF, 
UNSIGNED_INTEGER_TYPE_REF,
 };
 
 enum ScalarFunctionType {
@@ -1162,6 +1164,24 @@ pub(crate) fn from_substrait_type(dt: 
&substrait::proto::Type) -> Result<DataTyp
                     "Unsupported Substrait type variation {v} of type 
{s_kind:?}"
                 ),
             },
+            r#type::Kind::UserDefined(u) => {
+                match u.type_reference {
+                    INTERVAL_YEAR_MONTH_TYPE_REF => {
+                        Ok(DataType::Interval(IntervalUnit::YearMonth))
+                    }
+                    INTERVAL_DAY_TIME_TYPE_REF => {
+                        Ok(DataType::Interval(IntervalUnit::DayTime))
+                    }
+                    INTERVAL_MONTH_DAY_NANO_TYPE_REF => {
+                        Ok(DataType::Interval(IntervalUnit::MonthDayNano))
+                    }
+                    _ => not_impl_err!(
+                        "Unsupported Substrait user defined type with ref {} 
and variation {}",
+                        u.type_reference,
+                        u.type_variation_reference
+                    ),
+                }
+            },
             r#type::Kind::Struct(s) => {
                 let mut fields = vec![];
                 for (i, f) in s.types.iter().enumerate() {
@@ -1387,6 +1407,54 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> 
Result<ScalarValue> {
             builder.build()?
         }
         Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
+        Some(LiteralType::UserDefined(user_defined)) => {
+            match user_defined.type_reference {
+                INTERVAL_YEAR_MONTH_TYPE_REF => {
+                    let Some(Val::Value(raw_val)) = user_defined.val.as_ref() 
else {
+                        return substrait_err!("Interval year month value is 
empty");
+                    };
+                    let value_slice: [u8; 4] =
+                        raw_val.value.clone().try_into().map_err(|_| {
+                            substrait_datafusion_err!(
+                                "Failed to parse interval year month value"
+                            )
+                        })?;
+                    
ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes(value_slice)))
+                }
+                INTERVAL_DAY_TIME_TYPE_REF => {
+                    let Some(Val::Value(raw_val)) = user_defined.val.as_ref() 
else {
+                        return substrait_err!("Interval day time value is 
empty");
+                    };
+                    let value_slice: [u8; 8] =
+                        raw_val.value.clone().try_into().map_err(|_| {
+                            substrait_datafusion_err!(
+                                "Failed to parse interval day time value"
+                            )
+                        })?;
+                    
ScalarValue::IntervalDayTime(Some(i64::from_le_bytes(value_slice)))
+                }
+                INTERVAL_MONTH_DAY_NANO_TYPE_REF => {
+                    let Some(Val::Value(raw_val)) = user_defined.val.as_ref() 
else {
+                        return substrait_err!("Interval month day nano value 
is empty");
+                    };
+                    let value_slice: [u8; 16] =
+                        raw_val.value.clone().try_into().map_err(|_| {
+                            substrait_datafusion_err!(
+                                "Failed to parse interval month day nano value"
+                            )
+                        })?;
+                    ScalarValue::IntervalMonthDayNano(Some(i128::from_le_bytes(
+                        value_slice,
+                    )))
+                }
+                _ => {
+                    return not_impl_err!(
+                        "Unsupported Substrait user defined type with ref {}",
+                        user_defined.type_reference
+                    )
+                }
+            }
+        }
         _ => return not_impl_err!("Unsupported literal_type: {:?}", 
lit.literal_type),
     };
 
diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index c0aac0c0a4..400609ff14 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -19,6 +19,7 @@ use std::collections::HashMap;
 use std::ops::Deref;
 use std::sync::Arc;
 
+use datafusion::arrow::datatypes::IntervalUnit;
 use datafusion::logical_expr::{
     CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits,
 };
@@ -43,9 +44,12 @@ use datafusion::logical_expr::{expr, Between, 
JoinConstraint, LogicalPlan, Opera
 use datafusion::prelude::Expr;
 use prost_types::Any as ProtoAny;
 use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
+use substrait::proto::expression::literal::user_defined::Val;
+use substrait::proto::expression::literal::UserDefined;
 use substrait::proto::expression::literal::{List, Struct};
 use substrait::proto::expression::subquery::InPredicate;
 use substrait::proto::expression::window_function::BoundsType;
+use substrait::proto::r#type::{parameter, Parameter};
 use substrait::proto::{CrossRel, ExchangeRel};
 use substrait::{
     proto::{
@@ -84,9 +88,12 @@ use substrait::{
 
 use crate::variation_const::{
     DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, 
DECIMAL_256_TYPE_REF,
-    DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF,
-    TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, 
TIMESTAMP_NANO_TYPE_REF,
-    TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
+    DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, INTERVAL_DAY_TIME_TYPE_REF,
+    INTERVAL_DAY_TIME_TYPE_URL, INTERVAL_MONTH_DAY_NANO_TYPE_REF,
+    INTERVAL_MONTH_DAY_NANO_TYPE_URL, INTERVAL_YEAR_MONTH_TYPE_REF,
+    INTERVAL_YEAR_MONTH_TYPE_URL, LARGE_CONTAINER_TYPE_REF, 
TIMESTAMP_MICRO_TYPE_REF,
+    TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF, 
TIMESTAMP_SECOND_TYPE_REF,
+    UNSIGNED_INTEGER_TYPE_REF,
 };
 
 /// Convert DataFusion LogicalPlan to Substrait Plan
@@ -1398,6 +1405,49 @@ fn to_substrait_type(dt: &DataType, nullable: bool) -> 
Result<substrait::proto::
                 nullability,
             })),
         }),
+        DataType::Interval(interval_unit) => {
+            // define two type parameters for convenience
+            let i32_param = Parameter {
+                parameter: 
Some(parameter::Parameter::DataType(substrait::proto::Type {
+                    kind: Some(r#type::Kind::I32(r#type::I32 {
+                        type_variation_reference: DEFAULT_TYPE_REF,
+                        nullability: default_nullability,
+                    })),
+                })),
+            };
+            let i64_param = Parameter {
+                parameter: 
Some(parameter::Parameter::DataType(substrait::proto::Type {
+                    kind: Some(r#type::Kind::I64(r#type::I64 {
+                        type_variation_reference: DEFAULT_TYPE_REF,
+                        nullability: default_nullability,
+                    })),
+                })),
+            };
+
+            let (type_parameters, type_reference) = match interval_unit {
+                IntervalUnit::YearMonth => {
+                    let type_parameters = vec![i32_param];
+                    (type_parameters, INTERVAL_YEAR_MONTH_TYPE_REF)
+                }
+                IntervalUnit::DayTime => {
+                    let type_parameters = vec![i64_param];
+                    (type_parameters, INTERVAL_DAY_TIME_TYPE_REF)
+                }
+                IntervalUnit::MonthDayNano => {
+                    // use 2 `i64` as `i128`
+                    let type_parameters = vec![i64_param.clone(), i64_param];
+                    (type_parameters, INTERVAL_MONTH_DAY_NANO_TYPE_REF)
+                }
+            };
+            Ok(substrait::proto::Type {
+                kind: Some(r#type::Kind::UserDefined(r#type::UserDefined {
+                    type_reference,
+                    type_variation_reference: DEFAULT_TYPE_REF,
+                    nullability: default_nullability,
+                    type_parameters,
+                })),
+            })
+        }
         DataType::Binary => Ok(substrait::proto::Type {
             kind: Some(r#type::Kind::Binary(r#type::Binary {
                 type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
@@ -1735,6 +1785,75 @@ fn to_substrait_literal(value: &ScalarValue) -> 
Result<Literal> {
         }
         ScalarValue::Date32(Some(d)) => (LiteralType::Date(*d), 
DATE_32_TYPE_REF),
         // Date64 literal is not supported in Substrait
+        ScalarValue::IntervalYearMonth(Some(i)) => {
+            let bytes = i.to_le_bytes();
+            (
+                LiteralType::UserDefined(UserDefined {
+                    type_reference: INTERVAL_YEAR_MONTH_TYPE_REF,
+                    type_parameters: vec![Parameter {
+                        parameter: Some(parameter::Parameter::DataType(
+                            substrait::proto::Type {
+                                kind: Some(r#type::Kind::I32(r#type::I32 {
+                                    type_variation_reference: DEFAULT_TYPE_REF,
+                                    nullability: r#type::Nullability::Required 
as i32,
+                                })),
+                            },
+                        )),
+                    }],
+                    val: Some(Val::Value(ProtoAny {
+                        type_url: INTERVAL_YEAR_MONTH_TYPE_URL.to_string(),
+                        value: bytes.to_vec(),
+                    })),
+                }),
+                INTERVAL_YEAR_MONTH_TYPE_REF,
+            )
+        }
+        ScalarValue::IntervalMonthDayNano(Some(i)) => {
+            // treat `i128` as two contiguous `i64`
+            let bytes = i.to_le_bytes();
+            let i64_param = Parameter {
+                parameter: 
Some(parameter::Parameter::DataType(substrait::proto::Type {
+                    kind: Some(r#type::Kind::I64(r#type::I64 {
+                        type_variation_reference: DEFAULT_TYPE_REF,
+                        nullability: r#type::Nullability::Required as i32,
+                    })),
+                })),
+            };
+            (
+                LiteralType::UserDefined(UserDefined {
+                    type_reference: INTERVAL_MONTH_DAY_NANO_TYPE_REF,
+                    type_parameters: vec![i64_param.clone(), i64_param],
+                    val: Some(Val::Value(ProtoAny {
+                        type_url: INTERVAL_MONTH_DAY_NANO_TYPE_URL.to_string(),
+                        value: bytes.to_vec(),
+                    })),
+                }),
+                INTERVAL_MONTH_DAY_NANO_TYPE_REF,
+            )
+        }
+        ScalarValue::IntervalDayTime(Some(i)) => {
+            let bytes = i.to_le_bytes();
+            (
+                LiteralType::UserDefined(UserDefined {
+                    type_reference: INTERVAL_DAY_TIME_TYPE_REF,
+                    type_parameters: vec![Parameter {
+                        parameter: Some(parameter::Parameter::DataType(
+                            substrait::proto::Type {
+                                kind: Some(r#type::Kind::I64(r#type::I64 {
+                                    type_variation_reference: DEFAULT_TYPE_REF,
+                                    nullability: r#type::Nullability::Required 
as i32,
+                                })),
+                            },
+                        )),
+                    }],
+                    val: Some(Val::Value(ProtoAny {
+                        type_url: INTERVAL_DAY_TIME_TYPE_URL.to_string(),
+                        value: bytes.to_vec(),
+                    })),
+                }),
+                INTERVAL_DAY_TIME_TYPE_REF,
+            )
+        }
         ScalarValue::Binary(Some(b)) => {
             (LiteralType::Binary(b.clone()), DEFAULT_CONTAINER_TYPE_REF)
         }
diff --git a/datafusion/substrait/src/variation_const.rs 
b/datafusion/substrait/src/variation_const.rs
index 27ef15153b..51c0d3b021 100644
--- a/datafusion/substrait/src/variation_const.rs
+++ b/datafusion/substrait/src/variation_const.rs
@@ -25,6 +25,7 @@
 //! - Default type reference is 0. It is used when the actual type is the same 
with the original type.
 //! - Extended variant type references start from 1, and ususlly increase by 1.
 
+// For type variations
 pub const DEFAULT_TYPE_REF: u32 = 0;
 pub const UNSIGNED_INTEGER_TYPE_REF: u32 = 1;
 pub const TIMESTAMP_SECOND_TYPE_REF: u32 = 0;
@@ -37,3 +38,58 @@ pub const DEFAULT_CONTAINER_TYPE_REF: u32 = 0;
 pub const LARGE_CONTAINER_TYPE_REF: u32 = 1;
 pub const DECIMAL_128_TYPE_REF: u32 = 0;
 pub const DECIMAL_256_TYPE_REF: u32 = 1;
+
+// For custom types
+/// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`].
+///
+/// An `i32` for elapsed whole months. See also 
[`ScalarValue::IntervalYearMonth`]
+/// for the literal definition in DataFusion.
+///
+/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
+/// [`IntervalUnit::YearMonth`]: 
datafusion::arrow::datatypes::IntervalUnit::YearMonth
+/// [`ScalarValue::IntervalYearMonth`]: 
datafusion::common::ScalarValue::IntervalYearMonth
+pub const INTERVAL_YEAR_MONTH_TYPE_REF: u32 = 1;
+
+/// For [`DataType::Interval`] with [`IntervalUnit::DayTime`].
+///
+/// An `i64` as:
+/// - days: `i32`
+/// - milliseconds: `i32`
+///
+/// See also [`ScalarValue::IntervalDayTime`] for the literal definition in 
DataFusion.
+///
+/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
+/// [`IntervalUnit::DayTime`]: 
datafusion::arrow::datatypes::IntervalUnit::DayTime
+/// [`ScalarValue::IntervalDayTime`]: 
datafusion::common::ScalarValue::IntervalDayTime
+pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2;
+
+/// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`].
+///
+/// An `i128` as:
+/// - months: `i32`
+/// - days: `i32`
+/// - nanoseconds: `i64`
+///
+/// See also [`ScalarValue::IntervalMonthDayNano`] for the literal definition 
in DataFusion.
+///
+/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
+/// [`IntervalUnit::MonthDayNano`]: 
datafusion::arrow::datatypes::IntervalUnit::MonthDayNano
+/// [`ScalarValue::IntervalMonthDayNano`]: 
datafusion::common::ScalarValue::IntervalMonthDayNano
+pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3;
+
+// For User Defined URLs
+/// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`].
+///
+/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
+/// [`IntervalUnit::YearMonth`]: 
datafusion::arrow::datatypes::IntervalUnit::YearMonth
+pub const INTERVAL_YEAR_MONTH_TYPE_URL: &str = "interval-year-month";
+/// For [`DataType::Interval`] with [`IntervalUnit::DayTime`].
+///
+/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
+/// [`IntervalUnit::DayTime`]: 
datafusion::arrow::datatypes::IntervalUnit::DayTime
+pub const INTERVAL_DAY_TIME_TYPE_URL: &str = "interval-day-time";
+/// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`].
+///
+/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
+/// [`IntervalUnit::MonthDayNano`]: 
datafusion::arrow::datatypes::IntervalUnit::MonthDayNano
+pub const INTERVAL_MONTH_DAY_NANO_TYPE_URL: &str = "interval-month-day-nano";
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 8d0e96cedd..de989001df 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -25,7 +25,7 @@ use datafusion_substrait::logical_plan::{
 use std::hash::Hash;
 use std::sync::Arc;
 
-use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
+use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, 
TimeUnit};
 use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef};
 use datafusion::error::Result;
 use datafusion::execution::context::SessionState;
@@ -496,6 +496,24 @@ async fn roundtrip_arithmetic_ops() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn roundtrip_interval_literal() -> Result<()> {
+    roundtrip(
+        "SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 
'Interval(YearMonth)')",
+    )
+    .await?;
+    roundtrip(
+        "SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 
'Interval(DayTime)')",
+    )
+    .await?;
+    roundtrip(
+    "SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR', 
'Interval(MonthDayNano)')",
+    )
+    .await?;
+
+    Ok(())
+}
+
 #[tokio::test]
 async fn roundtrip_like() -> Result<()> {
     roundtrip("SELECT f FROM data WHERE f LIKE 'a%b'").await
@@ -1035,14 +1053,16 @@ async fn create_context() -> Result<SessionContext> {
     .with_serializer_registry(Arc::new(MockSerializerRegistry));
     let ctx = SessionContext::new_with_state(state);
     let mut explicit_options = CsvReadOptions::new();
-    let schema = Schema::new(vec![
+    let fields = vec![
         Field::new("a", DataType::Int64, true),
         Field::new("b", DataType::Decimal128(5, 2), true),
         Field::new("c", DataType::Date32, true),
         Field::new("d", DataType::Boolean, true),
         Field::new("e", DataType::UInt32, true),
         Field::new("f", DataType::Utf8, true),
-    ]);
+        Field::new("g", DataType::Interval(IntervalUnit::DayTime), true),
+    ];
+    let schema = Schema::new(fields);
     explicit_options.schema = Some(&schema);
     ctx.register_csv("data", "tests/testdata/data.csv", explicit_options)
         .await?;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to