This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/master by this push:
     new e0cbf48  Implement PyArrow Dataset TableProvider (#9)
e0cbf48 is described below

commit e0cbf48516d79ba28fc07f81d1f6d73e85416796
Author: Kyle Brooks <[email protected]>
AuthorDate: Tue Jul 26 07:26:23 2022 -0400

    Implement PyArrow Dataset TableProvider (#9)
    
    * Implement PyArrow Dataset TableProvider and register_dataset context 
functions.
    
    * Add dataset filter test.
    
    * Change match on booleans to if else.
    
    * Update Dataset TableProvider for updates in DataFusion 10.0.0 trait.
    
    * Fixes to build with DataFusion 10.0.0.
    
    * Improved DatasetExec physical plan printing.
    Added nested filter test.
---
 Cargo.lock                       |   2 +
 Cargo.toml                       |   2 +
 datafusion/tests/test_context.py |  70 ++++++++++
 datafusion/tests/test_sql.py     |  12 ++
 src/context.rs                   |  13 ++
 src/dataset.rs                   | 118 +++++++++++++++++
 src/dataset_exec.rs              | 270 +++++++++++++++++++++++++++++++++++++++
 src/errors.rs                    |  16 ++-
 src/lib.rs                       |   3 +
 src/pyarrow_filter_expression.rs | 190 +++++++++++++++++++++++++++
 10 files changed, 695 insertions(+), 1 deletion(-)

diff --git a/Cargo.lock b/Cargo.lock
index 1108211..bdbf3cc 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -395,9 +395,11 @@ dependencies = [
 name = "datafusion-python"
 version = "0.6.0"
 dependencies = [
+ "async-trait",
  "datafusion",
  "datafusion-common",
  "datafusion-expr",
+ "futures",
  "mimalloc",
  "pyo3",
  "rand 0.7.3",
diff --git a/Cargo.toml b/Cargo.toml
index fd20b4d..05a21e0 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -39,6 +39,8 @@ datafusion-expr = { version = "^10.0.0" }
 datafusion-common = { version = "^10.0.0", features = ["pyarrow"] }
 uuid = { version = "0.8", features = ["v4"] }
 mimalloc = { version = "*", optional = true, default-features = false }
+async-trait = "0.1"
+futures = "0.3"
 
 [lib]
 name = "datafusion_python"
diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py
index 4d4a38c..1e1e771 100644
--- a/datafusion/tests/test_context.py
+++ b/datafusion/tests/test_context.py
@@ -16,6 +16,9 @@
 # under the License.
 
 import pyarrow as pa
+import pyarrow.dataset as ds
+
+from datafusion import column, literal
 
 
 def test_register_record_batches(ctx):
@@ -72,3 +75,70 @@ def test_deregister_table(ctx, database):
 
     ctx.deregister_table("csv")
     assert public.names() == {"csv1", "csv2"}
+
+def test_register_dataset(ctx):
+    # create a RecordBatch and register it as a pyarrow.dataset.Dataset
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
+        names=["a", "b"],
+    )
+    dataset = ds.dataset([batch])
+    ctx.register_dataset("t", dataset)
+
+    assert ctx.tables() == {"t"}
+
+    result = ctx.sql("SELECT a+b, a-b FROM t").collect()
+
+    assert result[0].column(0) == pa.array([5, 7, 9])
+    assert result[0].column(1) == pa.array([-3, -3, -3])
+
+def test_dataset_filter(ctx, capfd):
+    # create a RecordBatch and register it as a pyarrow.dataset.Dataset
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
+        names=["a", "b"],
+    )
+    dataset = ds.dataset([batch])
+    ctx.register_dataset("t", dataset)
+
+    assert ctx.tables() == {"t"}
+    df = ctx.sql("SELECT a+b, a-b FROM t WHERE a BETWEEN 2 and 3 AND b > 5")
+
+    # Make sure the filter was pushed down in Physical Plan
+    df.explain()
+    captured = capfd.readouterr()
+    assert "filter_expr=(((2 <= a) and (a <= 3)) and (b > 5))" in captured.out
+
+    result = df.collect()
+
+    assert result[0].column(0) == pa.array([9])
+    assert result[0].column(1) == pa.array([-3])
+
+
+def test_dataset_filter_nested_data(ctx):
+    # create Arrow StructArrays to test nested data types
+    data = pa.StructArray.from_arrays(
+        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
+        names=["a", "b"],
+    )
+    batch = pa.RecordBatch.from_arrays(
+        [data],
+        names=["nested_data"],
+    )
+    dataset = ds.dataset([batch])
+    ctx.register_dataset("t", dataset)
+
+    assert ctx.tables() == {"t"}
+
+    df = ctx.table("t")
+
+    # This filter will not be pushed down to DatasetExec since it isn't 
supported
+    df = df.select(
+        column("nested_data")["a"] + column("nested_data")["b"],
+        column("nested_data")["a"] - column("nested_data")["b"],
+    ).filter(column("nested_data")["b"] > literal(5))
+
+    result = df.collect()
+
+    assert result[0].column(0) == pa.array([9])
+    assert result[0].column(1) == pa.array([-3])
diff --git a/datafusion/tests/test_sql.py b/datafusion/tests/test_sql.py
index af3b38a..ffbfc2c 100644
--- a/datafusion/tests/test_sql.py
+++ b/datafusion/tests/test_sql.py
@@ -17,6 +17,7 @@
 
 import numpy as np
 import pyarrow as pa
+import pyarrow.dataset as ds
 import pytest
 
 from datafusion import udf
@@ -121,6 +122,17 @@ def test_register_parquet_partitioned(ctx, tmp_path):
     rd = result.to_pydict()
     assert dict(zip(rd["grp"], rd["cnt"])) == {"a": 3, "b": 1}
 
+def test_register_dataset(ctx, tmp_path):
+    path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
+    dataset = ds.dataset(path, format="parquet")
+
+    ctx.register_dataset("t", dataset)
+    assert ctx.tables() == {"t"}
+
+    result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect()
+    result = pa.Table.from_batches(result)
+    assert result.to_pydict() == {"cnt": [100]}
+
 
 def test_execute(ctx, tmp_path):
     data = [1, 1, 2, 2, 3, 11, 12]
diff --git a/src/context.rs b/src/context.rs
index 213703f..d2c17ad 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -25,12 +25,14 @@ use pyo3::prelude::*;
 
 use datafusion::arrow::datatypes::Schema;
 use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::datasource::datasource::TableProvider;
 use datafusion::datasource::MemTable;
 use datafusion::execution::context::SessionContext;
 use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
 
 use crate::catalog::{PyCatalog, PyTable};
 use crate::dataframe::PyDataFrame;
+use crate::dataset::Dataset;
 use crate::errors::DataFusionError;
 use crate::udf::PyScalarUDF;
 use crate::utils::wait_for_future;
@@ -173,6 +175,17 @@ impl PySessionContext {
         Ok(())
     }
 
+    // Registers a PyArrow.Dataset
+    fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> 
PyResult<()> {
+        let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, 
py)?);
+
+        self.ctx
+            .register_table(name, table)
+            .map_err(DataFusionError::from)?;
+
+        Ok(())
+    }
+
     fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> {
         self.ctx.register_udf(udf.function);
         Ok(())
diff --git a/src/dataset.rs b/src/dataset.rs
new file mode 100644
index 0000000..6272bc8
--- /dev/null
+++ b/src/dataset.rs
@@ -0,0 +1,118 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use pyo3::exceptions::PyValueError;
+/// Implements a Datafusion TableProvider that delegates to a PyArrow Dataset
+/// This allows us to use PyArrow Datasets as Datafusion tables while pushing 
down projections and filters
+use pyo3::prelude::*;
+use pyo3::types::PyType;
+
+use std::any::Any;
+use std::sync::Arc;
+
+use async_trait::async_trait;
+
+use datafusion::arrow::datatypes::SchemaRef;
+use datafusion::datasource::datasource::TableProviderFilterPushDown;
+use datafusion::datasource::{TableProvider, TableType};
+use datafusion::error::{DataFusionError, Result as DFResult};
+use datafusion::execution::context::SessionState;
+use datafusion::logical_plan::*;
+use datafusion::physical_plan::ExecutionPlan;
+
+use crate::dataset_exec::DatasetExec;
+use crate::pyarrow_filter_expression::PyArrowFilterExpression;
+
+// Wraps a pyarrow.dataset.Dataset class and implements a Datafusion 
TableProvider around it
+#[derive(Debug, Clone)]
+pub(crate) struct Dataset {
+    dataset: PyObject,
+}
+
+impl Dataset {
+    // Creates a Python PyArrow.Dataset
+    pub fn new(dataset: &PyAny, py: Python) -> PyResult<Self> {
+        // Ensure that we were passed an instance of pyarrow.dataset.Dataset
+        let ds = PyModule::import(py, "pyarrow.dataset")?;
+        let ds_type: &PyType = ds.getattr("Dataset")?.downcast()?;
+        if dataset.is_instance(ds_type)? {
+            Ok(Dataset {
+                dataset: dataset.into(),
+            })
+        } else {
+            Err(PyValueError::new_err(
+                "dataset argument must be a pyarrow.dataset.Dataset object",
+            ))
+        }
+    }
+}
+
+#[async_trait]
+impl TableProvider for Dataset {
+    /// Returns the table provider as [`Any`](std::any::Any) so that it can be
+    /// downcast to a specific implementation.
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    /// Get a reference to the schema for this table
+    fn schema(&self) -> SchemaRef {
+        Python::with_gil(|py| {
+            let dataset = self.dataset.as_ref(py);
+            // This can panic but since we checked that self.dataset is a 
pyarrow.dataset.Dataset it should never
+            Arc::new(dataset.getattr("schema").unwrap().extract().unwrap())
+        })
+    }
+
+    /// Get the type of this table for metadata/catalog purposes.
+    fn table_type(&self) -> TableType {
+        TableType::Base
+    }
+
+    /// Create an ExecutionPlan that will scan the table.
+    /// The table provider will be usually responsible of grouping
+    /// the source data into partitions that can be efficiently
+    /// parallelized or distributed.
+    async fn scan(
+        &self,
+        _ctx: &SessionState,
+        projection: &Option<Vec<usize>>,
+        filters: &[Expr],
+        // limit can be used to reduce the amount scanned
+        // from the datasource as a performance optimization.
+        // If set, it contains the amount of rows needed by the `LogicalPlan`,
+        // The datasource should return *at least* this number of rows if 
available.
+        _limit: Option<usize>,
+    ) -> DFResult<Arc<dyn ExecutionPlan>> {
+        Python::with_gil(|py| {
+            let plan: Arc<dyn ExecutionPlan> = Arc::new(
+                DatasetExec::new(py, self.dataset.as_ref(py), 
projection.clone(), filters)
+                    .map_err(|err| DataFusionError::External(Box::new(err)))?,
+            );
+            Ok(plan)
+        })
+    }
+
+    /// Tests whether the table provider can make use of a filter expression
+    /// to optimise data retrieval.
+    fn supports_filter_pushdown(&self, filter: &Expr) -> 
DFResult<TableProviderFilterPushDown> {
+        match PyArrowFilterExpression::try_from(filter) {
+            Ok(_) => Ok(TableProviderFilterPushDown::Exact),
+            _ => Ok(TableProviderFilterPushDown::Unsupported),
+        }
+    }
+}
diff --git a/src/dataset_exec.rs b/src/dataset_exec.rs
new file mode 100644
index 0000000..a3925ad
--- /dev/null
+++ b/src/dataset_exec.rs
@@ -0,0 +1,270 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// Implements a Datafusion physical ExecutionPlan that delegates to a PyArrow 
Dataset
+/// This actually performs the projection, filtering and scanning of a Dataset
+use pyo3::prelude::*;
+use pyo3::types::{PyDict, PyIterator, PyList};
+
+use std::any::Any;
+use std::sync::Arc;
+
+use futures::stream;
+
+use datafusion::arrow::datatypes::SchemaRef;
+use datafusion::arrow::error::ArrowError;
+use datafusion::arrow::error::Result as ArrowResult;
+use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::error::{DataFusionError as InnerDataFusionError, Result as 
DFResult};
+use datafusion::execution::context::TaskContext;
+use datafusion::logical_plan::{combine_filters, Expr};
+use datafusion::physical_expr::PhysicalSortExpr;
+use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
+use datafusion::physical_plan::{
+    DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, 
Statistics,
+};
+
+use crate::errors::DataFusionError;
+use crate::pyarrow_filter_expression::PyArrowFilterExpression;
+
+struct PyArrowBatchesAdapter {
+    batches: Py<PyIterator>,
+}
+
+impl Iterator for PyArrowBatchesAdapter {
+    type Item = ArrowResult<RecordBatch>;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        Python::with_gil(|py| {
+            let mut batches: &PyIterator = self.batches.as_ref(py);
+            Some(
+                batches
+                    .next()?
+                    .and_then(|batch| batch.extract())
+                    .map_err(|err| ArrowError::ExternalError(Box::new(err))),
+            )
+        })
+    }
+}
+
+// Wraps a pyarrow.dataset.Dataset class and implements a Datafusion 
ExecutionPlan around it
+#[derive(Debug, Clone)]
+pub(crate) struct DatasetExec {
+    dataset: PyObject,
+    schema: SchemaRef,
+    fragments: Py<PyList>,
+    columns: Option<Vec<String>>,
+    filter_expr: Option<PyObject>,
+    projected_statistics: Statistics,
+}
+
+impl DatasetExec {
+    pub fn new(
+        py: Python,
+        dataset: &PyAny,
+        projection: Option<Vec<usize>>,
+        filters: &[Expr],
+    ) -> Result<Self, DataFusionError> {
+        let columns: Option<Result<Vec<String>, DataFusionError>> = 
projection.map(|p| {
+            p.iter()
+                .map(|index| {
+                    let name: String = dataset
+                        .getattr("schema")?
+                        .call_method1("field", (*index,))?
+                        .getattr("name")?
+                        .extract()?;
+                    Ok(name)
+                })
+                .collect()
+        });
+        let columns: Option<Vec<String>> = columns.transpose()?;
+        let filter_expr: Option<PyObject> = combine_filters(filters)
+            .map(|filters| {
+                PyArrowFilterExpression::try_from(&filters)
+                    .map(|filter_expr| filter_expr.inner().clone_ref(py))
+            })
+            .transpose()?;
+
+        let kwargs = PyDict::new(py);
+
+        kwargs.set_item("columns", columns.clone())?;
+        kwargs.set_item(
+            "filter",
+            filter_expr.as_ref().map(|expr| expr.clone_ref(py)),
+        )?;
+
+        let scanner = dataset.call_method("scanner", (), Some(kwargs))?;
+
+        let schema = Arc::new(scanner.getattr("projected_schema")?.extract()?);
+
+        let builtins = Python::import(py, "builtins")?;
+        let pylist = builtins.getattr("list")?;
+
+        // Get the fragments or partitions of the dataset
+        let fragments_iterator: &PyAny = dataset.call_method1(
+            "get_fragments",
+            (filter_expr.as_ref().map(|expr| expr.clone_ref(py)),),
+        )?;
+
+        let fragments: &PyList = pylist
+            .call1((fragments_iterator,))?
+            .downcast()
+            .map_err(PyErr::from)?;
+
+        Ok(DatasetExec {
+            dataset: dataset.into(),
+            schema,
+            fragments: fragments.into(),
+            columns,
+            filter_expr,
+            projected_statistics: Default::default(),
+        })
+    }
+}
+
+impl ExecutionPlan for DatasetExec {
+    /// Return a reference to Any that can be used for downcasting
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    /// Get the schema for this execution plan
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+
+    /// Get the output partitioning of this plan
+    fn output_partitioning(&self) -> Partitioning {
+        Python::with_gil(|py| {
+            let fragments = self.fragments.as_ref(py);
+            Partitioning::UnknownPartitioning(fragments.len())
+        })
+    }
+
+    fn relies_on_input_order(&self) -> bool {
+        false
+    }
+
+    fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
+        None
+    }
+
+    fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+        // this is a leaf node and has no children
+        vec![]
+    }
+
+    fn with_new_children(
+        self: Arc<Self>,
+        _: Vec<Arc<dyn ExecutionPlan>>,
+    ) -> DFResult<Arc<dyn ExecutionPlan>> {
+        Ok(self)
+    }
+
+    fn execute(
+        &self,
+        partition: usize,
+        context: Arc<TaskContext>,
+    ) -> DFResult<SendableRecordBatchStream> {
+        let batch_size = context.session_config().batch_size();
+        Python::with_gil(|py| {
+            let dataset = self.dataset.as_ref(py);
+            let fragments = self.fragments.as_ref(py);
+            let fragment = fragments
+                .get_item(partition)
+                .map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
+
+            // We need to pass the dataset schema to unify the fragment and 
dataset schema per PyArrow docs
+            let dataset_schema = dataset
+                .getattr("schema")
+                .map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
+            let kwargs = PyDict::new(py);
+            kwargs
+                .set_item("columns", self.columns.clone())
+                .map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
+            kwargs
+                .set_item(
+                    "filter",
+                    self.filter_expr.as_ref().map(|expr| expr.clone_ref(py)),
+                )
+                .map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
+            kwargs
+                .set_item("batch_size", batch_size)
+                .map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
+            let scanner = fragment
+                .call_method("scanner", (dataset_schema,), Some(kwargs))
+                .map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
+            let schema: SchemaRef = Arc::new(
+                scanner
+                    .getattr("projected_schema")
+                    .and_then(|schema| schema.extract())
+                    .map_err(|err| 
InnerDataFusionError::External(Box::new(err)))?,
+            );
+            let record_batches: &PyIterator = scanner
+                .call_method0("to_batches")
+                .map_err(|err| InnerDataFusionError::External(Box::new(err)))?
+                .iter()
+                .map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
+
+            let record_batches = PyArrowBatchesAdapter {
+                batches: record_batches.into(),
+            };
+
+            let record_batch_stream = stream::iter(record_batches);
+            let record_batch_stream: SendableRecordBatchStream =
+                Box::pin(RecordBatchStreamAdapter::new(schema, 
record_batch_stream));
+            Ok(record_batch_stream)
+        })
+    }
+
+    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> 
std::fmt::Result {
+        Python::with_gil(|py| {
+            let number_of_fragments = self.fragments.as_ref(py).len();
+            match t {
+                DisplayFormatType::Default => {
+                    let projected_columns: Vec<String> = self
+                        .schema
+                        .fields()
+                        .iter()
+                        .map(|x| x.name().to_owned())
+                        .collect();
+                    if let Some(filter_expr) = &self.filter_expr {
+                        let filter_expr = 
filter_expr.as_ref(py).str().or(Err(std::fmt::Error))?;
+                        write!(
+                            f,
+                            "DatasetExec: number_of_fragments={}, 
filter_expr={}, projection=[{}]",
+                            number_of_fragments,
+                            filter_expr,
+                            projected_columns.join(", "),
+                        )
+                    } else {
+                        write!(
+                            f,
+                            "DatasetExec: number_of_fragments={}, 
projection=[{}]",
+                            number_of_fragments,
+                            projected_columns.join(", "),
+                        )
+                    }
+                }
+            }
+        })
+    }
+
+    fn statistics(&self) -> Statistics {
+        self.projected_statistics.clone()
+    }
+}
diff --git a/src/errors.rs b/src/errors.rs
index 655ed84..29d3e8f 100644
--- a/src/errors.rs
+++ b/src/errors.rs
@@ -16,6 +16,7 @@
 // under the License.
 
 use core::fmt;
+use std::error::Error;
 
 use datafusion::arrow::error::ArrowError;
 use datafusion::error::DataFusionError as InnerDataFusionError;
@@ -26,6 +27,7 @@ pub enum DataFusionError {
     ExecutionError(InnerDataFusionError),
     ArrowError(ArrowError),
     Common(String),
+    PythonError(PyErr),
 }
 
 impl fmt::Display for DataFusionError {
@@ -33,6 +35,7 @@ impl fmt::Display for DataFusionError {
         match self {
             DataFusionError::ExecutionError(e) => write!(f, "DataFusion error: 
{:?}", e),
             DataFusionError::ArrowError(e) => write!(f, "Arrow error: {:?}", 
e),
+            DataFusionError::PythonError(e) => write!(f, "Python error {:?}", 
e),
             DataFusionError::Common(e) => write!(f, "{}", e),
         }
     }
@@ -50,8 +53,19 @@ impl From<InnerDataFusionError> for DataFusionError {
     }
 }
 
+impl From<PyErr> for DataFusionError {
+    fn from(err: PyErr) -> DataFusionError {
+        DataFusionError::PythonError(err)
+    }
+}
+
 impl From<DataFusionError> for PyErr {
     fn from(err: DataFusionError) -> PyErr {
-        PyException::new_err(err.to_string())
+        match err {
+            DataFusionError::PythonError(py_err) => py_err,
+            _ => PyException::new_err(err.to_string()),
+        }
     }
 }
+
+impl Error for DataFusionError {}
diff --git a/src/lib.rs b/src/lib.rs
index 25b63e8..c6ab58e 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -22,9 +22,12 @@ use pyo3::prelude::*;
 pub mod catalog;
 mod context;
 mod dataframe;
+mod dataset;
+mod dataset_exec;
 pub mod errors;
 mod expression;
 mod functions;
+mod pyarrow_filter_expression;
 mod udaf;
 mod udf;
 pub mod utils;
diff --git a/src/pyarrow_filter_expression.rs b/src/pyarrow_filter_expression.rs
new file mode 100644
index 0000000..3807553
--- /dev/null
+++ b/src/pyarrow_filter_expression.rs
@@ -0,0 +1,190 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// Converts a Datafusion logical plan expression (Expr) into a PyArrow 
compute expression
+use pyo3::prelude::*;
+
+use std::convert::TryFrom;
+use std::result::Result;
+
+use datafusion::logical_plan::*;
+use datafusion_common::ScalarValue;
+
+use crate::errors::DataFusionError;
+
+#[derive(Debug, Clone)]
+#[repr(transparent)]
+pub(crate) struct PyArrowFilterExpression(PyObject);
+
+fn operator_to_py<'py>(
+    operator: &Operator,
+    op: &'py PyModule,
+) -> Result<&'py PyAny, DataFusionError> {
+    let py_op: &PyAny = match operator {
+        Operator::Eq => op.getattr("eq")?,
+        Operator::NotEq => op.getattr("ne")?,
+        Operator::Lt => op.getattr("lt")?,
+        Operator::LtEq => op.getattr("le")?,
+        Operator::Gt => op.getattr("gt")?,
+        Operator::GtEq => op.getattr("ge")?,
+        Operator::And => op.getattr("and_")?,
+        Operator::Or => op.getattr("or_")?,
+        _ => {
+            return Err(DataFusionError::Common(format!(
+                "Unsupported operator {:?}",
+                operator
+            )))
+        }
+    };
+    Ok(py_op)
+}
+
+fn extract_scalar_list(exprs: &[Expr], py: Python) -> Result<Vec<PyObject>, 
DataFusionError> {
+    let ret: Result<Vec<PyObject>, DataFusionError> = exprs
+        .iter()
+        .map(|expr| match expr {
+            Expr::Literal(v) => match v {
+                ScalarValue::Boolean(Some(b)) => Ok(b.into_py(py)),
+                ScalarValue::Int8(Some(i)) => Ok(i.into_py(py)),
+                ScalarValue::Int16(Some(i)) => Ok(i.into_py(py)),
+                ScalarValue::Int32(Some(i)) => Ok(i.into_py(py)),
+                ScalarValue::Int64(Some(i)) => Ok(i.into_py(py)),
+                ScalarValue::UInt8(Some(i)) => Ok(i.into_py(py)),
+                ScalarValue::UInt16(Some(i)) => Ok(i.into_py(py)),
+                ScalarValue::UInt32(Some(i)) => Ok(i.into_py(py)),
+                ScalarValue::UInt64(Some(i)) => Ok(i.into_py(py)),
+                ScalarValue::Float32(Some(f)) => Ok(f.into_py(py)),
+                ScalarValue::Float64(Some(f)) => Ok(f.into_py(py)),
+                ScalarValue::Utf8(Some(s)) => Ok(s.into_py(py)),
+                _ => Err(DataFusionError::Common(format!(
+                    "PyArrow can't handle ScalarValue: {:?}",
+                    v
+                ))),
+            },
+            _ => Err(DataFusionError::Common(format!(
+                "Only a list of Literals are supported got {:?}",
+                expr
+            ))),
+        })
+        .collect();
+    ret
+}
+
+impl PyArrowFilterExpression {
+    pub fn inner(&self) -> &PyObject {
+        &self.0
+    }
+}
+
+impl TryFrom<&Expr> for PyArrowFilterExpression {
+    type Error = DataFusionError;
+
+    // Converts a Datafusion filter Expr into an expression string that can be 
evaluated by Python
+    // Note that pyarrow.compute.{field,scalar} are put into Python globals() 
when evaluated
+    // isin, is_null, and is_valid (~is_null) are methods of 
pyarrow.dataset.Expression
+    // 
https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html#pyarrow-dataset-expression
+    fn try_from(expr: &Expr) -> Result<Self, Self::Error> {
+        Python::with_gil(|py| {
+            let pc = Python::import(py, "pyarrow.compute")?;
+            let op_module = Python::import(py, "operator")?;
+            let pc_expr: Result<&PyAny, DataFusionError> = match expr {
+                Expr::Column(Column { name, .. }) => 
Ok(pc.getattr("field")?.call1((name,))?),
+                Expr::Literal(v) => match v {
+                    ScalarValue::Boolean(Some(b)) => 
Ok(pc.getattr("scalar")?.call1((*b,))?),
+                    ScalarValue::Int8(Some(i)) => 
Ok(pc.getattr("scalar")?.call1((*i,))?),
+                    ScalarValue::Int16(Some(i)) => 
Ok(pc.getattr("scalar")?.call1((*i,))?),
+                    ScalarValue::Int32(Some(i)) => 
Ok(pc.getattr("scalar")?.call1((*i,))?),
+                    ScalarValue::Int64(Some(i)) => 
Ok(pc.getattr("scalar")?.call1((*i,))?),
+                    ScalarValue::UInt8(Some(i)) => 
Ok(pc.getattr("scalar")?.call1((*i,))?),
+                    ScalarValue::UInt16(Some(i)) => 
Ok(pc.getattr("scalar")?.call1((*i,))?),
+                    ScalarValue::UInt32(Some(i)) => 
Ok(pc.getattr("scalar")?.call1((*i,))?),
+                    ScalarValue::UInt64(Some(i)) => 
Ok(pc.getattr("scalar")?.call1((*i,))?),
+                    ScalarValue::Float32(Some(f)) => 
Ok(pc.getattr("scalar")?.call1((*f,))?),
+                    ScalarValue::Float64(Some(f)) => 
Ok(pc.getattr("scalar")?.call1((*f,))?),
+                    ScalarValue::Utf8(Some(s)) => 
Ok(pc.getattr("scalar")?.call1((s,))?),
+                    _ => Err(DataFusionError::Common(format!(
+                        "PyArrow can't handle ScalarValue: {:?}",
+                        v
+                    ))),
+                },
+                Expr::BinaryExpr { left, op, right } => {
+                    let operator = operator_to_py(op, op_module)?;
+                    let left = 
PyArrowFilterExpression::try_from(left.as_ref())?.0;
+                    let right = 
PyArrowFilterExpression::try_from(right.as_ref())?.0;
+                    Ok(operator.call1((left, right))?)
+                }
+                Expr::Not(expr) => {
+                    let operator = op_module.getattr("invert")?;
+                    let py_expr = 
PyArrowFilterExpression::try_from(expr.as_ref())?.0;
+                    Ok(operator.call1((py_expr,))?)
+                }
+                Expr::IsNotNull(expr) => {
+                    let py_expr = 
PyArrowFilterExpression::try_from(expr.as_ref())?
+                        .0
+                        .into_ref(py);
+                    Ok(py_expr.call_method0("is_valid")?)
+                }
+                Expr::IsNull(expr) => {
+                    let expr = 
PyArrowFilterExpression::try_from(expr.as_ref())?
+                        .0
+                        .into_ref(py);
+                    Ok(expr.call_method1("is_null", (expr,))?)
+                }
+                Expr::Between {
+                    expr,
+                    negated,
+                    low,
+                    high,
+                } => {
+                    let expr = 
PyArrowFilterExpression::try_from(expr.as_ref())?.0;
+                    let low = 
PyArrowFilterExpression::try_from(low.as_ref())?.0;
+                    let high = 
PyArrowFilterExpression::try_from(high.as_ref())?.0;
+                    let and = op_module.getattr("and_")?;
+                    let le = op_module.getattr("le")?;
+                    let invert = op_module.getattr("invert")?;
+
+                    // scalar <= field() returns a boolean expression so we 
need to use and to combine these
+                    let ret = and.call1((
+                        le.call1((low, expr.clone_ref(py)))?,
+                        le.call1((expr, high))?,
+                    ))?;
+
+                    Ok(if *negated { invert.call1((ret,))? } else { ret })
+                }
+                Expr::InList {
+                    expr,
+                    list,
+                    negated,
+                } => {
+                    let expr = 
PyArrowFilterExpression::try_from(expr.as_ref())?
+                        .0
+                        .into_ref(py);
+                    let scalars = extract_scalar_list(list, py)?;
+                    let ret = expr.call_method1("isin", (scalars,))?;
+                    let invert = op_module.getattr("invert")?;
+
+                    Ok(if *negated { invert.call1((ret,))? } else { ret })
+                }
+                _ => Err(DataFusionError::Common(format!(
+                    "Unsupported Datafusion expression {:?}",
+                    expr
+                ))),
+            };
+            Ok(PyArrowFilterExpression(pc_expr?.into()))
+        })
+    }
+}

Reply via email to