This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 45a6844  Pyarrow filter pushdowns (#735)
45a6844 is described below

commit 45a684445e25032961a7bb44ced3ce06f5ed9e6d
Author: Michael J Ward <[email protected]>
AuthorDate: Wed Jun 19 11:20:39 2024 -0500

    Pyarrow filter pushdowns (#735)
    
    * fix pushdown for pyarrow filter IsNull
    
    The conversion was incorrectly passing in the expression itself as the 
`nan_as_null` argument. This caused the pushdown to silently fail.
    
    * expand the Expr::Literal's that can be used in PyArrowFilterExpression
    
    Closes #703
---
 python/datafusion/tests/test_context.py | 54 +++++++++++++++++++++++++++++++++
 src/pyarrow_filter_expression.rs        | 29 ++++++------------
 2 files changed, 64 insertions(+), 19 deletions(-)

diff --git a/python/datafusion/tests/test_context.py 
b/python/datafusion/tests/test_context.py
index df7e181..abc324d 100644
--- a/python/datafusion/tests/test_context.py
+++ b/python/datafusion/tests/test_context.py
@@ -16,6 +16,7 @@
 # under the License.
 import gzip
 import os
+import datetime as dt
 
 import pyarrow as pa
 import pyarrow.dataset as ds
@@ -303,6 +304,59 @@ def test_dataset_filter(ctx, capfd):
     assert result[0].column(1) == pa.array([-3])
 
 
+def test_pyarrow_predicate_pushdown_is_null(ctx, capfd):
+    """Ensure that pyarrow filter gets pushed down for `IsNull`"""
+    # create a RecordBatch and register it as a pyarrow.dataset.Dataset
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([7, None, 9])],
+        names=["a", "b", "c"],
+    )
+    dataset = ds.dataset([batch])
+    ctx.register_dataset("t", dataset)
+    # Make sure the filter was pushed down in Physical Plan
+    df = ctx.sql("SELECT a FROM t WHERE c is NULL")
+    df.explain()
+    captured = capfd.readouterr()
+    assert "filter_expr=is_null(c, {nan_is_null=false})" in captured.out
+
+    result = df.collect()
+    assert result[0].column(0) == pa.array([2])
+
+
+def test_pyarrow_predicate_pushdown_timestamp(ctx, tmpdir, capfd):
+    """Ensure that pyarrow filter gets pushed down for timestamp"""
+    # Ref: https://github.com/apache/datafusion-python/issues/703
+
+    # create pyarrow dataset with no actual files
+    col_type = pa.timestamp("ns", "+00:00")
+    nyd_2000 = pa.scalar(dt.datetime(2000, 1, 1, tzinfo=dt.timezone.utc), 
col_type)
+    pa_dataset_fs = pa.fs.SubTreeFileSystem(str(tmpdir), 
pa.fs.LocalFileSystem())
+    pa_dataset_format = pa.dataset.ParquetFileFormat()
+    pa_dataset_partition = pa.dataset.field("a") <= nyd_2000
+    fragments = [
+        # NOTE: we never actually make this file.
+        # Working predicate pushdown means it never gets accessed
+        pa_dataset_format.make_fragment(
+            "1.parquet",
+            filesystem=pa_dataset_fs,
+            partition_expression=pa_dataset_partition,
+        )
+    ]
+    pa_dataset = pa.dataset.FileSystemDataset(
+        fragments,
+        pa.schema([pa.field("a", col_type)]),
+        pa_dataset_format,
+        pa_dataset_fs,
+    )
+
+    ctx.register_dataset("t", pa_dataset)
+
+    # the partition for our only fragment is for a < 2000-01-01.
+    # so querying for a > 2024-01-01 should not touch any files
+    df = ctx.sql("SELECT * FROM t WHERE a > '2024-01-01T00:00:00+00:00'")
+    assert df.collect() == []
+
+
 def test_dataset_filter_nested_data(ctx):
     # create Arrow StructArrays to test nested data types
     data = pa.StructArray.from_arrays(
diff --git a/src/pyarrow_filter_expression.rs b/src/pyarrow_filter_expression.rs
index fca8851..ff447e1 100644
--- a/src/pyarrow_filter_expression.rs
+++ b/src/pyarrow_filter_expression.rs
@@ -21,6 +21,7 @@ use pyo3::prelude::*;
 use std::convert::TryFrom;
 use std::result::Result;
 
+use arrow::pyarrow::ToPyArrow;
 use datafusion_common::{Column, ScalarValue};
 use datafusion_expr::{expr::InList, Between, BinaryExpr, Expr, Operator};
 
@@ -56,6 +57,7 @@ fn extract_scalar_list(exprs: &[Expr], py: Python) -> 
Result<Vec<PyObject>, Data
     let ret: Result<Vec<PyObject>, DataFusionError> = exprs
         .iter()
         .map(|expr| match expr {
+            // TODO: should we also leverage `ScalarValue::to_pyarrow` here?
             Expr::Literal(v) => match v {
                 ScalarValue::Boolean(Some(b)) => Ok(b.into_py(py)),
                 ScalarValue::Int8(Some(i)) => Ok(i.into_py(py)),
@@ -100,23 +102,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression {
             let op_module = Python::import_bound(py, "operator")?;
             let pc_expr: Result<Bound<'_, PyAny>, DataFusionError> = match 
expr {
                 Expr::Column(Column { name, .. }) => 
Ok(pc.getattr("field")?.call1((name,))?),
-                Expr::Literal(v) => match v {
-                    ScalarValue::Boolean(Some(b)) => 
Ok(pc.getattr("scalar")?.call1((*b,))?),
-                    ScalarValue::Int8(Some(i)) => 
Ok(pc.getattr("scalar")?.call1((*i,))?),
-                    ScalarValue::Int16(Some(i)) => 
Ok(pc.getattr("scalar")?.call1((*i,))?),
-                    ScalarValue::Int32(Some(i)) => 
Ok(pc.getattr("scalar")?.call1((*i,))?),
-                    ScalarValue::Int64(Some(i)) => 
Ok(pc.getattr("scalar")?.call1((*i,))?),
-                    ScalarValue::UInt8(Some(i)) => 
Ok(pc.getattr("scalar")?.call1((*i,))?),
-                    ScalarValue::UInt16(Some(i)) => 
Ok(pc.getattr("scalar")?.call1((*i,))?),
-                    ScalarValue::UInt32(Some(i)) => 
Ok(pc.getattr("scalar")?.call1((*i,))?),
-                    ScalarValue::UInt64(Some(i)) => 
Ok(pc.getattr("scalar")?.call1((*i,))?),
-                    ScalarValue::Float32(Some(f)) => 
Ok(pc.getattr("scalar")?.call1((*f,))?),
-                    ScalarValue::Float64(Some(f)) => 
Ok(pc.getattr("scalar")?.call1((*f,))?),
-                    ScalarValue::Utf8(Some(s)) => 
Ok(pc.getattr("scalar")?.call1((s,))?),
-                    _ => Err(DataFusionError::Common(format!(
-                        "PyArrow can't handle ScalarValue: {v:?}"
-                    ))),
-                },
+                Expr::Literal(scalar) => 
Ok(scalar.to_pyarrow(py)?.into_bound(py)),
                 Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
                     let operator = operator_to_py(op, &op_module)?;
                     let left = 
PyArrowFilterExpression::try_from(left.as_ref())?.0;
@@ -138,8 +124,13 @@ impl TryFrom<&Expr> for PyArrowFilterExpression {
                     let expr = 
PyArrowFilterExpression::try_from(expr.as_ref())?
                         .0
                         .into_bound(py);
-                    // TODO: this expression does not seems like it should be 
`call_method0`
-                    Ok(expr.clone().call_method1("is_null", (expr,))?)
+
+                    // 
https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html#pyarrow.dataset.Expression.is_null
+                    // Whether floating-point NaNs are considered null.
+                    let nan_is_null = false;
+
+                    let res = expr.call_method1("is_null", (nan_is_null,))?;
+                    Ok(res)
                 }
                 Expr::Between(Between {
                     expr,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to