This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 774ea70 Implement `to_pandas()` (#197)
774ea70 is described below
commit 774ea70eabdf2fefbbd48ab36c85bff346e6c0e1
Author: Dejan Simic <[email protected]>
AuthorDate: Wed Feb 22 01:24:28 2023 +0100
Implement `to_pandas()` (#197)
* Implement to_pandas()
* Update documentation
* Write unit test
---
README.md | 12 +++---------
datafusion/tests/test_dataframe.py | 11 +++++++++++
examples/sql-to-pandas.py | 10 ++--------
src/dataframe.rs | 18 ++++++++++++++++++
4 files changed, 34 insertions(+), 17 deletions(-)
diff --git a/README.md b/README.md
index d465ebc..65f6ef3 100644
--- a/README.md
+++ b/README.md
@@ -36,7 +36,7 @@ from having to lock the GIL when running those operations.
Its query engine, DataFusion, is written in
[Rust](https://www.rust-lang.org/), which makes strong assumptions
about thread safety and lack of memory leaks.
-There is also experimental support for executing SQL against other DataFrame
libraries, such as Polars, Pandas, and any
+There is also experimental support for executing SQL against other DataFrame
libraries, such as Polars, Pandas, and any
drop-in replacements for Pandas.
Technically, zero-copy is achieved via the [c data
interface](https://arrow.apache.org/docs/format/CDataInterface.html).
@@ -70,17 +70,11 @@ df = ctx.sql("select passenger_count, count(*) "
"group by passenger_count "
"order by passenger_count")
-# collect as list of pyarrow.RecordBatch
-results = df.collect()
-
-# get first batch
-batch = results[0]
-
# convert to Pandas
-df = batch.to_pandas()
+pandas_df = df.to_pandas()
# create a chart
-fig = df.plot(kind="bar", title="Trip Count by Number of
Passengers").get_figure()
+fig = pandas_df.plot(kind="bar", title="Trip Count by Number of
Passengers").get_figure()
fig.savefig('chart.png')
```
diff --git a/datafusion/tests/test_dataframe.py
b/datafusion/tests/test_dataframe.py
index 1894688..292a4b0 100644
--- a/datafusion/tests/test_dataframe.py
+++ b/datafusion/tests/test_dataframe.py
@@ -533,3 +533,14 @@ def test_cache(df):
def test_count(df):
# Get number of rows
assert df.count() == 3
+
+
+def test_to_pandas(df):
+ # Skip test if pandas is not installed
+ pd = pytest.importorskip("pandas")
+
+ # Convert datafusion dataframe to pandas dataframe
+ pandas_df = df.to_pandas()
+ assert type(pandas_df) == pd.DataFrame
+ assert pandas_df.shape == (3, 3)
+ assert set(pandas_df.columns) == {"a", "b", "c"}
diff --git a/examples/sql-to-pandas.py b/examples/sql-to-pandas.py
index 3569e6d..3e99b22 100644
--- a/examples/sql-to-pandas.py
+++ b/examples/sql-to-pandas.py
@@ -33,17 +33,11 @@ df = ctx.sql(
"order by passenger_count"
)
-# collect as list of pyarrow.RecordBatch
-results = df.collect()
-
-# get first batch
-batch = results[0]
-
# convert to Pandas
-df = batch.to_pandas()
+pandas_df = df.to_pandas()
# create a chart
-fig = df.plot(
+fig = pandas_df.plot(
kind="bar", title="Trip Count by Number of Passengers"
).get_figure()
fig.savefig("chart.png")
diff --git a/src/dataframe.rs b/src/dataframe.rs
index 4b9fbca..a1c68dd 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -313,6 +313,24 @@ impl PyDataFrame {
Ok(())
}
+ /// Convert to pandas dataframe with pyarrow
+ /// Collect the batches, pass to Arrow Table & then convert to Pandas
DataFrame
+ fn to_pandas(&self, py: Python) -> PyResult<PyObject> {
+ let batches = self.collect(py);
+
+ Python::with_gil(|py| {
+ // Instantiate pyarrow Table object and use its from_batches method
+ let table_class = py.import("pyarrow")?.getattr("Table")?;
+ let args = PyTuple::new(py, batches);
+ let table: PyObject = table_class.call_method1("from_batches",
args)?.into();
+
+ // Use Table.to_pandas() method to convert batches to pandas
dataframe
+ // See also:
https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
+ let result = table.call_method0(py, "to_pandas")?;
+ Ok(result)
+ })
+ }
+
// Executes this DataFrame to get the total number of rows.
fn count(&self, py: Python) -> PyResult<usize> {
Ok(wait_for_future(py, self.df.as_ref().clone().count())?)