This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new ca8b055  Add experimental support for executing SQL with Polars and 
Pandas (#190)
ca8b055 is described below

commit ca8b0551cd9dc7f6832f94a8b44f31cb5b3aa218
Author: Andy Grove <[email protected]>
AuthorDate: Mon Feb 20 10:03:21 2023 -0700

    Add experimental support for executing SQL with Polars and Pandas (#190)
---
 README.md                        |  3 ++
 datafusion/pandas.py             | 62 +++++++++++++++++++++++++++++
 datafusion/polars.py             | 85 ++++++++++++++++++++++++++++++++++++++++
 datafusion/tests/test_expr.py    | 39 +++++++++++++++---
 datafusion/tests/test_imports.py | 12 +++++-
 examples/README.md               | 12 ++++++
 examples/sql-on-pandas.py        | 26 ++++++++++++
 examples/sql-on-polars.py        | 28 +++++++++++++
 src/expr.rs                      | 32 +++++++++++++++
 src/expr/aggregate_expr.rs       | 73 ++++++++++++++++++++++++++++++++++
 src/expr/binary_expr.rs          | 57 +++++++++++++++++++++++++++
 src/expr/column.rs               | 60 ++++++++++++++++++++++++++++
 src/expr/literal.rs              | 74 ++++++++++++++++++++++++++++++++++
 src/expr/projection.rs           |  6 +++
 src/expr/table_scan.rs           | 15 +++++++
 src/sql/logical.rs               |  8 +++-
 16 files changed, 583 insertions(+), 9 deletions(-)

diff --git a/README.md b/README.md
index ab89ff6..d465ebc 100644
--- a/README.md
+++ b/README.md
@@ -36,6 +36,9 @@ from having to lock the GIL when running those operations.
 Its query engine, DataFusion, is written in 
[Rust](https://www.rust-lang.org/), which makes strong assumptions
 about thread safety and lack of memory leaks.
 
+There is also experimental support for executing SQL against other DataFrame 
libraries, such as Polars, Pandas, and any 
+drop-in replacements for Pandas.
+
 Technically, zero-copy is achieved via the [c data 
interface](https://arrow.apache.org/docs/format/CDataInterface.html).
 
 ## Example Usage
diff --git a/datafusion/pandas.py b/datafusion/pandas.py
new file mode 100644
index 0000000..36e4ba2
--- /dev/null
+++ b/datafusion/pandas.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pandas as pd
+import datafusion
+from datafusion.expr import Projection, TableScan, Column
+
+
+class SessionContext:
+    def __init__(self):
+        self.datafusion_ctx = datafusion.SessionContext()
+        self.parquet_tables = {}
+
+    def register_parquet(self, name, path):
+        self.parquet_tables[name] = path
+        self.datafusion_ctx.register_parquet(name, path)
+
+    def to_pandas_expr(self, expr):
+
+        # get Python wrapper for logical expression
+        expr = expr.to_variant()
+
+        if isinstance(expr, Column):
+            return expr.name()
+        else:
+            raise Exception("unsupported expression: {}".format(expr))
+
+    def to_pandas_df(self, plan):
+        # recurse down first to translate inputs into pandas data frames
+        inputs = [self.to_pandas_df(x) for x in plan.inputs()]
+
+        # get Python wrapper for logical operator node
+        node = plan.to_variant()
+
+        if isinstance(node, Projection):
+            args = [self.to_pandas_expr(expr) for expr in node.projections()]
+            return inputs[0][args]
+        elif isinstance(node, TableScan):
+            return pd.read_parquet(self.parquet_tables[node.table_name()])
+        else:
+            raise Exception(
+                "unsupported logical operator: {}".format(type(node))
+            )
+
+    def sql(self, sql):
+        datafusion_df = self.datafusion_ctx.sql(sql)
+        plan = datafusion_df.logical_plan()
+        return self.to_pandas_df(plan)
diff --git a/datafusion/polars.py b/datafusion/polars.py
new file mode 100644
index 0000000..e29e511
--- /dev/null
+++ b/datafusion/polars.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import polars
+import datafusion
+from datafusion.expr import Projection, TableScan, Aggregate
+from datafusion.expr import Column, AggregateFunction
+
+
+class SessionContext:
+    def __init__(self):
+        self.datafusion_ctx = datafusion.SessionContext()
+        self.parquet_tables = {}
+
+    def register_parquet(self, name, path):
+        self.parquet_tables[name] = path
+        self.datafusion_ctx.register_parquet(name, path)
+
+    def to_polars_expr(self, expr):
+
+        # get Python wrapper for logical expression
+        expr = expr.to_variant()
+
+        if isinstance(expr, Column):
+            return polars.col(expr.name())
+        else:
+            raise Exception("unsupported expression: {}".format(expr))
+
+    def to_polars_df(self, plan):
+        # recurse down first to translate inputs into Polars data frames
+        inputs = [self.to_polars_df(x) for x in plan.inputs()]
+
+        # get Python wrapper for logical operator node
+        node = plan.to_variant()
+
+        if isinstance(node, Projection):
+            args = [self.to_polars_expr(expr) for expr in node.projections()]
+            return inputs[0].select(*args)
+        elif isinstance(node, Aggregate):
+            groupby_expr = [
+                self.to_polars_expr(expr) for expr in node.group_by_exprs()
+            ]
+            aggs = []
+            for expr in node.aggregate_exprs():
+                expr = expr.to_variant()
+                if isinstance(expr, AggregateFunction):
+                    if expr.aggregate_type() == "COUNT":
+                        aggs.append(polars.count().alias("{}".format(expr)))
+                    else:
+                        raise Exception(
+                            "Unsupported aggregate function {}".format(
+                                expr.aggregate_type()
+                            )
+                        )
+                else:
+                    raise Exception(
+                        "Unsupported aggregate function {}".format(expr)
+                    )
+            df = inputs[0].groupby(groupby_expr).agg(aggs)
+            return df
+        elif isinstance(node, TableScan):
+            return polars.read_parquet(self.parquet_tables[node.table_name()])
+        else:
+            raise Exception(
+                "unsupported logical operator: {}".format(type(node))
+            )
+
+    def sql(self, sql):
+        datafusion_df = self.datafusion_ctx.sql(sql)
+        plan = datafusion_df.logical_plan()
+        return self.to_polars_df(plan)
diff --git a/datafusion/tests/test_expr.py b/datafusion/tests/test_expr.py
index 4a7db87..143eea6 100644
--- a/datafusion/tests/test_expr.py
+++ b/datafusion/tests/test_expr.py
@@ -16,6 +16,7 @@
 # under the License.
 
 from datafusion import SessionContext
+from datafusion.expr import Column, Literal, BinaryExpr, AggregateFunction
 from datafusion.expr import (
     Projection,
     Filter,
@@ -41,6 +42,24 @@ def test_projection(test_ctx):
     plan = plan.to_variant()
     assert isinstance(plan, Projection)
 
+    expr = plan.projections()
+
+    col1 = expr[0].to_variant()
+    assert isinstance(col1, Column)
+    assert col1.name() == "c1"
+    assert col1.qualified_name() == "test.c1"
+
+    col2 = expr[1].to_variant()
+    assert isinstance(col2, Literal)
+    assert col2.data_type() == "Int64"
+    assert col2.value_i64() == 123
+
+    col3 = expr[2].to_variant()
+    assert isinstance(col3, BinaryExpr)
+    assert isinstance(col3.left().to_variant(), Column)
+    assert col3.op() == "<"
+    assert isinstance(col3.right().to_variant(), Literal)
+
     plan = plan.input().to_variant()
     assert isinstance(plan, TableScan)
 
@@ -64,15 +83,23 @@ def test_limit(test_ctx):
     assert isinstance(plan, Limit)
 
 
-def test_aggregate(test_ctx):
-    df = test_ctx.sql("select c1, COUNT(*) from test GROUP BY c1")
+def test_aggregate_query(test_ctx):
+    df = test_ctx.sql("select c1, count(*) from test group by c1")
     plan = df.logical_plan()
 
-    plan = plan.to_variant()
-    assert isinstance(plan, Projection)
+    projection = plan.to_variant()
+    assert isinstance(projection, Projection)
 
-    plan = plan.input().to_variant()
-    assert isinstance(plan, Aggregate)
+    aggregate = projection.input().to_variant()
+    assert isinstance(aggregate, Aggregate)
+
+    col1 = aggregate.group_by_exprs()[0].to_variant()
+    assert isinstance(col1, Column)
+    assert col1.name() == "c1"
+    assert col1.qualified_name() == "test.c1"
+
+    col2 = aggregate.aggregate_exprs()[0].to_variant()
+    assert isinstance(col2, AggregateFunction)
 
 
 def test_sort(test_ctx):
diff --git a/datafusion/tests/test_imports.py b/datafusion/tests/test_imports.py
index dfc1f65..40b005b 100644
--- a/datafusion/tests/test_imports.py
+++ b/datafusion/tests/test_imports.py
@@ -33,6 +33,10 @@ from datafusion.common import (
 
 from datafusion.expr import (
     Expr,
+    Column,
+    Literal,
+    BinaryExpr,
+    AggregateFunction,
     Projection,
     TableScan,
     Filter,
@@ -59,9 +63,15 @@ def test_class_module_is_datafusion():
     ]:
         assert klass.__module__ == "datafusion"
 
-    for klass in [Expr, Projection, TableScan, Aggregate, Sort, Limit, Filter]:
+    # expressions
+    for klass in [Expr, Column, Literal, BinaryExpr, AggregateFunction]:
         assert klass.__module__ == "datafusion.expr"
 
+    # operators
+    for klass in [Projection, TableScan, Aggregate, Sort, Limit, Filter]:
+        assert klass.__module__ == "datafusion.expr"
+
+    # schema
     for klass in [DFField, DFSchema]:
         assert klass.__module__ == "datafusion.common"
 
diff --git a/examples/README.md b/examples/README.md
index a3ae0ba..e736366 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -19,9 +19,21 @@
 
 # DataFusion Python Examples
 
+Some of the examples rely on data which can be downloaded from the following 
site:
+
+- https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page
+
+Here is a direct link to the file used in the examples:
+
+- 
https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2021-01.parquet
+
+## Examples
+
 - [Query a Parquet file using SQL](./sql-parquet.py)
 - [Query a Parquet file using the DataFrame API](./dataframe-parquet.py)
 - [Run a SQL query and store the results in a Pandas 
DataFrame](./sql-to-pandas.py)
 - [Query PyArrow Data](./query-pyarrow-data.py)
 - [Register a Python UDF with DataFusion](./python-udf.py)
 - [Register a Python UDAF with DataFusion](./python-udaf.py)
+- [Executing SQL on Polars](./sql-on-polars.py)
+- [Executing SQL on Pandas](./sql-on-pandas.py)
diff --git a/examples/sql-on-pandas.py b/examples/sql-on-pandas.py
new file mode 100644
index 0000000..8a2d593
--- /dev/null
+++ b/examples/sql-on-pandas.py
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datafusion.pandas import SessionContext
+
+
+ctx = SessionContext()
+ctx.register_parquet(
+    "taxi", "/mnt/bigdata/nyctaxi/yellow_tripdata_2021-01.parquet"
+)
+df = ctx.sql("select passenger_count from taxi")
+print(df)
diff --git a/examples/sql-on-polars.py b/examples/sql-on-polars.py
new file mode 100644
index 0000000..0173b68
--- /dev/null
+++ b/examples/sql-on-polars.py
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datafusion.polars import SessionContext
+
+
+ctx = SessionContext()
+ctx.register_parquet(
+    "taxi", "/mnt/bigdata/nyctaxi/yellow_tripdata_2021-01.parquet"
+)
+df = ctx.sql(
+    "select passenger_count, count(*) from taxi group by passenger_count"
+)
+print(df)
diff --git a/src/expr.rs b/src/expr.rs
index adb9e55..ba01c99 100644
--- a/src/expr.rs
+++ b/src/expr.rs
@@ -22,11 +22,20 @@ use datafusion::arrow::datatypes::DataType;
 use datafusion::arrow::pyarrow::PyArrowType;
 use datafusion_expr::{col, lit, Cast, Expr, GetIndexedField};
 
+use crate::errors::py_runtime_err;
+use crate::expr::aggregate_expr::PyAggregateFunction;
+use crate::expr::binary_expr::PyBinaryExpr;
+use crate::expr::column::PyColumn;
+use crate::expr::literal::PyLiteral;
 use datafusion::scalar::ScalarValue;
 
 pub mod aggregate;
+pub mod aggregate_expr;
+pub mod binary_expr;
+pub mod column;
 pub mod filter;
 pub mod limit;
+pub mod literal;
 pub mod logical_node;
 pub mod projection;
 pub mod sort;
@@ -53,6 +62,22 @@ impl From<Expr> for PyExpr {
 
 #[pymethods]
 impl PyExpr {
+    /// Return the specific expression
+    fn to_variant(&self, py: Python) -> PyResult<PyObject> {
+        Python::with_gil(|_| match &self.expr {
+            Expr::Column(col) => Ok(PyColumn::from(col.clone()).into_py(py)),
+            Expr::Literal(value) => 
Ok(PyLiteral::from(value.clone()).into_py(py)),
+            Expr::BinaryExpr(expr) => 
Ok(PyBinaryExpr::from(expr.clone()).into_py(py)),
+            Expr::AggregateFunction(expr) => {
+                Ok(PyAggregateFunction::from(expr.clone()).into_py(py))
+            }
+            other => Err(py_runtime_err(format!(
+                "Cannot convert this Expr to a Python object: {:?}",
+                other
+            ))),
+        })
+    }
+
     fn __richcmp__(&self, other: PyExpr, op: CompareOp) -> PyExpr {
         let expr = match op {
             CompareOp::Lt => self.expr.clone().lt(other.expr),
@@ -144,7 +169,14 @@ impl PyExpr {
 
 /// Initializes the `expr` module to match the pattern of `datafusion-expr` 
https://docs.rs/datafusion-expr/latest/datafusion_expr/
 pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
+    // expressions
     m.add_class::<PyExpr>()?;
+    m.add_class::<PyColumn>()?;
+    m.add_class::<PyLiteral>()?;
+    m.add_class::<PyBinaryExpr>()?;
+    m.add_class::<PyLiteral>()?;
+    m.add_class::<PyAggregateFunction>()?;
+    // operators
     m.add_class::<table_scan::PyTableScan>()?;
     m.add_class::<projection::PyProjection>()?;
     m.add_class::<filter::PyFilter>()?;
diff --git a/src/expr/aggregate_expr.rs b/src/expr/aggregate_expr.rs
new file mode 100644
index 0000000..1801051
--- /dev/null
+++ b/src/expr/aggregate_expr.rs
@@ -0,0 +1,73 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::expr::PyExpr;
+use datafusion_expr::expr::AggregateFunction;
+use pyo3::prelude::*;
+use std::fmt::{Display, Formatter};
+
+#[pyclass(name = "AggregateFunction", module = "datafusion.expr", subclass)]
+#[derive(Clone)]
+pub struct PyAggregateFunction {
+    aggr: AggregateFunction,
+}
+
+impl From<PyAggregateFunction> for AggregateFunction {
+    fn from(aggr: PyAggregateFunction) -> Self {
+        aggr.aggr
+    }
+}
+
+impl From<AggregateFunction> for PyAggregateFunction {
+    fn from(aggr: AggregateFunction) -> PyAggregateFunction {
+        PyAggregateFunction { aggr }
+    }
+}
+
+impl Display for PyAggregateFunction {
+    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
+        let args: Vec<String> = self.aggr.args.iter().map(|expr| 
expr.to_string()).collect();
+        write!(f, "{}({})", self.aggr.fun, args.join(", "))
+    }
+}
+
+#[pymethods]
+impl PyAggregateFunction {
+    /// Get the aggregate type, such as "MIN", or "MAX"
+    fn aggregate_type(&self) -> String {
+        format!("{}", self.aggr.fun)
+    }
+
+    /// is this a distinct aggregate such as `COUNT(DISTINCT expr)`
+    fn is_distinct(&self) -> bool {
+        self.aggr.distinct
+    }
+
+    /// Get the arguments to the aggregate function
+    fn args(&self) -> Vec<PyExpr> {
+        self.aggr
+            .args
+            .iter()
+            .map(|expr| PyExpr::from(expr.clone()))
+            .collect()
+    }
+
+    /// Get a String representation of this column
+    fn __repr__(&self) -> String {
+        format!("{}", self)
+    }
+}
diff --git a/src/expr/binary_expr.rs b/src/expr/binary_expr.rs
new file mode 100644
index 0000000..5f382b7
--- /dev/null
+++ b/src/expr/binary_expr.rs
@@ -0,0 +1,57 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::expr::PyExpr;
+use datafusion_expr::BinaryExpr;
+use pyo3::prelude::*;
+
+#[pyclass(name = "BinaryExpr", module = "datafusion.expr", subclass)]
+#[derive(Clone)]
+pub struct PyBinaryExpr {
+    expr: BinaryExpr,
+}
+
+impl From<PyBinaryExpr> for BinaryExpr {
+    fn from(expr: PyBinaryExpr) -> Self {
+        expr.expr
+    }
+}
+
+impl From<BinaryExpr> for PyBinaryExpr {
+    fn from(expr: BinaryExpr) -> PyBinaryExpr {
+        PyBinaryExpr { expr }
+    }
+}
+
+#[pymethods]
+impl PyBinaryExpr {
+    fn left(&self) -> PyExpr {
+        self.expr.left.as_ref().clone().into()
+    }
+
+    fn right(&self) -> PyExpr {
+        self.expr.right.as_ref().clone().into()
+    }
+
+    fn op(&self) -> String {
+        format!("{}", self.expr.op)
+    }
+
+    fn __repr__(&self) -> PyResult<String> {
+        Ok(format!("{}", self.expr))
+    }
+}
diff --git a/src/expr/column.rs b/src/expr/column.rs
new file mode 100644
index 0000000..16b8bce
--- /dev/null
+++ b/src/expr/column.rs
@@ -0,0 +1,60 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion_common::Column;
+use pyo3::prelude::*;
+
+#[pyclass(name = "Column", module = "datafusion.expr", subclass)]
+#[derive(Clone)]
+pub struct PyColumn {
+    pub col: Column,
+}
+
+impl PyColumn {
+    pub fn new(col: Column) -> Self {
+        Self { col }
+    }
+}
+
+impl From<Column> for PyColumn {
+    fn from(col: Column) -> PyColumn {
+        PyColumn { col }
+    }
+}
+
+#[pymethods]
+impl PyColumn {
+    /// Get the column name
+    fn name(&self) -> String {
+        self.col.name.clone()
+    }
+
+    /// Get the column relation
+    fn relation(&self) -> Option<String> {
+        self.col.relation.clone()
+    }
+
+    /// Get the fully-qualified column name
+    fn qualified_name(&self) -> String {
+        self.col.flat_name()
+    }
+
+    /// Get a String representation of this column
+    fn __repr__(&self) -> String {
+        self.qualified_name()
+    }
+}
diff --git a/src/expr/literal.rs b/src/expr/literal.rs
new file mode 100644
index 0000000..27674ce
--- /dev/null
+++ b/src/expr/literal.rs
@@ -0,0 +1,74 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::errors::py_runtime_err;
+use datafusion_common::ScalarValue;
+use pyo3::prelude::*;
+
+#[pyclass(name = "Literal", module = "datafusion.expr", subclass)]
+#[derive(Clone)]
+pub struct PyLiteral {
+    pub value: ScalarValue,
+}
+
+impl From<PyLiteral> for ScalarValue {
+    fn from(lit: PyLiteral) -> ScalarValue {
+        lit.value
+    }
+}
+
+impl From<ScalarValue> for PyLiteral {
+    fn from(value: ScalarValue) -> PyLiteral {
+        PyLiteral { value }
+    }
+}
+
+#[pymethods]
+impl PyLiteral {
+    /// Get the data type of this literal value
+    fn data_type(&self) -> String {
+        format!("{}", self.value.get_datatype())
+    }
+
+    fn value_i32(&self) -> PyResult<i32> {
+        if let ScalarValue::Int32(Some(n)) = &self.value {
+            Ok(*n)
+        } else {
+            Err(py_runtime_err("Cannot access value as i32"))
+        }
+    }
+
+    fn value_i64(&self) -> PyResult<i64> {
+        if let ScalarValue::Int64(Some(n)) = &self.value {
+            Ok(*n)
+        } else {
+            Err(py_runtime_err("Cannot access value as i64"))
+        }
+    }
+
+    fn value_str(&self) -> PyResult<String> {
+        if let ScalarValue::Utf8(Some(str)) = &self.value {
+            Ok(str.clone())
+        } else {
+            Err(py_runtime_err("Cannot access value as string"))
+        }
+    }
+
+    fn __repr__(&self) -> PyResult<String> {
+        Ok(format!("{}", self.value))
+    }
+}
diff --git a/src/expr/projection.rs b/src/expr/projection.rs
index 2d43632..4c158f7 100644
--- a/src/expr/projection.rs
+++ b/src/expr/projection.rs
@@ -30,6 +30,12 @@ pub struct PyProjection {
     projection: Projection,
 }
 
+impl PyProjection {
+    pub fn new(projection: Projection) -> Self {
+        Self { projection }
+    }
+}
+
 impl From<Projection> for PyProjection {
     fn from(projection: Projection) -> PyProjection {
         PyProjection { projection }
diff --git a/src/expr/table_scan.rs b/src/expr/table_scan.rs
index 00504b9..2784523 100644
--- a/src/expr/table_scan.rs
+++ b/src/expr/table_scan.rs
@@ -19,6 +19,8 @@ use datafusion_expr::logical_plan::TableScan;
 use pyo3::prelude::*;
 use std::fmt::{self, Display, Formatter};
 
+use crate::expr::logical_node::LogicalNode;
+use crate::sql::logical::PyLogicalPlan;
 use crate::{common::df_schema::PyDFSchema, expr::PyExpr};
 
 #[pyclass(name = "TableScan", module = "datafusion.expr", subclass)]
@@ -27,6 +29,12 @@ pub struct PyTableScan {
     table_scan: TableScan,
 }
 
+impl PyTableScan {
+    pub fn new(table_scan: TableScan) -> Self {
+        Self { table_scan }
+    }
+}
+
 impl From<PyTableScan> for TableScan {
     fn from(tbl_scan: PyTableScan) -> TableScan {
         tbl_scan.table_scan
@@ -117,3 +125,10 @@ impl PyTableScan {
         Ok(format!("TableScan({})", self))
     }
 }
+
+impl LogicalNode for PyTableScan {
+    fn input(&self) -> Vec<PyLogicalPlan> {
+        // table scans are leaf nodes and do not have inputs
+        vec![]
+    }
+}
diff --git a/src/sql/logical.rs b/src/sql/logical.rs
index 08d1961..ce6b1fb 100644
--- a/src/sql/logical.rs
+++ b/src/sql/logical.rs
@@ -49,10 +49,10 @@ impl PyLogicalPlan {
         Python::with_gil(|_| match self.plan.as_ref() {
             LogicalPlan::Projection(plan) => 
Ok(PyProjection::from(plan.clone()).into_py(py)),
             LogicalPlan::TableScan(plan) => 
Ok(PyTableScan::from(plan.clone()).into_py(py)),
-            LogicalPlan::Filter(plan) => 
Ok(PyFilter::from(plan.clone()).into_py(py)),
+            LogicalPlan::Aggregate(plan) => 
Ok(PyAggregate::from(plan.clone()).into_py(py)),
             LogicalPlan::Limit(plan) => 
Ok(PyLimit::from(plan.clone()).into_py(py)),
             LogicalPlan::Sort(plan) => 
Ok(PySort::from(plan.clone()).into_py(py)),
-            LogicalPlan::Aggregate(plan) => 
Ok(PyAggregate::from(plan.clone()).into_py(py)),
+            LogicalPlan::Filter(plan) => 
Ok(PyFilter::from(plan.clone()).into_py(py)),
             other => Err(py_runtime_err(format!(
                 "Cannot convert this plan to a LogicalNode: {:?}",
                 other
@@ -69,6 +69,10 @@ impl PyLogicalPlan {
         inputs
     }
 
+    fn __repr__(&self) -> PyResult<String> {
+        Ok(format!("{:?}", self.plan))
+    }
+
     pub fn display(&self) -> String {
         format!("{}", self.plan.display())
     }

Reply via email to