This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new d62cbdf Add tests for recently added functionality (#199)
d62cbdf is described below
commit d62cbdfa4d963e01627b33e2b6a5d41b836fb940
Author: Andy Grove <[email protected]>
AuthorDate: Mon Feb 20 08:39:06 2023 -0700
Add tests for recently added functionality (#199)
---
datafusion/tests/test_expr.py | 83 +++++++++++++++++++++++++++++++++++++++++++
src/sql/logical.rs | 23 ++++++++++++
2 files changed, 106 insertions(+)
diff --git a/datafusion/tests/test_expr.py b/datafusion/tests/test_expr.py
new file mode 100644
index 0000000..4a7db87
--- /dev/null
+++ b/datafusion/tests/test_expr.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datafusion import SessionContext
+from datafusion.expr import (
+ Projection,
+ Filter,
+ Aggregate,
+ Limit,
+ Sort,
+ TableScan,
+)
+import pytest
+
+
[email protected]
+def test_ctx():
+ ctx = SessionContext()
+ ctx.register_csv("test", "testing/data/csv/aggregate_test_100.csv")
+ return ctx
+
+
+def test_projection(test_ctx):
+ df = test_ctx.sql("select c1, 123, c1 < 123 from test")
+ plan = df.logical_plan()
+
+ plan = plan.to_variant()
+ assert isinstance(plan, Projection)
+
+ plan = plan.input().to_variant()
+ assert isinstance(plan, TableScan)
+
+
+def test_filter(test_ctx):
+ df = test_ctx.sql("select c1 from test WHERE c1 > 5")
+ plan = df.logical_plan()
+
+ plan = plan.to_variant()
+ assert isinstance(plan, Projection)
+
+ plan = plan.input().to_variant()
+ assert isinstance(plan, Filter)
+
+
+def test_limit(test_ctx):
+ df = test_ctx.sql("select c1 from test LIMIT 10")
+ plan = df.logical_plan()
+
+ plan = plan.to_variant()
+ assert isinstance(plan, Limit)
+
+
+def test_aggregate(test_ctx):
+ df = test_ctx.sql("select c1, COUNT(*) from test GROUP BY c1")
+ plan = df.logical_plan()
+
+ plan = plan.to_variant()
+ assert isinstance(plan, Projection)
+
+ plan = plan.input().to_variant()
+ assert isinstance(plan, Aggregate)
+
+
+def test_sort(test_ctx):
+ df = test_ctx.sql("select c1 from test order by c1")
+ plan = df.logical_plan()
+
+ plan = plan.to_variant()
+ assert isinstance(plan, Sort)
diff --git a/src/sql/logical.rs b/src/sql/logical.rs
index dcd7baa..08d1961 100644
--- a/src/sql/logical.rs
+++ b/src/sql/logical.rs
@@ -17,6 +17,13 @@
use std::sync::Arc;
+use crate::errors::py_runtime_err;
+use crate::expr::aggregate::PyAggregate;
+use crate::expr::filter::PyFilter;
+use crate::expr::limit::PyLimit;
+use crate::expr::projection::PyProjection;
+use crate::expr::sort::PySort;
+use crate::expr::table_scan::PyTableScan;
use datafusion_expr::LogicalPlan;
use pyo3::prelude::*;
@@ -37,6 +44,22 @@ impl PyLogicalPlan {
#[pymethods]
impl PyLogicalPlan {
+ /// Return the specific logical operator
+ fn to_variant(&self, py: Python) -> PyResult<PyObject> {
+ Python::with_gil(|_| match self.plan.as_ref() {
+ LogicalPlan::Projection(plan) =>
Ok(PyProjection::from(plan.clone()).into_py(py)),
+ LogicalPlan::TableScan(plan) =>
Ok(PyTableScan::from(plan.clone()).into_py(py)),
+ LogicalPlan::Filter(plan) =>
Ok(PyFilter::from(plan.clone()).into_py(py)),
+ LogicalPlan::Limit(plan) =>
Ok(PyLimit::from(plan.clone()).into_py(py)),
+ LogicalPlan::Sort(plan) =>
Ok(PySort::from(plan.clone()).into_py(py)),
+ LogicalPlan::Aggregate(plan) =>
Ok(PyAggregate::from(plan.clone()).into_py(py)),
+ other => Err(py_runtime_err(format!(
+ "Cannot convert this plan to a LogicalNode: {:?}",
+ other
+ ))),
+ })
+ }
+
/// Get the inputs to this plan
pub fn inputs(&self) -> Vec<PyLogicalPlan> {
let mut inputs = vec![];