This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 6167ce915e feat: add substrait support for Interval types and literals
(#10646)
6167ce915e is described below
commit 6167ce915e1768ea1951557118068bd68ce4aac3
Author: Ruihang Xia <[email protected]>
AuthorDate: Mon May 27 03:04:50 2024 +0800
feat: add substrait support for Interval types and literals (#10646)
* feat: support interval types
Signed-off-by: Ruihang Xia <[email protected]>
* impl literals
Signed-off-by: Ruihang Xia <[email protected]>
* fix deadlink in doc
Signed-off-by: Ruihang Xia <[email protected]>
---------
Signed-off-by: Ruihang Xia <[email protected]>
---
datafusion/substrait/src/logical_plan/consumer.rs | 76 ++++++++++++-
datafusion/substrait/src/logical_plan/producer.rs | 125 ++++++++++++++++++++-
datafusion/substrait/src/variation_const.rs | 56 +++++++++
.../tests/cases/roundtrip_logical_plan.rs | 26 ++++-
4 files changed, 273 insertions(+), 10 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 7e8a0cadb5..d6c60ebdde 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -16,7 +16,7 @@
// under the License.
use async_recursion::async_recursion;
-use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
+use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
use datafusion::common::{
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema,
DFSchemaRef,
};
@@ -39,6 +39,7 @@ use datafusion::{
scalar::ScalarValue,
};
use substrait::proto::exchange_rel::ExchangeKind;
+use substrait::proto::expression::literal::user_defined::Val;
use substrait::proto::expression::subquery::SubqueryType;
use substrait::proto::expression::{FieldReference, Literal, ScalarFunction};
use substrait::proto::{
@@ -71,9 +72,10 @@ use std::sync::Arc;
use crate::variation_const::{
DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF,
DECIMAL_256_TYPE_REF,
- DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF,
- TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF,
TIMESTAMP_NANO_TYPE_REF,
- TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
+ DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, INTERVAL_DAY_TIME_TYPE_REF,
+ INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF,
+ LARGE_CONTAINER_TYPE_REF, TIMESTAMP_MICRO_TYPE_REF,
TIMESTAMP_MILLI_TYPE_REF,
+ TIMESTAMP_NANO_TYPE_REF, TIMESTAMP_SECOND_TYPE_REF,
UNSIGNED_INTEGER_TYPE_REF,
};
enum ScalarFunctionType {
@@ -1162,6 +1164,24 @@ pub(crate) fn from_substrait_type(dt:
&substrait::proto::Type) -> Result<DataTyp
"Unsupported Substrait type variation {v} of type
{s_kind:?}"
),
},
+ r#type::Kind::UserDefined(u) => {
+ match u.type_reference {
+ INTERVAL_YEAR_MONTH_TYPE_REF => {
+ Ok(DataType::Interval(IntervalUnit::YearMonth))
+ }
+ INTERVAL_DAY_TIME_TYPE_REF => {
+ Ok(DataType::Interval(IntervalUnit::DayTime))
+ }
+ INTERVAL_MONTH_DAY_NANO_TYPE_REF => {
+ Ok(DataType::Interval(IntervalUnit::MonthDayNano))
+ }
+ _ => not_impl_err!(
+ "Unsupported Substrait user defined type with ref {}
and variation {}",
+ u.type_reference,
+ u.type_variation_reference
+ ),
+ }
+ },
r#type::Kind::Struct(s) => {
let mut fields = vec![];
for (i, f) in s.types.iter().enumerate() {
@@ -1387,6 +1407,54 @@ pub(crate) fn from_substrait_literal(lit: &Literal) ->
Result<ScalarValue> {
builder.build()?
}
Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
+ Some(LiteralType::UserDefined(user_defined)) => {
+ match user_defined.type_reference {
+ INTERVAL_YEAR_MONTH_TYPE_REF => {
+ let Some(Val::Value(raw_val)) = user_defined.val.as_ref()
else {
+ return substrait_err!("Interval year month value is
empty");
+ };
+ let value_slice: [u8; 4] =
+ raw_val.value.clone().try_into().map_err(|_| {
+ substrait_datafusion_err!(
+ "Failed to parse interval year month value"
+ )
+ })?;
+
ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes(value_slice)))
+ }
+ INTERVAL_DAY_TIME_TYPE_REF => {
+ let Some(Val::Value(raw_val)) = user_defined.val.as_ref()
else {
+ return substrait_err!("Interval day time value is
empty");
+ };
+ let value_slice: [u8; 8] =
+ raw_val.value.clone().try_into().map_err(|_| {
+ substrait_datafusion_err!(
+ "Failed to parse interval day time value"
+ )
+ })?;
+
ScalarValue::IntervalDayTime(Some(i64::from_le_bytes(value_slice)))
+ }
+ INTERVAL_MONTH_DAY_NANO_TYPE_REF => {
+ let Some(Val::Value(raw_val)) = user_defined.val.as_ref()
else {
+ return substrait_err!("Interval month day nano value
is empty");
+ };
+ let value_slice: [u8; 16] =
+ raw_val.value.clone().try_into().map_err(|_| {
+ substrait_datafusion_err!(
+ "Failed to parse interval month day nano value"
+ )
+ })?;
+ ScalarValue::IntervalMonthDayNano(Some(i128::from_le_bytes(
+ value_slice,
+ )))
+ }
+ _ => {
+ return not_impl_err!(
+ "Unsupported Substrait user defined type with ref {}",
+ user_defined.type_reference
+ )
+ }
+ }
+ }
_ => return not_impl_err!("Unsupported literal_type: {:?}",
lit.literal_type),
};
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index c0aac0c0a4..400609ff14 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -19,6 +19,7 @@ use std::collections::HashMap;
use std::ops::Deref;
use std::sync::Arc;
+use datafusion::arrow::datatypes::IntervalUnit;
use datafusion::logical_expr::{
CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits,
};
@@ -43,9 +44,12 @@ use datafusion::logical_expr::{expr, Between,
JoinConstraint, LogicalPlan, Opera
use datafusion::prelude::Expr;
use prost_types::Any as ProtoAny;
use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
+use substrait::proto::expression::literal::user_defined::Val;
+use substrait::proto::expression::literal::UserDefined;
use substrait::proto::expression::literal::{List, Struct};
use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
+use substrait::proto::r#type::{parameter, Parameter};
use substrait::proto::{CrossRel, ExchangeRel};
use substrait::{
proto::{
@@ -84,9 +88,12 @@ use substrait::{
use crate::variation_const::{
DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF,
DECIMAL_256_TYPE_REF,
- DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF,
- TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF,
TIMESTAMP_NANO_TYPE_REF,
- TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF,
+ DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, INTERVAL_DAY_TIME_TYPE_REF,
+ INTERVAL_DAY_TIME_TYPE_URL, INTERVAL_MONTH_DAY_NANO_TYPE_REF,
+ INTERVAL_MONTH_DAY_NANO_TYPE_URL, INTERVAL_YEAR_MONTH_TYPE_REF,
+ INTERVAL_YEAR_MONTH_TYPE_URL, LARGE_CONTAINER_TYPE_REF,
TIMESTAMP_MICRO_TYPE_REF,
+ TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF,
TIMESTAMP_SECOND_TYPE_REF,
+ UNSIGNED_INTEGER_TYPE_REF,
};
/// Convert DataFusion LogicalPlan to Substrait Plan
@@ -1398,6 +1405,49 @@ fn to_substrait_type(dt: &DataType, nullable: bool) ->
Result<substrait::proto::
nullability,
})),
}),
+ DataType::Interval(interval_unit) => {
+ // define two type parameters for convenience
+ let i32_param = Parameter {
+ parameter:
Some(parameter::Parameter::DataType(substrait::proto::Type {
+ kind: Some(r#type::Kind::I32(r#type::I32 {
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ })),
+ };
+ let i64_param = Parameter {
+ parameter:
Some(parameter::Parameter::DataType(substrait::proto::Type {
+ kind: Some(r#type::Kind::I64(r#type::I64 {
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ })),
+ })),
+ };
+
+ let (type_parameters, type_reference) = match interval_unit {
+ IntervalUnit::YearMonth => {
+ let type_parameters = vec![i32_param];
+ (type_parameters, INTERVAL_YEAR_MONTH_TYPE_REF)
+ }
+ IntervalUnit::DayTime => {
+ let type_parameters = vec![i64_param];
+ (type_parameters, INTERVAL_DAY_TIME_TYPE_REF)
+ }
+ IntervalUnit::MonthDayNano => {
+ // use 2 `i64` as `i128`
+ let type_parameters = vec![i64_param.clone(), i64_param];
+ (type_parameters, INTERVAL_MONTH_DAY_NANO_TYPE_REF)
+ }
+ };
+ Ok(substrait::proto::Type {
+ kind: Some(r#type::Kind::UserDefined(r#type::UserDefined {
+ type_reference,
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: default_nullability,
+ type_parameters,
+ })),
+ })
+ }
DataType::Binary => Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Binary(r#type::Binary {
type_variation_reference: DEFAULT_CONTAINER_TYPE_REF,
@@ -1735,6 +1785,75 @@ fn to_substrait_literal(value: &ScalarValue) ->
Result<Literal> {
}
ScalarValue::Date32(Some(d)) => (LiteralType::Date(*d),
DATE_32_TYPE_REF),
// Date64 literal is not supported in Substrait
+ ScalarValue::IntervalYearMonth(Some(i)) => {
+ let bytes = i.to_le_bytes();
+ (
+ LiteralType::UserDefined(UserDefined {
+ type_reference: INTERVAL_YEAR_MONTH_TYPE_REF,
+ type_parameters: vec![Parameter {
+ parameter: Some(parameter::Parameter::DataType(
+ substrait::proto::Type {
+ kind: Some(r#type::Kind::I32(r#type::I32 {
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: r#type::Nullability::Required
as i32,
+ })),
+ },
+ )),
+ }],
+ val: Some(Val::Value(ProtoAny {
+ type_url: INTERVAL_YEAR_MONTH_TYPE_URL.to_string(),
+ value: bytes.to_vec(),
+ })),
+ }),
+ INTERVAL_YEAR_MONTH_TYPE_REF,
+ )
+ }
+ ScalarValue::IntervalMonthDayNano(Some(i)) => {
+ // treat `i128` as two contiguous `i64`
+ let bytes = i.to_le_bytes();
+ let i64_param = Parameter {
+ parameter:
Some(parameter::Parameter::DataType(substrait::proto::Type {
+ kind: Some(r#type::Kind::I64(r#type::I64 {
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: r#type::Nullability::Required as i32,
+ })),
+ })),
+ };
+ (
+ LiteralType::UserDefined(UserDefined {
+ type_reference: INTERVAL_MONTH_DAY_NANO_TYPE_REF,
+ type_parameters: vec![i64_param.clone(), i64_param],
+ val: Some(Val::Value(ProtoAny {
+ type_url: INTERVAL_MONTH_DAY_NANO_TYPE_URL.to_string(),
+ value: bytes.to_vec(),
+ })),
+ }),
+ INTERVAL_MONTH_DAY_NANO_TYPE_REF,
+ )
+ }
+ ScalarValue::IntervalDayTime(Some(i)) => {
+ let bytes = i.to_le_bytes();
+ (
+ LiteralType::UserDefined(UserDefined {
+ type_reference: INTERVAL_DAY_TIME_TYPE_REF,
+ type_parameters: vec![Parameter {
+ parameter: Some(parameter::Parameter::DataType(
+ substrait::proto::Type {
+ kind: Some(r#type::Kind::I64(r#type::I64 {
+ type_variation_reference: DEFAULT_TYPE_REF,
+ nullability: r#type::Nullability::Required
as i32,
+ })),
+ },
+ )),
+ }],
+ val: Some(Val::Value(ProtoAny {
+ type_url: INTERVAL_DAY_TIME_TYPE_URL.to_string(),
+ value: bytes.to_vec(),
+ })),
+ }),
+ INTERVAL_DAY_TIME_TYPE_REF,
+ )
+ }
ScalarValue::Binary(Some(b)) => {
(LiteralType::Binary(b.clone()), DEFAULT_CONTAINER_TYPE_REF)
}
diff --git a/datafusion/substrait/src/variation_const.rs
b/datafusion/substrait/src/variation_const.rs
index 27ef15153b..51c0d3b021 100644
--- a/datafusion/substrait/src/variation_const.rs
+++ b/datafusion/substrait/src/variation_const.rs
@@ -25,6 +25,7 @@
//! - Default type reference is 0. It is used when the actual type is the same
with the original type.
//! - Extended variant type references start from 1, and ususlly increase by 1.
+// For type variations
pub const DEFAULT_TYPE_REF: u32 = 0;
pub const UNSIGNED_INTEGER_TYPE_REF: u32 = 1;
pub const TIMESTAMP_SECOND_TYPE_REF: u32 = 0;
@@ -37,3 +38,58 @@ pub const DEFAULT_CONTAINER_TYPE_REF: u32 = 0;
pub const LARGE_CONTAINER_TYPE_REF: u32 = 1;
pub const DECIMAL_128_TYPE_REF: u32 = 0;
pub const DECIMAL_256_TYPE_REF: u32 = 1;
+
+// For custom types
+/// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`].
+///
+/// An `i32` for elapsed whole months. See also
[`ScalarValue::IntervalYearMonth`]
+/// for the literal definition in DataFusion.
+///
+/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
+/// [`IntervalUnit::YearMonth`]:
datafusion::arrow::datatypes::IntervalUnit::YearMonth
+/// [`ScalarValue::IntervalYearMonth`]:
datafusion::common::ScalarValue::IntervalYearMonth
+pub const INTERVAL_YEAR_MONTH_TYPE_REF: u32 = 1;
+
+/// For [`DataType::Interval`] with [`IntervalUnit::DayTime`].
+///
+/// An `i64` as:
+/// - days: `i32`
+/// - milliseconds: `i32`
+///
+/// See also [`ScalarValue::IntervalDayTime`] for the literal definition in
DataFusion.
+///
+/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
+/// [`IntervalUnit::DayTime`]:
datafusion::arrow::datatypes::IntervalUnit::DayTime
+/// [`ScalarValue::IntervalDayTime`]:
datafusion::common::ScalarValue::IntervalDayTime
+pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2;
+
+/// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`].
+///
+/// An `i128` as:
+/// - months: `i32`
+/// - days: `i32`
+/// - nanoseconds: `i64`
+///
+/// See also [`ScalarValue::IntervalMonthDayNano`] for the literal definition
in DataFusion.
+///
+/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
+/// [`IntervalUnit::MonthDayNano`]:
datafusion::arrow::datatypes::IntervalUnit::MonthDayNano
+/// [`ScalarValue::IntervalMonthDayNano`]:
datafusion::common::ScalarValue::IntervalMonthDayNano
+pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3;
+
+// For User Defined URLs
+/// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`].
+///
+/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
+/// [`IntervalUnit::YearMonth`]:
datafusion::arrow::datatypes::IntervalUnit::YearMonth
+pub const INTERVAL_YEAR_MONTH_TYPE_URL: &str = "interval-year-month";
+/// For [`DataType::Interval`] with [`IntervalUnit::DayTime`].
+///
+/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
+/// [`IntervalUnit::DayTime`]:
datafusion::arrow::datatypes::IntervalUnit::DayTime
+pub const INTERVAL_DAY_TIME_TYPE_URL: &str = "interval-day-time";
+/// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`].
+///
+/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval
+/// [`IntervalUnit::MonthDayNano`]:
datafusion::arrow::datatypes::IntervalUnit::MonthDayNano
+pub const INTERVAL_MONTH_DAY_NANO_TYPE_URL: &str = "interval-month-day-nano";
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 8d0e96cedd..de989001df 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -25,7 +25,7 @@ use datafusion_substrait::logical_plan::{
use std::hash::Hash;
use std::sync::Arc;
-use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
+use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema,
TimeUnit};
use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef};
use datafusion::error::Result;
use datafusion::execution::context::SessionState;
@@ -496,6 +496,24 @@ async fn roundtrip_arithmetic_ops() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn roundtrip_interval_literal() -> Result<()> {
+ roundtrip(
+ "SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR',
'Interval(YearMonth)')",
+ )
+ .await?;
+ roundtrip(
+ "SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR',
'Interval(DayTime)')",
+ )
+ .await?;
+ roundtrip(
+ "SELECT g from data where g = arrow_cast(INTERVAL '1 YEAR',
'Interval(MonthDayNano)')",
+ )
+ .await?;
+
+ Ok(())
+}
+
#[tokio::test]
async fn roundtrip_like() -> Result<()> {
roundtrip("SELECT f FROM data WHERE f LIKE 'a%b'").await
@@ -1035,14 +1053,16 @@ async fn create_context() -> Result<SessionContext> {
.with_serializer_registry(Arc::new(MockSerializerRegistry));
let ctx = SessionContext::new_with_state(state);
let mut explicit_options = CsvReadOptions::new();
- let schema = Schema::new(vec![
+ let fields = vec![
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Decimal128(5, 2), true),
Field::new("c", DataType::Date32, true),
Field::new("d", DataType::Boolean, true),
Field::new("e", DataType::UInt32, true),
Field::new("f", DataType::Utf8, true),
- ]);
+ Field::new("g", DataType::Interval(IntervalUnit::DayTime), true),
+ ];
+ let schema = Schema::new(fields);
explicit_options.schema = Some(&schema);
ctx.register_csv("data", "tests/testdata/data.csv", explicit_options)
.await?;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]