This is an automated email from the ASF dual-hosted git repository.

jiayuliu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new f455357  python `lit` function to support bool and byte vec (#1152)
f455357 is described below

commit f455357bf159763a19312bab2c9238bc101792e0
Author: Jiayu Liu <jimex...@users.noreply.github.com>
AuthorDate: Thu Oct 21 13:04:41 2021 +0800

    python `lit` function to support bool and byte vec (#1152)
    
    * python lit function to support bool and byte vec
    
    * update per comment
---
 datafusion/src/logical_plan/expr.rs |  12 ++
 python/Cargo.lock                   | 224 ++++++++++++++++--------------------
 python/src/functions.rs             |  48 +++++---
 python/tests/test_functions.py      |  11 +-
 4 files changed, 156 insertions(+), 139 deletions(-)

diff --git a/datafusion/src/logical_plan/expr.rs 
b/datafusion/src/logical_plan/expr.rs
index d50d533..011068d 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -1407,6 +1407,18 @@ impl Literal for String {
     }
 }
 
+impl Literal for Vec<u8> {
+    fn lit(&self) -> Expr {
+        Expr::Literal(ScalarValue::Binary(Some((*self).to_owned())))
+    }
+}
+
+impl Literal for &[u8] {
+    fn lit(&self) -> Expr {
+        Expr::Literal(ScalarValue::Binary(Some((*self).to_owned())))
+    }
+}
+
 impl Literal for ScalarValue {
     fn lit(&self) -> Expr {
         Expr::Literal(self.clone())
diff --git a/python/Cargo.lock b/python/Cargo.lock
index 6daefea..6ae2702 100644
--- a/python/Cargo.lock
+++ b/python/Cargo.lock
@@ -51,24 +51,19 @@ checksum = 
"a4c527152e37cf757a3f78aae5a06fbeefdb07ccc535c980a3208ee3060dd544"
 
 [[package]]
 name = "arrayvec"
-version = "0.5.2"
-source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b"
-
-[[package]]
-name = "arrayvec"
 version = "0.7.1"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "be4dc07131ffa69b8072d35f5007352af944213cde02545e2103680baed38fcd"
 
 [[package]]
 name = "arrow"
-version = "5.3.0"
+version = "6.0.0"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "a4091f84cacfdbd5238e161d314e585820269926f79e05d184db8a2898782d44"
+checksum = "337e668497751234149fd607f5cb41a6ae7b286b6329589126fe67f0ac55d637"
 dependencies = [
  "bitflags",
  "chrono",
+ "comfy-table",
  "csv",
  "flatbuffers",
  "hex",
@@ -77,7 +72,6 @@ dependencies = [
  "lexical-core",
  "multiversion",
  "num",
- "prettytable-rs",
  "rand 0.8.4",
  "regex",
  "serde",
@@ -97,17 +91,6 @@ dependencies = [
 ]
 
 [[package]]
-name = "atty"
-version = "0.2.14"
-source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
-dependencies = [
- "hermit-abi",
- "libc",
- "winapi",
-]
-
-[[package]]
 name = "autocfg"
 version = "1.0.1"
 source = "registry+https://github.com/rust-lang/crates.io-index";
@@ -137,24 +120,13 @@ dependencies = [
 ]
 
 [[package]]
-name = "blake2b_simd"
-version = "0.5.11"
-source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "afa748e348ad3be8263be728124b24a24f268266f6f5d58af9d75f6a40b5c587"
-dependencies = [
- "arrayref",
- "arrayvec 0.5.2",
- "constant_time_eq",
-]
-
-[[package]]
 name = "blake3"
 version = "1.0.0"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "dcd555c66291d5f836dbb6883b48660ece810fe25a31f3bdfb911945dff2691f"
 dependencies = [
  "arrayref",
- "arrayvec 0.7.1",
+ "arrayvec",
  "cc",
  "cfg-if",
  "constant_time_eq",
@@ -238,6 +210,17 @@ dependencies = [
 ]
 
 [[package]]
+name = "comfy-table"
+version = "4.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "11e95a3e867422fd8d04049041f5671f94d53c32a9dcd82e2be268714942f3f3"
+dependencies = [
+ "strum",
+ "strum_macros",
+ "unicode-width",
+]
+
+[[package]]
 name = "constant_time_eq"
 version = "0.1.5"
 source = "registry+https://github.com/rust-lang/crates.io-index";
@@ -262,16 +245,6 @@ dependencies = [
 ]
 
 [[package]]
-name = "crossbeam-utils"
-version = "0.8.5"
-source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "d82cfc11ce7f2c3faef78d8a684447b40d503d9681acebed6cb728d45940c4db"
-dependencies = [
- "cfg-if",
- "lazy_static",
-]
-
-[[package]]
 name = "crypto-mac"
 version = "0.8.0"
 source = "registry+https://github.com/rust-lang/crates.io-index";
@@ -355,23 +328,6 @@ dependencies = [
 ]
 
 [[package]]
-name = "dirs"
-version = "1.0.5"
-source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "3fd78930633bd1c6e35c4b42b1df7b0cbc6bc191146e512bb3bedf243fcc3901"
-dependencies = [
- "libc",
- "redox_users",
- "winapi",
-]
-
-[[package]]
-name = "encode_unicode"
-version = "0.3.6"
-source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f"
-
-[[package]]
 name = "flatbuffers"
 version = "2.0.0"
 source = "registry+https://github.com/rust-lang/crates.io-index";
@@ -530,6 +486,15 @@ dependencies = [
 ]
 
 [[package]]
+name = "heck"
+version = "0.3.3"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "6d621efb26863f0e9924c6ac577e8275e5e6b77455db64ffa6c65c904e9e132c"
+dependencies = [
+ "unicode-segmentation",
+]
+
+[[package]]
 name = "hermit-abi"
 version = "0.1.19"
 source = "registry+https://github.com/rust-lang/crates.io-index";
@@ -615,14 +580,65 @@ checksum = 
"e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
 
 [[package]]
 name = "lexical-core"
-version = "0.7.6"
+version = "0.8.2"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "6607c62aa161d23d17a9072cc5da0be67cdfc89d3afb1e8d9c842bebc2525ffe"
+checksum = "6a3926d8f156019890be4abe5fd3785e0cff1001e06f59c597641fd513a5a284"
 dependencies = [
- "arrayvec 0.5.2",
- "bitflags",
- "cfg-if",
- "ryu",
+ "lexical-parse-float",
+ "lexical-parse-integer",
+ "lexical-util",
+ "lexical-write-float",
+ "lexical-write-integer",
+]
+
+[[package]]
+name = "lexical-parse-float"
+version = "0.8.2"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "b4d066d004fa762d9da995ed21aa8845bb9f6e4265f540d716fb4b315197bf0e"
+dependencies = [
+ "lexical-parse-integer",
+ "lexical-util",
+ "static_assertions",
+]
+
+[[package]]
+name = "lexical-parse-integer"
+version = "0.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "f2c92badda8cc0fc4f3d3cc1c30aaefafb830510c8781ce4e8669881f3ed53ac"
+dependencies = [
+ "lexical-util",
+ "static_assertions",
+]
+
+[[package]]
+name = "lexical-util"
+version = "0.8.1"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "6ff669ccaae16ee33af90dc51125755efed17f1309626ba5c12052512b11e291"
+dependencies = [
+ "static_assertions",
+]
+
+[[package]]
+name = "lexical-write-float"
+version = "0.8.2"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "8b5186948c7b297abaaa51560f2581dae625e5ce7dfc2d8fdc56345adb6dc576"
+dependencies = [
+ "lexical-util",
+ "lexical-write-integer",
+ "static_assertions",
+]
+
+[[package]]
+name = "lexical-write-integer"
+version = "0.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "ece956492e0e40fd95ef8658a34d53a3b8c2015762fdcaaff2167b28de1f56ef"
+dependencies = [
+ "lexical-util",
  "static_assertions",
 ]
 
@@ -853,16 +869,16 @@ dependencies = [
  "cfg-if",
  "instant",
  "libc",
- "redox_syscall 0.2.10",
+ "redox_syscall",
  "smallvec",
  "winapi",
 ]
 
 [[package]]
 name = "parquet"
-version = "5.3.0"
+version = "6.0.0"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "2943ec261708f7aaa51ea8a1fa2cd3367f5dec3219f7c231d8759fef4b93cc06"
+checksum = "d263b9b59ba260518de9e57bd65931c3f765fea0fabacfe84f40d6fde38e841a"
 dependencies = [
  "arrow",
  "base64",
@@ -932,20 +948,6 @@ source = 
"registry+https://github.com/rust-lang/crates.io-index";
 checksum = "ac74c624d6b2d21f425f752262f42188365d7b8ff1aff74c82e45136510a4857"
 
 [[package]]
-name = "prettytable-rs"
-version = "0.8.0"
-source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "0fd04b170004fa2daccf418a7f8253aaf033c27760b5f225889024cf66d7ac2e"
-dependencies = [
- "atty",
- "csv",
- "encode_unicode",
- "lazy_static",
- "term",
- "unicode-width",
-]
-
-[[package]]
 name = "proc-macro-hack"
 version = "0.5.19"
 source = "registry+https://github.com/rust-lang/crates.io-index";
@@ -1106,12 +1108,6 @@ dependencies = [
 
 [[package]]
 name = "redox_syscall"
-version = "0.1.57"
-source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce"
-
-[[package]]
-name = "redox_syscall"
 version = "0.2.10"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "8383f39639269cde97d255a32bdb68c047337295414940c68bdd30c2e13203ff"
@@ -1120,17 +1116,6 @@ dependencies = [
 ]
 
 [[package]]
-name = "redox_users"
-version = "0.3.5"
-source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "de0737333e7a9502c789a36d7c7fa6092a49895d4faa31ca5df163857ded2e9d"
-dependencies = [
- "getrandom 0.1.16",
- "redox_syscall 0.1.57",
- "rust-argon2",
-]
-
-[[package]]
 name = "regex"
 version = "1.5.4"
 source = "registry+https://github.com/rust-lang/crates.io-index";
@@ -1154,18 +1139,6 @@ source = 
"registry+https://github.com/rust-lang/crates.io-index";
 checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
 
 [[package]]
-name = "rust-argon2"
-version = "0.8.3"
-source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "4b18820d944b33caa75a71378964ac46f58517c92b6ae5f762636247c09e78fb"
-dependencies = [
- "base64",
- "blake2b_simd",
- "constant_time_eq",
- "crossbeam-utils",
-]
-
-[[package]]
 name = "ryu"
 version = "1.0.5"
 source = "registry+https://github.com/rust-lang/crates.io-index";
@@ -1253,6 +1226,24 @@ source = 
"registry+https://github.com/rust-lang/crates.io-index";
 checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
 
 [[package]]
+name = "strum"
+version = "0.21.0"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "aaf86bbcfd1fa9670b7a129f64fc0c9fcbbfe4f1bc4210e9e98fe71ffc12cde2"
+
+[[package]]
+name = "strum_macros"
+version = "0.21.1"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "d06aaeeee809dbc59eb4556183dd927df67db1540de5be8d3ec0b6636358a5ec"
+dependencies = [
+ "heck",
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
 name = "subtle"
 version = "2.4.1"
 source = "registry+https://github.com/rust-lang/crates.io-index";
@@ -1270,17 +1261,6 @@ dependencies = [
 ]
 
 [[package]]
-name = "term"
-version = "0.5.2"
-source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "edd106a334b7657c10b7c540a0106114feadeb4dc314513e97df481d5d966f42"
-dependencies = [
- "byteorder",
- "dirs",
- "winapi",
-]
-
-[[package]]
 name = "thiserror"
 version = "1.0.29"
 source = "registry+https://github.com/rust-lang/crates.io-index";
diff --git a/python/src/functions.rs b/python/src/functions.rs
index 22a5ce4..9c52497 100644
--- a/python/src/functions.rs
+++ b/python/src/functions.rs
@@ -19,11 +19,9 @@ use crate::udaf;
 use crate::udf;
 use crate::{expression, types::PyDataType};
 use datafusion::arrow::datatypes::DataType;
-use datafusion::logical_plan;
+use datafusion::logical_plan::{self, Literal};
 use datafusion::physical_plan::functions::Volatility;
-use pyo3::{
-    exceptions::PyTypeError, prelude::*, types::PyTuple, wrap_pyfunction, 
Python,
-};
+use pyo3::{prelude::*, types::PyTuple, wrap_pyfunction, Python};
 use std::sync::Arc;
 
 /// Expression representing a column on the existing plan.
@@ -35,22 +33,40 @@ fn col(name: &str) -> expression::Expression {
     }
 }
 
+/// # A bridge type that converts PyAny data into datafusion literal
+///
+/// Note that the ordering here matters because it has to be from
+/// narrow to wider values because Python has duck typing so putting
+/// Int before Boolean results in a premature match.
+#[derive(FromPyObject)]
+enum PythonLiteral<'a> {
+    Boolean(bool),
+    Int(i64),
+    UInt(u64),
+    Float(f64),
+    Str(&'a str),
+    Binary(&'a [u8]),
+}
+
+impl<'a> Literal for PythonLiteral<'a> {
+    fn lit(&self) -> logical_plan::Expr {
+        match self {
+            PythonLiteral::Boolean(val) => val.lit(),
+            PythonLiteral::Int(val) => val.lit(),
+            PythonLiteral::UInt(val) => val.lit(),
+            PythonLiteral::Float(val) => val.lit(),
+            PythonLiteral::Str(val) => val.lit(),
+            PythonLiteral::Binary(val) => val.lit(),
+        }
+    }
+}
+
 /// Expression representing a constant value
 #[pyfunction]
 #[pyo3(text_signature = "(value)")]
 fn lit(value: &PyAny) -> PyResult<expression::Expression> {
-    let expr = if let Ok(v) = value.extract::<i64>() {
-        logical_plan::lit(v)
-    } else if let Ok(v) = value.extract::<f64>() {
-        logical_plan::lit(v)
-    } else if let Ok(v) = value.extract::<String>() {
-        logical_plan::lit(v)
-    } else {
-        return Err(PyTypeError::new_err(format!(
-            "Unsupported value {}, expected one of i64, f64, or String type",
-            value
-        )));
-    };
+    let py_lit = value.extract::<PythonLiteral>()?;
+    let expr = py_lit.lit();
     Ok(expression::Expression { expr })
 }
 
diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py
index c6c1cf6..67cf502 100644
--- a/python/tests/test_functions.py
+++ b/python/tests/test_functions.py
@@ -34,7 +34,14 @@ def df():
 
 def test_lit(df):
     """test lit function"""
-    df = df.select(f.lit(1), f.lit("1"), f.lit("OK"), f.lit(3.14))
+    df = df.select(
+        f.lit(1),
+        f.lit("1"),
+        f.lit("OK"),
+        f.lit(3.14),
+        f.lit(True),
+        f.lit(b"hello world"),
+    )
     result = df.collect()
     assert len(result) == 1
     result = result[0]
@@ -42,6 +49,8 @@ def test_lit(df):
     assert result.column(1) == pa.array(["1"] * 3)
     assert result.column(2) == pa.array(["OK"] * 3)
     assert result.column(3) == pa.array([3.14] * 3)
+    assert result.column(4) == pa.array([True] * 3)
+    assert result.column(5) == pa.array([b"hello world"] * 3)
 
 
 def test_lit_arith(df):

Reply via email to