This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 7204a35 bugfix: no panic on empty table (#613)
7204a35 is described below
commit 7204a354f8d3f3c3fa45ffc3211e252cc3706d47
Author: Daniel Mesejo <[email protected]>
AuthorDate: Sat Apr 13 23:39:51 2024 +0200
bugfix: no panic on empty table (#613)
---
datafusion/tests/test_context.py | 31 +++++++++++++++++++++++++++++++
src/context.rs | 15 ++++++++++++---
2 files changed, 43 insertions(+), 3 deletions(-)
diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py
index 962a7ff..df7e181 100644
--- a/datafusion/tests/test_context.py
+++ b/datafusion/tests/test_context.py
@@ -139,6 +139,37 @@ def test_from_arrow_table_with_name(ctx):
assert tables[0] == "tbl"
+def test_from_arrow_table_empty(ctx):
+ data = {"a": [], "b": []}
+ schema = pa.schema([("a", pa.int32()), ("b", pa.string())])
+ table = pa.Table.from_pydict(data, schema=schema)
+
+ # convert to DataFrame
+ df = ctx.from_arrow_table(table)
+ tables = list(ctx.tables())
+
+ assert df
+ assert len(tables) == 1
+ assert isinstance(df, DataFrame)
+ assert set(df.schema().names) == {"a", "b"}
+ assert len(df.collect()) == 0
+
+
+def test_from_arrow_table_empty_no_schema(ctx):
+ data = {"a": [], "b": []}
+ table = pa.Table.from_pydict(data)
+
+ # convert to DataFrame
+ df = ctx.from_arrow_table(table)
+ tables = list(ctx.tables())
+
+ assert df
+ assert len(tables) == 1
+ assert isinstance(df, DataFrame)
+ assert set(df.schema().names) == {"a", "b"}
+ assert len(df.collect()) == 0
+
+
def test_from_pylist(ctx):
# create a dataframe from Python list
data = [
diff --git a/src/context.rs b/src/context.rs
index 825d3e3..d9b5860 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -39,7 +39,7 @@ use crate::store::StorageContexts;
use crate::udaf::PyAggregateUDF;
use crate::udf::PyScalarUDF;
use crate::utils::{get_tokio_runtime, wait_for_future};
-use datafusion::arrow::datatypes::{DataType, Schema};
+use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
use
datafusion::datasource::file_format::file_compression_type::FileCompressionType;
@@ -344,9 +344,15 @@ impl PySessionContext {
&mut self,
partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
name: Option<&str>,
+ schema: Option<PyArrowType<Schema>>,
py: Python,
) -> PyResult<PyDataFrame> {
- let schema = partitions.0[0][0].schema();
+ let schema = if let Some(schema) = schema {
+ SchemaRef::from(schema.0)
+ } else {
+ partitions.0[0][0].schema()
+ };
+
let table = MemTable::try_new(schema,
partitions.0).map_err(DataFusionError::from)?;
// generate a random (unique) name for this table if none is provided
@@ -428,12 +434,15 @@ impl PySessionContext {
// Instantiate pyarrow Table object & convert to batches
let table = data.call_method0(py, "to_batches")?;
+ let schema = data.getattr(py, "schema")?;
+ let schema = schema.extract::<PyArrowType<Schema>>(py)?;
+
// Cast PyObject to RecordBatch type
// Because create_dataframe() expects a vector of vectors of
record batches
// here we need to wrap the vector of record batches in an
additional vector
let batches = table.extract::<PyArrowType<Vec<RecordBatch>>>(py)?;
let list_of_batches = PyArrowType::from(vec![batches.0]);
- self.create_dataframe(list_of_batches, name, py)
+ self.create_dataframe(list_of_batches, name, Some(schema), py)
})
}