This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 399fa75  feat: expose PyWindowFrame (#509)
399fa75 is described below

commit 399fa758ccb1dc785929123350144902a9b6c502
Author: Dan Lovell <[email protected]>
AuthorDate: Tue Oct 17 16:59:41 2023 -0400

    feat: expose PyWindowFrame (#509)
    
    * feat: expose PyWindowFrame
    
    * fix: PyWindowFrame: return Err instead of panicking
    
    * test: test PyWindowFrame creation
---
 datafusion/__init__.py             |   2 +
 datafusion/tests/test_dataframe.py |  41 +++++++++++++-
 src/functions.rs                   |  20 ++++++-
 src/lib.rs                         |   2 +
 src/udaf.rs                        |  30 +++++++++-
 src/window_frame.rs                | 110 +++++++++++++++++++++++++++++++++++++
 6 files changed, 200 insertions(+), 5 deletions(-)

diff --git a/datafusion/__init__.py b/datafusion/__init__.py
index bb1beac..4a495b4 100644
--- a/datafusion/__init__.py
+++ b/datafusion/__init__.py
@@ -33,6 +33,7 @@ from ._internal import (
     SessionConfig,
     RuntimeConfig,
     ScalarUDF,
+    WindowFrame,
 )
 
 from .common import (
@@ -98,6 +99,7 @@ __all__ = [
     "Expr",
     "AggregateUDF",
     "ScalarUDF",
+    "WindowFrame",
     "column",
     "literal",
     "TableScan",
diff --git a/datafusion/tests/test_dataframe.py 
b/datafusion/tests/test_dataframe.py
index ce7d89e..c9b0f07 100644
--- a/datafusion/tests/test_dataframe.py
+++ b/datafusion/tests/test_dataframe.py
@@ -21,7 +21,14 @@ import pyarrow.parquet as pq
 import pytest
 
 from datafusion import functions as f
-from datafusion import DataFrame, SessionContext, column, literal, udf
+from datafusion import (
+    DataFrame,
+    SessionContext,
+    WindowFrame,
+    column,
+    literal,
+    udf,
+)
 
 
 @pytest.fixture
@@ -304,6 +311,38 @@ def test_window_functions(df):
     assert table.sort_by("a").to_pydict() == expected
 
 
[email protected](
+    ("units", "start_bound", "end_bound"),
+    [
+        (units, start_bound, end_bound)
+        for units in ("rows", "range")
+        for start_bound in (None, 0, 1)
+        for end_bound in (None, 0, 1)
+    ]
+    + [
+        ("groups", 0, 0),
+    ],
+)
+def test_valid_window_frame(units, start_bound, end_bound):
+    WindowFrame(units, start_bound, end_bound)
+
+
[email protected](
+    ("units", "start_bound", "end_bound"),
+    [
+        ("invalid-units", 0, None),
+        ("invalid-units", None, 0),
+        ("invalid-units", None, None),
+        ("groups", None, 0),
+        ("groups", 0, None),
+        ("groups", None, None),
+    ],
+)
+def test_invalid_window_frame(units, start_bound, end_bound):
+    with pytest.raises(RuntimeError):
+        WindowFrame(units, start_bound, end_bound)
+
+
 def test_get_dataframe(tmp_path):
     ctx = SessionContext()
 
diff --git a/src/functions.rs b/src/functions.rs
index e509aff..42203d7 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -17,9 +17,12 @@
 
 use pyo3::{prelude::*, wrap_pyfunction};
 
+use crate::context::PySessionContext;
 use crate::errors::DataFusionError;
 use crate::expr::conditional_expr::PyCaseBuilder;
 use crate::expr::PyExpr;
+use crate::window_frame::PyWindowFrame;
+use datafusion::execution::FunctionRegistry;
 use datafusion_common::Column;
 use datafusion_expr::expr::Alias;
 use datafusion_expr::{
@@ -27,7 +30,7 @@ use datafusion_expr::{
     expr::{AggregateFunction, ScalarFunction, Sort, WindowFunction},
     lit,
     window_function::find_df_window_func,
-    BuiltinScalarFunction, Expr, WindowFrame,
+    BuiltinScalarFunction, Expr,
 };
 
 #[pyfunction]
@@ -130,13 +133,24 @@ fn window(
     args: Vec<PyExpr>,
     partition_by: Option<Vec<PyExpr>>,
     order_by: Option<Vec<PyExpr>>,
+    window_frame: Option<PyWindowFrame>,
+    ctx: Option<PySessionContext>,
 ) -> PyResult<PyExpr> {
-    let fun = find_df_window_func(name);
+    let fun = find_df_window_func(name).or_else(|| {
+        ctx.and_then(|ctx| {
+            ctx.ctx
+                .udaf(name)
+                .map(|fun| datafusion_expr::WindowFunction::AggregateUDF(fun))
+                .ok()
+        })
+    });
     if fun.is_none() {
         return Err(DataFusionError::Common("window function not 
found".to_string()).into());
     }
     let fun = fun.unwrap();
-    let window_frame = WindowFrame::new(order_by.is_some());
+    let window_frame = window_frame
+        .unwrap_or_else(|| PyWindowFrame::new("rows", None, Some(0)).unwrap())
+        .into();
     Ok(PyExpr {
         expr: datafusion_expr::Expr::WindowFunction(WindowFunction {
             fun,
diff --git a/src/lib.rs b/src/lib.rs
index 2512aef..b9bd576 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -54,6 +54,7 @@ mod udaf;
 #[allow(clippy::borrow_deref_ref)]
 mod udf;
 pub mod utils;
+mod window_frame;
 
 #[cfg(feature = "mimalloc")]
 #[global_allocator]
@@ -83,6 +84,7 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
     m.add_class::<context::PySessionContext>()?;
     m.add_class::<dataframe::PyDataFrame>()?;
     m.add_class::<udf::PyScalarUDF>()?;
+    m.add_class::<window_frame::PyWindowFrame>()?;
     m.add_class::<udaf::PyAggregateUDF>()?;
     m.add_class::<config::PyConfig>()?;
     m.add_class::<sql::logical::PyLogicalPlan>()?;
diff --git a/src/udaf.rs b/src/udaf.rs
index 3b70aeb..6450f03 100644
--- a/src/udaf.rs
+++ b/src/udaf.rs
@@ -17,7 +17,7 @@
 
 use std::sync::Arc;
 
-use pyo3::{prelude::*, types::PyTuple};
+use pyo3::{prelude::*, types::PyBool, types::PyTuple};
 
 use datafusion::arrow::array::{Array, ArrayRef};
 use datafusion::arrow::datatypes::DataType;
@@ -93,6 +93,34 @@ impl Accumulator for RustAccumulator {
     fn size(&self) -> usize {
         std::mem::size_of_val(self)
     }
+
+    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        Python::with_gil(|py| {
+            // 1. cast args to Pyarrow array
+            let py_args = values
+                .iter()
+                .map(|arg| arg.into_data().to_pyarrow(py).unwrap())
+                .collect::<Vec<_>>();
+            let py_args = PyTuple::new(py, py_args);
+
+            // 2. call function
+            self.accum
+                .as_ref(py)
+                .call_method1("retract_batch", py_args)
+                .map_err(|e| DataFusionError::Execution(format!("{e}")))?;
+
+            Ok(())
+        })
+    }
+
+    fn supports_retract_batch(&self) -> bool {
+        Python::with_gil(|py| {
+            let x: Result<&PyAny, PyErr> =
+                self.accum.as_ref(py).call_method0("supports_retract_batch");
+            let x: &PyAny = x.unwrap_or(PyBool::new(py, false));
+            x.extract().unwrap_or(false)
+        })
+    }
 }
 
 pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction {
diff --git a/src/window_frame.rs b/src/window_frame.rs
new file mode 100644
index 0000000..b8f414e
--- /dev/null
+++ b/src/window_frame.rs
@@ -0,0 +1,110 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion_common::{DataFusionError, ScalarValue};
+use datafusion_expr::window_frame::{WindowFrame, WindowFrameBound, 
WindowFrameUnits};
+use pyo3::prelude::*;
+use std::fmt::{Display, Formatter};
+
+use crate::errors::py_datafusion_err;
+
+#[pyclass(name = "WindowFrame", module = "datafusion", subclass)]
+#[derive(Clone)]
+pub struct PyWindowFrame {
+    frame: WindowFrame,
+}
+
+impl From<PyWindowFrame> for WindowFrame {
+    fn from(frame: PyWindowFrame) -> Self {
+        frame.frame
+    }
+}
+
+impl From<WindowFrame> for PyWindowFrame {
+    fn from(frame: WindowFrame) -> PyWindowFrame {
+        PyWindowFrame { frame }
+    }
+}
+
+impl Display for PyWindowFrame {
+    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
+        write!(
+            f,
+            "OVER ({} BETWEEN {} AND {})",
+            self.frame.units, self.frame.start_bound, self.frame.end_bound
+        )
+    }
+}
+
+#[pymethods]
+impl PyWindowFrame {
+    #[new(unit, start_bound, end_bound)]
+    pub fn new(units: &str, start_bound: Option<u64>, end_bound: Option<u64>) 
-> PyResult<Self> {
+        let units = units.to_ascii_lowercase();
+        let units = match units.as_str() {
+            "rows" => WindowFrameUnits::Rows,
+            "range" => WindowFrameUnits::Range,
+            "groups" => WindowFrameUnits::Groups,
+            _ => {
+                return 
Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
+                    "{:?}",
+                    units,
+                ))));
+            }
+        };
+        let start_bound = match start_bound {
+            Some(start_bound) => {
+                
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(start_bound)))
+            }
+            None => match units {
+                WindowFrameUnits::Range => 
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
+                WindowFrameUnits::Rows => 
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
+                WindowFrameUnits::Groups => {
+                    return 
Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
+                        "{:?}",
+                        units,
+                    ))));
+                }
+            },
+        };
+        let end_bound = match end_bound {
+            Some(end_bound) => 
WindowFrameBound::Following(ScalarValue::UInt64(Some(end_bound))),
+            None => match units {
+                WindowFrameUnits::Rows => 
WindowFrameBound::Following(ScalarValue::UInt64(None)),
+                WindowFrameUnits::Range => 
WindowFrameBound::Following(ScalarValue::UInt64(None)),
+                WindowFrameUnits::Groups => {
+                    return 
Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
+                        "{:?}",
+                        units,
+                    ))));
+                }
+            },
+        };
+        Ok(PyWindowFrame {
+            frame: WindowFrame {
+                units,
+                start_bound,
+                end_bound,
+            },
+        })
+    }
+
+    /// Get a String representation of this window frame
+    fn __repr__(&self) -> String {
+        format!("{}", self)
+    }
+}

Reply via email to