This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 4a65ee3 fix python crate with the changes to logical plan builder
(#650)
4a65ee3 is described below
commit 4a65ee37abd6319ed6d342f7ff22f46f4d800a03
Author: Jiayu Liu <[email protected]>
AuthorDate: Sat Jul 3 05:18:47 2021 +0800
fix python crate with the changes to logical plan builder (#650)
---
python/.gitignore | 1 +
python/Cargo.toml | 2 +-
python/src/dataframe.rs | 16 ++++++++--------
python/tests/generic.py | 12 +++---------
python/tests/test_sql.py | 8 ++------
python/tests/test_udaf.py | 3 ++-
6 files changed, 17 insertions(+), 25 deletions(-)
diff --git a/python/.gitignore b/python/.gitignore
index 48fe4db..feb402e 100644
--- a/python/.gitignore
+++ b/python/.gitignore
@@ -18,3 +18,4 @@
/target
Cargo.lock
venv
+.venv
diff --git a/python/Cargo.toml b/python/Cargo.toml
index 8f1480d..777e427 100644
--- a/python/Cargo.toml
+++ b/python/Cargo.toml
@@ -31,7 +31,7 @@ libc = "0.2"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread",
"sync"] }
rand = "0.7"
pyo3 = { version = "0.13.2", features = ["extension-module"] }
-datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev =
"c92079dfb3045a9a46d12c3bc22361a44d11b8bc" }
+datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev =
"fddab22aa562750f67385a961497dc020b18c4b2" }
[lib]
name = "datafusion"
diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs
index 8ceac64..89c85f9 100644
--- a/python/src/dataframe.rs
+++ b/python/src/dataframe.rs
@@ -51,7 +51,7 @@ impl DataFrame {
#[args(args = "*")]
fn select(&self, args: &PyTuple) -> PyResult<Self> {
let expressions = expression::from_tuple(args)?;
- let builder = LogicalPlanBuilder::from(&self.plan);
+ let builder = LogicalPlanBuilder::from(self.plan.clone());
let builder =
errors::wrap(builder.project(expressions.into_iter().map(|e|
e.expr)))?;
let plan = errors::wrap(builder.build())?;
@@ -64,7 +64,7 @@ impl DataFrame {
/// Filter according to the `predicate` expression
fn filter(&self, predicate: expression::Expression) -> PyResult<Self> {
- let builder = LogicalPlanBuilder::from(&self.plan);
+ let builder = LogicalPlanBuilder::from(self.plan.clone());
let builder = errors::wrap(builder.filter(predicate.expr))?;
let plan = errors::wrap(builder.build())?;
@@ -80,7 +80,7 @@ impl DataFrame {
group_by: Vec<expression::Expression>,
aggs: Vec<expression::Expression>,
) -> PyResult<Self> {
- let builder = LogicalPlanBuilder::from(&self.plan);
+ let builder = LogicalPlanBuilder::from(self.plan.clone());
let builder = errors::wrap(builder.aggregate(
group_by.into_iter().map(|e| e.expr),
aggs.into_iter().map(|e| e.expr),
@@ -96,7 +96,7 @@ impl DataFrame {
/// Sort by specified sorting expressions
fn sort(&self, exprs: Vec<expression::Expression>) -> PyResult<Self> {
let exprs = exprs.into_iter().map(|e| e.expr);
- let builder = LogicalPlanBuilder::from(&self.plan);
+ let builder = LogicalPlanBuilder::from(self.plan.clone());
let builder = errors::wrap(builder.sort(exprs))?;
let plan = errors::wrap(builder.build())?;
Ok(DataFrame {
@@ -107,7 +107,7 @@ impl DataFrame {
/// Limits the plan to return at most `count` rows
fn limit(&self, count: usize) -> PyResult<Self> {
- let builder = LogicalPlanBuilder::from(&self.plan);
+ let builder = LogicalPlanBuilder::from(self.plan.clone());
let builder = errors::wrap(builder.limit(count))?;
let plan = errors::wrap(builder.build())?;
@@ -141,7 +141,7 @@ impl DataFrame {
/// Returns the join of two DataFrames `on`.
fn join(&self, right: &DataFrame, on: Vec<&str>, how: &str) ->
PyResult<Self> {
- let builder = LogicalPlanBuilder::from(&self.plan);
+ let builder = LogicalPlanBuilder::from(self.plan.clone());
let join_type = match how {
"inner" => JoinType::Inner,
@@ -162,8 +162,8 @@ impl DataFrame {
let builder = errors::wrap(builder.join(
&right.plan,
join_type,
- on.as_slice(),
- on.as_slice(),
+ on.clone(),
+ on,
))?;
let plan = errors::wrap(builder.build())?;
diff --git a/python/tests/generic.py b/python/tests/generic.py
index e61542e..5871c5e 100644
--- a/python/tests/generic.py
+++ b/python/tests/generic.py
@@ -49,9 +49,7 @@ def data_datetime(f):
datetime.datetime.now() - datetime.timedelta(days=1),
datetime.datetime.now() + datetime.timedelta(days=1),
]
- return pa.array(
- data, type=pa.timestamp(f), mask=np.array([False, True, False])
- )
+ return pa.array(data, type=pa.timestamp(f), mask=np.array([False, True,
False]))
def data_date32():
@@ -60,9 +58,7 @@ def data_date32():
datetime.date(1980, 1, 1),
datetime.date(2030, 1, 1),
]
- return pa.array(
- data, type=pa.date32(), mask=np.array([False, True, False])
- )
+ return pa.array(data, type=pa.date32(), mask=np.array([False, True,
False]))
def data_timedelta(f):
@@ -71,9 +67,7 @@ def data_timedelta(f):
datetime.timedelta(days=1),
datetime.timedelta(seconds=1),
]
- return pa.array(
- data, type=pa.duration(f), mask=np.array([False, True, False])
- )
+ return pa.array(data, type=pa.duration(f), mask=np.array([False, True,
False]))
def data_binary_other():
diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py
index 361526d..62d6c09 100644
--- a/python/tests/test_sql.py
+++ b/python/tests/test_sql.py
@@ -112,9 +112,7 @@ def test_cast(ctx, tmp_path):
"float",
]
- select = ", ".join(
- [f"CAST(9 AS {t}) AS A{i}" for i, t in enumerate(valid_types)]
- )
+ select = ", ".join([f"CAST(9 AS {t}) AS A{i}" for i, t in
enumerate(valid_types)])
# can execute, which implies that we can cast
ctx.sql(f"SELECT {select} FROM t").collect()
@@ -143,9 +141,7 @@ def test_udf(
ctx, tmp_path, fn, input_types, output_type, input_values, expected_values
):
# write to disk
- path = helpers.write_parquet(
- tmp_path / "a.parquet", pa.array(input_values)
- )
+ path = helpers.write_parquet(tmp_path / "a.parquet",
pa.array(input_values))
ctx.register_parquet("t", path)
ctx.register_udf("udf", fn, input_types, output_type)
diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py
index b24c08d..103d967 100644
--- a/python/tests/test_udaf.py
+++ b/python/tests/test_udaf.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
+from typing import List
import pyarrow as pa
import pyarrow.compute as pc
import pytest
@@ -30,7 +31,7 @@ class Accumulator:
def __init__(self):
self._sum = pa.scalar(0.0)
- def to_scalars(self) -> [pa.Scalar]:
+ def to_scalars(self) -> List[pa.Scalar]:
return [self._sum]
def update(self, values: pa.Array) -> None: