This is an automated email from the ASF dual-hosted git repository.

kszucs pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git

commit 8251cc90a20b0f429d1dbade0a32646437168222
Author: Jorge C. Leitao <[email protected]>
AuthorDate: Wed Jul 22 07:43:41 2020 -0600

    ARROW-9534: [Rust] [DataFusion] Added support for lit to all supported rust 
types.
    
    @andygrove fyi
    
    Closes #7811 from jorgecarleitao/lit
    
    Authored-by: Jorge C. Leitao <[email protected]>
    Signed-off-by: Andy Grove <[email protected]>
---
 rust/datafusion/examples/memory_table_api.rs       |  4 +-
 rust/datafusion/src/logicalplan.rs                 | 51 +++++++++++++++++++---
 .../src/optimizer/projection_push_down.rs          |  7 +--
 rust/datafusion/src/sql/planner.rs                 | 18 +++-----
 4 files changed, 56 insertions(+), 24 deletions(-)

diff --git a/rust/datafusion/examples/memory_table_api.rs 
b/rust/datafusion/examples/memory_table_api.rs
index bfa8612..937d80e 100644
--- a/rust/datafusion/examples/memory_table_api.rs
+++ b/rust/datafusion/examples/memory_table_api.rs
@@ -26,7 +26,7 @@ use arrow::util::pretty;
 use datafusion::datasource::MemTable;
 use datafusion::error::Result;
 use datafusion::execution::context::ExecutionContext;
-use datafusion::logicalplan::{Expr, ScalarValue};
+use datafusion::logicalplan::lit;
 
 /// This example demonstrates basic uses of the Table API on an in-memory table
 fn main() -> Result<()> {
@@ -54,7 +54,7 @@ fn main() -> Result<()> {
     let t = ctx.table("t")?;
 
     // construct an expression corresponding to "SELECT a, b FROM t WHERE b = 
10" in SQL
-    let filter = t.col("b")?.eq(&Expr::Literal(ScalarValue::Int32(10)));
+    let filter = t.col("b")?.eq(&lit(10));
 
     let t = t.select_columns(vec!["a", "b"])?.filter(filter)?;
 
diff --git a/rust/datafusion/src/logicalplan.rs 
b/rust/datafusion/src/logicalplan.rs
index fe711ee..032bfb9 100644
--- a/rust/datafusion/src/logicalplan.rs
+++ b/rust/datafusion/src/logicalplan.rs
@@ -378,9 +378,50 @@ pub fn col(name: &str) -> Expr {
     Expr::UnresolvedColumn(name.to_owned())
 }
 
-/// Create a literal string expression
-pub fn lit_str(str: &str) -> Expr {
-    Expr::Literal(ScalarValue::Utf8(str.to_owned()))
+/// Whether it can be represented as a literal expression
+pub trait Literal {
+    /// convert the value to a Literal expression
+    fn lit(&self) -> Expr;
+}
+
+impl Literal for &str {
+    fn lit(&self) -> Expr {
+        Expr::Literal(ScalarValue::Utf8((*self).to_owned()))
+    }
+}
+
+impl Literal for String {
+    fn lit(&self) -> Expr {
+        Expr::Literal(ScalarValue::Utf8((*self).to_owned()))
+    }
+}
+
+macro_rules! make_literal {
+    ($TYPE:ty, $SCALAR:ident) => {
+        #[allow(missing_docs)]
+        impl Literal for $TYPE {
+            fn lit(&self) -> Expr {
+                Expr::Literal(ScalarValue::$SCALAR(self.clone()))
+            }
+        }
+    };
+}
+
+make_literal!(bool, Boolean);
+make_literal!(f32, Float32);
+make_literal!(f64, Float64);
+make_literal!(i8, Int8);
+make_literal!(i16, Int16);
+make_literal!(i32, Int32);
+make_literal!(i64, Int64);
+make_literal!(u8, UInt8);
+make_literal!(u16, UInt16);
+make_literal!(u32, UInt32);
+make_literal!(u64, UInt64);
+
+/// Create a literal expression
+pub fn lit<T: Literal>(n: T) -> Expr {
+    n.lit()
 }
 
 /// Create an convenience function representing a unary scalar function
@@ -965,7 +1006,7 @@ mod tests {
             &employee_schema(),
             Some(vec![0, 3]),
         )?
-        .filter(col("state").eq(&lit_str("CO")))?
+        .filter(col("state").eq(&lit("CO")))?
         .project(vec![col("id")])?
         .build()?;
 
@@ -985,7 +1026,7 @@ mod tests {
             CsvReadOptions::new().schema(&employee_schema()),
             Some(vec![0, 3]),
         )?
-        .filter(col("state").eq(&lit_str("CO")))?
+        .filter(col("state").eq(&lit("CO")))?
         .project(vec![col("id")])?
         .build()?;
 
diff --git a/rust/datafusion/src/optimizer/projection_push_down.rs 
b/rust/datafusion/src/optimizer/projection_push_down.rs
index 8fc203e..d1bba6e 100644
--- a/rust/datafusion/src/optimizer/projection_push_down.rs
+++ b/rust/datafusion/src/optimizer/projection_push_down.rs
@@ -368,8 +368,8 @@ fn get_projected_schema(
 mod tests {
 
     use super::*;
+    use crate::logicalplan::lit;
     use crate::logicalplan::Expr::*;
-    use crate::logicalplan::ScalarValue;
     use crate::test::*;
     use arrow::datatypes::DataType;
 
@@ -498,10 +498,7 @@ mod tests {
     fn table_scan_with_literal_projection() -> Result<()> {
         let table_scan = test_table_scan()?;
         let plan = LogicalPlanBuilder::from(&table_scan)
-            .project(vec![
-                Expr::Literal(ScalarValue::Int64(1)),
-                Expr::Literal(ScalarValue::Int64(2)),
-            ])?
+            .project(vec![lit(1_i64), lit(2_i64)])?
             .build()?;
         let expected = "Projection: Int64(1), Int64(2)\
                       \n  TableScan: test projection=Some([0])";
diff --git a/rust/datafusion/src/sql/planner.rs 
b/rust/datafusion/src/sql/planner.rs
index d6bd296..8c74c5a 100644
--- a/rust/datafusion/src/sql/planner.rs
+++ b/rust/datafusion/src/sql/planner.rs
@@ -21,7 +21,7 @@ use std::sync::Arc;
 
 use crate::error::{ExecutionError, Result};
 use crate::logicalplan::{
-    Expr, FunctionMeta, LogicalPlan, LogicalPlanBuilder, Operator, ScalarValue,
+    lit, Expr, FunctionMeta, LogicalPlan, LogicalPlanBuilder, Operator, 
ScalarValue,
 };
 
 use arrow::datatypes::*;
@@ -262,14 +262,10 @@ impl<S: SchemaProvider> SqlToRel<S> {
     /// Generate a relational expression from a SQL expression
     pub fn sql_to_rex(&self, sql: &ASTNode, schema: &Schema) -> Result<Expr> {
         match *sql {
-            ASTNode::SQLValue(sqlparser::sqlast::Value::Long(n)) => {
-                Ok(Expr::Literal(ScalarValue::Int64(n)))
-            }
-            ASTNode::SQLValue(sqlparser::sqlast::Value::Double(n)) => {
-                Ok(Expr::Literal(ScalarValue::Float64(n)))
-            }
+            ASTNode::SQLValue(sqlparser::sqlast::Value::Long(n)) => Ok(lit(n)),
+            ASTNode::SQLValue(sqlparser::sqlast::Value::Double(n)) => 
Ok(lit(n)),
             ASTNode::SQLValue(sqlparser::sqlast::Value::SingleQuotedString(ref 
s)) => {
-                Ok(Expr::Literal(ScalarValue::Utf8(s.clone())))
+                Ok(lit(s.clone()))
             }
 
             ASTNode::SQLAliasedExpr(ref expr, ref alias) => Ok(Alias(
@@ -382,11 +378,9 @@ impl<S: SchemaProvider> SqlToRel<S> {
                             .iter()
                             .map(|a| match a {
                                 
ASTNode::SQLValue(sqlparser::sqlast::Value::Long(_)) => {
-                                    Ok(Expr::Literal(ScalarValue::UInt8(1)))
-                                }
-                                ASTNode::SQLWildcard => {
-                                    Ok(Expr::Literal(ScalarValue::UInt8(1)))
+                                    Ok(lit(1_u8))
                                 }
+                                ASTNode::SQLWildcard => Ok(lit(1_u8)),
                                 _ => self.sql_to_rex(a, schema),
                             })
                             .collect::<Result<Vec<Expr>>>()?;

Reply via email to