This is an automated email from the ASF dual-hosted git repository.
jiayuliu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 4b15fa5 fix lit function to allow multiple types (#1130)
4b15fa5 is described below
commit 4b15fa56067af50c593049c6cf43b79cbfa6d183
Author: Jiayu Liu <[email protected]>
AuthorDate: Sun Oct 17 22:42:51 2021 +0800
fix lit function to allow multiple types (#1130)
---
python/src/functions.rs | 22 +++++++++++++----
python/tests/generic.py | 2 +-
python/tests/test_df_sql.py | 1 -
python/tests/test_functions.py | 54 ++++++++++++++++++++++++++++++++++++++++++
4 files changed, 72 insertions(+), 7 deletions(-)
diff --git a/python/src/functions.rs b/python/src/functions.rs
index 6633f0a..8611ca5 100644
--- a/python/src/functions.rs
+++ b/python/src/functions.rs
@@ -21,7 +21,9 @@ use crate::{expression, types::PyDataType};
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_plan;
use datafusion::physical_plan::functions::Volatility;
-use pyo3::{prelude::*, types::PyTuple, wrap_pyfunction};
+use pyo3::{
+ exceptions::PyTypeError, prelude::*, types::PyTuple, wrap_pyfunction,
Python,
+};
use std::sync::Arc;
/// Expression representing a column on the existing plan.
@@ -36,10 +38,20 @@ fn col(name: &str) -> expression::Expression {
/// Expression representing a constant value
#[pyfunction]
#[pyo3(text_signature = "(value)")]
-fn lit(value: i32) -> expression::Expression {
- expression::Expression {
- expr: logical_plan::lit(value),
- }
+fn lit(value: &PyAny) -> PyResult<expression::Expression> {
+ let expr = if let Ok(v) = value.extract::<i64>() {
+ logical_plan::lit(v)
+ } else if let Ok(v) = value.extract::<f64>() {
+ logical_plan::lit(v)
+ } else if let Ok(v) = value.extract::<String>() {
+ logical_plan::lit(v)
+ } else {
+ return Err(PyTypeError::new_err(format!(
+ "Unsupported value {}, expected one of i64, f64, or String type",
+ value
+ )));
+ };
+ Ok(expression::Expression { expr })
}
#[pyfunction]
diff --git a/python/tests/generic.py b/python/tests/generic.py
index 8d5adaa..1f984a4 100644
--- a/python/tests/generic.py
+++ b/python/tests/generic.py
@@ -20,9 +20,9 @@ import datetime
import numpy as np
import pyarrow as pa
import pyarrow.csv
-import pyarrow.parquet as pq
# used to write parquet files
+import pyarrow.parquet as pq
def data():
diff --git a/python/tests/test_df_sql.py b/python/tests/test_df_sql.py
index ebc38b1..c6eac6b 100644
--- a/python/tests/test_df_sql.py
+++ b/python/tests/test_df_sql.py
@@ -26,7 +26,6 @@ def ctx():
def test_register_record_batches(ctx):
-
# create a RecordBatch and register it as memtable
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py
new file mode 100644
index 0000000..c6c1cf6
--- /dev/null
+++ b/python/tests/test_functions.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pyarrow as pa
+import pytest
+from datafusion import ExecutionContext
+from datafusion import functions as f
+
+
[email protected]
+def df():
+ ctx = ExecutionContext()
+ # create a RecordBatch and a new DataFrame from it
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array(["Hello", "World", "!"]), pa.array([4, 5, 6])],
+ names=["a", "b"],
+ )
+ return ctx.create_dataframe([[batch]])
+
+
+def test_lit(df):
+ """test lit function"""
+ df = df.select(f.lit(1), f.lit("1"), f.lit("OK"), f.lit(3.14))
+ result = df.collect()
+ assert len(result) == 1
+ result = result[0]
+ assert result.column(0) == pa.array([1] * 3)
+ assert result.column(1) == pa.array(["1"] * 3)
+ assert result.column(2) == pa.array(["OK"] * 3)
+ assert result.column(3) == pa.array([3.14] * 3)
+
+
+def test_lit_arith(df):
+ """test lit function within arithmatics"""
+ df = df.select(f.lit(1) + f.col("b"), f.concat(f.col("a"), f.lit("!")))
+ result = df.collect()
+ assert len(result) == 1
+ result = result[0]
+ assert result.column(0) == pa.array([5, 6, 7])
+ assert result.column(1) == pa.array(["Hello!", "World!", "!!"])