This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new c763d7d4 [Python] Add more methods to SessionContext (#978)
c763d7d4 is described below

commit c763d7d4e459f8290e90e86e2264076bfbc4faf3
Author: Andy Grove <[email protected]>
AuthorDate: Sun Feb 11 14:17:05 2024 -0700

    [Python] Add more methods to SessionContext (#978)
    
    * implement register_csv and register_parquet
    
    * add more methods to pyballista
---
 python/pyballista/tests/test_context.py |  16 +++
 python/src/context.rs                   | 174 ++++++++++++++++++++++++++++++++
 2 files changed, 190 insertions(+)

diff --git a/python/pyballista/tests/test_context.py 
b/python/pyballista/tests/test_context.py
index 9c264e98..0a395527 100644
--- a/python/pyballista/tests/test_context.py
+++ b/python/pyballista/tests/test_context.py
@@ -34,6 +34,14 @@ def test_read_csv():
     assert len(batches) == 1
     assert len(batches[0]) == 1
 
+def test_register_csv():
+    ctx = SessionContext("localhost", 50050)
+    ctx.register_csv("test", "testdata/test.csv", has_header=True)
+    df = ctx.sql("SELECT * FROM test")
+    batches = df.collect()
+    assert len(batches) == 1
+    assert len(batches[0]) == 1
+
 def test_read_parquet():
     ctx = SessionContext("localhost", 50050)
     df = ctx.read_parquet("testdata/test.parquet")
@@ -41,6 +49,14 @@ def test_read_parquet():
     assert len(batches) == 1
     assert len(batches[0]) == 8
 
+def test_register_parquet():
+    ctx = SessionContext("localhost", 50050)
+    ctx.register_parquet("test", "testdata/test.parquet")
+    df = ctx.sql("SELECT * FROM test")
+    batches = df.collect()
+    assert len(batches) == 1
+    assert len(batches[0]) == 8
+
 def test_read_dataframe_api():
     ctx = SessionContext("localhost", 50050)
     df = ctx.read_csv("testdata/test.csv", has_header=True) \
diff --git a/python/src/context.rs b/python/src/context.rs
index d9e7feee..b85e0e10 100644
--- a/python/src/context.rs
+++ b/python/src/context.rs
@@ -24,6 +24,7 @@ use crate::utils::to_pyerr;
 use ballista::prelude::*;
 use datafusion::arrow::datatypes::Schema;
 use datafusion::arrow::pyarrow::PyArrowType;
+use datafusion_python::catalog::PyTable;
 use datafusion_python::context::{
     convert_table_partition_cols, parse_file_compression_type,
 };
@@ -60,6 +61,30 @@ impl PySessionContext {
         Ok(PyDataFrame::new(df))
     }
 
+    #[allow(clippy::too_many_arguments)]
+    #[pyo3(signature = (path, schema=None, table_partition_cols=vec![], 
file_extension=".avro"))]
+    pub fn read_avro(
+        &self,
+        path: &str,
+        schema: Option<PyArrowType<Schema>>,
+        table_partition_cols: Vec<(String, String)>,
+        file_extension: &str,
+        py: Python,
+    ) -> PyResult<PyDataFrame> {
+        let mut options = AvroReadOptions::default()
+            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
+        options.file_extension = file_extension;
+        let df = if let Some(schema) = schema {
+            options.schema = Some(&schema.0);
+            let read_future = self.ctx.read_avro(path, options);
+            wait_for_future(py, read_future).map_err(DataFusionError::from)?
+        } else {
+            let read_future = self.ctx.read_avro(path, options);
+            wait_for_future(py, read_future).map_err(DataFusionError::from)?
+        };
+        Ok(PyDataFrame::new(df))
+    }
+
     #[allow(clippy::too_many_arguments)]
     #[pyo3(signature = (
         path,
@@ -113,6 +138,37 @@ impl PySessionContext {
         }
     }
 
+    #[allow(clippy::too_many_arguments)]
+    #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, 
file_extension=".json", table_partition_cols=vec![], 
file_compression_type=None))]
+    pub fn read_json(
+        &mut self,
+        path: PathBuf,
+        schema: Option<PyArrowType<Schema>>,
+        schema_infer_max_records: usize,
+        file_extension: &str,
+        table_partition_cols: Vec<(String, String)>,
+        file_compression_type: Option<String>,
+        py: Python,
+    ) -> PyResult<PyDataFrame> {
+        let path = path
+            .to_str()
+            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a 
string"))?;
+        let mut options = NdJsonReadOptions::default()
+            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+            
.file_compression_type(parse_file_compression_type(file_compression_type)?);
+        options.schema_infer_max_records = schema_infer_max_records;
+        options.file_extension = file_extension;
+        let df = if let Some(schema) = schema {
+            options.schema = Some(&schema.0);
+            let result = self.ctx.read_json(path, options);
+            wait_for_future(py, result).map_err(DataFusionError::from)?
+        } else {
+            let result = self.ctx.read_json(path, options);
+            wait_for_future(py, result).map_err(DataFusionError::from)?
+        };
+        Ok(PyDataFrame::new(df))
+    }
+
     #[allow(clippy::too_many_arguments)]
     #[pyo3(signature = (
         path,
@@ -150,4 +206,122 @@ impl PySessionContext {
             PyDataFrame::new(wait_for_future(py, 
result).map_err(DataFusionError::from)?);
         Ok(df)
     }
+
+    #[allow(clippy::too_many_arguments)]
+    #[pyo3(signature = (name,
+    path,
+    schema=None,
+    file_extension=".avro",
+    table_partition_cols=vec![]))]
+    pub fn register_avro(
+        &mut self,
+        name: &str,
+        path: PathBuf,
+        schema: Option<PyArrowType<Schema>>,
+        file_extension: &str,
+        table_partition_cols: Vec<(String, String)>,
+        py: Python,
+    ) -> PyResult<()> {
+        let path = path
+            .to_str()
+            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a 
string"))?;
+
+        let mut options = AvroReadOptions::default()
+            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
+        options.file_extension = file_extension;
+        options.schema = schema.as_ref().map(|x| &x.0);
+
+        let result = self.ctx.register_avro(name, path, options);
+        wait_for_future(py, result).map_err(DataFusionError::from)?;
+
+        Ok(())
+    }
+
+    #[allow(clippy::too_many_arguments)]
+    #[pyo3(signature = (name,
+    path,
+    schema=None,
+    has_header=true,
+    delimiter=",",
+    schema_infer_max_records=1000,
+    file_extension=".csv",
+    file_compression_type=None))]
+    pub fn register_csv(
+        &mut self,
+        name: &str,
+        path: PathBuf,
+        schema: Option<PyArrowType<Schema>>,
+        has_header: bool,
+        delimiter: &str,
+        schema_infer_max_records: usize,
+        file_extension: &str,
+        file_compression_type: Option<String>,
+        py: Python,
+    ) -> PyResult<()> {
+        let path = path
+            .to_str()
+            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a 
string"))?;
+        let delimiter = delimiter.as_bytes();
+        if delimiter.len() != 1 {
+            return Err(PyValueError::new_err(
+                "Delimiter must be a single character",
+            ));
+        }
+
+        let mut options = CsvReadOptions::new()
+            .has_header(has_header)
+            .delimiter(delimiter[0])
+            .schema_infer_max_records(schema_infer_max_records)
+            .file_extension(file_extension)
+            
.file_compression_type(parse_file_compression_type(file_compression_type)?);
+        options.schema = schema.as_ref().map(|x| &x.0);
+
+        let result = self.ctx.register_csv(name, path, options);
+        wait_for_future(py, result).map_err(DataFusionError::from)?;
+
+        Ok(())
+    }
+
+    #[allow(clippy::too_many_arguments)]
+    #[pyo3(signature = (name, path, table_partition_cols=vec![],
+    parquet_pruning=true,
+    file_extension=".parquet",
+    skip_metadata=true,
+    schema=None,
+    file_sort_order=None))]
+    pub fn register_parquet(
+        &mut self,
+        name: &str,
+        path: &str,
+        table_partition_cols: Vec<(String, String)>,
+        parquet_pruning: bool,
+        file_extension: &str,
+        skip_metadata: bool,
+        schema: Option<PyArrowType<Schema>>,
+        file_sort_order: Option<Vec<Vec<PyExpr>>>,
+        py: Python,
+    ) -> PyResult<()> {
+        let mut options = ParquetReadOptions::default()
+            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+            .parquet_pruning(parquet_pruning)
+            .skip_metadata(skip_metadata);
+        options.file_extension = file_extension;
+        options.schema = schema.as_ref().map(|x| &x.0);
+        options.file_sort_order = file_sort_order
+            .unwrap_or_default()
+            .into_iter()
+            .map(|e| e.into_iter().map(|f| f.into()).collect())
+            .collect();
+
+        let result = self.ctx.register_parquet(name, path, options);
+        wait_for_future(py, result).map_err(DataFusionError::from)?;
+        Ok(())
+    }
+
+    pub fn register_table(&mut self, name: &str, table: &PyTable) -> 
PyResult<()> {
+        self.ctx
+            .register_table(name, table.table())
+            .map_err(DataFusionError::from)?;
+        Ok(())
+    }
 }

Reply via email to