This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 7204a35  bugfix: no panic on empty table (#613)
7204a35 is described below

commit 7204a354f8d3f3c3fa45ffc3211e252cc3706d47
Author: Daniel Mesejo <[email protected]>
AuthorDate: Sat Apr 13 23:39:51 2024 +0200

    bugfix: no panic on empty table (#613)
---
 datafusion/tests/test_context.py | 31 +++++++++++++++++++++++++++++++
 src/context.rs                   | 15 ++++++++++++---
 2 files changed, 43 insertions(+), 3 deletions(-)

diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py
index 962a7ff..df7e181 100644
--- a/datafusion/tests/test_context.py
+++ b/datafusion/tests/test_context.py
@@ -139,6 +139,37 @@ def test_from_arrow_table_with_name(ctx):
     assert tables[0] == "tbl"
 
 
+def test_from_arrow_table_empty(ctx):
+    data = {"a": [], "b": []}
+    schema = pa.schema([("a", pa.int32()), ("b", pa.string())])
+    table = pa.Table.from_pydict(data, schema=schema)
+
+    # convert to DataFrame
+    df = ctx.from_arrow_table(table)
+    tables = list(ctx.tables())
+
+    assert df
+    assert len(tables) == 1
+    assert isinstance(df, DataFrame)
+    assert set(df.schema().names) == {"a", "b"}
+    assert len(df.collect()) == 0
+
+
+def test_from_arrow_table_empty_no_schema(ctx):
+    data = {"a": [], "b": []}
+    table = pa.Table.from_pydict(data)
+
+    # convert to DataFrame
+    df = ctx.from_arrow_table(table)
+    tables = list(ctx.tables())
+
+    assert df
+    assert len(tables) == 1
+    assert isinstance(df, DataFrame)
+    assert set(df.schema().names) == {"a", "b"}
+    assert len(df.collect()) == 0
+
+
 def test_from_pylist(ctx):
     # create a dataframe from Python list
     data = [
diff --git a/src/context.rs b/src/context.rs
index 825d3e3..d9b5860 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -39,7 +39,7 @@ use crate::store::StorageContexts;
 use crate::udaf::PyAggregateUDF;
 use crate::udf::PyScalarUDF;
 use crate::utils::{get_tokio_runtime, wait_for_future};
-use datafusion::arrow::datatypes::{DataType, Schema};
+use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
 use datafusion::arrow::pyarrow::PyArrowType;
 use datafusion::arrow::record_batch::RecordBatch;
 use 
datafusion::datasource::file_format::file_compression_type::FileCompressionType;
@@ -344,9 +344,15 @@ impl PySessionContext {
         &mut self,
         partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
         name: Option<&str>,
+        schema: Option<PyArrowType<Schema>>,
         py: Python,
     ) -> PyResult<PyDataFrame> {
-        let schema = partitions.0[0][0].schema();
+        let schema = if let Some(schema) = schema {
+            SchemaRef::from(schema.0)
+        } else {
+            partitions.0[0][0].schema()
+        };
+
         let table = MemTable::try_new(schema, 
partitions.0).map_err(DataFusionError::from)?;
 
         // generate a random (unique) name for this table if none is provided
@@ -428,12 +434,15 @@ impl PySessionContext {
             // Instantiate pyarrow Table object & convert to batches
             let table = data.call_method0(py, "to_batches")?;
 
+            let schema = data.getattr(py, "schema")?;
+            let schema = schema.extract::<PyArrowType<Schema>>(py)?;
+
             // Cast PyObject to RecordBatch type
             // Because create_dataframe() expects a vector of vectors of 
record batches
             // here we need to wrap the vector of record batches in an 
additional vector
             let batches = table.extract::<PyArrowType<Vec<RecordBatch>>>(py)?;
             let list_of_batches = PyArrowType::from(vec![batches.0]);
-            self.create_dataframe(list_of_batches, name, py)
+            self.create_dataframe(list_of_batches, name, Some(schema), py)
         })
     }
 

Reply via email to