This is an automated email from the ASF dual-hosted git repository.

jiayuliu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 4b15fa5  fix lit function to allow multiple types (#1130)
4b15fa5 is described below

commit 4b15fa56067af50c593049c6cf43b79cbfa6d183
Author: Jiayu Liu <[email protected]>
AuthorDate: Sun Oct 17 22:42:51 2021 +0800

    fix lit function to allow multiple types (#1130)
---
 python/src/functions.rs        | 22 +++++++++++++----
 python/tests/generic.py        |  2 +-
 python/tests/test_df_sql.py    |  1 -
 python/tests/test_functions.py | 54 ++++++++++++++++++++++++++++++++++++++++++
 4 files changed, 72 insertions(+), 7 deletions(-)

diff --git a/python/src/functions.rs b/python/src/functions.rs
index 6633f0a..8611ca5 100644
--- a/python/src/functions.rs
+++ b/python/src/functions.rs
@@ -21,7 +21,9 @@ use crate::{expression, types::PyDataType};
 use datafusion::arrow::datatypes::DataType;
 use datafusion::logical_plan;
 use datafusion::physical_plan::functions::Volatility;
-use pyo3::{prelude::*, types::PyTuple, wrap_pyfunction};
+use pyo3::{
+    exceptions::PyTypeError, prelude::*, types::PyTuple, wrap_pyfunction, 
Python,
+};
 use std::sync::Arc;
 
 /// Expression representing a column on the existing plan.
@@ -36,10 +38,20 @@ fn col(name: &str) -> expression::Expression {
 /// Expression representing a constant value
 #[pyfunction]
 #[pyo3(text_signature = "(value)")]
-fn lit(value: i32) -> expression::Expression {
-    expression::Expression {
-        expr: logical_plan::lit(value),
-    }
+fn lit(value: &PyAny) -> PyResult<expression::Expression> {
+    let expr = if let Ok(v) = value.extract::<i64>() {
+        logical_plan::lit(v)
+    } else if let Ok(v) = value.extract::<f64>() {
+        logical_plan::lit(v)
+    } else if let Ok(v) = value.extract::<String>() {
+        logical_plan::lit(v)
+    } else {
+        return Err(PyTypeError::new_err(format!(
+            "Unsupported value {}, expected one of i64, f64, or String type",
+            value
+        )));
+    };
+    Ok(expression::Expression { expr })
 }
 
 #[pyfunction]
diff --git a/python/tests/generic.py b/python/tests/generic.py
index 8d5adaa..1f984a4 100644
--- a/python/tests/generic.py
+++ b/python/tests/generic.py
@@ -20,9 +20,9 @@ import datetime
 import numpy as np
 import pyarrow as pa
 import pyarrow.csv
-import pyarrow.parquet as pq
 
 # used to write parquet files
+import pyarrow.parquet as pq
 
 
 def data():
diff --git a/python/tests/test_df_sql.py b/python/tests/test_df_sql.py
index ebc38b1..c6eac6b 100644
--- a/python/tests/test_df_sql.py
+++ b/python/tests/test_df_sql.py
@@ -26,7 +26,6 @@ def ctx():
 
 
 def test_register_record_batches(ctx):
-
     # create a RecordBatch and register it as memtable
     batch = pa.RecordBatch.from_arrays(
         [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py
new file mode 100644
index 0000000..c6c1cf6
--- /dev/null
+++ b/python/tests/test_functions.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pyarrow as pa
+import pytest
+from datafusion import ExecutionContext
+from datafusion import functions as f
+
+
[email protected]
+def df():
+    ctx = ExecutionContext()
+    # create a RecordBatch and a new DataFrame from it
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array(["Hello", "World", "!"]), pa.array([4, 5, 6])],
+        names=["a", "b"],
+    )
+    return ctx.create_dataframe([[batch]])
+
+
+def test_lit(df):
+    """test lit function"""
+    df = df.select(f.lit(1), f.lit("1"), f.lit("OK"), f.lit(3.14))
+    result = df.collect()
+    assert len(result) == 1
+    result = result[0]
+    assert result.column(0) == pa.array([1] * 3)
+    assert result.column(1) == pa.array(["1"] * 3)
+    assert result.column(2) == pa.array(["OK"] * 3)
+    assert result.column(3) == pa.array([3.14] * 3)
+
+
+def test_lit_arith(df):
+    """test lit function within arithmatics"""
+    df = df.select(f.lit(1) + f.col("b"), f.concat(f.col("a"), f.lit("!")))
+    result = df.collect()
+    assert len(result) == 1
+    result = result[0]
+    assert result.column(0) == pa.array([5, 6, 7])
+    assert result.column(1) == pa.array(["Hello!", "World!", "!!"])

Reply via email to