This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 9fc5332  feature: Set table name from ctx functions (#260)
9fc5332 is described below

commit 9fc5332c59d0301eef4b3711a2cb072e7894ea63
Author: Dejan Simic <[email protected]>
AuthorDate: Tue Mar 7 22:19:30 2023 +0100

    feature: Set table name from ctx functions (#260)
---
 datafusion/tests/test_context.py  | 28 +++++++++++++++++
 examples/sql-using-python-udaf.py |  5 ++-
 examples/sql-using-python-udf.py  |  5 ++-
 src/context.rs                    | 65 +++++++++++++++++++++++++++++----------
 4 files changed, 80 insertions(+), 23 deletions(-)

diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py
index 0cdf380..1aea21c 100644
--- a/datafusion/tests/test_context.py
+++ b/datafusion/tests/test_context.py
@@ -96,6 +96,21 @@ def test_create_dataframe_registers_unique_table_name(ctx):
         assert c in "0123456789abcdef"
 
 
+def test_create_dataframe_registers_with_defined_table_name(ctx):
+    # create a RecordBatch and register it as memtable
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
+        names=["a", "b"],
+    )
+
+    df = ctx.create_dataframe([[batch]], name="tbl")
+    tables = list(ctx.tables())
+
+    assert df
+    assert len(tables) == 1
+    assert tables[0] == "tbl"
+
+
 def test_from_arrow_table(ctx):
     # create a PyArrow table
     data = {"a": [1, 2, 3], "b": [4, 5, 6]}
@@ -112,6 +127,19 @@ def test_from_arrow_table(ctx):
     assert df.collect()[0].num_rows == 3
 
 
+def test_from_arrow_table_with_name(ctx):
+    # create a PyArrow table
+    data = {"a": [1, 2, 3], "b": [4, 5, 6]}
+    table = pa.Table.from_pydict(data)
+
+    # convert to DataFrame with optional name
+    df = ctx.from_arrow_table(table, name="tbl")
+    tables = list(ctx.tables())
+
+    assert df
+    assert tables[0] == "tbl"
+
+
 def test_from_pylist(ctx):
     # create a dataframe from Python list
     data = [
diff --git a/examples/sql-using-python-udaf.py 
b/examples/sql-using-python-udaf.py
index 9aacc5d..3326c4a 100644
--- a/examples/sql-using-python-udaf.py
+++ b/examples/sql-using-python-udaf.py
@@ -62,7 +62,7 @@ my_udaf = udaf(
 ctx = SessionContext()
 
 # Create a datafusion DataFrame from a Python dictionary
-source_df = ctx.from_pydict({"a": [1, 1, 3], "b": [4, 5, 6]})
+source_df = ctx.from_pydict({"a": [1, 1, 3], "b": [4, 5, 6]}, name="t")
 # Dataframe:
 # +---+---+
 # | a | b |
@@ -76,9 +76,8 @@ source_df = ctx.from_pydict({"a": [1, 1, 3], "b": [4, 5, 6]})
 ctx.register_udaf(my_udaf)
 
 # Query the DataFrame using SQL
-table_name = ctx.catalog().database().names().pop()
 result_df = ctx.sql(
-    f"select a, my_accumulator(b) as b_aggregated from {table_name} group by a 
order by a"
+    "select a, my_accumulator(b) as b_aggregated from t group by a order by a"
 )
 # Dataframe:
 # +---+--------------+
diff --git a/examples/sql-using-python-udf.py b/examples/sql-using-python-udf.py
index 717b88e..d6bbe3a 100644
--- a/examples/sql-using-python-udf.py
+++ b/examples/sql-using-python-udf.py
@@ -38,7 +38,7 @@ is_null_arr = udf(
 ctx = SessionContext()
 
 # Create a datafusion DataFrame from a Python dictionary
-source_df = ctx.from_pydict({"a": [1, 2, 3], "b": [4, None, 6]})
+ctx.from_pydict({"a": [1, 2, 3], "b": [4, None, 6]}, name="t")
 # Dataframe:
 # +---+---+
 # | a | b |
@@ -52,8 +52,7 @@ source_df = ctx.from_pydict({"a": [1, 2, 3], "b": [4, None, 
6]})
 ctx.register_udf(is_null_arr)
 
 # Query the DataFrame using SQL
-table_name = ctx.catalog().database().names().pop()
-result_df = ctx.sql(f"select a, is_null(b) as b_is_null from {table_name}")
+result_df = ctx.sql("select a, is_null(b) as b_is_null from t")
 # Dataframe:
 # +---+-----------+
 # | a | b_is_null |
diff --git a/src/context.rs b/src/context.rs
index 4767e47..e77a3b3 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -276,23 +276,29 @@ impl PySessionContext {
     fn create_dataframe(
         &mut self,
         partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
+        name: Option<&str>,
         py: Python,
     ) -> PyResult<PyDataFrame> {
         let schema = partitions.0[0][0].schema();
         let table = MemTable::try_new(schema, 
partitions.0).map_err(DataFusionError::from)?;
 
-        // generate a random (unique) name for this table
+        // generate a random (unique) name for this table if none is provided
         // table name cannot start with numeric digit
-        let name = "c".to_owned()
-            + Uuid::new_v4()
-                .simple()
-                .encode_lower(&mut Uuid::encode_buffer());
+        let table_name = match name {
+            Some(val) => val.to_owned(),
+            None => {
+                "c".to_owned()
+                    + Uuid::new_v4()
+                        .simple()
+                        .encode_lower(&mut Uuid::encode_buffer())
+            }
+        };
 
         self.ctx
-            .register_table(&*name, Arc::new(table))
+            .register_table(&*table_name, Arc::new(table))
             .map_err(DataFusionError::from)?;
 
-        let table = wait_for_future(py, 
self._table(&name)).map_err(DataFusionError::from)?;
+        let table = wait_for_future(py, 
self._table(&table_name)).map_err(DataFusionError::from)?;
 
         let df = PyDataFrame::new(table);
         Ok(df)
@@ -305,7 +311,12 @@ impl PySessionContext {
 
     /// Construct datafusion dataframe from Python list
     #[allow(clippy::wrong_self_convention)]
-    fn from_pylist(&mut self, data: PyObject, _py: Python) -> 
PyResult<PyDataFrame> {
+    fn from_pylist(
+        &mut self,
+        data: PyObject,
+        name: Option<&str>,
+        _py: Python,
+    ) -> PyResult<PyDataFrame> {
         Python::with_gil(|py| {
             // Instantiate pyarrow Table object & convert to Arrow Table
             let table_class = py.import("pyarrow")?.getattr("Table")?;
@@ -313,14 +324,19 @@ impl PySessionContext {
             let table = table_class.call_method1("from_pylist", args)?.into();
 
             // Convert Arrow Table to datafusion DataFrame
-            let df = self.from_arrow_table(table, py)?;
+            let df = self.from_arrow_table(table, name, py)?;
             Ok(df)
         })
     }
 
     /// Construct datafusion dataframe from Python dictionary
     #[allow(clippy::wrong_self_convention)]
-    fn from_pydict(&mut self, data: PyObject, _py: Python) -> 
PyResult<PyDataFrame> {
+    fn from_pydict(
+        &mut self,
+        data: PyObject,
+        name: Option<&str>,
+        _py: Python,
+    ) -> PyResult<PyDataFrame> {
         Python::with_gil(|py| {
             // Instantiate pyarrow Table object & convert to Arrow Table
             let table_class = py.import("pyarrow")?.getattr("Table")?;
@@ -328,14 +344,19 @@ impl PySessionContext {
             let table = table_class.call_method1("from_pydict", args)?.into();
 
             // Convert Arrow Table to datafusion DataFrame
-            let df = self.from_arrow_table(table, py)?;
+            let df = self.from_arrow_table(table, name, py)?;
             Ok(df)
         })
     }
 
     /// Construct datafusion dataframe from Arrow Table
     #[allow(clippy::wrong_self_convention)]
-    fn from_arrow_table(&mut self, data: PyObject, _py: Python) -> 
PyResult<PyDataFrame> {
+    fn from_arrow_table(
+        &mut self,
+        data: PyObject,
+        name: Option<&str>,
+        _py: Python,
+    ) -> PyResult<PyDataFrame> {
         Python::with_gil(|py| {
             // Instantiate pyarrow Table object & convert to batches
             let table = data.call_method0(py, "to_batches")?;
@@ -345,13 +366,18 @@ impl PySessionContext {
             // here we need to wrap the vector of record batches in an 
additional vector
             let batches = table.extract::<PyArrowType<Vec<RecordBatch>>>(py)?;
             let list_of_batches = PyArrowType::try_from(vec![batches.0])?;
-            self.create_dataframe(list_of_batches, py)
+            self.create_dataframe(list_of_batches, name, py)
         })
     }
 
     /// Construct datafusion dataframe from pandas
     #[allow(clippy::wrong_self_convention)]
-    fn from_pandas(&mut self, data: PyObject, _py: Python) -> 
PyResult<PyDataFrame> {
+    fn from_pandas(
+        &mut self,
+        data: PyObject,
+        name: Option<&str>,
+        _py: Python,
+    ) -> PyResult<PyDataFrame> {
         Python::with_gil(|py| {
             // Instantiate pyarrow Table object & convert to Arrow Table
             let table_class = py.import("pyarrow")?.getattr("Table")?;
@@ -359,20 +385,25 @@ impl PySessionContext {
             let table = table_class.call_method1("from_pandas", args)?.into();
 
             // Convert Arrow Table to datafusion DataFrame
-            let df = self.from_arrow_table(table, py)?;
+            let df = self.from_arrow_table(table, name, py)?;
             Ok(df)
         })
     }
 
     /// Construct datafusion dataframe from polars
     #[allow(clippy::wrong_self_convention)]
-    fn from_polars(&mut self, data: PyObject, _py: Python) -> 
PyResult<PyDataFrame> {
+    fn from_polars(
+        &mut self,
+        data: PyObject,
+        name: Option<&str>,
+        _py: Python,
+    ) -> PyResult<PyDataFrame> {
         Python::with_gil(|py| {
             // Convert Polars dataframe to Arrow Table
             let table = data.call_method0(py, "to_arrow")?;
 
             // Convert Arrow Table to datafusion DataFrame
-            let df = self.from_arrow_table(table, py)?;
+            let df = self.from_arrow_table(table, name, py)?;
             Ok(df)
         })
     }

Reply via email to