This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 4ce35ef07d Remove transmute in datafusion-proto (#5946)
4ce35ef07d is described below
commit 4ce35ef07daa72b10fba7dd03a62e86d361514dd
Author: Andrew Lamb <[email protected]>
AuthorDate: Mon Apr 10 11:32:18 2023 -0400
Remove transmute in datafusion-proto (#5946)
---
datafusion/substrait/src/logical_plan/consumer.rs | 17 ++---
datafusion/substrait/src/logical_plan/producer.rs | 75 +++++++++++++++++++---
.../substrait/tests/roundtrip_logical_plan.rs | 2 +-
3 files changed, 76 insertions(+), 18 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 607012bfd6..b807a9acf7 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -453,7 +453,12 @@ pub async fn from_substrait_sorts(
let asc_nullfirst = match &s.sort_kind {
Some(k) => match k {
Direction(d) => {
- let direction: SortDirection = unsafe {
::std::mem::transmute(*d) };
+ let Some(direction) = SortDirection::from_i32(*d) else {
+ return Err(DataFusionError::NotImplemented(
+ format!("Unsupported Substrait SortDirection value
{d}"),
+ ))
+ };
+
match direction {
SortDirection::AscNullsFirst => Ok((true, true)),
SortDirection::AscNullsLast => Ok((true, false)),
@@ -908,7 +913,7 @@ fn from_substrait_bound(
}
}
-fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
+pub(crate) fn from_substrait_literal(lit: &Literal) -> Result<ScalarValue> {
let scalar_value = match &lit.literal_type {
Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)),
Some(LiteralType::I8(n)) => match lit.type_variation_reference {
@@ -931,9 +936,7 @@ fn from_substrait_literal(lit: &Literal) ->
Result<ScalarValue> {
},
Some(LiteralType::I32(n)) => match lit.type_variation_reference {
DEFAULT_TYPE_REF => ScalarValue::Int32(Some(*n)),
- UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt32(Some(unsafe {
- std::mem::transmute_copy::<i32, u32>(n)
- })),
+ UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt32(Some(*n as u32)),
others => {
return Err(DataFusionError::Substrait(format!(
"Unknown type variation reference {others}",
@@ -942,9 +945,7 @@ fn from_substrait_literal(lit: &Literal) ->
Result<ScalarValue> {
},
Some(LiteralType::I64(n)) => match lit.type_variation_reference {
DEFAULT_TYPE_REF => ScalarValue::Int64(Some(*n)),
- UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt64(Some(unsafe {
- std::mem::transmute_copy::<i64, u64>(n)
- })),
+ UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt64(Some(*n as u64)),
others => {
return Err(DataFusionError::Substrait(format!(
"Unknown type variation reference {others}",
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 9ad9645ffc..b3df6af238 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use std::{collections::HashMap, mem, sync::Arc};
+use std::{collections::HashMap, sync::Arc};
use datafusion::{
arrow::datatypes::{DataType, TimeUnit},
@@ -1021,15 +1021,13 @@ fn to_substrait_literal(value: &ScalarValue) ->
Result<Expression> {
(LiteralType::I16(*n as i32), UNSIGNED_INTEGER_TYPE_REF)
}
ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n),
DEFAULT_TYPE_REF),
- ScalarValue::UInt32(Some(n)) => (
- LiteralType::I32(unsafe { mem::transmute_copy::<u32, i32>(n) }),
- UNSIGNED_INTEGER_TYPE_REF,
- ),
+ ScalarValue::UInt32(Some(n)) => {
+ (LiteralType::I32(*n as i32), UNSIGNED_INTEGER_TYPE_REF)
+ }
ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n),
DEFAULT_TYPE_REF),
- ScalarValue::UInt64(Some(n)) => (
- LiteralType::I64(unsafe { mem::transmute_copy::<u64, i64>(n) }),
- UNSIGNED_INTEGER_TYPE_REF,
- ),
+ ScalarValue::UInt64(Some(n)) => {
+ (LiteralType::I64(*n as i64), UNSIGNED_INTEGER_TYPE_REF)
+ }
ScalarValue::Float32(Some(f)) => (LiteralType::Fp32(*f),
DEFAULT_TYPE_REF),
ScalarValue::Float64(Some(f)) => (LiteralType::Fp64(*f),
DEFAULT_TYPE_REF),
ScalarValue::TimestampSecond(Some(t), _) => {
@@ -1285,3 +1283,62 @@ fn substrait_field_ref(index: usize) ->
Result<Expression> {
}))),
})
}
+
+#[cfg(test)]
+mod test {
+ use crate::logical_plan::consumer::from_substrait_literal;
+
+ use super::*;
+
+ #[test]
+ fn round_trip_literals() -> Result<()> {
+ round_trip_literal(ScalarValue::Boolean(None))?;
+ round_trip_literal(ScalarValue::Boolean(Some(true)))?;
+ round_trip_literal(ScalarValue::Boolean(Some(false)))?;
+
+ round_trip_literal(ScalarValue::Int8(None))?;
+ round_trip_literal(ScalarValue::Int8(Some(i8::MIN)))?;
+ round_trip_literal(ScalarValue::Int8(Some(i8::MAX)))?;
+ round_trip_literal(ScalarValue::UInt8(None))?;
+ round_trip_literal(ScalarValue::UInt8(Some(u8::MIN)))?;
+ round_trip_literal(ScalarValue::UInt8(Some(u8::MAX)))?;
+
+ round_trip_literal(ScalarValue::Int16(None))?;
+ round_trip_literal(ScalarValue::Int16(Some(i16::MIN)))?;
+ round_trip_literal(ScalarValue::Int16(Some(i16::MAX)))?;
+ round_trip_literal(ScalarValue::UInt16(None))?;
+ round_trip_literal(ScalarValue::UInt16(Some(u16::MIN)))?;
+ round_trip_literal(ScalarValue::UInt16(Some(u16::MAX)))?;
+
+ round_trip_literal(ScalarValue::Int32(None))?;
+ round_trip_literal(ScalarValue::Int32(Some(i32::MIN)))?;
+ round_trip_literal(ScalarValue::Int32(Some(i32::MAX)))?;
+ round_trip_literal(ScalarValue::UInt32(None))?;
+ round_trip_literal(ScalarValue::UInt32(Some(u32::MIN)))?;
+ round_trip_literal(ScalarValue::UInt32(Some(u32::MAX)))?;
+
+ round_trip_literal(ScalarValue::Int64(None))?;
+ round_trip_literal(ScalarValue::Int64(Some(i64::MIN)))?;
+ round_trip_literal(ScalarValue::Int64(Some(i64::MAX)))?;
+ round_trip_literal(ScalarValue::UInt64(None))?;
+ round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?;
+ round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?;
+
+ Ok(())
+ }
+
+ fn round_trip_literal(scalar: ScalarValue) -> Result<()> {
+ println!("Checking round trip of {:?}", scalar);
+
+ let scalar = ScalarValue::Int32(Some(i32::MAX));
+ let substrait = to_substrait_literal(&scalar)?;
+
+ let Expression { rex_type: Some(RexType::Literal(substrait_literal)) }
= substrait else {
+ panic!("Expected Literal expression, got {:?}", substrait);
+ };
+
+ let roundtrip_scalar = from_substrait_literal(&substrait_literal)?;
+ assert_eq!(scalar, roundtrip_scalar);
+ Ok(())
+ }
+}
diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/roundtrip_logical_plan.rs
index 965c007e98..0b012ffded 100644
--- a/datafusion/substrait/tests/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs
@@ -274,7 +274,7 @@ mod tests {
#[tokio::test]
async fn all_type_literal() -> Result<()> {
roundtrip_all_types(
- "select * from data where
+ "select * from data where
bool_col = TRUE AND
int8_col = arrow_cast('0', 'Int8') AND
uint8_col = arrow_cast('0', 'UInt8') AND