This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 912f789e [Python] Add `read_csv` and `read_parquet` methods (#976)
912f789e is described below

commit 912f789e7f73877d2c262fbe2d4d41e34812294f
Author: Andy Grove <[email protected]>
AuthorDate: Sat Feb 10 11:09:45 2024 -0700

    [Python] Add `read_csv` and `read_parquet` methods (#976)
---
 python/Cargo.toml                       |   9 +-
 python/README.md                        |  27 +++---
 python/pyballista/__init__.py           |   2 -
 python/pyballista/tests/test_context.py |  27 +++++-
 python/requirements.txt                 |   1 +
 python/src/context.rs                   | 153 ++++++++++++++++++++++++++++++++
 python/src/lib.rs                       |  98 ++------------------
 python/src/utils.rs                     |  24 +++++
 python/testdata/test.csv                |   2 +
 python/testdata/test.parquet            | Bin 0 -> 1851 bytes
 10 files changed, 237 insertions(+), 106 deletions(-)

diff --git a/python/Cargo.toml b/python/Cargo.toml
index 6a63b6f7..833f6c7c 100644
--- a/python/Cargo.toml
+++ b/python/Cargo.toml
@@ -27,11 +27,18 @@ license = "Apache-2.0"
 edition = "2021"
 rust-version = "1.64"
 include = ["/src", "/pyballista", "/LICENSE.txt", "pyproject.toml", 
"Cargo.toml", "Cargo.lock"]
+publish = false
 
 [dependencies]
+async-trait = "0.1.77"
 ballista = { path = "../ballista/client", version = "0.12.0" }
+ballista-core = { path = "../ballista/core", version = "0.12.0" }
 datafusion = "35.0.0"
-datafusion-python = "35.0.0"
+datafusion-proto = "35.0.0"
+
+# we need to use a recent build of ADP that has a public PyDataFrame
+datafusion-python = { git = 
"https://github.com/apache/arrow-datafusion-python";, rev = 
"5296c0cfcf8e6fcb654d5935252469bf04f929e9" }
+
 pyo3 = { version = "0.20", features = ["extension-module", "abi3", 
"abi3-py38"] }
 tokio = { version = "1.35", features = ["macros", "rt", "rt-multi-thread", 
"sync"] }
 
diff --git a/python/README.md b/python/README.md
index 1819a720..2898cb16 100644
--- a/python/README.md
+++ b/python/README.md
@@ -19,26 +19,33 @@
 
 # PyBallista
 
-Minimal Python client for Ballista.
-
-The goal of this project is to provide a way to run SQL against a Ballista 
cluster from Python and collect
-results as PyArrow record batches.
-
-Note that this client currently only provides a SQL API and not a DataFrame 
API. A future release will support
-using the DataFrame API from DataFusion's Python bindings to create a logical 
plan and then execute that logical plan
-from the Ballista context ([tracking 
issue](https://github.com/apache/arrow-ballista/issues/971)).
+Python client for Ballista.
 
 This project is versioned and released independently from the main Ballista 
project and is intentionally not
 part of the default Cargo workspace so that it doesn't cause overhead for 
maintainers of the main Ballista codebase.
 
-## Example Usage
+## Creating a SessionContext
+
+Creates a new context and connects to a Ballista scheduler process.
 
 ```python
 from pyballista import SessionContext
 >>> ctx = SessionContext("localhost", 50050)
+```
+
+## Example SQL Usage
+
+```python
 >>> ctx.sql("create external table t stored as parquet location 
 >>> '/mnt/bigdata/tpch/sf10-parquet/lineitem.parquet'")
 >>> df = ctx.sql("select * from t limit 5")
->>> df.collect()
+>>> pyarrow_batches = df.collect()
+```
+
+## Example DataFrame Usage
+
+```python
+>>> df = 
ctx.read_parquet('/mnt/bigdata/tpch/sf10-parquet/lineitem.parquet').limit(5)
+>>> pyarrow_batches = df.collect()
 ```
 
 ## Creating Virtual Environment
diff --git a/python/pyballista/__init__.py b/python/pyballista/__init__.py
index 480e9eda..62a6bc79 100644
--- a/python/pyballista/__init__.py
+++ b/python/pyballista/__init__.py
@@ -27,12 +27,10 @@ import pyarrow as pa
 
 from .pyballista_internal import (
     SessionContext,
-    DataFrame
 )
 
 __version__ = importlib_metadata.version(__name__)
 
 __all__ = [
     "SessionContext",
-    "DataFrame",
 ]
diff --git a/python/pyballista/tests/test_context.py 
b/python/pyballista/tests/test_context.py
index 46e67e10..9c264e98 100644
--- a/python/pyballista/tests/test_context.py
+++ b/python/pyballista/tests/test_context.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 from pyballista import SessionContext
 import pytest
 
@@ -23,4 +24,28 @@ def test_create_context():
 def test_select_one():
     ctx = SessionContext("localhost", 50050)
     df = ctx.sql("SELECT 1")
-    df.collect()
\ No newline at end of file
+    batches = df.collect()
+    assert len(batches) == 1
+
+def test_read_csv():
+    ctx = SessionContext("localhost", 50050)
+    df = ctx.read_csv("testdata/test.csv", has_header=True)
+    batches = df.collect()
+    assert len(batches) == 1
+    assert len(batches[0]) == 1
+
+def test_read_parquet():
+    ctx = SessionContext("localhost", 50050)
+    df = ctx.read_parquet("testdata/test.parquet")
+    batches = df.collect()
+    assert len(batches) == 1
+    assert len(batches[0]) == 8
+
+def test_read_dataframe_api():
+    ctx = SessionContext("localhost", 50050)
+    df = ctx.read_csv("testdata/test.csv", has_header=True) \
+        .select_columns('a', 'b') \
+        .limit(1)
+    batches = df.collect()
+    assert len(batches) == 1
+    assert len(batches[0]) == 1
diff --git a/python/requirements.txt b/python/requirements.txt
index f6acb176..a03a8f8d 100644
--- a/python/requirements.txt
+++ b/python/requirements.txt
@@ -1,2 +1,3 @@
+datafusion==35.0.0
 pyarrow
 pytest
\ No newline at end of file
diff --git a/python/src/context.rs b/python/src/context.rs
new file mode 100644
index 00000000..d9e7feee
--- /dev/null
+++ b/python/src/context.rs
@@ -0,0 +1,153 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion::prelude::*;
+use pyo3::exceptions::PyValueError;
+use pyo3::prelude::*;
+use std::path::PathBuf;
+
+use crate::utils::to_pyerr;
+use ballista::prelude::*;
+use datafusion::arrow::datatypes::Schema;
+use datafusion::arrow::pyarrow::PyArrowType;
+use datafusion_python::context::{
+    convert_table_partition_cols, parse_file_compression_type,
+};
+use datafusion_python::dataframe::PyDataFrame;
+use datafusion_python::errors::DataFusionError;
+use datafusion_python::expr::PyExpr;
+use datafusion_python::utils::wait_for_future;
+
+/// PyBallista session context. This is largely a duplicate of
+/// DataFusion's PySessionContext, with the main difference being
+/// that this operates on a BallistaContext instead of DataFusion's
+/// SessionContext. We could probably add extra extension points to
+/// DataFusion to allow for a pluggable context and remove much of
+/// this code.
+#[pyclass(name = "SessionContext", module = "pyballista", subclass)]
+pub struct PySessionContext {
+    ctx: BallistaContext,
+}
+
+#[pymethods]
+impl PySessionContext {
+    /// Create a new SessionContext by connecting to a Ballista scheduler 
process.
+    #[new]
+    pub fn new(host: &str, port: u16, py: Python) -> PyResult<Self> {
+        let config = BallistaConfig::new().unwrap();
+        let ballista_context = BallistaContext::remote(host, port, &config);
+        let ctx = wait_for_future(py, ballista_context).map_err(to_pyerr)?;
+        Ok(Self { ctx })
+    }
+
+    pub fn sql(&mut self, query: &str, py: Python) -> PyResult<PyDataFrame> {
+        let result = self.ctx.sql(query);
+        let df = wait_for_future(py, result)?;
+        Ok(PyDataFrame::new(df))
+    }
+
+    #[allow(clippy::too_many_arguments)]
+    #[pyo3(signature = (
+        path,
+        schema=None,
+        has_header=true,
+        delimiter=",",
+        schema_infer_max_records=1000,
+        file_extension=".csv",
+        table_partition_cols=vec![],
+        file_compression_type=None))]
+    pub fn read_csv(
+        &self,
+        path: PathBuf,
+        schema: Option<PyArrowType<Schema>>,
+        has_header: bool,
+        delimiter: &str,
+        schema_infer_max_records: usize,
+        file_extension: &str,
+        table_partition_cols: Vec<(String, String)>,
+        file_compression_type: Option<String>,
+        py: Python,
+    ) -> PyResult<PyDataFrame> {
+        let path = path
+            .to_str()
+            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a 
string"))?;
+
+        let delimiter = delimiter.as_bytes();
+        if delimiter.len() != 1 {
+            return Err(PyValueError::new_err(
+                "Delimiter must be a single character",
+            ));
+        };
+
+        let mut options = CsvReadOptions::new()
+            .has_header(has_header)
+            .delimiter(delimiter[0])
+            .schema_infer_max_records(schema_infer_max_records)
+            .file_extension(file_extension)
+            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+            
.file_compression_type(parse_file_compression_type(file_compression_type)?);
+
+        if let Some(py_schema) = schema {
+            options.schema = Some(&py_schema.0);
+            let result = self.ctx.read_csv(path, options);
+            let df = PyDataFrame::new(wait_for_future(py, result)?);
+            Ok(df)
+        } else {
+            let result = self.ctx.read_csv(path, options);
+            let df = PyDataFrame::new(wait_for_future(py, result)?);
+            Ok(df)
+        }
+    }
+
+    #[allow(clippy::too_many_arguments)]
+    #[pyo3(signature = (
+        path,
+        table_partition_cols=vec![],
+        parquet_pruning=true,
+        file_extension=".parquet",
+        skip_metadata=true,
+        schema=None,
+        file_sort_order=None))]
+    pub fn read_parquet(
+        &self,
+        path: &str,
+        table_partition_cols: Vec<(String, String)>,
+        parquet_pruning: bool,
+        file_extension: &str,
+        skip_metadata: bool,
+        schema: Option<PyArrowType<Schema>>,
+        file_sort_order: Option<Vec<Vec<PyExpr>>>,
+        py: Python,
+    ) -> PyResult<PyDataFrame> {
+        let mut options = ParquetReadOptions::default()
+            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+            .parquet_pruning(parquet_pruning)
+            .skip_metadata(skip_metadata);
+        options.file_extension = file_extension;
+        options.schema = schema.as_ref().map(|x| &x.0);
+        options.file_sort_order = file_sort_order
+            .unwrap_or_default()
+            .into_iter()
+            .map(|e| e.into_iter().map(|f| f.into()).collect())
+            .collect();
+
+        let result = self.ctx.read_parquet(path, options);
+        let df =
+            PyDataFrame::new(wait_for_future(py, 
result).map_err(DataFusionError::from)?);
+        Ok(df)
+    }
+}
diff --git a/python/src/lib.rs b/python/src/lib.rs
index 186a570e..04cf232a 100644
--- a/python/src/lib.rs
+++ b/python/src/lib.rs
@@ -15,103 +15,17 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use datafusion::arrow::pyarrow::ToPyArrow;
-use datafusion::prelude::DataFrame;
-use pyo3::exceptions::PyException;
 use pyo3::prelude::*;
-use std::future::Future;
-use std::sync::Arc;
-use tokio::runtime::Runtime;
+pub mod context;
+mod utils;
 
-use ballista::prelude::*;
-
-/// PyBallista SessionContext
-#[pyclass(name = "SessionContext", module = "pyballista", subclass)]
-pub struct PySessionContext {
-    ctx: BallistaContext,
-}
-
-#[pymethods]
-impl PySessionContext {
-    #[new]
-    pub fn new(host: &str, port: u16, py: Python) -> PyResult<Self> {
-        let config = BallistaConfig::new().unwrap();
-        let ballista_context = BallistaContext::remote(host, port, &config);
-        let ctx = wait_for_future(py, ballista_context).map_err(to_pyerr)?;
-        Ok(Self { ctx })
-    }
-
-    pub fn sql(&mut self, query: &str, py: Python) -> PyResult<PyDataFrame> {
-        let result = self.ctx.sql(query);
-        let df = wait_for_future(py, result)?;
-        Ok(PyDataFrame::new(df))
-    }
-}
-
-#[pyclass(name = "DataFrame", module = "pyballista", subclass)]
-#[derive(Clone)]
-pub struct PyDataFrame {
-    /// DataFusion DataFrame
-    df: Arc<DataFrame>,
-}
-
-impl PyDataFrame {
-    /// creates a new PyDataFrame
-    pub fn new(df: DataFrame) -> Self {
-        Self { df: Arc::new(df) }
-    }
-}
-
-#[pymethods]
-impl PyDataFrame {
-    /// Executes the plan, returning a list of `RecordBatch`es.
-    /// Unless some order is specified in the plan, there is no
-    /// guarantee of the order of the result.
-    fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
-        let batches = wait_for_future(py, self.df.as_ref().clone().collect())?;
-        // cannot use PyResult<Vec<RecordBatch>> return type due to
-        // https://github.com/PyO3/pyo3/issues/1813
-        batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
-    }
-}
-
-fn wait_for_future<F: Future>(py: Python, f: F) -> F::Output
-where
-    F: Send,
-    F::Output: Send,
-{
-    let runtime: &Runtime = &get_tokio_runtime(py).0;
-    py.allow_threads(|| runtime.block_on(f))
-}
-
-fn get_tokio_runtime(py: Python) -> PyRef<TokioRuntime> {
-    let ballista = py.import("pyballista._internal").unwrap();
-    let tmp = ballista.getattr("runtime").unwrap();
-    match tmp.extract::<PyRef<TokioRuntime>>() {
-        Ok(runtime) => runtime,
-        Err(_e) => {
-            let rt = TokioRuntime(tokio::runtime::Runtime::new().unwrap());
-            let obj: &PyAny = Py::new(py, rt).unwrap().into_ref(py);
-            obj.extract().unwrap()
-        }
-    }
-}
-
-fn to_pyerr(err: BallistaError) -> PyErr {
-    PyException::new_err(err.to_string())
-}
-
-#[pyclass]
-pub(crate) struct TokioRuntime(tokio::runtime::Runtime);
+pub use crate::context::PySessionContext;
 
 #[pymodule]
 fn pyballista_internal(_py: Python, m: &PyModule) -> PyResult<()> {
-    // Register the Tokio Runtime as a module attribute so we can reuse it
-    m.add(
-        "runtime",
-        TokioRuntime(tokio::runtime::Runtime::new().unwrap()),
-    )?;
+    // Ballista structs
     m.add_class::<PySessionContext>()?;
-    m.add_class::<PyDataFrame>()?;
+    // DataFusion structs
+    m.add_class::<datafusion_python::dataframe::PyDataFrame>()?;
     Ok(())
 }
diff --git a/python/src/utils.rs b/python/src/utils.rs
new file mode 100644
index 00000000..10278537
--- /dev/null
+++ b/python/src/utils.rs
@@ -0,0 +1,24 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use ballista_core::error::BallistaError;
+use pyo3::exceptions::PyException;
+use pyo3::PyErr;
+
+pub(crate) fn to_pyerr(err: BallistaError) -> PyErr {
+    PyException::new_err(err.to_string())
+}
diff --git a/python/testdata/test.csv b/python/testdata/test.csv
new file mode 100644
index 00000000..00910b0f
--- /dev/null
+++ b/python/testdata/test.csv
@@ -0,0 +1,2 @@
+a,b
+1,2
\ No newline at end of file
diff --git a/python/testdata/test.parquet b/python/testdata/test.parquet
new file mode 100755
index 00000000..a63f5dca
Binary files /dev/null and b/python/testdata/test.parquet differ

Reply via email to