This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 4ce35ef07d Remove transmute in datafusion-proto (#5946)
4ce35ef07d is described below

commit 4ce35ef07daa72b10fba7dd03a62e86d361514dd
Author: Andrew Lamb <[email protected]>
AuthorDate: Mon Apr 10 11:32:18 2023 -0400

    Remove transmute in datafusion-proto (#5946)
---
 datafusion/substrait/src/logical_plan/consumer.rs  | 17 ++---
 datafusion/substrait/src/logical_plan/producer.rs  | 75 +++++++++++++++++++---
 .../substrait/tests/roundtrip_logical_plan.rs      |  2 +-
 3 files changed, 76 insertions(+), 18 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index 607012bfd6..b807a9acf7 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -453,7 +453,12 @@ pub async fn from_substrait_sorts(
         let asc_nullfirst = match &s.sort_kind {
             Some(k) => match k {
                 Direction(d) => {
-                    let direction: SortDirection = unsafe { 
::std::mem::transmute(*d) };
+                    let Some(direction) = SortDirection::from_i32(*d) else {
+                        return Err(DataFusionError::NotImplemented(
+                            format!("Unsupported Substrait SortDirection value 
{d}"),
+                        ))
+                    };
+
                     match direction {
                         SortDirection::AscNullsFirst => Ok((true, true)),
                         SortDirection::AscNullsLast => Ok((true, false)),
@@ -908,7 +913,7 @@ fn from_substrait_bound(
     }
 }
 
-fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
+pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
     let scalar_value = match &lit.literal_type {
         Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)),
         Some(LiteralType::I8(n)) => match lit.type_variation_reference {
@@ -931,9 +936,7 @@ fn from_substrait_literal(lit: &Literal) -> 
Result<ScalarValue> {
         },
         Some(LiteralType::I32(n)) => match lit.type_variation_reference {
             DEFAULT_TYPE_REF => ScalarValue::Int32(Some(*n)),
-            UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt32(Some(unsafe {
-                std::mem::transmute_copy::<i32, u32>(n)
-            })),
+            UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt32(Some(*n as u32)),
             others => {
                 return Err(DataFusionError::Substrait(format!(
                     "Unknown type variation reference {others}",
@@ -942,9 +945,7 @@ fn from_substrait_literal(lit: &Literal) -> 
Result<ScalarValue> {
         },
         Some(LiteralType::I64(n)) => match lit.type_variation_reference {
             DEFAULT_TYPE_REF => ScalarValue::Int64(Some(*n)),
-            UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt64(Some(unsafe {
-                std::mem::transmute_copy::<i64, u64>(n)
-            })),
+            UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt64(Some(*n as u64)),
             others => {
                 return Err(DataFusionError::Substrait(format!(
                     "Unknown type variation reference {others}",
diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index 9ad9645ffc..b3df6af238 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::{collections::HashMap, mem, sync::Arc};
+use std::{collections::HashMap, sync::Arc};
 
 use datafusion::{
     arrow::datatypes::{DataType, TimeUnit},
@@ -1021,15 +1021,13 @@ fn to_substrait_literal(value: &ScalarValue) -> 
Result<Expression> {
             (LiteralType::I16(*n as i32), UNSIGNED_INTEGER_TYPE_REF)
         }
         ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n), 
DEFAULT_TYPE_REF),
-        ScalarValue::UInt32(Some(n)) => (
-            LiteralType::I32(unsafe { mem::transmute_copy::<u32, i32>(n) }),
-            UNSIGNED_INTEGER_TYPE_REF,
-        ),
+        ScalarValue::UInt32(Some(n)) => {
+            (LiteralType::I32(*n as i32), UNSIGNED_INTEGER_TYPE_REF)
+        }
         ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n), 
DEFAULT_TYPE_REF),
-        ScalarValue::UInt64(Some(n)) => (
-            LiteralType::I64(unsafe { mem::transmute_copy::<u64, i64>(n) }),
-            UNSIGNED_INTEGER_TYPE_REF,
-        ),
+        ScalarValue::UInt64(Some(n)) => {
+            (LiteralType::I64(*n as i64), UNSIGNED_INTEGER_TYPE_REF)
+        }
         ScalarValue::Float32(Some(f)) => (LiteralType::Fp32(*f), 
DEFAULT_TYPE_REF),
         ScalarValue::Float64(Some(f)) => (LiteralType::Fp64(*f), 
DEFAULT_TYPE_REF),
         ScalarValue::TimestampSecond(Some(t), _) => {
@@ -1285,3 +1283,62 @@ fn substrait_field_ref(index: usize) -> 
Result<Expression> {
         }))),
     })
 }
+
+#[cfg(test)]
+mod test {
+    use crate::logical_plan::consumer::from_substrait_literal;
+
+    use super::*;
+
+    #[test]
+    fn round_trip_literals() -> Result<()> {
+        round_trip_literal(ScalarValue::Boolean(None))?;
+        round_trip_literal(ScalarValue::Boolean(Some(true)))?;
+        round_trip_literal(ScalarValue::Boolean(Some(false)))?;
+
+        round_trip_literal(ScalarValue::Int8(None))?;
+        round_trip_literal(ScalarValue::Int8(Some(i8::MIN)))?;
+        round_trip_literal(ScalarValue::Int8(Some(i8::MAX)))?;
+        round_trip_literal(ScalarValue::UInt8(None))?;
+        round_trip_literal(ScalarValue::UInt8(Some(u8::MIN)))?;
+        round_trip_literal(ScalarValue::UInt8(Some(u8::MAX)))?;
+
+        round_trip_literal(ScalarValue::Int16(None))?;
+        round_trip_literal(ScalarValue::Int16(Some(i16::MIN)))?;
+        round_trip_literal(ScalarValue::Int16(Some(i16::MAX)))?;
+        round_trip_literal(ScalarValue::UInt16(None))?;
+        round_trip_literal(ScalarValue::UInt16(Some(u16::MIN)))?;
+        round_trip_literal(ScalarValue::UInt16(Some(u16::MAX)))?;
+
+        round_trip_literal(ScalarValue::Int32(None))?;
+        round_trip_literal(ScalarValue::Int32(Some(i32::MIN)))?;
+        round_trip_literal(ScalarValue::Int32(Some(i32::MAX)))?;
+        round_trip_literal(ScalarValue::UInt32(None))?;
+        round_trip_literal(ScalarValue::UInt32(Some(u32::MIN)))?;
+        round_trip_literal(ScalarValue::UInt32(Some(u32::MAX)))?;
+
+        round_trip_literal(ScalarValue::Int64(None))?;
+        round_trip_literal(ScalarValue::Int64(Some(i64::MIN)))?;
+        round_trip_literal(ScalarValue::Int64(Some(i64::MAX)))?;
+        round_trip_literal(ScalarValue::UInt64(None))?;
+        round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?;
+        round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?;
+
+        Ok(())
+    }
+
+    fn round_trip_literal(scalar: ScalarValue) -> Result<()> {
+        println!("Checking round trip of {:?}", scalar);
+
+        let scalar = ScalarValue::Int32(Some(i32::MAX));
+        let substrait = to_substrait_literal(&scalar)?;
+
+        let Expression { rex_type: Some(RexType::Literal(substrait_literal)) } 
= substrait else {
+            panic!("Expected Literal expression, got {:?}", substrait);
+        };
+
+        let roundtrip_scalar = from_substrait_literal(&substrait_literal)?;
+        assert_eq!(scalar, roundtrip_scalar);
+        Ok(())
+    }
+}
diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/roundtrip_logical_plan.rs
index 965c007e98..0b012ffded 100644
--- a/datafusion/substrait/tests/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs
@@ -274,7 +274,7 @@ mod tests {
     #[tokio::test]
     async fn all_type_literal() -> Result<()> {
         roundtrip_all_types(
-            "select * from data where 
+            "select * from data where
             bool_col = TRUE AND
             int8_col = arrow_cast('0', 'Int8') AND
             uint8_col = arrow_cast('0', 'UInt8') AND

Reply via email to