This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 88fa0dfce6 Add `Field` to `Expr::Cast` -- allow logical expressions to
express a cast to an extension type (#18136)
88fa0dfce6 is described below
commit 88fa0dfce64d146508319fd3dd8a3eba137ed410
Author: Dewey Dunnington <[email protected]>
AuthorDate: Tue Mar 3 11:01:16 2026 -0600
Add `Field` to `Expr::Cast` -- allow logical expressions to express a cast
to an extension type (#18136)
## Which issue does this PR close?
- Closes #18060.
I am sorry that I missed the previous PR implementing this (
https://github.com/apache/datafusion/pull/18120 ) and I'm also happy to
review that one instead of updating this!
## Rationale for this change
Other systems that interact with the logical plan (e.g., SQL, Substrait)
can express types that are not strictly within the arrow DataType enum.
## What changes are included in this PR?
For the Cast and TryCast structs, the destination data type was changed
from a DataType to a FieldRef.
## Are these changes tested?
Yes.
## Are there any user-facing changes?
Yes, any code using `Cast { .. }` to create an expression would need to
use `Cast::new()` instead (or pass on field metadata if it has it).
Existing matches will need to be upated for the `data_type` -> `field`
member rename.
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
.../provider_filter_pushdown.rs | 2 +-
.../user_defined/user_defined_scalar_functions.rs | 5 +-
datafusion/expr/src/expr.rs | 60 +++++---
datafusion/expr/src/expr_rewriter/order_by.rs | 8 +-
datafusion/expr/src/expr_schema.rs | 23 ++-
datafusion/expr/src/tree_node.rs | 8 +-
datafusion/functions/src/core/arrow_cast.rs | 8 +-
datafusion/optimizer/src/eliminate_outer_join.rs | 4 +-
.../src/simplify_expressions/expr_simplifier.rs | 5 +-
datafusion/physical-expr/src/planner.rs | 89 ++++++++++--
datafusion/proto/proto/datafusion.proto | 4 +
datafusion/proto/src/generated/pbjson.rs | 72 ++++++++++
datafusion/proto/src/generated/prost.rs | 14 ++
datafusion/proto/src/logical_plan/from_proto.rs | 20 ++-
datafusion/proto/src/logical_plan/to_proto.rs | 12 +-
datafusion/sql/src/unparser/expr.rs | 160 +++++++++------------
.../src/logical_plan/consumer/expr/cast.rs | 11 +-
.../substrait/src/logical_plan/consumer/types.rs | 20 ++-
.../src/logical_plan/producer/expr/cast.rs | 16 +--
.../substrait/src/logical_plan/producer/types.rs | 44 +++---
20 files changed, 387 insertions(+), 198 deletions(-)
diff --git
a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs
b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs
index 6bb119fcb7..8078b0a7ec 100644
--- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs
+++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs
@@ -207,7 +207,7 @@ impl TableProvider for CustomProvider {
Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64,
Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64,
Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i,
- Expr::Cast(Cast { expr, data_type: _ }) => match
expr.deref() {
+ Expr::Cast(Cast { expr, field: _ }) => match expr.deref() {
Expr::Literal(lit_value, _) => match lit_value {
ScalarValue::Int8(Some(v)) => *v as i64,
ScalarValue::Int16(Some(v)) => *v as i64,
diff --git
a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index b4ce3a03db..025ee9767c 100644
--- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -715,10 +715,7 @@ impl ScalarUDFImpl for CastToI64UDF {
arg
} else {
// need to use an actual cast to get the correct type
- Expr::Cast(datafusion_expr::Cast {
- expr: Box::new(arg),
- data_type: DataType::Int64,
- })
+ Expr::Cast(datafusion_expr::Cast::new(Box::new(arg),
DataType::Int64))
};
// return the newly written argument to DataFusion
Ok(ExprSimplifyResult::Simplified(new_expr))
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 87e8e029a6..ef0187a878 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -32,6 +32,8 @@ use crate::{ExprSchemable, Operator, Signature, WindowFrame,
WindowUDF};
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable};
+use datafusion_common::datatype::DataTypeExt;
+use datafusion_common::metadata::format_type_and_metadata;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeContainer,
TreeNodeRecursion,
};
@@ -800,13 +802,20 @@ pub struct Cast {
/// The expression being cast
pub expr: Box<Expr>,
/// The `DataType` the expression will yield
- pub data_type: DataType,
+ pub field: FieldRef,
}
impl Cast {
/// Create a new Cast expression
pub fn new(expr: Box<Expr>, data_type: DataType) -> Self {
- Self { expr, data_type }
+ Self {
+ expr,
+ field: data_type.into_nullable_field_ref(),
+ }
+ }
+
+ pub fn new_from_field(expr: Box<Expr>, field: FieldRef) -> Self {
+ Self { expr, field }
}
}
@@ -816,13 +825,20 @@ pub struct TryCast {
/// The expression being cast
pub expr: Box<Expr>,
/// The `DataType` the expression will yield
- pub data_type: DataType,
+ pub field: FieldRef,
}
impl TryCast {
/// Create a new TryCast expression
pub fn new(expr: Box<Expr>, data_type: DataType) -> Self {
- Self { expr, data_type }
+ Self {
+ expr,
+ field: data_type.into_nullable_field_ref(),
+ }
+ }
+
+ pub fn new_from_field(expr: Box<Expr>, field: FieldRef) -> Self {
+ Self { expr, field }
}
}
@@ -2323,23 +2339,23 @@ impl NormalizeEq for Expr {
(
Expr::Cast(Cast {
expr: self_expr,
- data_type: self_data_type,
+ field: self_field,
}),
Expr::Cast(Cast {
expr: other_expr,
- data_type: other_data_type,
+ field: other_field,
}),
)
| (
Expr::TryCast(TryCast {
expr: self_expr,
- data_type: self_data_type,
+ field: self_field,
}),
Expr::TryCast(TryCast {
expr: other_expr,
- data_type: other_data_type,
+ field: other_field,
}),
- ) => self_data_type == other_data_type &&
self_expr.normalize_eq(other_expr),
+ ) => self_field == other_field &&
self_expr.normalize_eq(other_expr),
(
Expr::ScalarFunction(ScalarFunction {
func: self_func,
@@ -2655,15 +2671,9 @@ impl HashNode for Expr {
when_then_expr: _when_then_expr,
else_expr: _else_expr,
}) => {}
- Expr::Cast(Cast {
- expr: _expr,
- data_type,
- })
- | Expr::TryCast(TryCast {
- expr: _expr,
- data_type,
- }) => {
- data_type.hash(state);
+ Expr::Cast(Cast { expr: _expr, field })
+ | Expr::TryCast(TryCast { expr: _expr, field }) => {
+ field.hash(state);
}
Expr::ScalarFunction(ScalarFunction { func, args: _args }) => {
func.hash(state);
@@ -3369,11 +3379,15 @@ impl Display for Expr {
}
write!(f, "END")
}
- Expr::Cast(Cast { expr, data_type }) => {
- write!(f, "CAST({expr} AS {data_type})")
+ Expr::Cast(Cast { expr, field }) => {
+ let formatted =
+ format_type_and_metadata(field.data_type(),
Some(field.metadata()));
+ write!(f, "CAST({expr} AS {formatted})")
}
- Expr::TryCast(TryCast { expr, data_type }) => {
- write!(f, "TRY_CAST({expr} AS {data_type})")
+ Expr::TryCast(TryCast { expr, field }) => {
+ let formatted =
+ format_type_and_metadata(field.data_type(),
Some(field.metadata()));
+ write!(f, "TRY_CAST({expr} AS {formatted})")
}
Expr::Not(expr) => write!(f, "NOT {expr}"),
Expr::Negative(expr) => write!(f, "(- {expr})"),
@@ -3765,7 +3779,7 @@ mod test {
fn format_cast() -> Result<()> {
let expr = Expr::Cast(Cast {
expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)),
None)),
- data_type: DataType::Utf8,
+ field: DataType::Utf8.into_nullable_field_ref(),
});
let expected_canonical = "CAST(Float32(1.23) AS Utf8)";
assert_eq!(expected_canonical, format!("{expr}"));
diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs
b/datafusion/expr/src/expr_rewriter/order_by.rs
index ec22be5254..7c6af56c8d 100644
--- a/datafusion/expr/src/expr_rewriter/order_by.rs
+++ b/datafusion/expr/src/expr_rewriter/order_by.rs
@@ -116,13 +116,13 @@ fn rewrite_in_terms_of_projection(
if let Some(found) = found {
return Ok(Transformed::yes(match normalized_expr {
- Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast {
+ Expr::Cast(Cast { expr: _, field }) => Expr::Cast(Cast {
expr: Box::new(found),
- data_type,
+ field,
}),
- Expr::TryCast(TryCast { expr: _, data_type }) =>
Expr::TryCast(TryCast {
+ Expr::TryCast(TryCast { expr: _, field }) =>
Expr::TryCast(TryCast {
expr: Box::new(found),
- data_type,
+ field,
}),
_ => found,
}));
diff --git a/datafusion/expr/src/expr_schema.rs
b/datafusion/expr/src/expr_schema.rs
index f4e4f014f5..f2ac777af3 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -132,8 +132,9 @@ impl ExprSchemable for Expr {
.as_ref()
.map_or(Ok(DataType::Null), |e| e.get_type(schema))
}
- Expr::Cast(Cast { data_type, .. })
- | Expr::TryCast(TryCast { data_type, .. }) =>
Ok(data_type.clone()),
+ Expr::Cast(Cast { field, .. }) | Expr::TryCast(TryCast { field, ..
}) => {
+ Ok(field.data_type().clone())
+ }
Expr::Unnest(Unnest { expr }) => {
let arg_data_type = expr.get_type(schema)?;
// Unnest's output type is the inner type of the list
@@ -550,9 +551,23 @@ impl ExprSchemable for Expr {
func.return_field_from_args(args)
}
// _ => Ok((self.get_type(schema)?, self.nullable(schema)?)),
- Expr::Cast(Cast { expr, data_type }) => expr
+ Expr::Cast(Cast { expr, field }) => expr
.to_field(schema)
- .map(|(_, f)| f.retyped(data_type.clone())),
+ .map(|(_table_ref, destination_field)| {
+ // This propagates the nullability of the input rather than
+ // force the nullability of the destination field. This is
+ // usually the desired behaviour (i.e., specifying a cast
+ // destination type usually does not force a user to pick
+ // nullability, and assuming `true` would prevent the
non-nullability
+ // of the parent expression to make the result eligible for
+ // optimizations that only apply to non-nullable values).
+ destination_field
+ .as_ref()
+ .clone()
+ .with_data_type(field.data_type().clone())
+ .with_metadata(destination_field.metadata().clone())
+ })
+ .map(Arc::new),
Expr::Placeholder(Placeholder {
id: _,
field: Some(field),
diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs
index 226c512a97..f3bec6bbf9 100644
--- a/datafusion/expr/src/tree_node.rs
+++ b/datafusion/expr/src/tree_node.rs
@@ -234,12 +234,12 @@ impl TreeNode for Expr {
.update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
Expr::Case(Case::new(new_expr, new_when_then_expr,
new_else_expr))
}),
- Expr::Cast(Cast { expr, data_type }) => expr
+ Expr::Cast(Cast { expr, field }) => expr
.map_elements(f)?
- .update_data(|be| Expr::Cast(Cast::new(be, data_type))),
- Expr::TryCast(TryCast { expr, data_type }) => expr
+ .update_data(|be| Expr::Cast(Cast::new_from_field(be, field))),
+ Expr::TryCast(TryCast { expr, field }) => expr
.map_elements(f)?
- .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))),
+ .update_data(|be| Expr::TryCast(TryCast::new_from_field(be,
field))),
Expr::ScalarFunction(ScalarFunction { func, args }) => {
args.map_elements(f)?.map_data(|new_args| {
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
diff --git a/datafusion/functions/src/core/arrow_cast.rs
b/datafusion/functions/src/core/arrow_cast.rs
index 7c24450adf..e555081e41 100644
--- a/datafusion/functions/src/core/arrow_cast.rs
+++ b/datafusion/functions/src/core/arrow_cast.rs
@@ -19,11 +19,11 @@
use arrow::datatypes::{DataType, Field, FieldRef};
use arrow::error::ArrowError;
-use datafusion_common::types::logical_string;
use datafusion_common::{
- Result, ScalarValue, arrow_datafusion_err, exec_err, internal_err,
+ Result, ScalarValue, arrow_datafusion_err, datatype::DataTypeExt,
+ exec_datafusion_err, exec_err, internal_err, types::logical_string,
+ utils::take_function_args,
};
-use datafusion_common::{exec_datafusion_err, utils::take_function_args};
use std::any::Any;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
@@ -176,7 +176,7 @@ impl ScalarUDFImpl for ArrowCastFunc {
// Use an actual cast to get the correct type
Expr::Cast(datafusion_expr::Cast {
expr: Box::new(arg),
- data_type: target_type,
+ field: target_type.into_nullable_field_ref(),
})
};
// return the newly written argument to DataFusion
diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs
b/datafusion/optimizer/src/eliminate_outer_join.rs
index 58abe38d04..5c47d6b7c5 100644
--- a/datafusion/optimizer/src/eliminate_outer_join.rs
+++ b/datafusion/optimizer/src/eliminate_outer_join.rs
@@ -290,8 +290,8 @@ fn extract_non_nullable_columns(
false,
)
}
- Expr::Cast(Cast { expr, data_type: _ })
- | Expr::TryCast(TryCast { expr, data_type: _ }) =>
extract_non_nullable_columns(
+ Expr::Cast(Cast { expr, field: _ })
+ | Expr::TryCast(TryCast { expr, field: _ }) =>
extract_non_nullable_columns(
expr,
non_nullable_cols,
left_schema,
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index fe2e1a3b04..28fcdf1ded 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -646,12 +646,11 @@ impl ConstEvaluator {
Expr::ScalarFunction(ScalarFunction { func, .. }) => {
Self::volatility_ok(func.signature().volatility)
}
- Expr::Cast(Cast { expr, data_type })
- | Expr::TryCast(TryCast { expr, data_type }) => {
+ Expr::Cast(Cast { expr, field }) | Expr::TryCast(TryCast { expr,
field }) => {
if let (
Ok(DataType::Struct(source_fields)),
DataType::Struct(target_fields),
- ) = (expr.get_type(&DFSchema::empty()), data_type)
+ ) = (expr.get_type(&DFSchema::empty()), field.data_type())
{
// Don't const-fold struct casts with different field
counts
if source_fields.len() != target_fields.len() {
diff --git a/datafusion/physical-expr/src/planner.rs
b/datafusion/physical-expr/src/planner.rs
index 84a6aa4309..5c170700d9 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -25,7 +25,7 @@ use crate::{
use arrow::datatypes::Schema;
use datafusion_common::config::ConfigOptions;
-use datafusion_common::metadata::FieldMetadata;
+use datafusion_common::metadata::{FieldMetadata, format_type_and_metadata};
use datafusion_common::{
DFSchema, Result, ScalarValue, ToDFSchema, exec_err, not_impl_err,
plan_err,
};
@@ -34,7 +34,7 @@ use datafusion_expr::expr::{Alias, Cast, InList, Placeholder,
ScalarFunction};
use datafusion_expr::var_provider::VarType;
use datafusion_expr::var_provider::is_system_variables;
use datafusion_expr::{
- Between, BinaryExpr, Expr, Like, Operator, TryCast, binary_expr, lit,
+ Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, TryCast,
binary_expr, lit,
};
/// [PhysicalExpr] evaluate DataFusion expressions such as `A + 1`, or `CAST(c1
@@ -288,16 +288,44 @@ pub fn create_physical_expr(
};
Ok(expressions::case(expr, when_then_expr, else_expr)?)
}
- Expr::Cast(Cast { expr, data_type }) => expressions::cast(
- create_physical_expr(expr, input_dfschema, execution_props)?,
- input_schema,
- data_type.clone(),
- ),
- Expr::TryCast(TryCast { expr, data_type }) => expressions::try_cast(
- create_physical_expr(expr, input_dfschema, execution_props)?,
- input_schema,
- data_type.clone(),
- ),
+ Expr::Cast(Cast { expr, field }) => {
+ if !field.metadata().is_empty() {
+ let (_, src_field) = expr.to_field(input_dfschema)?;
+ return plan_err!(
+ "Cast from {} to {} is not supported",
+ format_type_and_metadata(
+ src_field.data_type(),
+ Some(src_field.metadata()),
+ ),
+ format_type_and_metadata(field.data_type(),
Some(field.metadata()))
+ );
+ }
+
+ expressions::cast(
+ create_physical_expr(expr, input_dfschema, execution_props)?,
+ input_schema,
+ field.data_type().clone(),
+ )
+ }
+ Expr::TryCast(TryCast { expr, field }) => {
+ if !field.metadata().is_empty() {
+ let (_, src_field) = expr.to_field(input_dfschema)?;
+ return plan_err!(
+ "TryCast from {} to {} is not supported",
+ format_type_and_metadata(
+ src_field.data_type(),
+ Some(src_field.metadata()),
+ ),
+ format_type_and_metadata(field.data_type(),
Some(field.metadata()))
+ );
+ }
+
+ expressions::try_cast(
+ create_physical_expr(expr, input_dfschema, execution_props)?,
+ input_schema,
+ field.data_type().clone(),
+ )
+ }
Expr::Not(expr) => {
expressions::not(create_physical_expr(expr, input_dfschema,
execution_props)?)
}
@@ -417,7 +445,7 @@ pub fn logical2physical(expr: &Expr, schema: &Schema) ->
Arc<dyn PhysicalExpr> {
mod tests {
use arrow::array::{ArrayRef, BooleanArray, RecordBatch, StringArray};
use arrow::datatypes::{DataType, Field};
-
+ use datafusion_common::datatype::DataTypeExt;
use datafusion_expr::{Operator, col, lit};
use super::*;
@@ -447,6 +475,41 @@ mod tests {
Ok(())
}
+ #[test]
+ fn test_cast_to_extension_type() -> Result<()> {
+ let extension_field_type = Arc::new(
+ DataType::FixedSizeBinary(16)
+ .into_nullable_field()
+ .with_metadata(
+ [("ARROW:extension:name".to_string(),
"arrow.uuid".to_string())]
+ .into(),
+ ),
+ );
+ let expr = lit("3230e5d4-888e-408b-b09b-831f44aa0c58");
+ let cast_expr = Expr::Cast(Cast::new_from_field(
+ Box::new(expr.clone()),
+ Arc::clone(&extension_field_type),
+ ));
+ let err =
+ create_physical_expr(&cast_expr, &DFSchema::empty(),
&ExecutionProps::new())
+ .unwrap_err();
+ assert!(err.message().contains("arrow.uuid"));
+
+ let try_cast_expr = Expr::TryCast(TryCast::new_from_field(
+ Box::new(expr.clone()),
+ Arc::clone(&extension_field_type),
+ ));
+ let err = create_physical_expr(
+ &try_cast_expr,
+ &DFSchema::empty(),
+ &ExecutionProps::new(),
+ )
+ .unwrap_err();
+ assert!(err.message().contains("arrow.uuid"));
+
+ Ok(())
+ }
+
/// Test that deeply nested expressions do not cause a stack overflow.
///
/// This test only runs when the `recursive_protection` feature is enabled,
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 7c02688676..e64bcdb41e 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -594,11 +594,15 @@ message WhenThen {
message CastNode {
LogicalExprNode expr = 1;
datafusion_common.ArrowType arrow_type = 2;
+ map<string, string> metadata = 3;
+ optional bool nullable = 4;
}
message TryCastNode {
LogicalExprNode expr = 1;
datafusion_common.ArrowType arrow_type = 2;
+ map<string, string> metadata = 3;
+ optional bool nullable = 4;
}
message SortExprNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 5b2b9133ce..15b9ba88f4 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -2203,6 +2203,12 @@ impl serde::Serialize for CastNode {
if self.arrow_type.is_some() {
len += 1;
}
+ if !self.metadata.is_empty() {
+ len += 1;
+ }
+ if self.nullable.is_some() {
+ len += 1;
+ }
let mut struct_ser =
serializer.serialize_struct("datafusion.CastNode", len)?;
if let Some(v) = self.expr.as_ref() {
struct_ser.serialize_field("expr", v)?;
@@ -2210,6 +2216,12 @@ impl serde::Serialize for CastNode {
if let Some(v) = self.arrow_type.as_ref() {
struct_ser.serialize_field("arrowType", v)?;
}
+ if !self.metadata.is_empty() {
+ struct_ser.serialize_field("metadata", &self.metadata)?;
+ }
+ if let Some(v) = self.nullable.as_ref() {
+ struct_ser.serialize_field("nullable", v)?;
+ }
struct_ser.end()
}
}
@@ -2223,12 +2235,16 @@ impl<'de> serde::Deserialize<'de> for CastNode {
"expr",
"arrow_type",
"arrowType",
+ "metadata",
+ "nullable",
];
#[allow(clippy::enum_variant_names)]
enum GeneratedField {
Expr,
ArrowType,
+ Metadata,
+ Nullable,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -2252,6 +2268,8 @@ impl<'de> serde::Deserialize<'de> for CastNode {
match value {
"expr" => Ok(GeneratedField::Expr),
"arrowType" | "arrow_type" =>
Ok(GeneratedField::ArrowType),
+ "metadata" => Ok(GeneratedField::Metadata),
+ "nullable" => Ok(GeneratedField::Nullable),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -2273,6 +2291,8 @@ impl<'de> serde::Deserialize<'de> for CastNode {
{
let mut expr__ = None;
let mut arrow_type__ = None;
+ let mut metadata__ = None;
+ let mut nullable__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::Expr => {
@@ -2287,11 +2307,27 @@ impl<'de> serde::Deserialize<'de> for CastNode {
}
arrow_type__ = map_.next_value()?;
}
+ GeneratedField::Metadata => {
+ if metadata__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("metadata"));
+ }
+ metadata__ = Some(
+ map_.next_value::<std::collections::HashMap<_,
_>>()?
+ );
+ }
+ GeneratedField::Nullable => {
+ if nullable__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("nullable"));
+ }
+ nullable__ = map_.next_value()?;
+ }
}
}
Ok(CastNode {
expr: expr__,
arrow_type: arrow_type__,
+ metadata: metadata__.unwrap_or_default(),
+ nullable: nullable__,
})
}
}
@@ -23040,6 +23076,12 @@ impl serde::Serialize for TryCastNode {
if self.arrow_type.is_some() {
len += 1;
}
+ if !self.metadata.is_empty() {
+ len += 1;
+ }
+ if self.nullable.is_some() {
+ len += 1;
+ }
let mut struct_ser =
serializer.serialize_struct("datafusion.TryCastNode", len)?;
if let Some(v) = self.expr.as_ref() {
struct_ser.serialize_field("expr", v)?;
@@ -23047,6 +23089,12 @@ impl serde::Serialize for TryCastNode {
if let Some(v) = self.arrow_type.as_ref() {
struct_ser.serialize_field("arrowType", v)?;
}
+ if !self.metadata.is_empty() {
+ struct_ser.serialize_field("metadata", &self.metadata)?;
+ }
+ if let Some(v) = self.nullable.as_ref() {
+ struct_ser.serialize_field("nullable", v)?;
+ }
struct_ser.end()
}
}
@@ -23060,12 +23108,16 @@ impl<'de> serde::Deserialize<'de> for TryCastNode {
"expr",
"arrow_type",
"arrowType",
+ "metadata",
+ "nullable",
];
#[allow(clippy::enum_variant_names)]
enum GeneratedField {
Expr,
ArrowType,
+ Metadata,
+ Nullable,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -23089,6 +23141,8 @@ impl<'de> serde::Deserialize<'de> for TryCastNode {
match value {
"expr" => Ok(GeneratedField::Expr),
"arrowType" | "arrow_type" =>
Ok(GeneratedField::ArrowType),
+ "metadata" => Ok(GeneratedField::Metadata),
+ "nullable" => Ok(GeneratedField::Nullable),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -23110,6 +23164,8 @@ impl<'de> serde::Deserialize<'de> for TryCastNode {
{
let mut expr__ = None;
let mut arrow_type__ = None;
+ let mut metadata__ = None;
+ let mut nullable__ = None;
while let Some(k) = map_.next_key()? {
match k {
GeneratedField::Expr => {
@@ -23124,11 +23180,27 @@ impl<'de> serde::Deserialize<'de> for TryCastNode {
}
arrow_type__ = map_.next_value()?;
}
+ GeneratedField::Metadata => {
+ if metadata__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("metadata"));
+ }
+ metadata__ = Some(
+ map_.next_value::<std::collections::HashMap<_,
_>>()?
+ );
+ }
+ GeneratedField::Nullable => {
+ if nullable__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("nullable"));
+ }
+ nullable__ = map_.next_value()?;
+ }
}
}
Ok(TryCastNode {
expr: expr__,
arrow_type: arrow_type__,
+ metadata: metadata__.unwrap_or_default(),
+ nullable: nullable__,
})
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index d9602665c2..48ad8dcfea 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -922,6 +922,13 @@ pub struct CastNode {
pub expr:
::core::option::Option<::prost::alloc::boxed::Box<LogicalExprNode>>,
#[prost(message, optional, tag = "2")]
pub arrow_type:
::core::option::Option<super::datafusion_common::ArrowType>,
+ #[prost(map = "string, string", tag = "3")]
+ pub metadata: ::std::collections::HashMap<
+ ::prost::alloc::string::String,
+ ::prost::alloc::string::String,
+ >,
+ #[prost(bool, optional, tag = "4")]
+ pub nullable: ::core::option::Option<bool>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TryCastNode {
@@ -929,6 +936,13 @@ pub struct TryCastNode {
pub expr:
::core::option::Option<::prost::alloc::boxed::Box<LogicalExprNode>>,
#[prost(message, optional, tag = "2")]
pub arrow_type:
::core::option::Option<super::datafusion_common::ArrowType>,
+ #[prost(map = "string, string", tag = "3")]
+ pub metadata: ::std::collections::HashMap<
+ ::prost::alloc::string::String,
+ ::prost::alloc::string::String,
+ >,
+ #[prost(bool, optional, tag = "4")]
+ pub nullable: ::core::option::Option<bool>,
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct SortExprNode {
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index a653f517b7..ed33d9fab1 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -17,7 +17,8 @@
use std::sync::Arc;
-use arrow::datatypes::Field;
+use arrow::datatypes::{DataType, Field};
+use datafusion_common::datatype::DataTypeExt;
use datafusion_common::{
NullEquality, RecursionUnnestOption, Result, ScalarValue, TableReference,
UnnestOptions, exec_datafusion_err, internal_err, plan_datafusion_err,
@@ -527,8 +528,11 @@ pub fn parse_expr(
"expr",
codec,
)?);
- let data_type = cast.arrow_type.as_ref().required("arrow_type")?;
- Ok(Expr::Cast(Cast::new(expr, data_type)))
+ let data_type: DataType =
cast.arrow_type.as_ref().required("arrow_type")?;
+ let field = data_type
+ .into_nullable_field()
+ .with_nullable(cast.nullable.unwrap_or(true));
+ Ok(Expr::Cast(Cast::new_from_field(expr, Arc::new(field))))
}
ExprType::TryCast(cast) => {
let expr = Box::new(parse_required_expr(
@@ -537,8 +541,14 @@ pub fn parse_expr(
"expr",
codec,
)?);
- let data_type = cast.arrow_type.as_ref().required("arrow_type")?;
- Ok(Expr::TryCast(TryCast::new(expr, data_type)))
+ let data_type: DataType =
cast.arrow_type.as_ref().required("arrow_type")?;
+ let field = data_type
+ .into_nullable_field()
+ .with_nullable(cast.nullable.unwrap_or(true));
+ Ok(Expr::TryCast(TryCast::new_from_field(
+ expr,
+ Arc::new(field),
+ )))
}
ExprType::Negative(negative) => Ok(Expr::Negative(Box::new(
parse_required_expr(negative.expr.as_deref(), registry, "expr",
codec)?,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index fe63fce6ee..6fcb738992 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -521,19 +521,23 @@ pub fn serialize_expr(
expr_type: Some(ExprType::Case(expr)),
}
}
- Expr::Cast(Cast { expr, data_type }) => {
+ Expr::Cast(Cast { expr, field }) => {
let expr = Box::new(protobuf::CastNode {
expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)),
- arrow_type: Some(data_type.try_into()?),
+ arrow_type: Some(field.data_type().try_into()?),
+ metadata: field.metadata().clone(),
+ nullable: Some(field.is_nullable()),
});
protobuf::LogicalExprNode {
expr_type: Some(ExprType::Cast(expr)),
}
}
- Expr::TryCast(TryCast { expr, data_type }) => {
+ Expr::TryCast(TryCast { expr, field }) => {
let expr = Box::new(protobuf::TryCastNode {
expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)),
- arrow_type: Some(data_type.try_into()?),
+ arrow_type: Some(field.data_type().try_into()?),
+ metadata: field.metadata().clone(),
+ nullable: Some(field.is_nullable()),
});
protobuf::LogicalExprNode {
expr_type: Some(ExprType::TryCast(expr)),
diff --git a/datafusion/sql/src/unparser/expr.rs
b/datafusion/sql/src/unparser/expr.rs
index 59a9207b51..bbe1f3dd9d 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use datafusion_common::datatype::DataTypeExt;
use datafusion_expr::expr::{AggregateFunctionParams, Unnest,
WindowFunctionParams};
use sqlparser::ast::Value::SingleQuotedString;
use sqlparser::ast::{
@@ -37,6 +38,7 @@ use arrow::array::{
};
use arrow::datatypes::{
DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type,
DecimalType,
+ FieldRef,
};
use arrow::util::display::array_value_to_string;
use datafusion_common::{
@@ -188,9 +190,7 @@ impl Unparser<'_> {
end_token: AttachedToken::empty(),
})
}
- Expr::Cast(Cast { expr, data_type }) => {
- Ok(self.cast_to_sql(expr, data_type)?)
- }
+ Expr::Cast(Cast { expr, field }) => Ok(self.cast_to_sql(expr,
field)?),
Expr::Literal(value, _) => Ok(self.scalar_to_sql(value)?),
Expr::Alias(Alias { expr, name: _, .. }) =>
self.expr_to_sql_inner(expr),
Expr::WindowFunction(window_fun) => {
@@ -488,12 +488,12 @@ impl Unparser<'_> {
)
})
}
- Expr::TryCast(TryCast { expr, data_type }) => {
+ Expr::TryCast(TryCast { expr, field }) => {
let inner_expr = self.expr_to_sql_inner(expr)?;
Ok(ast::Expr::Cast {
kind: ast::CastKind::TryCast,
expr: Box::new(inner_expr),
- data_type: self.arrow_dtype_to_ast_dtype(data_type)?,
+ data_type: self.arrow_dtype_to_ast_dtype(field)?,
array: false,
format: None,
})
@@ -1176,17 +1176,20 @@ impl Unparser<'_> {
// Explicit type cast on ast::Expr::Value is not needed by underlying
engine for certain types
// For example: CAST(Utf8("binary_value") AS Binary) and
CAST(Utf8("dictionary_value") AS Dictionary)
- fn cast_to_sql(&self, expr: &Expr, data_type: &DataType) ->
Result<ast::Expr> {
+ fn cast_to_sql(&self, expr: &Expr, field: &FieldRef) -> Result<ast::Expr> {
let inner_expr = self.expr_to_sql_inner(expr)?;
+ let data_type = field.data_type();
match inner_expr {
ast::Expr::Value(_) => match data_type {
- DataType::Dictionary(_, _) | DataType::Binary |
DataType::BinaryView => {
+ DataType::Dictionary(_, _) | DataType::Binary |
DataType::BinaryView
+ if field.metadata().is_empty() =>
+ {
Ok(inner_expr)
}
_ => Ok(ast::Expr::Cast {
kind: ast::CastKind::Cast,
expr: Box::new(inner_expr),
- data_type: self.arrow_dtype_to_ast_dtype(data_type)?,
+ data_type: self.arrow_dtype_to_ast_dtype(field)?,
array: false,
format: None,
}),
@@ -1194,7 +1197,7 @@ impl Unparser<'_> {
_ => Ok(ast::Expr::Cast {
kind: ast::CastKind::Cast,
expr: Box::new(inner_expr),
- data_type: self.arrow_dtype_to_ast_dtype(data_type)?,
+ data_type: self.arrow_dtype_to_ast_dtype(field)?,
array: false,
format: None,
}),
@@ -1724,7 +1727,8 @@ impl Unparser<'_> {
}))
}
- fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) ->
Result<ast::DataType> {
+ fn arrow_dtype_to_ast_dtype(&self, field: &FieldRef) ->
Result<ast::DataType> {
+ let data_type = field.data_type();
match data_type {
DataType::Null => {
not_impl_err!("Unsupported DataType: conversion: {data_type}")
@@ -1797,10 +1801,10 @@ impl Unparser<'_> {
DataType::Union(_, _) => {
not_impl_err!("Unsupported DataType: conversion: {data_type}")
}
- DataType::Dictionary(_, val) => self.arrow_dtype_to_ast_dtype(val),
- DataType::RunEndEncoded(_, val) => {
- self.arrow_dtype_to_ast_dtype(val.data_type())
+ DataType::Dictionary(_, val) => {
+
self.arrow_dtype_to_ast_dtype(&val.clone().into_nullable_field_ref())
}
+ DataType::RunEndEncoded(_, val) =>
self.arrow_dtype_to_ast_dtype(val),
DataType::Decimal32(precision, scale)
| DataType::Decimal64(precision, scale)
| DataType::Decimal128(precision, scale)
@@ -1938,34 +1942,25 @@ mod tests {
r#"CASE WHEN a IS NOT NULL THEN true ELSE false END"#,
),
(
- Expr::Cast(Cast {
- expr: Box::new(col("a")),
- data_type: DataType::Date64,
- }),
+ Expr::Cast(Cast::new(Box::new(col("a")), DataType::Date64)),
r#"CAST(a AS DATETIME)"#,
),
(
- Expr::Cast(Cast {
- expr: Box::new(col("a")),
- data_type: DataType::Timestamp(
- TimeUnit::Nanosecond,
- Some("+08:00".into()),
- ),
- }),
+ Expr::Cast(Cast::new(
+ Box::new(col("a")),
+ DataType::Timestamp(TimeUnit::Nanosecond,
Some("+08:00".into())),
+ )),
r#"CAST(a AS TIMESTAMP WITH TIME ZONE)"#,
),
(
- Expr::Cast(Cast {
- expr: Box::new(col("a")),
- data_type: DataType::Timestamp(TimeUnit::Millisecond,
None),
- }),
+ Expr::Cast(Cast::new(
+ Box::new(col("a")),
+ DataType::Timestamp(TimeUnit::Millisecond, None),
+ )),
r#"CAST(a AS TIMESTAMP)"#,
),
(
- Expr::Cast(Cast {
- expr: Box::new(col("a")),
- data_type: DataType::UInt32,
- }),
+ Expr::Cast(Cast::new(Box::new(col("a")), DataType::UInt32)),
r#"CAST(a AS INTEGER UNSIGNED)"#,
),
(
@@ -2283,10 +2278,7 @@ mod tests {
r#"((a + b) > 100.123)"#,
),
(
- Expr::Cast(Cast {
- expr: Box::new(col("a")),
- data_type: DataType::Decimal128(10, -2),
- }),
+ Expr::Cast(Cast::new(Box::new(col("a")),
DataType::Decimal128(10, -2))),
r#"CAST(a AS DECIMAL(12,0))"#,
),
(
@@ -2434,10 +2426,7 @@ mod tests {
.build();
let unparser = Unparser::new(&dialect);
- let expr = Expr::Cast(Cast {
- expr: Box::new(col("a")),
- data_type: DataType::Date64,
- });
+ let expr = Expr::Cast(Cast::new(Box::new(col("a")),
DataType::Date64));
let ast = unparser.expr_to_sql(&expr)?;
let actual = format!("{ast}");
@@ -2459,10 +2448,7 @@ mod tests {
.build();
let unparser = Unparser::new(&dialect);
- let expr = Expr::Cast(Cast {
- expr: Box::new(col("a")),
- data_type: DataType::Float64,
- });
+ let expr = Expr::Cast(Cast::new(Box::new(col("a")),
DataType::Float64));
let ast = unparser.expr_to_sql(&expr)?;
let actual = format!("{ast}");
@@ -2692,23 +2678,23 @@ mod tests {
fn test_cast_value_to_binary_expr() {
let tests = [
(
- Expr::Cast(Cast {
- expr: Box::new(Expr::Literal(
+ Expr::Cast(Cast::new(
+ Box::new(Expr::Literal(
ScalarValue::Utf8(Some("blah".to_string())),
None,
)),
- data_type: DataType::Binary,
- }),
+ DataType::Binary,
+ )),
"'blah'",
),
(
- Expr::Cast(Cast {
- expr: Box::new(Expr::Literal(
+ Expr::Cast(Cast::new(
+ Box::new(Expr::Literal(
ScalarValue::Utf8(Some("blah".to_string())),
None,
)),
- data_type: DataType::BinaryView,
- }),
+ DataType::BinaryView,
+ )),
"'blah'",
),
];
@@ -2739,10 +2725,7 @@ mod tests {
] {
let unparser = Unparser::new(dialect);
- let expr = Expr::Cast(Cast {
- expr: Box::new(col("a")),
- data_type,
- });
+ let expr = Expr::Cast(Cast::new(Box::new(col("a")), data_type));
let ast = unparser.expr_to_sql(&expr)?;
let actual = format!("{ast}");
@@ -2825,10 +2808,7 @@ mod tests {
[(default_dialect, "BIGINT"), (mysql_dialect, "SIGNED")]
{
let unparser = Unparser::new(&dialect);
- let expr = Expr::Cast(Cast {
- expr: Box::new(col("a")),
- data_type: DataType::Int64,
- });
+ let expr = Expr::Cast(Cast::new(Box::new(col("a")),
DataType::Int64));
let ast = unparser.expr_to_sql(&expr)?;
let actual = format!("{ast}");
@@ -2853,10 +2833,7 @@ mod tests {
[(default_dialect, "INTEGER"), (mysql_dialect, "SIGNED")]
{
let unparser = Unparser::new(&dialect);
- let expr = Expr::Cast(Cast {
- expr: Box::new(col("a")),
- data_type: DataType::Int32,
- });
+ let expr = Expr::Cast(Cast::new(Box::new(col("a")),
DataType::Int32));
let ast = unparser.expr_to_sql(&expr)?;
let actual = format!("{ast}");
@@ -2892,10 +2869,7 @@ mod tests {
(&mysql_dialect, ×tamp_with_tz, "DATETIME"),
] {
let unparser = Unparser::new(dialect);
- let expr = Expr::Cast(Cast {
- expr: Box::new(col("a")),
- data_type: data_type.clone(),
- });
+ let expr = Expr::Cast(Cast::new(Box::new(col("a")),
data_type.clone()));
let ast = unparser.expr_to_sql(&expr)?;
let actual = format!("{ast}");
@@ -2948,10 +2922,7 @@ mod tests {
] {
let unparser = Unparser::new(dialect);
- let expr = Expr::Cast(Cast {
- expr: Box::new(col("a")),
- data_type,
- });
+ let expr = Expr::Cast(Cast::new(Box::new(col("a")), data_type));
let ast = unparser.expr_to_sql(&expr)?;
let actual = format!("{ast}");
@@ -2991,13 +2962,13 @@ mod tests {
#[test]
fn test_cast_value_to_dict_expr() {
let tests = [(
- Expr::Cast(Cast {
- expr: Box::new(Expr::Literal(
+ Expr::Cast(Cast::new(
+ Box::new(Expr::Literal(
ScalarValue::Utf8(Some("variation".to_string())),
None,
)),
- data_type: DataType::Dictionary(Box::new(Int8),
Box::new(DataType::Utf8)),
- }),
+ DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)),
+ )),
"'variation'",
)];
for (value, expected) in tests {
@@ -3029,10 +3000,7 @@ mod tests {
datafusion_functions::math::round::RoundFunc::new(),
)),
args: vec![
- Expr::Cast(Cast {
- expr: Box::new(col("a")),
- data_type: DataType::Float64,
- }),
+ Expr::Cast(Cast::new(Box::new(col("a")),
DataType::Float64)),
Expr::Literal(ScalarValue::Int64(Some(2)), None),
],
});
@@ -3194,10 +3162,12 @@ mod tests {
let unparser = Unparser::new(&dialect);
- let ast_dtype =
unparser.arrow_dtype_to_ast_dtype(&DataType::Dictionary(
- Box::new(DataType::Int32),
- Box::new(DataType::Utf8),
- ))?;
+ let arrow_field = Arc::new(Field::new(
+ "",
+ DataType::Dictionary(Box::new(DataType::Int32),
Box::new(DataType::Utf8)),
+ true,
+ ));
+ let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&arrow_field)?;
assert_eq!(ast_dtype, ast::DataType::Varchar(None));
@@ -3210,10 +3180,13 @@ mod tests {
let unparser = Unparser::new(&dialect);
- let ast_dtype =
unparser.arrow_dtype_to_ast_dtype(&DataType::RunEndEncoded(
- Field::new("run_ends", DataType::Int32, false).into(),
- Field::new("values", DataType::Utf8, true).into(),
- ))?;
+ let ast_dtype = unparser.arrow_dtype_to_ast_dtype(
+ &DataType::RunEndEncoded(
+ Field::new("run_ends", DataType::Int32, false).into(),
+ Field::new("values", DataType::Utf8, true).into(),
+ )
+ .into_nullable_field_ref(),
+ )?;
assert_eq!(ast_dtype, ast::DataType::Varchar(None));
@@ -3227,7 +3200,8 @@ mod tests {
.build();
let unparser = Unparser::new(&dialect);
- let ast_dtype =
unparser.arrow_dtype_to_ast_dtype(&DataType::Utf8View)?;
+ let arrow_field = Arc::new(Field::new("", DataType::Utf8View, true));
+ let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&arrow_field)?;
assert_eq!(ast_dtype, ast::DataType::Char(None));
@@ -3295,10 +3269,10 @@ mod tests {
let dialect: Arc<dyn Dialect> = Arc::new(SqliteDialect {});
let unparser = Unparser::new(dialect.as_ref());
- let expr = Expr::Cast(Cast {
- expr: Box::new(col("a")),
- data_type: DataType::Timestamp(TimeUnit::Nanosecond, None),
- });
+ let expr = Expr::Cast(Cast::new(
+ Box::new(col("a")),
+ DataType::Timestamp(TimeUnit::Nanosecond, None),
+ ));
let ast = unparser.expr_to_sql(&expr)?;
diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs
b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs
index ec70ac3fec..3dd62afe8f 100644
--- a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs
+++ b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs
@@ -15,8 +15,9 @@
// specific language governing permissions and limitations
// under the License.
-use crate::logical_plan::consumer::SubstraitConsumer;
-use crate::logical_plan::consumer::types::from_substrait_type_without_names;
+use crate::logical_plan::consumer::{
+ SubstraitConsumer, field_from_substrait_type_without_names,
+};
use datafusion::common::{DFSchema, substrait_err};
use datafusion::logical_expr::{Cast, Expr, TryCast};
use substrait::proto::expression as substrait_expression;
@@ -37,11 +38,11 @@ pub async fn from_cast(
)
.await?,
);
- let data_type = from_substrait_type_without_names(consumer,
output_type)?;
+ let field = field_from_substrait_type_without_names(consumer,
output_type)?;
if cast.failure_behavior() == ReturnNull {
- Ok(Expr::TryCast(TryCast::new(input_expr, data_type)))
+ Ok(Expr::TryCast(TryCast::new_from_field(input_expr, field)))
} else {
- Ok(Expr::Cast(Cast::new(input_expr, data_type)))
+ Ok(Expr::Cast(Cast::new_from_field(input_expr, field)))
}
}
None => substrait_err!("Cast expression without output type is not
allowed"),
diff --git a/datafusion/substrait/src/logical_plan/consumer/types.rs
b/datafusion/substrait/src/logical_plan/consumer/types.rs
index 9ef7a0dd46..2493ac1e5a 100644
--- a/datafusion/substrait/src/logical_plan/consumer/types.rs
+++ b/datafusion/substrait/src/logical_plan/consumer/types.rs
@@ -34,14 +34,22 @@ use crate::variation_const::{
};
use crate::variation_const::{FLOAT_16_TYPE_NAME, NULL_TYPE_NAME};
use datafusion::arrow::datatypes::{
- DataType, Field, Fields, IntervalUnit, Schema, TimeUnit,
+ DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
};
+use datafusion::common::datatype::DataTypeExt;
use datafusion::common::{
DFSchema, not_impl_err, substrait_datafusion_err, substrait_err,
};
use std::sync::Arc;
use substrait::proto::{NamedStruct, Type, r#type};
+pub(crate) fn field_from_substrait_type_without_names(
+ consumer: &impl SubstraitConsumer,
+ dt: &Type,
+) -> datafusion::common::Result<FieldRef> {
+ Ok(from_substrait_type_without_names(consumer,
dt)?.into_nullable_field_ref())
+}
+
pub(crate) fn from_substrait_type_without_names(
consumer: &impl SubstraitConsumer,
dt: &Type,
@@ -49,6 +57,16 @@ pub(crate) fn from_substrait_type_without_names(
from_substrait_type(consumer, dt, &[], &mut 0)
}
+pub fn field_from_substrait_type(
+ consumer: &impl SubstraitConsumer,
+ dt: &Type,
+ dfs_names: &[String],
+ name_idx: &mut usize,
+) -> datafusion::common::Result<FieldRef> {
+ // We could add nullability here now that we are returning a Field
+ Ok(from_substrait_type(consumer, dt, dfs_names,
name_idx)?.into_nullable_field_ref())
+}
+
pub fn from_substrait_type(
consumer: &impl SubstraitConsumer,
dt: &Type,
diff --git a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs
b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs
index 6eb27fc39d..2a5a6fe5c3 100644
--- a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs
+++ b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use crate::logical_plan::producer::{SubstraitProducer, to_substrait_type};
+use crate::logical_plan::producer::{SubstraitProducer,
to_substrait_type_from_field};
use crate::variation_const::DEFAULT_TYPE_VARIATION_REF;
use datafusion::common::{DFSchemaRef, ScalarValue};
use datafusion::logical_expr::{Cast, Expr, TryCast};
@@ -29,7 +29,7 @@ pub fn from_cast(
cast: &Cast,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
- let Cast { expr, data_type } = cast;
+ let Cast { expr, field } = cast;
// since substrait Null must be typed, so if we see a cast(null, dt), we
make it a typed null
if let Expr::Literal(lit, _) = expr.as_ref() {
// only the untyped(a null scalar value) null literal need this
special handling
@@ -39,8 +39,8 @@ pub fn from_cast(
let lit = Literal {
nullable: true,
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
- literal_type: Some(LiteralType::Null(to_substrait_type(
- producer, data_type, true,
+ literal_type:
Some(LiteralType::Null(to_substrait_type_from_field(
+ producer, field,
)?)),
};
return Ok(Expression {
@@ -51,7 +51,7 @@ pub fn from_cast(
Ok(Expression {
rex_type: Some(RexType::Cast(Box::new(
substrait::proto::expression::Cast {
- r#type: Some(to_substrait_type(producer, data_type, true)?),
+ r#type: Some(to_substrait_type_from_field(producer, field)?),
input: Some(Box::new(producer.handle_expr(expr, schema)?)),
failure_behavior: FailureBehavior::ThrowException.into(),
},
@@ -64,11 +64,11 @@ pub fn from_try_cast(
cast: &TryCast,
schema: &DFSchemaRef,
) -> datafusion::common::Result<Expression> {
- let TryCast { expr, data_type } = cast;
+ let TryCast { expr, field } = cast;
Ok(Expression {
rex_type: Some(RexType::Cast(Box::new(
substrait::proto::expression::Cast {
- r#type: Some(to_substrait_type(producer, data_type, true)?),
+ r#type: Some(to_substrait_type_from_field(producer, field)?),
input: Some(Box::new(producer.handle_expr(expr, schema)?)),
failure_behavior: FailureBehavior::ReturnNull.into(),
},
@@ -80,7 +80,7 @@ pub fn from_try_cast(
mod tests {
use super::*;
use crate::logical_plan::producer::{
- DefaultSubstraitProducer, to_substrait_extended_expr,
+ DefaultSubstraitProducer, to_substrait_extended_expr,
to_substrait_type,
};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::DFSchema;
diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs
b/datafusion/substrait/src/logical_plan/producer/types.rs
index 3727596119..fa58949e6e 100644
--- a/datafusion/substrait/src/logical_plan/producer/types.rs
+++ b/datafusion/substrait/src/logical_plan/producer/types.rs
@@ -27,7 +27,7 @@ use crate::variation_const::{
TIME_32_TYPE_VARIATION_REF, TIME_64_TYPE_VARIATION_REF,
UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF,
};
-use datafusion::arrow::datatypes::{DataType, IntervalUnit};
+use datafusion::arrow::datatypes::{DataType, Field, FieldRef, IntervalUnit};
use datafusion::common::{DFSchemaRef, not_impl_err, plan_err};
use substrait::proto::{NamedStruct, r#type};
@@ -36,12 +36,19 @@ pub(crate) fn to_substrait_type(
dt: &DataType,
nullable: bool,
) -> datafusion::common::Result<substrait::proto::Type> {
- let nullability = if nullable {
+ to_substrait_type_from_field(producer, &Field::new("", dt.clone(),
nullable).into())
+}
+
+pub(crate) fn to_substrait_type_from_field(
+ producer: &mut impl SubstraitProducer,
+ field: &FieldRef,
+) -> datafusion::common::Result<substrait::proto::Type> {
+ let nullability = if field.is_nullable() {
r#type::Nullability::Nullable as i32
} else {
r#type::Nullability::Required as i32
};
- match dt {
+ match field.data_type() {
DataType::Null => {
let type_anchor =
producer.register_type(NULL_TYPE_NAME.to_string());
Ok(substrait::proto::Type {
@@ -288,16 +295,9 @@ pub(crate) fn to_substrait_type(
}
DataType::Map(inner, _) => match inner.data_type() {
DataType::Struct(key_and_value) if key_and_value.len() == 2 => {
- let key_type = to_substrait_type(
- producer,
- key_and_value[0].data_type(),
- key_and_value[0].is_nullable(),
- )?;
- let value_type = to_substrait_type(
- producer,
- key_and_value[1].data_type(),
- key_and_value[1].is_nullable(),
- )?;
+ let key_type = to_substrait_type_from_field(producer,
&key_and_value[0])?;
+ let value_type =
+ to_substrait_type_from_field(producer, &key_and_value[1])?;
Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Map(Box::new(r#type::Map {
key: Some(Box::new(key_type)),
@@ -310,8 +310,14 @@ pub(crate) fn to_substrait_type(
_ => plan_err!("Map fields must contain a Struct with exactly 2
fields"),
},
DataType::Dictionary(key_type, value_type) => {
- let key_type = to_substrait_type(producer, key_type, nullable)?;
- let value_type = to_substrait_type(producer, value_type,
nullable)?;
+ let key_type = to_substrait_type_from_field(
+ producer,
+ &Field::new("", key_type.as_ref().clone(),
field.is_nullable()).into(),
+ )?;
+ let value_type = to_substrait_type_from_field(
+ producer,
+ &Field::new("", value_type.as_ref().clone(),
field.is_nullable()).into(),
+ )?;
Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Map(Box::new(r#type::Map {
key: Some(Box::new(key_type)),
@@ -324,9 +330,7 @@ pub(crate) fn to_substrait_type(
DataType::Struct(fields) => {
let field_types = fields
.iter()
- .map(|field| {
- to_substrait_type(producer, field.data_type(),
field.is_nullable())
- })
+ .map(|field| to_substrait_type_from_field(producer, field))
.collect::<datafusion::common::Result<Vec<_>>>()?;
Ok(substrait::proto::Type {
kind: Some(r#type::Kind::Struct(r#type::Struct {
@@ -352,7 +356,7 @@ pub(crate) fn to_substrait_type(
precision: *p as i32,
})),
}),
- _ => not_impl_err!("Unsupported cast type: {dt}"),
+ _ => not_impl_err!("Unsupported cast type: {field}"),
}
}
@@ -369,7 +373,7 @@ pub(crate) fn to_substrait_named_struct(
types: schema
.fields()
.iter()
- .map(|f| to_substrait_type(producer, f.data_type(),
f.is_nullable()))
+ .map(|f| to_substrait_type_from_field(producer, f))
.collect::<datafusion::common::Result<_>>()?,
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
nullability: r#type::Nullability::Required as i32,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]