This is an automated email from the ASF dual-hosted git repository.

liurenjie1024 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-rust.git


The following commit(s) were added to refs/heads/main by this push:
     new def6114  feat: make BoundPredicate,Datum serializable (#406)
def6114 is described below

commit def6114b910f842946a05febe5fe310dd36a5991
Author: ZENOTME <[email protected]>
AuthorDate: Wed Jun 19 09:37:50 2024 +0800

    feat: make BoundPredicate,Datum serializable (#406)
    
    * make BoundPredicate,Datum serializable
    
    * refine error
    
    * fix float check
    
    * use value instead of string to avoid precision loss
    
    ---------
    
    Co-authored-by: ZENOTME <[email protected]>
---
 crates/iceberg/src/expr/mod.rs       |   3 +-
 crates/iceberg/src/expr/predicate.rs |  73 ++++++++-
 crates/iceberg/src/expr/term.rs      |   5 +-
 crates/iceberg/src/spec/values.rs    | 290 +++++++++++++++++++++++++++++++++++
 4 files changed, 362 insertions(+), 9 deletions(-)

diff --git a/crates/iceberg/src/expr/mod.rs b/crates/iceberg/src/expr/mod.rs
index 3d77c4d..16f75b0 100644
--- a/crates/iceberg/src/expr/mod.rs
+++ b/crates/iceberg/src/expr/mod.rs
@@ -18,6 +18,7 @@
 //! This module contains expressions.
 
 mod term;
+use serde::{Deserialize, Serialize};
 pub use term::*;
 pub(crate) mod accessor;
 mod predicate;
@@ -32,7 +33,7 @@ use std::fmt::{Display, Formatter};
 /// The discriminant of this enum is used for determining the type of the 
operator, see
 /// [`PredicateOperator::is_unary`], [`PredicateOperator::is_binary`], 
[`PredicateOperator::is_set`]
 #[allow(missing_docs)]
-#[derive(Debug, Clone, Copy, PartialEq)]
+#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
 #[non_exhaustive]
 #[repr(u16)]
 pub enum PredicateOperator {
diff --git a/crates/iceberg/src/expr/predicate.rs 
b/crates/iceberg/src/expr/predicate.rs
index 158ab13..3a91d6b 100644
--- a/crates/iceberg/src/expr/predicate.rs
+++ b/crates/iceberg/src/expr/predicate.rs
@@ -25,6 +25,7 @@ use std::ops::Not;
 use array_init::array_init;
 use fnv::FnvHashSet;
 use itertools::Itertools;
+use serde::{Deserialize, Serialize};
 
 use crate::error::Result;
 use crate::expr::{Bind, BoundReference, PredicateOperator, Reference};
@@ -37,6 +38,29 @@ pub struct LogicalExpression<T, const N: usize> {
     inputs: [Box<T>; N],
 }
 
+impl<T: Serialize, const N: usize> Serialize for LogicalExpression<T, N> {
+    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
+    where
+        S: serde::Serializer,
+    {
+        self.inputs.serialize(serializer)
+    }
+}
+
+impl<'de, T: Deserialize<'de>, const N: usize> Deserialize<'de> for 
LogicalExpression<T, N> {
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        let inputs = Vec::<Box<T>>::deserialize(deserializer)?;
+        Ok(LogicalExpression::new(
+            array_init::from_iter(inputs.into_iter()).ok_or_else(|| {
+                serde::de::Error::custom(format!("Failed to deserialize 
LogicalExpression: the len of inputs is not match with the len of 
LogicalExpression {}",N))
+            })?,
+        ))
+    }
+}
+
 impl<T: Debug, const N: usize> Debug for LogicalExpression<T, N> {
     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
         f.debug_struct("LogicalExpression")
@@ -79,11 +103,12 @@ where
 }
 
 /// Unary predicate, for example, `a IS NULL`.
-#[derive(PartialEq, Clone)]
+#[derive(PartialEq, Clone, Serialize, Deserialize)]
 pub struct UnaryExpression<T> {
     /// Operator of this predicate, must be single operand operator.
     op: PredicateOperator,
     /// Term of this predicate, for example, `a` in `a IS NULL`.
+    #[serde(bound(serialize = "T: Serialize", deserialize = "T: 
Deserialize<'de>"))]
     term: T,
 }
 
@@ -129,11 +154,12 @@ impl<T> UnaryExpression<T> {
 }
 
 /// Binary predicate, for example, `a > 10`.
-#[derive(PartialEq, Clone)]
+#[derive(PartialEq, Clone, Serialize, Deserialize)]
 pub struct BinaryExpression<T> {
     /// Operator of this predicate, must be binary operator, such as `=`, `>`, 
`<`, etc.
     op: PredicateOperator,
     /// Term of this predicate, for example, `a` in `a > 10`.
+    #[serde(bound(serialize = "T: Serialize", deserialize = "T: 
Deserialize<'de>"))]
     term: T,
     /// Literal of this predicate, for example, `10` in `a > 10`.
     literal: Datum,
@@ -190,7 +216,7 @@ impl<T: Bind> Bind for BinaryExpression<T> {
 }
 
 /// Set predicates, for example, `a in (1, 2, 3)`.
-#[derive(PartialEq, Clone)]
+#[derive(PartialEq, Clone, Serialize, Deserialize)]
 pub struct SetExpression<T> {
     /// Operator of this predicate, must be set operator, such as `IN`, `NOT 
IN`, etc.
     op: PredicateOperator,
@@ -253,7 +279,7 @@ impl<T: Display + Debug> Display for SetExpression<T> {
 }
 
 /// Unbound predicate expression before binding to a schema.
-#[derive(Debug, PartialEq)]
+#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
 pub enum Predicate {
     /// AlwaysTrue predicate, for example, `TRUE`.
     AlwaysTrue,
@@ -622,7 +648,7 @@ impl Not for Predicate {
 }
 
 /// Bound predicate expression after binding to a schema.
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
 pub enum BoundPredicate {
     /// An expression always evaluates to true.
     AlwaysTrue,
@@ -678,9 +704,9 @@ mod tests {
     use std::ops::Not;
     use std::sync::Arc;
 
-    use crate::expr::Bind;
     use crate::expr::Predicate::{AlwaysFalse, AlwaysTrue};
     use crate::expr::Reference;
+    use crate::expr::{Bind, BoundPredicate};
     use crate::spec::Datum;
     use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type};
 
@@ -879,12 +905,19 @@ mod tests {
         )
     }
 
+    fn test_bound_predicate_serialize_diserialize(bound_predicate: 
BoundPredicate) {
+        let serialized = serde_json::to_string(&bound_predicate).unwrap();
+        let deserialized: BoundPredicate = 
serde_json::from_str(&serialized).unwrap();
+        assert_eq!(bound_predicate, deserialized);
+    }
+
     #[test]
     fn test_bind_is_null() {
         let schema = table_schema_simple();
         let expr = Reference::new("foo").is_null();
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "foo IS NULL");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -893,6 +926,7 @@ mod tests {
         let expr = Reference::new("bar").is_null();
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "False");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -901,6 +935,7 @@ mod tests {
         let expr = Reference::new("foo").is_not_null();
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "foo IS NOT NULL");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -909,6 +944,7 @@ mod tests {
         let expr = Reference::new("bar").is_not_null();
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "True");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -922,6 +958,7 @@ mod tests {
         let expr_string = Reference::new("foo").is_nan();
         let bound_expr_string = expr_string.bind(schema_string, true);
         assert!(bound_expr_string.is_err());
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -938,6 +975,7 @@ mod tests {
         let expr = Reference::new("qux").is_not_nan();
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "qux IS NOT NAN");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -954,6 +992,7 @@ mod tests {
         let expr = Reference::new("bar").less_than(Datum::int(10));
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "bar < 10");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -970,6 +1009,7 @@ mod tests {
         let expr = Reference::new("bar").less_than_or_equal_to(Datum::int(10));
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "bar <= 10");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -986,6 +1026,7 @@ mod tests {
         let expr = Reference::new("bar").greater_than(Datum::int(10));
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "bar > 10");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1002,6 +1043,7 @@ mod tests {
         let expr = 
Reference::new("bar").greater_than_or_equal_to(Datum::int(10));
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "bar >= 10");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1018,6 +1060,7 @@ mod tests {
         let expr = Reference::new("bar").equal_to(Datum::int(10));
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "bar = 10");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1034,6 +1077,7 @@ mod tests {
         let expr = Reference::new("bar").not_equal_to(Datum::int(10));
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "bar != 10");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1050,6 +1094,7 @@ mod tests {
         let expr = Reference::new("foo").starts_with(Datum::string("abcd"));
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), r#"foo STARTS WITH "abcd""#);
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1066,6 +1111,7 @@ mod tests {
         let expr = 
Reference::new("foo").not_starts_with(Datum::string("abcd"));
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), r#"foo NOT STARTS WITH "abcd""#);
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1082,6 +1128,7 @@ mod tests {
         let expr = Reference::new("bar").is_in([Datum::int(10), 
Datum::int(20)]);
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "bar IN (20, 10)");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1090,6 +1137,7 @@ mod tests {
         let expr = Reference::new("bar").is_in(vec![]);
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "False");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1098,6 +1146,7 @@ mod tests {
         let expr = Reference::new("bar").is_in(vec![Datum::int(10)]);
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "bar = 10");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1114,6 +1163,7 @@ mod tests {
         let expr = Reference::new("bar").is_not_in([Datum::int(10), 
Datum::int(20)]);
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "bar NOT IN (20, 10)");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1122,6 +1172,7 @@ mod tests {
         let expr = Reference::new("bar").is_not_in(vec![]);
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "True");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1130,6 +1181,7 @@ mod tests {
         let expr = Reference::new("bar").is_not_in(vec![Datum::int(10)]);
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "bar != 10");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1148,6 +1200,7 @@ mod tests {
             .and(Reference::new("foo").is_null());
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "(bar < 10) AND (foo IS NULL)");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1158,6 +1211,7 @@ mod tests {
             .and(Reference::new("bar").is_null());
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "False");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1168,6 +1222,7 @@ mod tests {
             .and(Reference::new("bar").is_not_null());
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), r#"foo < "abcd""#);
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1178,6 +1233,7 @@ mod tests {
             .or(Reference::new("foo").is_null());
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "(bar < 10) OR (foo IS NULL)");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1188,6 +1244,7 @@ mod tests {
             .or(Reference::new("bar").is_not_null());
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "True");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1198,6 +1255,7 @@ mod tests {
             .or(Reference::new("bar").is_null());
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), r#"foo < "abcd""#);
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1206,6 +1264,7 @@ mod tests {
         let expr = !Reference::new("bar").less_than(Datum::int(10));
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "NOT (bar < 10)");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1214,6 +1273,7 @@ mod tests {
         let expr = !Reference::new("bar").is_not_null();
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), "False");
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 
     #[test]
@@ -1222,5 +1282,6 @@ mod tests {
         let expr = !Reference::new("bar").is_null();
         let bound_expr = expr.bind(schema, true).unwrap();
         assert_eq!(&format!("{bound_expr}"), r#"True"#);
+        test_bound_predicate_serialize_diserialize(bound_expr);
     }
 }
diff --git a/crates/iceberg/src/expr/term.rs b/crates/iceberg/src/expr/term.rs
index 1fbf86c..909aa62 100644
--- a/crates/iceberg/src/expr/term.rs
+++ b/crates/iceberg/src/expr/term.rs
@@ -20,6 +20,7 @@
 use std::fmt::{Display, Formatter};
 
 use fnv::FnvHashSet;
+use serde::{Deserialize, Serialize};
 
 use crate::expr::accessor::{StructAccessor, StructAccessorRef};
 use crate::expr::Bind;
@@ -32,7 +33,7 @@ pub type Term = Reference;
 
 /// A named reference in an unbound expression.
 /// For example, `a` in `a > 10`.
-#[derive(Debug, Clone, PartialEq)]
+#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
 pub struct Reference {
     name: String,
 }
@@ -351,7 +352,7 @@ impl Bind for Reference {
 }
 
 /// A named reference in a bound expression after binding to a schema.
-#[derive(Debug, Clone, PartialEq, Eq)]
+#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 pub struct BoundReference {
     // This maybe different from [`name`] filed in [`NestedField`] since this 
contains full path.
     // For example, if the field is `a.b.c`, then `field.name` is `c`, but 
`original_name` is `a.b.c`.
diff --git a/crates/iceberg/src/spec/values.rs 
b/crates/iceberg/src/spec/values.rs
index 567a847..a905903 100644
--- a/crates/iceberg/src/spec/values.rs
+++ b/crates/iceberg/src/spec/values.rs
@@ -30,6 +30,9 @@ use bitvec::vec::BitVec;
 use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc};
 use ordered_float::OrderedFloat;
 use rust_decimal::Decimal;
+use serde::de::{self, MapAccess};
+use serde::ser::SerializeStruct;
+use serde::{Deserialize, Serialize};
 use serde_bytes::ByteBuf;
 use serde_json::{Map as JsonMap, Number, Value as JsonValue};
 use uuid::Uuid;
@@ -105,6 +108,115 @@ pub struct Datum {
     literal: PrimitiveLiteral,
 }
 
+impl Serialize for Datum {
+    fn serialize<S: serde::Serializer>(
+        &self,
+        serializer: S,
+    ) -> std::result::Result<S::Ok, S::Error> {
+        let mut struct_ser = serializer
+            .serialize_struct("Datum", 2)
+            .map_err(serde::ser::Error::custom)?;
+        struct_ser
+            .serialize_field("type", &self.r#type)
+            .map_err(serde::ser::Error::custom)?;
+        struct_ser
+            .serialize_field(
+                "literal",
+                &RawLiteral::try_from(
+                    Literal::Primitive(self.literal.clone()),
+                    &Type::Primitive(self.r#type.clone()),
+                )
+                .map_err(serde::ser::Error::custom)?,
+            )
+            .map_err(serde::ser::Error::custom)?;
+        struct_ser.end()
+    }
+}
+
+impl<'de> Deserialize<'de> for Datum {
+    fn deserialize<D: serde::Deserializer<'de>>(
+        deserializer: D,
+    ) -> std::result::Result<Self, D::Error> {
+        #[derive(Deserialize)]
+        #[serde(field_identifier, rename_all = "lowercase")]
+        enum Field {
+            Type,
+            Literal,
+        }
+
+        struct DatumVisitor;
+
+        impl<'de> serde::de::Visitor<'de> for DatumVisitor {
+            type Value = Datum;
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> 
std::fmt::Result {
+                formatter.write_str("struct Datum")
+            }
+
+            fn visit_seq<A>(self, mut seq: A) -> 
std::result::Result<Self::Value, A::Error>
+            where
+                A: serde::de::SeqAccess<'de>,
+            {
+                let r#type = seq
+                    .next_element::<PrimitiveType>()?
+                    .ok_or_else(|| serde::de::Error::invalid_length(0, 
&self))?;
+                let value = seq
+                    .next_element::<RawLiteral>()?
+                    .ok_or_else(|| serde::de::Error::invalid_length(1, 
&self))?;
+                let Literal::Primitive(primitive) = value
+                    .try_into(&Type::Primitive(r#type.clone()))
+                    .map_err(serde::de::Error::custom)?
+                    .ok_or_else(|| serde::de::Error::custom("None value"))?
+                else {
+                    return Err(serde::de::Error::custom("Invalid value"));
+                };
+
+                Ok(Datum::new(r#type, primitive))
+            }
+
+            fn visit_map<V>(self, mut map: V) -> std::result::Result<Datum, 
V::Error>
+            where
+                V: MapAccess<'de>,
+            {
+                let mut raw_primitive: Option<RawLiteral> = None;
+                let mut r#type: Option<PrimitiveType> = None;
+                while let Some(key) = map.next_key()? {
+                    match key {
+                        Field::Type => {
+                            if r#type.is_some() {
+                                return Err(de::Error::duplicate_field("type"));
+                            }
+                            r#type = Some(map.next_value()?);
+                        }
+                        Field::Literal => {
+                            if raw_primitive.is_some() {
+                                return 
Err(de::Error::duplicate_field("literal"));
+                            }
+                            raw_primitive = Some(map.next_value()?);
+                        }
+                    }
+                }
+                let Some(r#type) = r#type else {
+                    return Err(serde::de::Error::missing_field("type"));
+                };
+                let Some(raw_primitive) = raw_primitive else {
+                    return Err(serde::de::Error::missing_field("literal"));
+                };
+                let Literal::Primitive(primitive) = raw_primitive
+                    .try_into(&Type::Primitive(r#type.clone()))
+                    .map_err(serde::de::Error::custom)?
+                    .ok_or_else(|| serde::de::Error::custom("None value"))?
+                else {
+                    return Err(serde::de::Error::custom("Invalid value"));
+                };
+                Ok(Datum::new(r#type, primitive))
+            }
+        }
+        const FIELDS: &[&str] = &["type", "literal"];
+        deserializer.deserialize_struct("Datum", FIELDS, DatumVisitor)
+    }
+}
+
 impl PartialOrd for Datum {
     fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
         match (&self.literal, &other.literal, &self.r#type, &other.r#type) {
@@ -2320,10 +2432,17 @@ mod _serde {
                 RawLiteralEnum::Boolean(v) => Ok(Some(Literal::bool(v))),
                 RawLiteralEnum::Int(v) => match ty {
                     Type::Primitive(PrimitiveType::Int) => 
Ok(Some(Literal::int(v))),
+                    Type::Primitive(PrimitiveType::Long) => 
Ok(Some(Literal::long(i64::from(v)))),
                     Type::Primitive(PrimitiveType::Date) => 
Ok(Some(Literal::date(v))),
                     _ => Err(invalid_err("int")),
                 },
                 RawLiteralEnum::Long(v) => match ty {
+                    Type::Primitive(PrimitiveType::Int) => 
Ok(Some(Literal::int(
+                        i32::try_from(v).map_err(|_| invalid_err("long"))?,
+                    ))),
+                    Type::Primitive(PrimitiveType::Date) => 
Ok(Some(Literal::date(
+                        i32::try_from(v).map_err(|_| invalid_err("long"))?,
+                    ))),
                     Type::Primitive(PrimitiveType::Long) => 
Ok(Some(Literal::long(v))),
                     Type::Primitive(PrimitiveType::Time) => 
Ok(Some(Literal::time(v))),
                     Type::Primitive(PrimitiveType::Timestamp) => 
Ok(Some(Literal::timestamp(v))),
@@ -2334,9 +2453,23 @@ mod _serde {
                 },
                 RawLiteralEnum::Float(v) => match ty {
                     Type::Primitive(PrimitiveType::Float) => 
Ok(Some(Literal::float(v))),
+                    Type::Primitive(PrimitiveType::Double) => {
+                        Ok(Some(Literal::double(f64::from(v))))
+                    }
                     _ => Err(invalid_err("float")),
                 },
                 RawLiteralEnum::Double(v) => match ty {
+                    Type::Primitive(PrimitiveType::Float) => {
+                        let v_32 = v as f32;
+                        if v_32.is_finite() {
+                            let v_64 = f64::from(v_32);
+                            if (v_64 - v).abs() > f32::EPSILON as f64 {
+                                // there is a precision loss
+                                return Err(invalid_err("double"));
+                            }
+                        }
+                        Ok(Some(Literal::float(v_32)))
+                    }
                     Type::Primitive(PrimitiveType::Double) => 
Ok(Some(Literal::double(v))),
                     _ => Err(invalid_err("double")),
                 },
@@ -2418,6 +2551,89 @@ mod _serde {
                             }
                             Ok(Some(Literal::Map(map)))
                         }
+                        Type::Primitive(PrimitiveType::Uuid) => {
+                            if v.list.len() != 16 {
+                                return Err(invalid_err_with_reason(
+                                    "list",
+                                    "The length of list should be 16",
+                                ));
+                            }
+                            let mut bytes = [0u8; 16];
+                            for (i, v) in v.list.iter().enumerate() {
+                                if let Some(RawLiteralEnum::Long(v)) = v {
+                                    bytes[i] = *v as u8;
+                                } else {
+                                    return Err(invalid_err_with_reason(
+                                        "list",
+                                        "The element of list should be int",
+                                    ));
+                                }
+                            }
+                            
Ok(Some(Literal::uuid(uuid::Uuid::from_bytes(bytes))))
+                        }
+                        Type::Primitive(PrimitiveType::Decimal {
+                            precision: _,
+                            scale: _,
+                        }) => {
+                            if v.list.len() != 16 {
+                                return Err(invalid_err_with_reason(
+                                    "list",
+                                    "The length of list should be 16",
+                                ));
+                            }
+                            let mut bytes = [0u8; 16];
+                            for (i, v) in v.list.iter().enumerate() {
+                                if let Some(RawLiteralEnum::Long(v)) = v {
+                                    bytes[i] = *v as u8;
+                                } else {
+                                    return Err(invalid_err_with_reason(
+                                        "list",
+                                        "The element of list should be int",
+                                    ));
+                                }
+                            }
+                            
Ok(Some(Literal::decimal(i128::from_be_bytes(bytes))))
+                        }
+                        Type::Primitive(PrimitiveType::Binary) => {
+                            let bytes = v
+                                .list
+                                .into_iter()
+                                .map(|v| {
+                                    if let Some(RawLiteralEnum::Long(v)) = v {
+                                        Ok(v as u8)
+                                    } else {
+                                        Err(invalid_err_with_reason(
+                                            "list",
+                                            "The element of list should be 
int",
+                                        ))
+                                    }
+                                })
+                                .collect::<Result<Vec<_>, Error>>()?;
+                            Ok(Some(Literal::binary(bytes)))
+                        }
+                        Type::Primitive(PrimitiveType::Fixed(size)) => {
+                            if v.list.len() != *size as usize {
+                                return Err(invalid_err_with_reason(
+                                    "list",
+                                    "The length of list should be equal to 
size",
+                                ));
+                            }
+                            let bytes = v
+                                .list
+                                .into_iter()
+                                .map(|v| {
+                                    if let Some(RawLiteralEnum::Long(v)) = v {
+                                        Ok(v as u8)
+                                    } else {
+                                        Err(invalid_err_with_reason(
+                                            "list",
+                                            "The element of list should be 
int",
+                                        ))
+                                    }
+                                })
+                                .collect::<Result<Vec<_>, Error>>()?;
+                            Ok(Some(Literal::fixed(bytes)))
+                        }
                         _ => Err(invalid_err("list")),
                     }
                 }
@@ -3180,4 +3396,78 @@ mod tests {
             "Parse timestamptz with invalid input should fail!"
         );
     }
+
+    #[test]
+    fn test_datum_ser_deser() {
+        let test_fn = |datum: Datum| {
+            let json = serde_json::to_value(&datum).unwrap();
+            let desered_datum: Datum = serde_json::from_value(json).unwrap();
+            assert_eq!(datum, desered_datum);
+        };
+        let datum = Datum::int(1);
+        test_fn(datum);
+        let datum = Datum::long(1);
+        test_fn(datum);
+
+        let datum = Datum::float(1.0);
+        test_fn(datum);
+        let datum = Datum::float(0_f32);
+        test_fn(datum);
+        let datum = Datum::float(-0_f32);
+        test_fn(datum);
+        let datum = Datum::float(f32::MAX);
+        test_fn(datum);
+        let datum = Datum::float(f32::MIN);
+        test_fn(datum);
+
+        // serde_json can't serialize f32::INFINITY, f32::NEG_INFINITY, 
f32::NAN
+        let datum = Datum::float(f32::INFINITY);
+        let json = serde_json::to_string(&datum).unwrap();
+        assert!(serde_json::from_str::<Datum>(&json).is_err());
+        let datum = Datum::float(f32::NEG_INFINITY);
+        let json = serde_json::to_string(&datum).unwrap();
+        assert!(serde_json::from_str::<Datum>(&json).is_err());
+        let datum = Datum::float(f32::NAN);
+        let json = serde_json::to_string(&datum).unwrap();
+        assert!(serde_json::from_str::<Datum>(&json).is_err());
+
+        let datum = Datum::double(1.0);
+        test_fn(datum);
+        let datum = Datum::double(f64::MAX);
+        test_fn(datum);
+        let datum = Datum::double(f64::MIN);
+        test_fn(datum);
+
+        // serde_json can't serialize f32::INFINITY, f32::NEG_INFINITY, 
f32::NAN
+        let datum = Datum::double(f64::INFINITY);
+        let json = serde_json::to_string(&datum).unwrap();
+        assert!(serde_json::from_str::<Datum>(&json).is_err());
+        let datum = Datum::double(f64::NEG_INFINITY);
+        let json = serde_json::to_string(&datum).unwrap();
+        assert!(serde_json::from_str::<Datum>(&json).is_err());
+        let datum = Datum::double(f64::NAN);
+        let json = serde_json::to_string(&datum).unwrap();
+        assert!(serde_json::from_str::<Datum>(&json).is_err());
+
+        let datum = Datum::string("iceberg");
+        test_fn(datum);
+        let datum = Datum::bool(true);
+        test_fn(datum);
+        let datum = Datum::date(17486);
+        test_fn(datum);
+        let datum = Datum::time_from_hms_micro(22, 15, 33, 111).unwrap();
+        test_fn(datum);
+        let datum = Datum::timestamp_micros(1510871468123456);
+        test_fn(datum);
+        let datum = Datum::timestamptz_micros(1510871468123456);
+        test_fn(datum);
+        let datum = 
Datum::uuid(Uuid::parse_str("f79c3e09-677c-4bbd-a479-3f349cb785e7").unwrap());
+        test_fn(datum);
+        let datum = Datum::decimal(1420).unwrap();
+        test_fn(datum);
+        let datum = Datum::binary(vec![1, 2, 3, 4, 5]);
+        test_fn(datum);
+        let datum = Datum::fixed(vec![1, 2, 3, 4, 5]);
+        test_fn(datum);
+    }
 }

Reply via email to