This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new ca8b055 Add experimental support for executing SQL with Polars and
Pandas (#190)
ca8b055 is described below
commit ca8b0551cd9dc7f6832f94a8b44f31cb5b3aa218
Author: Andy Grove <[email protected]>
AuthorDate: Mon Feb 20 10:03:21 2023 -0700
Add experimental support for executing SQL with Polars and Pandas (#190)
---
README.md | 3 ++
datafusion/pandas.py | 62 +++++++++++++++++++++++++++++
datafusion/polars.py | 85 ++++++++++++++++++++++++++++++++++++++++
datafusion/tests/test_expr.py | 39 +++++++++++++++---
datafusion/tests/test_imports.py | 12 +++++-
examples/README.md | 12 ++++++
examples/sql-on-pandas.py | 26 ++++++++++++
examples/sql-on-polars.py | 28 +++++++++++++
src/expr.rs | 32 +++++++++++++++
src/expr/aggregate_expr.rs | 73 ++++++++++++++++++++++++++++++++++
src/expr/binary_expr.rs | 57 +++++++++++++++++++++++++++
src/expr/column.rs | 60 ++++++++++++++++++++++++++++
src/expr/literal.rs | 74 ++++++++++++++++++++++++++++++++++
src/expr/projection.rs | 6 +++
src/expr/table_scan.rs | 15 +++++++
src/sql/logical.rs | 8 +++-
16 files changed, 583 insertions(+), 9 deletions(-)
diff --git a/README.md b/README.md
index ab89ff6..d465ebc 100644
--- a/README.md
+++ b/README.md
@@ -36,6 +36,9 @@ from having to lock the GIL when running those operations.
Its query engine, DataFusion, is written in
[Rust](https://www.rust-lang.org/), which makes strong assumptions
about thread safety and lack of memory leaks.
+There is also experimental support for executing SQL against other DataFrame
libraries, such as Polars, Pandas, and any
+drop-in replacements for Pandas.
+
Technically, zero-copy is achieved via the [c data
interface](https://arrow.apache.org/docs/format/CDataInterface.html).
## Example Usage
diff --git a/datafusion/pandas.py b/datafusion/pandas.py
new file mode 100644
index 0000000..36e4ba2
--- /dev/null
+++ b/datafusion/pandas.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pandas as pd
+import datafusion
+from datafusion.expr import Projection, TableScan, Column
+
+
+class SessionContext:
+ def __init__(self):
+ self.datafusion_ctx = datafusion.SessionContext()
+ self.parquet_tables = {}
+
+ def register_parquet(self, name, path):
+ self.parquet_tables[name] = path
+ self.datafusion_ctx.register_parquet(name, path)
+
+ def to_pandas_expr(self, expr):
+
+ # get Python wrapper for logical expression
+ expr = expr.to_variant()
+
+ if isinstance(expr, Column):
+ return expr.name()
+ else:
+ raise Exception("unsupported expression: {}".format(expr))
+
+ def to_pandas_df(self, plan):
+ # recurse down first to translate inputs into pandas data frames
+ inputs = [self.to_pandas_df(x) for x in plan.inputs()]
+
+ # get Python wrapper for logical operator node
+ node = plan.to_variant()
+
+ if isinstance(node, Projection):
+ args = [self.to_pandas_expr(expr) for expr in node.projections()]
+ return inputs[0][args]
+ elif isinstance(node, TableScan):
+ return pd.read_parquet(self.parquet_tables[node.table_name()])
+ else:
+ raise Exception(
+ "unsupported logical operator: {}".format(type(node))
+ )
+
+ def sql(self, sql):
+ datafusion_df = self.datafusion_ctx.sql(sql)
+ plan = datafusion_df.logical_plan()
+ return self.to_pandas_df(plan)
diff --git a/datafusion/polars.py b/datafusion/polars.py
new file mode 100644
index 0000000..e29e511
--- /dev/null
+++ b/datafusion/polars.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import polars
+import datafusion
+from datafusion.expr import Projection, TableScan, Aggregate
+from datafusion.expr import Column, AggregateFunction
+
+
+class SessionContext:
+ def __init__(self):
+ self.datafusion_ctx = datafusion.SessionContext()
+ self.parquet_tables = {}
+
+ def register_parquet(self, name, path):
+ self.parquet_tables[name] = path
+ self.datafusion_ctx.register_parquet(name, path)
+
+ def to_polars_expr(self, expr):
+
+ # get Python wrapper for logical expression
+ expr = expr.to_variant()
+
+ if isinstance(expr, Column):
+ return polars.col(expr.name())
+ else:
+ raise Exception("unsupported expression: {}".format(expr))
+
+ def to_polars_df(self, plan):
+ # recurse down first to translate inputs into Polars data frames
+ inputs = [self.to_polars_df(x) for x in plan.inputs()]
+
+ # get Python wrapper for logical operator node
+ node = plan.to_variant()
+
+ if isinstance(node, Projection):
+ args = [self.to_polars_expr(expr) for expr in node.projections()]
+ return inputs[0].select(*args)
+ elif isinstance(node, Aggregate):
+ groupby_expr = [
+ self.to_polars_expr(expr) for expr in node.group_by_exprs()
+ ]
+ aggs = []
+ for expr in node.aggregate_exprs():
+ expr = expr.to_variant()
+ if isinstance(expr, AggregateFunction):
+ if expr.aggregate_type() == "COUNT":
+ aggs.append(polars.count().alias("{}".format(expr)))
+ else:
+ raise Exception(
+ "Unsupported aggregate function {}".format(
+ expr.aggregate_type()
+ )
+ )
+ else:
+ raise Exception(
+ "Unsupported aggregate function {}".format(expr)
+ )
+ df = inputs[0].groupby(groupby_expr).agg(aggs)
+ return df
+ elif isinstance(node, TableScan):
+ return polars.read_parquet(self.parquet_tables[node.table_name()])
+ else:
+ raise Exception(
+ "unsupported logical operator: {}".format(type(node))
+ )
+
+ def sql(self, sql):
+ datafusion_df = self.datafusion_ctx.sql(sql)
+ plan = datafusion_df.logical_plan()
+ return self.to_polars_df(plan)
diff --git a/datafusion/tests/test_expr.py b/datafusion/tests/test_expr.py
index 4a7db87..143eea6 100644
--- a/datafusion/tests/test_expr.py
+++ b/datafusion/tests/test_expr.py
@@ -16,6 +16,7 @@
# under the License.
from datafusion import SessionContext
+from datafusion.expr import Column, Literal, BinaryExpr, AggregateFunction
from datafusion.expr import (
Projection,
Filter,
@@ -41,6 +42,24 @@ def test_projection(test_ctx):
plan = plan.to_variant()
assert isinstance(plan, Projection)
+ expr = plan.projections()
+
+ col1 = expr[0].to_variant()
+ assert isinstance(col1, Column)
+ assert col1.name() == "c1"
+ assert col1.qualified_name() == "test.c1"
+
+ col2 = expr[1].to_variant()
+ assert isinstance(col2, Literal)
+ assert col2.data_type() == "Int64"
+ assert col2.value_i64() == 123
+
+ col3 = expr[2].to_variant()
+ assert isinstance(col3, BinaryExpr)
+ assert isinstance(col3.left().to_variant(), Column)
+ assert col3.op() == "<"
+ assert isinstance(col3.right().to_variant(), Literal)
+
plan = plan.input().to_variant()
assert isinstance(plan, TableScan)
@@ -64,15 +83,23 @@ def test_limit(test_ctx):
assert isinstance(plan, Limit)
-def test_aggregate(test_ctx):
- df = test_ctx.sql("select c1, COUNT(*) from test GROUP BY c1")
+def test_aggregate_query(test_ctx):
+ df = test_ctx.sql("select c1, count(*) from test group by c1")
plan = df.logical_plan()
- plan = plan.to_variant()
- assert isinstance(plan, Projection)
+ projection = plan.to_variant()
+ assert isinstance(projection, Projection)
- plan = plan.input().to_variant()
- assert isinstance(plan, Aggregate)
+ aggregate = projection.input().to_variant()
+ assert isinstance(aggregate, Aggregate)
+
+ col1 = aggregate.group_by_exprs()[0].to_variant()
+ assert isinstance(col1, Column)
+ assert col1.name() == "c1"
+ assert col1.qualified_name() == "test.c1"
+
+ col2 = aggregate.aggregate_exprs()[0].to_variant()
+ assert isinstance(col2, AggregateFunction)
def test_sort(test_ctx):
diff --git a/datafusion/tests/test_imports.py b/datafusion/tests/test_imports.py
index dfc1f65..40b005b 100644
--- a/datafusion/tests/test_imports.py
+++ b/datafusion/tests/test_imports.py
@@ -33,6 +33,10 @@ from datafusion.common import (
from datafusion.expr import (
Expr,
+ Column,
+ Literal,
+ BinaryExpr,
+ AggregateFunction,
Projection,
TableScan,
Filter,
@@ -59,9 +63,15 @@ def test_class_module_is_datafusion():
]:
assert klass.__module__ == "datafusion"
- for klass in [Expr, Projection, TableScan, Aggregate, Sort, Limit, Filter]:
+ # expressions
+ for klass in [Expr, Column, Literal, BinaryExpr, AggregateFunction]:
assert klass.__module__ == "datafusion.expr"
+ # operators
+ for klass in [Projection, TableScan, Aggregate, Sort, Limit, Filter]:
+ assert klass.__module__ == "datafusion.expr"
+
+ # schema
for klass in [DFField, DFSchema]:
assert klass.__module__ == "datafusion.common"
diff --git a/examples/README.md b/examples/README.md
index a3ae0ba..e736366 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -19,9 +19,21 @@
# DataFusion Python Examples
+Some of the examples rely on data which can be downloaded from the following
site:
+
+- https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page
+
+Here is a direct link to the file used in the examples:
+
+-
https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2021-01.parquet
+
+## Examples
+
- [Query a Parquet file using SQL](./sql-parquet.py)
- [Query a Parquet file using the DataFrame API](./dataframe-parquet.py)
- [Run a SQL query and store the results in a Pandas
DataFrame](./sql-to-pandas.py)
- [Query PyArrow Data](./query-pyarrow-data.py)
- [Register a Python UDF with DataFusion](./python-udf.py)
- [Register a Python UDAF with DataFusion](./python-udaf.py)
+- [Executing SQL on Polars](./sql-on-polars.py)
+- [Executing SQL on Pandas](./sql-on-pandas.py)
diff --git a/examples/sql-on-pandas.py b/examples/sql-on-pandas.py
new file mode 100644
index 0000000..8a2d593
--- /dev/null
+++ b/examples/sql-on-pandas.py
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datafusion.pandas import SessionContext
+
+
+ctx = SessionContext()
+ctx.register_parquet(
+ "taxi", "/mnt/bigdata/nyctaxi/yellow_tripdata_2021-01.parquet"
+)
+df = ctx.sql("select passenger_count from taxi")
+print(df)
diff --git a/examples/sql-on-polars.py b/examples/sql-on-polars.py
new file mode 100644
index 0000000..0173b68
--- /dev/null
+++ b/examples/sql-on-polars.py
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datafusion.polars import SessionContext
+
+
+ctx = SessionContext()
+ctx.register_parquet(
+ "taxi", "/mnt/bigdata/nyctaxi/yellow_tripdata_2021-01.parquet"
+)
+df = ctx.sql(
+ "select passenger_count, count(*) from taxi group by passenger_count"
+)
+print(df)
diff --git a/src/expr.rs b/src/expr.rs
index adb9e55..ba01c99 100644
--- a/src/expr.rs
+++ b/src/expr.rs
@@ -22,11 +22,20 @@ use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion_expr::{col, lit, Cast, Expr, GetIndexedField};
+use crate::errors::py_runtime_err;
+use crate::expr::aggregate_expr::PyAggregateFunction;
+use crate::expr::binary_expr::PyBinaryExpr;
+use crate::expr::column::PyColumn;
+use crate::expr::literal::PyLiteral;
use datafusion::scalar::ScalarValue;
pub mod aggregate;
+pub mod aggregate_expr;
+pub mod binary_expr;
+pub mod column;
pub mod filter;
pub mod limit;
+pub mod literal;
pub mod logical_node;
pub mod projection;
pub mod sort;
@@ -53,6 +62,22 @@ impl From<Expr> for PyExpr {
#[pymethods]
impl PyExpr {
+ /// Return the specific expression
+ fn to_variant(&self, py: Python) -> PyResult<PyObject> {
+ Python::with_gil(|_| match &self.expr {
+ Expr::Column(col) => Ok(PyColumn::from(col.clone()).into_py(py)),
+ Expr::Literal(value) =>
Ok(PyLiteral::from(value.clone()).into_py(py)),
+ Expr::BinaryExpr(expr) =>
Ok(PyBinaryExpr::from(expr.clone()).into_py(py)),
+ Expr::AggregateFunction(expr) => {
+ Ok(PyAggregateFunction::from(expr.clone()).into_py(py))
+ }
+ other => Err(py_runtime_err(format!(
+ "Cannot convert this Expr to a Python object: {:?}",
+ other
+ ))),
+ })
+ }
+
fn __richcmp__(&self, other: PyExpr, op: CompareOp) -> PyExpr {
let expr = match op {
CompareOp::Lt => self.expr.clone().lt(other.expr),
@@ -144,7 +169,14 @@ impl PyExpr {
/// Initializes the `expr` module to match the pattern of `datafusion-expr`
https://docs.rs/datafusion-expr/latest/datafusion_expr/
pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
+ // expressions
m.add_class::<PyExpr>()?;
+ m.add_class::<PyColumn>()?;
+ m.add_class::<PyLiteral>()?;
+ m.add_class::<PyBinaryExpr>()?;
+ m.add_class::<PyLiteral>()?;
+ m.add_class::<PyAggregateFunction>()?;
+ // operators
m.add_class::<table_scan::PyTableScan>()?;
m.add_class::<projection::PyProjection>()?;
m.add_class::<filter::PyFilter>()?;
diff --git a/src/expr/aggregate_expr.rs b/src/expr/aggregate_expr.rs
new file mode 100644
index 0000000..1801051
--- /dev/null
+++ b/src/expr/aggregate_expr.rs
@@ -0,0 +1,73 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::expr::PyExpr;
+use datafusion_expr::expr::AggregateFunction;
+use pyo3::prelude::*;
+use std::fmt::{Display, Formatter};
+
+#[pyclass(name = "AggregateFunction", module = "datafusion.expr", subclass)]
+#[derive(Clone)]
+pub struct PyAggregateFunction {
+ aggr: AggregateFunction,
+}
+
+impl From<PyAggregateFunction> for AggregateFunction {
+ fn from(aggr: PyAggregateFunction) -> Self {
+ aggr.aggr
+ }
+}
+
+impl From<AggregateFunction> for PyAggregateFunction {
+ fn from(aggr: AggregateFunction) -> PyAggregateFunction {
+ PyAggregateFunction { aggr }
+ }
+}
+
+impl Display for PyAggregateFunction {
+ fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
+ let args: Vec<String> = self.aggr.args.iter().map(|expr|
expr.to_string()).collect();
+ write!(f, "{}({})", self.aggr.fun, args.join(", "))
+ }
+}
+
+#[pymethods]
+impl PyAggregateFunction {
+ /// Get the aggregate type, such as "MIN", or "MAX"
+ fn aggregate_type(&self) -> String {
+ format!("{}", self.aggr.fun)
+ }
+
+ /// is this a distinct aggregate such as `COUNT(DISTINCT expr)`
+ fn is_distinct(&self) -> bool {
+ self.aggr.distinct
+ }
+
+ /// Get the arguments to the aggregate function
+ fn args(&self) -> Vec<PyExpr> {
+ self.aggr
+ .args
+ .iter()
+ .map(|expr| PyExpr::from(expr.clone()))
+ .collect()
+ }
+
+ /// Get a String representation of this column
+ fn __repr__(&self) -> String {
+ format!("{}", self)
+ }
+}
diff --git a/src/expr/binary_expr.rs b/src/expr/binary_expr.rs
new file mode 100644
index 0000000..5f382b7
--- /dev/null
+++ b/src/expr/binary_expr.rs
@@ -0,0 +1,57 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::expr::PyExpr;
+use datafusion_expr::BinaryExpr;
+use pyo3::prelude::*;
+
+#[pyclass(name = "BinaryExpr", module = "datafusion.expr", subclass)]
+#[derive(Clone)]
+pub struct PyBinaryExpr {
+ expr: BinaryExpr,
+}
+
+impl From<PyBinaryExpr> for BinaryExpr {
+ fn from(expr: PyBinaryExpr) -> Self {
+ expr.expr
+ }
+}
+
+impl From<BinaryExpr> for PyBinaryExpr {
+ fn from(expr: BinaryExpr) -> PyBinaryExpr {
+ PyBinaryExpr { expr }
+ }
+}
+
+#[pymethods]
+impl PyBinaryExpr {
+ fn left(&self) -> PyExpr {
+ self.expr.left.as_ref().clone().into()
+ }
+
+ fn right(&self) -> PyExpr {
+ self.expr.right.as_ref().clone().into()
+ }
+
+ fn op(&self) -> String {
+ format!("{}", self.expr.op)
+ }
+
+ fn __repr__(&self) -> PyResult<String> {
+ Ok(format!("{}", self.expr))
+ }
+}
diff --git a/src/expr/column.rs b/src/expr/column.rs
new file mode 100644
index 0000000..16b8bce
--- /dev/null
+++ b/src/expr/column.rs
@@ -0,0 +1,60 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion_common::Column;
+use pyo3::prelude::*;
+
+#[pyclass(name = "Column", module = "datafusion.expr", subclass)]
+#[derive(Clone)]
+pub struct PyColumn {
+ pub col: Column,
+}
+
+impl PyColumn {
+ pub fn new(col: Column) -> Self {
+ Self { col }
+ }
+}
+
+impl From<Column> for PyColumn {
+ fn from(col: Column) -> PyColumn {
+ PyColumn { col }
+ }
+}
+
+#[pymethods]
+impl PyColumn {
+ /// Get the column name
+ fn name(&self) -> String {
+ self.col.name.clone()
+ }
+
+ /// Get the column relation
+ fn relation(&self) -> Option<String> {
+ self.col.relation.clone()
+ }
+
+ /// Get the fully-qualified column name
+ fn qualified_name(&self) -> String {
+ self.col.flat_name()
+ }
+
+ /// Get a String representation of this column
+ fn __repr__(&self) -> String {
+ self.qualified_name()
+ }
+}
diff --git a/src/expr/literal.rs b/src/expr/literal.rs
new file mode 100644
index 0000000..27674ce
--- /dev/null
+++ b/src/expr/literal.rs
@@ -0,0 +1,74 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::errors::py_runtime_err;
+use datafusion_common::ScalarValue;
+use pyo3::prelude::*;
+
+#[pyclass(name = "Literal", module = "datafusion.expr", subclass)]
+#[derive(Clone)]
+pub struct PyLiteral {
+ pub value: ScalarValue,
+}
+
+impl From<PyLiteral> for ScalarValue {
+ fn from(lit: PyLiteral) -> ScalarValue {
+ lit.value
+ }
+}
+
+impl From<ScalarValue> for PyLiteral {
+ fn from(value: ScalarValue) -> PyLiteral {
+ PyLiteral { value }
+ }
+}
+
+#[pymethods]
+impl PyLiteral {
+ /// Get the data type of this literal value
+ fn data_type(&self) -> String {
+ format!("{}", self.value.get_datatype())
+ }
+
+ fn value_i32(&self) -> PyResult<i32> {
+ if let ScalarValue::Int32(Some(n)) = &self.value {
+ Ok(*n)
+ } else {
+ Err(py_runtime_err("Cannot access value as i32"))
+ }
+ }
+
+ fn value_i64(&self) -> PyResult<i64> {
+ if let ScalarValue::Int64(Some(n)) = &self.value {
+ Ok(*n)
+ } else {
+ Err(py_runtime_err("Cannot access value as i64"))
+ }
+ }
+
+ fn value_str(&self) -> PyResult<String> {
+ if let ScalarValue::Utf8(Some(str)) = &self.value {
+ Ok(str.clone())
+ } else {
+ Err(py_runtime_err("Cannot access value as string"))
+ }
+ }
+
+ fn __repr__(&self) -> PyResult<String> {
+ Ok(format!("{}", self.value))
+ }
+}
diff --git a/src/expr/projection.rs b/src/expr/projection.rs
index 2d43632..4c158f7 100644
--- a/src/expr/projection.rs
+++ b/src/expr/projection.rs
@@ -30,6 +30,12 @@ pub struct PyProjection {
projection: Projection,
}
+impl PyProjection {
+ pub fn new(projection: Projection) -> Self {
+ Self { projection }
+ }
+}
+
impl From<Projection> for PyProjection {
fn from(projection: Projection) -> PyProjection {
PyProjection { projection }
diff --git a/src/expr/table_scan.rs b/src/expr/table_scan.rs
index 00504b9..2784523 100644
--- a/src/expr/table_scan.rs
+++ b/src/expr/table_scan.rs
@@ -19,6 +19,8 @@ use datafusion_expr::logical_plan::TableScan;
use pyo3::prelude::*;
use std::fmt::{self, Display, Formatter};
+use crate::expr::logical_node::LogicalNode;
+use crate::sql::logical::PyLogicalPlan;
use crate::{common::df_schema::PyDFSchema, expr::PyExpr};
#[pyclass(name = "TableScan", module = "datafusion.expr", subclass)]
@@ -27,6 +29,12 @@ pub struct PyTableScan {
table_scan: TableScan,
}
+impl PyTableScan {
+ pub fn new(table_scan: TableScan) -> Self {
+ Self { table_scan }
+ }
+}
+
impl From<PyTableScan> for TableScan {
fn from(tbl_scan: PyTableScan) -> TableScan {
tbl_scan.table_scan
@@ -117,3 +125,10 @@ impl PyTableScan {
Ok(format!("TableScan({})", self))
}
}
+
+impl LogicalNode for PyTableScan {
+ fn input(&self) -> Vec<PyLogicalPlan> {
+ // table scans are leaf nodes and do not have inputs
+ vec![]
+ }
+}
diff --git a/src/sql/logical.rs b/src/sql/logical.rs
index 08d1961..ce6b1fb 100644
--- a/src/sql/logical.rs
+++ b/src/sql/logical.rs
@@ -49,10 +49,10 @@ impl PyLogicalPlan {
Python::with_gil(|_| match self.plan.as_ref() {
LogicalPlan::Projection(plan) =>
Ok(PyProjection::from(plan.clone()).into_py(py)),
LogicalPlan::TableScan(plan) =>
Ok(PyTableScan::from(plan.clone()).into_py(py)),
- LogicalPlan::Filter(plan) =>
Ok(PyFilter::from(plan.clone()).into_py(py)),
+ LogicalPlan::Aggregate(plan) =>
Ok(PyAggregate::from(plan.clone()).into_py(py)),
LogicalPlan::Limit(plan) =>
Ok(PyLimit::from(plan.clone()).into_py(py)),
LogicalPlan::Sort(plan) =>
Ok(PySort::from(plan.clone()).into_py(py)),
- LogicalPlan::Aggregate(plan) =>
Ok(PyAggregate::from(plan.clone()).into_py(py)),
+ LogicalPlan::Filter(plan) =>
Ok(PyFilter::from(plan.clone()).into_py(py)),
other => Err(py_runtime_err(format!(
"Cannot convert this plan to a LogicalNode: {:?}",
other
@@ -69,6 +69,10 @@ impl PyLogicalPlan {
inputs
}
+ fn __repr__(&self) -> PyResult<String> {
+ Ok(format!("{:?}", self.plan))
+ }
+
pub fn display(&self) -> String {
format!("{}", self.plan.display())
}