This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git
The following commit(s) were added to refs/heads/main by this push:
new c763d7d4 [Python] Add more methods to SessionContext (#978)
c763d7d4 is described below
commit c763d7d4e459f8290e90e86e2264076bfbc4faf3
Author: Andy Grove <[email protected]>
AuthorDate: Sun Feb 11 14:17:05 2024 -0700
[Python] Add more methods to SessionContext (#978)
* implement register_csv and register_parquet
* add more methods to pyballista
---
python/pyballista/tests/test_context.py | 16 +++
python/src/context.rs | 174 ++++++++++++++++++++++++++++++++
2 files changed, 190 insertions(+)
diff --git a/python/pyballista/tests/test_context.py
b/python/pyballista/tests/test_context.py
index 9c264e98..0a395527 100644
--- a/python/pyballista/tests/test_context.py
+++ b/python/pyballista/tests/test_context.py
@@ -34,6 +34,14 @@ def test_read_csv():
assert len(batches) == 1
assert len(batches[0]) == 1
+def test_register_csv():
+ ctx = SessionContext("localhost", 50050)
+ ctx.register_csv("test", "testdata/test.csv", has_header=True)
+ df = ctx.sql("SELECT * FROM test")
+ batches = df.collect()
+ assert len(batches) == 1
+ assert len(batches[0]) == 1
+
def test_read_parquet():
ctx = SessionContext("localhost", 50050)
df = ctx.read_parquet("testdata/test.parquet")
@@ -41,6 +49,14 @@ def test_read_parquet():
assert len(batches) == 1
assert len(batches[0]) == 8
+def test_register_parquet():
+ ctx = SessionContext("localhost", 50050)
+ ctx.register_parquet("test", "testdata/test.parquet")
+ df = ctx.sql("SELECT * FROM test")
+ batches = df.collect()
+ assert len(batches) == 1
+ assert len(batches[0]) == 8
+
def test_read_dataframe_api():
ctx = SessionContext("localhost", 50050)
df = ctx.read_csv("testdata/test.csv", has_header=True) \
diff --git a/python/src/context.rs b/python/src/context.rs
index d9e7feee..b85e0e10 100644
--- a/python/src/context.rs
+++ b/python/src/context.rs
@@ -24,6 +24,7 @@ use crate::utils::to_pyerr;
use ballista::prelude::*;
use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::PyArrowType;
+use datafusion_python::catalog::PyTable;
use datafusion_python::context::{
convert_table_partition_cols, parse_file_compression_type,
};
@@ -60,6 +61,30 @@ impl PySessionContext {
Ok(PyDataFrame::new(df))
}
+ #[allow(clippy::too_many_arguments)]
+ #[pyo3(signature = (path, schema=None, table_partition_cols=vec![],
file_extension=".avro"))]
+ pub fn read_avro(
+ &self,
+ path: &str,
+ schema: Option<PyArrowType<Schema>>,
+ table_partition_cols: Vec<(String, String)>,
+ file_extension: &str,
+ py: Python,
+ ) -> PyResult<PyDataFrame> {
+ let mut options = AvroReadOptions::default()
+
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
+ options.file_extension = file_extension;
+ let df = if let Some(schema) = schema {
+ options.schema = Some(&schema.0);
+ let read_future = self.ctx.read_avro(path, options);
+ wait_for_future(py, read_future).map_err(DataFusionError::from)?
+ } else {
+ let read_future = self.ctx.read_avro(path, options);
+ wait_for_future(py, read_future).map_err(DataFusionError::from)?
+ };
+ Ok(PyDataFrame::new(df))
+ }
+
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (
path,
@@ -113,6 +138,37 @@ impl PySessionContext {
}
}
+ #[allow(clippy::too_many_arguments)]
+ #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000,
file_extension=".json", table_partition_cols=vec![],
file_compression_type=None))]
+ pub fn read_json(
+ &mut self,
+ path: PathBuf,
+ schema: Option<PyArrowType<Schema>>,
+ schema_infer_max_records: usize,
+ file_extension: &str,
+ table_partition_cols: Vec<(String, String)>,
+ file_compression_type: Option<String>,
+ py: Python,
+ ) -> PyResult<PyDataFrame> {
+ let path = path
+ .to_str()
+ .ok_or_else(|| PyValueError::new_err("Unable to convert path to a
string"))?;
+ let mut options = NdJsonReadOptions::default()
+
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+
.file_compression_type(parse_file_compression_type(file_compression_type)?);
+ options.schema_infer_max_records = schema_infer_max_records;
+ options.file_extension = file_extension;
+ let df = if let Some(schema) = schema {
+ options.schema = Some(&schema.0);
+ let result = self.ctx.read_json(path, options);
+ wait_for_future(py, result).map_err(DataFusionError::from)?
+ } else {
+ let result = self.ctx.read_json(path, options);
+ wait_for_future(py, result).map_err(DataFusionError::from)?
+ };
+ Ok(PyDataFrame::new(df))
+ }
+
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (
path,
@@ -150,4 +206,122 @@ impl PySessionContext {
PyDataFrame::new(wait_for_future(py,
result).map_err(DataFusionError::from)?);
Ok(df)
}
+
+ #[allow(clippy::too_many_arguments)]
+ #[pyo3(signature = (name,
+ path,
+ schema=None,
+ file_extension=".avro",
+ table_partition_cols=vec![]))]
+ pub fn register_avro(
+ &mut self,
+ name: &str,
+ path: PathBuf,
+ schema: Option<PyArrowType<Schema>>,
+ file_extension: &str,
+ table_partition_cols: Vec<(String, String)>,
+ py: Python,
+ ) -> PyResult<()> {
+ let path = path
+ .to_str()
+ .ok_or_else(|| PyValueError::new_err("Unable to convert path to a
string"))?;
+
+ let mut options = AvroReadOptions::default()
+
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
+ options.file_extension = file_extension;
+ options.schema = schema.as_ref().map(|x| &x.0);
+
+ let result = self.ctx.register_avro(name, path, options);
+ wait_for_future(py, result).map_err(DataFusionError::from)?;
+
+ Ok(())
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ #[pyo3(signature = (name,
+ path,
+ schema=None,
+ has_header=true,
+ delimiter=",",
+ schema_infer_max_records=1000,
+ file_extension=".csv",
+ file_compression_type=None))]
+ pub fn register_csv(
+ &mut self,
+ name: &str,
+ path: PathBuf,
+ schema: Option<PyArrowType<Schema>>,
+ has_header: bool,
+ delimiter: &str,
+ schema_infer_max_records: usize,
+ file_extension: &str,
+ file_compression_type: Option<String>,
+ py: Python,
+ ) -> PyResult<()> {
+ let path = path
+ .to_str()
+ .ok_or_else(|| PyValueError::new_err("Unable to convert path to a
string"))?;
+ let delimiter = delimiter.as_bytes();
+ if delimiter.len() != 1 {
+ return Err(PyValueError::new_err(
+ "Delimiter must be a single character",
+ ));
+ }
+
+ let mut options = CsvReadOptions::new()
+ .has_header(has_header)
+ .delimiter(delimiter[0])
+ .schema_infer_max_records(schema_infer_max_records)
+ .file_extension(file_extension)
+
.file_compression_type(parse_file_compression_type(file_compression_type)?);
+ options.schema = schema.as_ref().map(|x| &x.0);
+
+ let result = self.ctx.register_csv(name, path, options);
+ wait_for_future(py, result).map_err(DataFusionError::from)?;
+
+ Ok(())
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ #[pyo3(signature = (name, path, table_partition_cols=vec![],
+ parquet_pruning=true,
+ file_extension=".parquet",
+ skip_metadata=true,
+ schema=None,
+ file_sort_order=None))]
+ pub fn register_parquet(
+ &mut self,
+ name: &str,
+ path: &str,
+ table_partition_cols: Vec<(String, String)>,
+ parquet_pruning: bool,
+ file_extension: &str,
+ skip_metadata: bool,
+ schema: Option<PyArrowType<Schema>>,
+ file_sort_order: Option<Vec<Vec<PyExpr>>>,
+ py: Python,
+ ) -> PyResult<()> {
+ let mut options = ParquetReadOptions::default()
+
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+ .parquet_pruning(parquet_pruning)
+ .skip_metadata(skip_metadata);
+ options.file_extension = file_extension;
+ options.schema = schema.as_ref().map(|x| &x.0);
+ options.file_sort_order = file_sort_order
+ .unwrap_or_default()
+ .into_iter()
+ .map(|e| e.into_iter().map(|f| f.into()).collect())
+ .collect();
+
+ let result = self.ctx.register_parquet(name, path, options);
+ wait_for_future(py, result).map_err(DataFusionError::from)?;
+ Ok(())
+ }
+
+ pub fn register_table(&mut self, name: &str, table: &PyTable) ->
PyResult<()> {
+ self.ctx
+ .register_table(name, table.table())
+ .map_err(DataFusionError::from)?;
+ Ok(())
+ }
}