This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new c37ddf72e feat: express unsigned literal in substrait (#5448)
c37ddf72e is described below

commit c37ddf72ec539bd39cce0dd4ff38db2e36ddb55f
Author: Ruihang Xia <[email protected]>
AuthorDate: Sun Mar 5 04:08:38 2023 +0800

    feat: express unsigned literal in substrait (#5448)
    
    Signed-off-by: Ruihang Xia <[email protected]>
---
 datafusion/substrait/src/logical_plan/consumer.rs  | 50 ++++++++++++++++++++--
 datafusion/substrait/src/logical_plan/producer.rs  | 16 +++++--
 .../substrait/tests/roundtrip_logical_plan.rs      |  8 +++-
 datafusion/substrait/tests/testdata/data.csv       |  6 +--
 4 files changed, 69 insertions(+), 11 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index afb83058a..627e71638 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -615,16 +615,58 @@ pub async fn from_substrait_rex(
         Some(RexType::Literal(lit)) => {
             match &lit.literal_type {
                 Some(LiteralType::I8(n)) => {
-                    Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as 
i8)))))
+                    if lit.type_variation_reference == 0 {
+                        Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as 
i8)))))
+                    } else if lit.type_variation_reference == 1 {
+                        Ok(Arc::new(Expr::Literal(ScalarValue::UInt8(Some(*n 
as u8)))))
+                    } else {
+                        Err(DataFusionError::Substrait(format!(
+                            "Unknown type variation reference {}",
+                            lit.type_variation_reference
+                        )))
+                    }
                 }
                 Some(LiteralType::I16(n)) => {
-                    Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as 
i16)))))
+                    if lit.type_variation_reference == 0 {
+                        Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n 
as i16)))))
+                    } else if lit.type_variation_reference == 1 {
+                        Ok(Arc::new(Expr::Literal(ScalarValue::UInt16(Some(
+                            *n as u16,
+                        )))))
+                    } else {
+                        Err(DataFusionError::Substrait(format!(
+                            "Unknown type variation reference {}",
+                            lit.type_variation_reference
+                        )))
+                    }
                 }
                 Some(LiteralType::I32(n)) => {
-                    Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n)))))
+                    if lit.type_variation_reference == 0 {
+                        
Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n)))))
+                    } else if lit.type_variation_reference == 1 {
+                        
Ok(Arc::new(Expr::Literal(ScalarValue::UInt32(Some(unsafe {
+                            std::mem::transmute_copy::<i32, u32>(n)
+                        })))))
+                    } else {
+                        Err(DataFusionError::Substrait(format!(
+                            "Unknown type variation reference {}",
+                            lit.type_variation_reference
+                        )))
+                    }
                 }
                 Some(LiteralType::I64(n)) => {
-                    Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n)))))
+                    if lit.type_variation_reference == 0 {
+                        
Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n)))))
+                    } else if lit.type_variation_reference == 1 {
+                        
Ok(Arc::new(Expr::Literal(ScalarValue::UInt64(Some(unsafe {
+                            std::mem::transmute_copy::<i64, u64>(n)
+                        })))))
+                    } else {
+                        Err(DataFusionError::Substrait(format!(
+                            "Unknown type variation reference {}",
+                            lit.type_variation_reference
+                        )))
+                    }
                 }
                 Some(LiteralType::Boolean(b)) => {
                     Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b)))))
diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index ef8a52d9d..cf4003c1c 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::{collections::HashMap, sync::Arc};
+use std::{collections::HashMap, mem, sync::Arc};
 
 use datafusion::{
     error::{DataFusionError, Result},
@@ -580,10 +580,17 @@ pub fn to_substrait_rex(
         Expr::Literal(value) => {
             let literal_type = match value {
                 ScalarValue::Int8(Some(n)) => Some(LiteralType::I8(*n as i32)),
+                ScalarValue::UInt8(Some(n)) => Some(LiteralType::I8(*n as 
i32)),
                 ScalarValue::Int16(Some(n)) => Some(LiteralType::I16(*n as 
i32)),
+                ScalarValue::UInt16(Some(n)) => Some(LiteralType::I16(*n as 
i32)),
                 ScalarValue::Int32(Some(n)) => Some(LiteralType::I32(*n)),
+                ScalarValue::UInt32(Some(n)) => Some(LiteralType::I32(unsafe {
+                    mem::transmute_copy::<u32, i32>(n)
+                })),
                 ScalarValue::Int64(Some(n)) => Some(LiteralType::I64(*n)),
-                ScalarValue::UInt8(Some(n)) => Some(LiteralType::I16(*n as 
i32)), // Substrait currently does not support unsigned integer
+                ScalarValue::UInt64(Some(n)) => Some(LiteralType::I64(unsafe {
+                    mem::transmute_copy::<u64, i64>(n)
+                })),
                 ScalarValue::Boolean(Some(b)) => 
Some(LiteralType::Boolean(*b)),
                 ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)),
                 ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)),
@@ -601,10 +608,13 @@ pub fn to_substrait_rex(
                 ScalarValue::Date32(Some(d)) => Some(LiteralType::Date(*d)),
                 _ => Some(try_to_substrait_null(value)?),
             };
+
+            let type_variation_reference = if value.is_unsigned() { 1 } else { 
0 };
+
             Ok(Expression {
                 rex_type: Some(RexType::Literal(Literal {
                     nullable: true,
-                    type_variation_reference: 0,
+                    type_variation_reference,
                     literal_type,
                 })),
             })
diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/roundtrip_logical_plan.rs
index 3ce3343d1..85b3ef19f 100644
--- a/datafusion/substrait/tests/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs
@@ -104,6 +104,11 @@ mod tests {
         roundtrip("SELECT * FROM data WHERE b = NULL").await
     }
 
+    #[tokio::test]
+    async fn u32_literal() -> Result<()> {
+        roundtrip("SELECT * FROM data WHERE e > 4294967295").await
+    }
+
     #[tokio::test]
     async fn simple_distinct() -> Result<()> {
         test_alias(
@@ -226,7 +231,7 @@ mod tests {
     async fn simple_intersect() -> Result<()> {
         assert_expected_plan(
             "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT 
data2.a FROM data2);",
-            "Aggregate: groupBy=[[]], aggr=[[COUNT(Int16(1))]]\
+            "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
             \n  LeftSemi Join: data.a = data2.a\
             \n    Aggregate: groupBy=[[data.a]], aggr=[[]]\
             \n      TableScan: data projection=[a]\
@@ -335,6 +340,7 @@ mod tests {
             Field::new("b", DataType::Decimal128(5, 2), true),
             Field::new("c", DataType::Date32, true),
             Field::new("d", DataType::Boolean, true),
+            Field::new("e", DataType::UInt32, true),
         ]);
         explicit_options.schema = Some(&schema);
         ctx.register_csv("data", "tests/testdata/data.csv", explicit_options)
diff --git a/datafusion/substrait/tests/testdata/data.csv 
b/datafusion/substrait/tests/testdata/data.csv
index b0fc71024..170457da5 100644
--- a/datafusion/substrait/tests/testdata/data.csv
+++ b/datafusion/substrait/tests/testdata/data.csv
@@ -1,3 +1,3 @@
-a,b,c,d
-1,2.0,2020-01-01,false
-3,4.5,2020-01-01,true
\ No newline at end of file
+a,b,c,d,e
+1,2.0,2020-01-01,false,4294967296
+3,4.5,2020-01-01,true,2147483648
\ No newline at end of file

Reply via email to