This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new d62cbdf  Add tests for recently added functionality (#199)
d62cbdf is described below

commit d62cbdfa4d963e01627b33e2b6a5d41b836fb940
Author: Andy Grove <[email protected]>
AuthorDate: Mon Feb 20 08:39:06 2023 -0700

    Add tests for recently added functionality (#199)
---
 datafusion/tests/test_expr.py | 83 +++++++++++++++++++++++++++++++++++++++++++
 src/sql/logical.rs            | 23 ++++++++++++
 2 files changed, 106 insertions(+)

diff --git a/datafusion/tests/test_expr.py b/datafusion/tests/test_expr.py
new file mode 100644
index 0000000..4a7db87
--- /dev/null
+++ b/datafusion/tests/test_expr.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datafusion import SessionContext
+from datafusion.expr import (
+    Projection,
+    Filter,
+    Aggregate,
+    Limit,
+    Sort,
+    TableScan,
+)
+import pytest
+
+
[email protected]
+def test_ctx():
+    ctx = SessionContext()
+    ctx.register_csv("test", "testing/data/csv/aggregate_test_100.csv")
+    return ctx
+
+
+def test_projection(test_ctx):
+    df = test_ctx.sql("select c1, 123, c1 < 123 from test")
+    plan = df.logical_plan()
+
+    plan = plan.to_variant()
+    assert isinstance(plan, Projection)
+
+    plan = plan.input().to_variant()
+    assert isinstance(plan, TableScan)
+
+
+def test_filter(test_ctx):
+    df = test_ctx.sql("select c1 from test WHERE c1 > 5")
+    plan = df.logical_plan()
+
+    plan = plan.to_variant()
+    assert isinstance(plan, Projection)
+
+    plan = plan.input().to_variant()
+    assert isinstance(plan, Filter)
+
+
+def test_limit(test_ctx):
+    df = test_ctx.sql("select c1 from test LIMIT 10")
+    plan = df.logical_plan()
+
+    plan = plan.to_variant()
+    assert isinstance(plan, Limit)
+
+
+def test_aggregate(test_ctx):
+    df = test_ctx.sql("select c1, COUNT(*) from test GROUP BY c1")
+    plan = df.logical_plan()
+
+    plan = plan.to_variant()
+    assert isinstance(plan, Projection)
+
+    plan = plan.input().to_variant()
+    assert isinstance(plan, Aggregate)
+
+
+def test_sort(test_ctx):
+    df = test_ctx.sql("select c1 from test order by c1")
+    plan = df.logical_plan()
+
+    plan = plan.to_variant()
+    assert isinstance(plan, Sort)
diff --git a/src/sql/logical.rs b/src/sql/logical.rs
index dcd7baa..08d1961 100644
--- a/src/sql/logical.rs
+++ b/src/sql/logical.rs
@@ -17,6 +17,13 @@
 
 use std::sync::Arc;
 
+use crate::errors::py_runtime_err;
+use crate::expr::aggregate::PyAggregate;
+use crate::expr::filter::PyFilter;
+use crate::expr::limit::PyLimit;
+use crate::expr::projection::PyProjection;
+use crate::expr::sort::PySort;
+use crate::expr::table_scan::PyTableScan;
 use datafusion_expr::LogicalPlan;
 use pyo3::prelude::*;
 
@@ -37,6 +44,22 @@ impl PyLogicalPlan {
 
 #[pymethods]
 impl PyLogicalPlan {
+    /// Return the specific logical operator
+    fn to_variant(&self, py: Python) -> PyResult<PyObject> {
+        Python::with_gil(|_| match self.plan.as_ref() {
+            LogicalPlan::Projection(plan) => 
Ok(PyProjection::from(plan.clone()).into_py(py)),
+            LogicalPlan::TableScan(plan) => 
Ok(PyTableScan::from(plan.clone()).into_py(py)),
+            LogicalPlan::Filter(plan) => 
Ok(PyFilter::from(plan.clone()).into_py(py)),
+            LogicalPlan::Limit(plan) => 
Ok(PyLimit::from(plan.clone()).into_py(py)),
+            LogicalPlan::Sort(plan) => 
Ok(PySort::from(plan.clone()).into_py(py)),
+            LogicalPlan::Aggregate(plan) => 
Ok(PyAggregate::from(plan.clone()).into_py(py)),
+            other => Err(py_runtime_err(format!(
+                "Cannot convert this plan to a LogicalNode: {:?}",
+                other
+            ))),
+        })
+    }
+
     /// Get the inputs to this plan
     pub fn inputs(&self) -> Vec<PyLogicalPlan> {
         let mut inputs = vec![];

Reply via email to