This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 31094b00e2 fix bug with `to_timestamp` and `InitCap` logical
serialization, add roundtrip test between expression and proto, (#8868)
31094b00e2 is described below
commit 31094b00e2e5f764a89a2e9806e98acf0576729f
Author: Alex Huang <[email protected]>
AuthorDate: Thu Jan 18 00:04:59 2024 +0800
fix bug with `to_timestamp` and `InitCap` logical serialization, add
roundtrip test between expression and proto, (#8868)
* add roundtrip test between expression and proto
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/proto/Cargo.toml | 1 +
datafusion/proto/src/logical_plan/from_proto.rs | 17 +++++++++---
datafusion/proto/tests/cases/serialize.rs | 37 +++++++++++++++++++++++++
3 files changed, 51 insertions(+), 4 deletions(-)
diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml
index f9f24b28db..e423220216 100644
--- a/datafusion/proto/Cargo.toml
+++ b/datafusion/proto/Cargo.toml
@@ -54,4 +54,5 @@ serde_json = { workspace = true, optional = true }
[dev-dependencies]
doc-comment = { workspace = true }
+strum = { version = "0.25.0", features = ["derive"] }
tokio = "1.18"
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 9185bdb804..973e366d0b 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -58,8 +58,8 @@ use datafusion_expr::{
current_date, current_time, date_bin, date_part, date_trunc, decode,
degrees, digest,
encode, exp,
expr::{self, InList, Sort, WindowFunction},
- factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range,
isnan, iszero,
- lcm, left, levenshtein, ln, log, log10, log2,
+ factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range,
initcap,
+ isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2,
logical_plan::{PlanType, StringifiedPlan},
lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi,
power,
radians, random, regexp_match, regexp_replace, repeat, replace, reverse,
right,
@@ -1585,7 +1585,7 @@ pub fn parse_expr(
Ok(character_length(parse_expr(&args[0], registry)?))
}
ScalarFunction::Chr => Ok(chr(parse_expr(&args[0],
registry)?)),
- ScalarFunction::InitCap => Ok(ascii(parse_expr(&args[0],
registry)?)),
+ ScalarFunction::InitCap => Ok(initcap(parse_expr(&args[0],
registry)?)),
ScalarFunction::Gcd => Ok(gcd(
parse_expr(&args[0], registry)?,
parse_expr(&args[1], registry)?,
@@ -1742,7 +1742,16 @@ pub fn parse_expr(
Ok(arrow_typeof(parse_expr(&args[0], registry)?))
}
ScalarFunction::ToTimestamp => {
- Ok(to_timestamp_seconds(parse_expr(&args[0], registry)?))
+ let args: Vec<_> = args
+ .iter()
+ .map(|expr| parse_expr(expr, registry))
+ .collect::<Result<_, _>>()?;
+ Ok(Expr::ScalarFunction(
+ datafusion_expr::expr::ScalarFunction::new(
+ BuiltinScalarFunction::ToTimestamp,
+ args,
+ ),
+ ))
}
ScalarFunction::Flatten => Ok(flatten(parse_expr(&args[0],
registry)?)),
ScalarFunction::StringToArray => Ok(string_to_array(
diff --git a/datafusion/proto/tests/cases/serialize.rs
b/datafusion/proto/tests/cases/serialize.rs
index 5b890accd8..222d1a3a62 100644
--- a/datafusion/proto/tests/cases/serialize.rs
+++ b/datafusion/proto/tests/cases/serialize.rs
@@ -243,3 +243,40 @@ fn context_with_udf() -> SessionContext {
ctx
}
+
+#[test]
+fn test_expression_serialization_roundtrip() {
+ use datafusion_common::ScalarValue;
+ use datafusion_expr::expr::ScalarFunction;
+ use datafusion_expr::BuiltinScalarFunction;
+ use datafusion_proto::logical_plan::from_proto::parse_expr;
+ use datafusion_proto::protobuf::LogicalExprNode;
+ use strum::IntoEnumIterator;
+
+ let ctx = SessionContext::new();
+ let lit = Expr::Literal(ScalarValue::Utf8(None));
+ for builtin_fun in BuiltinScalarFunction::iter() {
+ // default to 4 args (though some exprs like substr have error
checking)
+ let num_args = match builtin_fun {
+ BuiltinScalarFunction::Substr => 3,
+ _ => 4,
+ };
+ let args: Vec<_> =
std::iter::repeat(&lit).take(num_args).cloned().collect();
+ let expr = Expr::ScalarFunction(ScalarFunction::new(builtin_fun,
args));
+
+ let proto = LogicalExprNode::try_from(&expr).unwrap();
+ let deserialize = parse_expr(&proto, &ctx).unwrap();
+
+ let serialize_name = extract_function_name(&expr);
+ let deserialize_name = extract_function_name(&deserialize);
+
+ assert_eq!(serialize_name, deserialize_name);
+ }
+
+ /// Extracts the first part of a function name
+ /// 'foo(bar)' -> 'foo'
+ fn extract_function_name(expr: &Expr) -> String {
+ let name = expr.display_name().unwrap();
+ name.split('(').next().unwrap().to_string()
+ }
+}