This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new c37ddf72e feat: express unsigned literal in substrait (#5448)
c37ddf72e is described below
commit c37ddf72ec539bd39cce0dd4ff38db2e36ddb55f
Author: Ruihang Xia <[email protected]>
AuthorDate: Sun Mar 5 04:08:38 2023 +0800
feat: express unsigned literal in substrait (#5448)
Signed-off-by: Ruihang Xia <[email protected]>
---
datafusion/substrait/src/logical_plan/consumer.rs | 50 ++++++++++++++++++++--
datafusion/substrait/src/logical_plan/producer.rs | 16 +++++--
.../substrait/tests/roundtrip_logical_plan.rs | 8 +++-
datafusion/substrait/tests/testdata/data.csv | 6 +--
4 files changed, 69 insertions(+), 11 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index afb83058a..627e71638 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -615,16 +615,58 @@ pub async fn from_substrait_rex(
Some(RexType::Literal(lit)) => {
match &lit.literal_type {
Some(LiteralType::I8(n)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as
i8)))))
+ if lit.type_variation_reference == 0 {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Int8(Some(*n as
i8)))))
+ } else if lit.type_variation_reference == 1 {
+ Ok(Arc::new(Expr::Literal(ScalarValue::UInt8(Some(*n
as u8)))))
+ } else {
+ Err(DataFusionError::Substrait(format!(
+ "Unknown type variation reference {}",
+ lit.type_variation_reference
+ )))
+ }
}
Some(LiteralType::I16(n)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n as
i16)))))
+ if lit.type_variation_reference == 0 {
+ Ok(Arc::new(Expr::Literal(ScalarValue::Int16(Some(*n
as i16)))))
+ } else if lit.type_variation_reference == 1 {
+ Ok(Arc::new(Expr::Literal(ScalarValue::UInt16(Some(
+ *n as u16,
+ )))))
+ } else {
+ Err(DataFusionError::Substrait(format!(
+ "Unknown type variation reference {}",
+ lit.type_variation_reference
+ )))
+ }
}
Some(LiteralType::I32(n)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n)))))
+ if lit.type_variation_reference == 0 {
+
Ok(Arc::new(Expr::Literal(ScalarValue::Int32(Some(*n)))))
+ } else if lit.type_variation_reference == 1 {
+
Ok(Arc::new(Expr::Literal(ScalarValue::UInt32(Some(unsafe {
+ std::mem::transmute_copy::<i32, u32>(n)
+ })))))
+ } else {
+ Err(DataFusionError::Substrait(format!(
+ "Unknown type variation reference {}",
+ lit.type_variation_reference
+ )))
+ }
}
Some(LiteralType::I64(n)) => {
- Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n)))))
+ if lit.type_variation_reference == 0 {
+
Ok(Arc::new(Expr::Literal(ScalarValue::Int64(Some(*n)))))
+ } else if lit.type_variation_reference == 1 {
+
Ok(Arc::new(Expr::Literal(ScalarValue::UInt64(Some(unsafe {
+ std::mem::transmute_copy::<i64, u64>(n)
+ })))))
+ } else {
+ Err(DataFusionError::Substrait(format!(
+ "Unknown type variation reference {}",
+ lit.type_variation_reference
+ )))
+ }
}
Some(LiteralType::Boolean(b)) => {
Ok(Arc::new(Expr::Literal(ScalarValue::Boolean(Some(*b)))))
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index ef8a52d9d..cf4003c1c 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use std::{collections::HashMap, sync::Arc};
+use std::{collections::HashMap, mem, sync::Arc};
use datafusion::{
error::{DataFusionError, Result},
@@ -580,10 +580,17 @@ pub fn to_substrait_rex(
Expr::Literal(value) => {
let literal_type = match value {
ScalarValue::Int8(Some(n)) => Some(LiteralType::I8(*n as i32)),
+ ScalarValue::UInt8(Some(n)) => Some(LiteralType::I8(*n as
i32)),
ScalarValue::Int16(Some(n)) => Some(LiteralType::I16(*n as
i32)),
+ ScalarValue::UInt16(Some(n)) => Some(LiteralType::I16(*n as
i32)),
ScalarValue::Int32(Some(n)) => Some(LiteralType::I32(*n)),
+ ScalarValue::UInt32(Some(n)) => Some(LiteralType::I32(unsafe {
+ mem::transmute_copy::<u32, i32>(n)
+ })),
ScalarValue::Int64(Some(n)) => Some(LiteralType::I64(*n)),
- ScalarValue::UInt8(Some(n)) => Some(LiteralType::I16(*n as
i32)), // Substrait currently does not support unsigned integer
+ ScalarValue::UInt64(Some(n)) => Some(LiteralType::I64(unsafe {
+ mem::transmute_copy::<u64, i64>(n)
+ })),
ScalarValue::Boolean(Some(b)) =>
Some(LiteralType::Boolean(*b)),
ScalarValue::Float32(Some(f)) => Some(LiteralType::Fp32(*f)),
ScalarValue::Float64(Some(f)) => Some(LiteralType::Fp64(*f)),
@@ -601,10 +608,13 @@ pub fn to_substrait_rex(
ScalarValue::Date32(Some(d)) => Some(LiteralType::Date(*d)),
_ => Some(try_to_substrait_null(value)?),
};
+
+ let type_variation_reference = if value.is_unsigned() { 1 } else {
0 };
+
Ok(Expression {
rex_type: Some(RexType::Literal(Literal {
nullable: true,
- type_variation_reference: 0,
+ type_variation_reference,
literal_type,
})),
})
diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/roundtrip_logical_plan.rs
index 3ce3343d1..85b3ef19f 100644
--- a/datafusion/substrait/tests/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs
@@ -104,6 +104,11 @@ mod tests {
roundtrip("SELECT * FROM data WHERE b = NULL").await
}
+ #[tokio::test]
+ async fn u32_literal() -> Result<()> {
+ roundtrip("SELECT * FROM data WHERE e > 4294967295").await
+ }
+
#[tokio::test]
async fn simple_distinct() -> Result<()> {
test_alias(
@@ -226,7 +231,7 @@ mod tests {
async fn simple_intersect() -> Result<()> {
assert_expected_plan(
"SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT
data2.a FROM data2);",
- "Aggregate: groupBy=[[]], aggr=[[COUNT(Int16(1))]]\
+ "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
\n LeftSemi Join: data.a = data2.a\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
@@ -335,6 +340,7 @@ mod tests {
Field::new("b", DataType::Decimal128(5, 2), true),
Field::new("c", DataType::Date32, true),
Field::new("d", DataType::Boolean, true),
+ Field::new("e", DataType::UInt32, true),
]);
explicit_options.schema = Some(&schema);
ctx.register_csv("data", "tests/testdata/data.csv", explicit_options)
diff --git a/datafusion/substrait/tests/testdata/data.csv
b/datafusion/substrait/tests/testdata/data.csv
index b0fc71024..170457da5 100644
--- a/datafusion/substrait/tests/testdata/data.csv
+++ b/datafusion/substrait/tests/testdata/data.csv
@@ -1,3 +1,3 @@
-a,b,c,d
-1,2.0,2020-01-01,false
-3,4.5,2020-01-01,true
\ No newline at end of file
+a,b,c,d,e
+1,2.0,2020-01-01,false,4294967296
+3,4.5,2020-01-01,true,2147483648
\ No newline at end of file